Skip to content

Commit

Permalink
Merge pull request #83 from zayarhtet/main
Browse files Browse the repository at this point in the history
add H2 and H2C client with the interface
  • Loading branch information
zombi-HU authored Dec 11, 2024
2 parents 72e5414 + 405e03c commit 0d9b54d
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 36 deletions.
120 changes: 84 additions & 36 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,37 +141,6 @@ type Client struct {
msgpackUsage msgpackUsage
}

var h2CTransport = http2.Transport{
AllowHTTP: true,
DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
// Skip TLS dial
return net.DialTimeout(network, addr, 2*time.Second)
},
}

var h2Transport = http2.Transport{
DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
dialer := net.Dialer{Timeout: 2 * time.Second}
cn, err := tls.DialWithDialer(&dialer, network, addr, cfg)
if err != nil {
return nil, err
}
if err := cn.Handshake(); err != nil {
return nil, err
}
if !cfg.InsecureSkipVerify {
if err := cn.VerifyHostname(cfg.ServerName); err != nil {
return nil, err
}
}
state := cn.ConnectionState()
if p := state.NegotiatedProtocol; p != http2.NextProtoTLS {
return nil, fmt.Errorf("http2: unexpected ALPN protocol %q; want %q", p, http2.NextProtoTLS)
}
return cn, nil
},
}

