Skip to content

Commit

Permalink
Extract fingerprint into an exported wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
Anaethelion committed Dec 2, 2021
1 parent 5010506 commit 5646885
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 24 deletions.
63 changes: 39 additions & 24 deletions elastictransport/elastictransport.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ type Config struct {

ConnectionPoolFunc func([]*Connection, Selector) ConnectionPool

CertificateFingerprint string
CertificateFingerprints []string
}

// Client represents the HTTP client.
Expand Down Expand Up @@ -132,29 +132,9 @@ func New(cfg Config) (*Client, error) {
cfg.Transport = http.DefaultTransport
}

if transport, ok := cfg.Transport.(*http.Transport); ok {
if cfg.CertificateFingerprint != "" {
transport.DialTLS = func(network, addr string) (net.Conn, error) {
fingerprint, _ := hex.DecodeString(cfg.CertificateFingerprint)

c, err := tls.Dial(network, addr, &tls.Config{InsecureSkipVerify: true})
if err != nil {
return nil, err
}

// Retrieve the connection state from the remote server.
cState := c.ConnectionState()
for _, cert := range cState.PeerCertificates {
// Compute digest for each certificate.
digest := sha256.Sum256(cert.Raw)

// Provided fingerprint should match at least one certificate from remote before we continue.
if bytes.Compare(digest[0:], fingerprint) == 0 {
return c, nil
}
}
return nil, fmt.Errorf("fingerprint mismatch, provided: %s", cfg.CertificateFingerprint)
}
if len(cfg.CertificateFingerprints) > 0 {
if transport, ok := cfg.Transport.(*http.Transport); ok {
transport.DialTLS = WrapDialTLS(transport.DialTLS, cfg.CertificateFingerprints)
}
}

Expand Down Expand Up @@ -522,3 +502,38 @@ func (c *Client) logRoundTrip(
}
c.logger.LogRoundTrip(req, &dupRes, err, start, dur) // errcheck exclude
}

func WrapDialTLS(dialTls func(network, addr string) (net.Conn, error), certificateFingerprints []string) func(network, addr string) (net.Conn, error) {
return func(network, addr string) (net.Conn, error) {
c, err := tls.Dial(network, addr, &tls.Config{InsecureSkipVerify: false})

if _, ok := err.(x509.UnknownAuthorityError); ok {
c, err = tls.Dial(network, addr, &tls.Config{InsecureSkipVerify: true})
if err != nil {
return nil, err
}
// Retrieve the connection state from the remote server.
cState := c.ConnectionState()
for _, cert := range cState.PeerCertificates {
// Compute digest for each certificate.
digest := sha256.Sum256(cert.Raw)

// Provided fingerprint should match at least one certificate from remote before we continue.
for _, certificateFingerPrint := range certificateFingerprints {
fingerprint, _ := hex.DecodeString(certificateFingerPrint)
if bytes.Compare(digest[0:], fingerprint) == 0 {
if dialTls != nil {
return dialTls(network, addr)
}
return c, nil
}
}
}
}

if err != nil {
return nil, err
}
return c, nil
}
}
46 changes: 46 additions & 0 deletions elastictransport/elastictransport_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@ import (
"bytes"
"compress/gzip"
"context"
"crypto/x509"
"fmt"
"io"
"io/ioutil"
"math/rand"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"strings"
Expand Down Expand Up @@ -990,3 +992,47 @@ func TestRequestCompression(t *testing.T) {
})
}
}

func TestFingerprint(t *testing.T) {
body := []byte(`{"body": true"}"`)
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Elastic-Product", "Elasticsearch")
w.Write(body)
}))
defer server.Close()

u, _ := url.Parse(server.URL)

config := Config{
URLs: []*url.URL{u},
DisableRetry: true,
}

t.Run("Self signed test certificate only", func(t *testing.T) {
// Without certificate and authority, client should fail on TLS
transport, _ := New(config)
req, _ := http.NewRequest(http.MethodGet, server.URL, nil)

_, err := transport.Perform(req)
if _, ok := err.(x509.UnknownAuthorityError); !ok {
t.Fatalf("Uknown error, expected UnknownAuthorityError, got: %s", err)
}
})

t.Run("Self signed test certificate with fingerprint", func(t *testing.T) {
// We add the fingerprint corresponding ton testcert.LocalhostCert
config.CertificateFingerprints = append(config.CertificateFingerprints, "448F628A8A65AA18560E53A80C53ACB38C51B427DF0334082349141147DC9BF6")
transport, _ := New(config)
req, _ := http.NewRequest(http.MethodGet, server.URL, nil)
res, err := transport.Perform(req)
if err != nil {
t.Fatal(err)
}
defer res.Body.Close()

data, _ := ioutil.ReadAll(res.Body)
if bytes.Compare(data, body) != 0 {
t.Fatalf("unexpected payload returned: expected: %s, got: %s", body, data)
}
})
}

0 comments on commit 5646885

Please sign in to comment.