// NewClient creates a RESTful client instance.
// The instance has a semi-permanent transport TCP connection.
func NewClient() *Client {
Expand Down Expand Up @@ -223,26 +192,105 @@ func NewClientWInterface(networkInterface string) *Client {

// NewH2Client creates a RESTful client instance, forced to use HTTP2 with TLS (H2) (a.k.a. prior knowledge).
func NewH2Client() *Client {
return NewH2ClientWInterface("")
}

// NewH2CClient creates a RESTful client instance, forced to use HTTP2 Cleartext (H2C).
func NewH2CClient() *Client {
return NewH2CClientWInterface("")
}

// NewH2ClientWInterface creates a RESTful client instance with the http2 protocol bound to that network interface.
// The instance has a semi-permanent transport TCP connection.
func NewH2ClientWInterface(networkInterface string) *Client {
c := &Client{Kind: KindH2}
var rt http.RoundTripper = &h2Transport
var rt http.RoundTripper = getH2Transport(networkInterface)
if isTraced && tracer.GetOTel() {
rt = otelhttp.NewTransport(rt)
}
c.Client = &http.Client{Transport: rt}
return c
}

// NewH2CClient creates a RESTful client instance, forced to use HTTP2 Cleartext (H2C).
func NewH2CClient() *Client {
// NewH2ClientWInterface creates a RESTful client instance with the http2 clear text protocol bound to that network interface.
// In other words, the http2 clear text is the http2 but without TLS handshake.
// The instance has a semi-permanent transport TCP connection.
func NewH2CClientWInterface(networkInterface string) *Client {
c := &Client{Kind: KindH2C}
var rt http.RoundTripper = &h2CTransport
var rt http.RoundTripper = getH2CTransport(networkInterface)
if isTraced && tracer.GetOTel() {
rt = otelhttp.NewTransport(rt)
}
c.Client = &http.Client{Transport: rt}
return c
}

func getH2Transport(iface string) *http2.Transport {
return &http2.Transport{
DialTLS: getDialTLSCallback(iface, true),
}
}

func getH2CTransport(iface string) *http2.Transport {
return &http2.Transport{
AllowHTTP: true,
DialTLS: getDialTLSCallback(iface, false),
}
}

func getDialTLSCallback(iface string, withTLS bool) func(string,string,*tls.Config) (net.Conn, error) {
return func(network, addr string, cfg *tls.Config) (net.Conn, error) {
dialer := net.Dialer{Timeout: 2 * time.Second}

var conn net.Conn
var err error
if iface != "" {
IPs := getIPFromInterface(iface)
if IPs.IPv4 != nil {
dialer.LocalAddr = IPs.IPv4
conn, err = dialWithDialer(&dialer, network, addr, cfg, withTLS)
}

// Try IPv6 if IPv4 is unavailable or connection fails.
if IPs.IPv4 == nil || (IPs.IPv6 != nil && err != nil && !errDeadlineOrCancel(err)) {
dialer.LocalAddr = IPs.IPv6
conn, err = dialWithDialer(&dialer, network, addr, cfg, withTLS)
}
} else {
conn, err = dialWithDialer(&dialer, network, addr, cfg, withTLS)
}

if err != nil {
return nil, err
}

// Skip TLS dial if it is the H2C
if withTLS {
if err := conn.(*tls.Conn).Handshake(); err != nil {
return nil, err
}
if !cfg.InsecureSkipVerify {
if err := conn.(*tls.Conn).VerifyHostname(cfg.ServerName); err != nil {
return nil, err
}
}
state := conn.(*tls.Conn).ConnectionState()
if p := state.NegotiatedProtocol; p != http2.NextProtoTLS {
return nil, fmt.Errorf("http2: unexpected ALPN protocol %q; want %q", p, http2.NextProtoTLS)
}
}
return conn, nil
}
}

func dialWithDialer(dialer *net.Dialer, network, addr string, cfg *tls.Config, withTLS bool) (net.Conn, error) {
if withTLS {
return tls.DialWithDialer(dialer, network, addr, cfg)
} else {
return dialer.Dial(network, addr)
}
}

// UserAgent to be sent as User-Agent HTTP header. If not set then default Go client settings are used.
func (c *Client) UserAgent(userAgent string) *Client {
c.userAgent = userAgent
Expand Down Expand Up @@ -377,7 +425,7 @@ func (c *Client) SetOauth2Conf(config oauth2.Config, tokenClient *http.Client, g

// SetOauth2H2 makes OAuth2 token client communicate using h2 transport with Authorization Server.
func (c *Client) SetOauth2H2() *Client {
c.oauth2.client = &http.Client{Timeout: 10 * time.Second, Transport: &h2Transport}
c.oauth2.client = &http.Client{Timeout: 10 * time.Second, Transport: getH2Transport("")}
return c
}

Expand Down
95 changes: 95 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,24 @@ package restful

import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/cookiejar"
"net/http/httptest"
"net/url"
"strings"
"sync"
"testing"
"time"

"github.com/stretchr/testify/assert"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
"golang.org/x/oauth2"
)

Expand Down Expand Up @@ -855,3 +861,92 @@ func TestCientInterface(t *testing.T) {
c := NewClientWInterface(theUsedInterface)
assert.NotNil(t, c)
}

func startH2Server(mux *http.ServeMux, wg *sync.WaitGroup) *http.Server {
defer wg.Done()
server := &http.Server{
Addr: "localhost:8443",
Handler: mux,
TLSConfig: &tls.Config{
NextProtos: []string{"h2"},
},
}

go func() {
if err := server.ListenAndServeTLS("test_certs/tls.crt", "test_certs/tls.key"); err != nil && err != http.ErrServerClosed {
fmt.Printf("Failed to start server: %v", err)
}
}()
return server
}

func startH2CServer(mux *http.ServeMux, wg *sync.WaitGroup) *http.Server {
defer wg.Done()
server := &http.Server{
Addr: "localhost:8440",
Handler: h2c.NewHandler(mux, &http2.Server{}),
}

go func() {
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
fmt.Printf("Failed to start server: %v", err)
}
}()
return server
}

func TestClients(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
response := map[string]string{"message": "Hello, world!"}
json.NewEncoder(w).Encode(response)
})
var wg sync.WaitGroup
wg.Add(2)
h2Server := startH2Server(mux, &wg)
h2cServer := startH2CServer(mux, &wg)
defer func() {
h2Server.Close()
h2cServer.Close()
}()

h2Client := NewH2Client()
h2Client.Client.Transport.(*http2.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true,}
h2cClient := NewH2CClient()

wg.Wait()

tests := []struct {
name string
client *Client
serverURL string
}{
{
name: "HTTP/2 Client (H2)",
client: h2Client,
serverURL: "https://localhost:8443", // H2 server
},
{
name: "HTTP/2 Cleartext Client (H2C)",
client: h2cClient,
serverURL: "http://localhost:8440", // H2C server
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {

var resp any
err := test.client.Get(context.Background(), test.serverURL, &resp)
if err != nil {
t.Fatalf("Failed to make request: %v", err)
}

b, _ := json.Marshal(resp)
if string(b) != "{\"message\":\"Hello, world!\"}" {
t.Fatalf("Unexpected response: %s", b)
}
})
}
}

0 comments on commit 0d9b54d

Please sign in to comment.