From f9ced2ed32eb7b082acbe3e682cf3b0dc7cbefc8 Mon Sep 17 00:00:00 2001 From: Hank Donnay Date: Wed, 17 May 2023 11:01:56 -0500 Subject: [PATCH] httputil: improve ctlLocalOnly errors Go1.20 introduced a Context-enabled version of the control function, so switch to that while adding more specific error messages. Signed-off-by: Hank Donnay --- internal/httputil/client.go | 48 ++++++++++++++++---------- internal/httputil/client_test.go | 59 ++++++++++++++++++++++---------- 2 files changed, 71 insertions(+), 36 deletions(-) diff --git a/internal/httputil/client.go b/internal/httputil/client.go index ad218a0cfa..e2ad309e65 100644 --- a/internal/httputil/client.go +++ b/internal/httputil/client.go @@ -7,6 +7,7 @@ import ( "net" "net/http" "net/http/cookiejar" + "net/netip" "os" "path/filepath" "strings" @@ -26,7 +27,7 @@ func NewClient(ctx context.Context, localOnly bool) (*http.Client, error) { dialer := &net.Dialer{} // Set a control function if we're restricting subnets. if localOnly { - dialer.Control = ctlLocalOnly + dialer.ControlContext = ctlLocalOnly } tr.DialContext = dialer.DialContext @@ -42,36 +43,47 @@ func NewClient(ctx context.Context, localOnly bool) (*http.Client, error) { }, nil } -func ctlLocalOnly(network, address string, _ syscall.RawConn) error { - // Future-proof for QUIC by allowing UDP here. - if !strings.HasPrefix(network, "tcp") && !strings.HasPrefix(network, "udp") { +func ctlLocalOnly(_ context.Context, network, address string, _ syscall.RawConn) error { + // Now that this has a Context'd version, we could jam a policy engine in + // here if someone really feeling froggy. + switch { + case strings.HasPrefix(network, "tcp"): // OK + case strings.HasPrefix(network, "udp"): // OK + case strings.HasPrefix(network, "unix"): + // Local by definition. + return nil + default: return &net.AddrError{ Addr: network + "!" + address, - Err: "disallowed by policy", + Err: fmt.Sprintf("disallowed by policy: network %q", network), } } - host, _, err := net.SplitHostPort(address) + + ap, err := netip.ParseAddrPort(address) if err != nil { return &net.AddrError{ - Addr: network + "!" + address, - Err: "martian address", + Addr: address, + Err: fmt.Sprintf("unable to parse address: %v", err), } } - addr := net.ParseIP(host) - if addr == nil { + switch addr := ap.Addr(); { + case addr.IsMulticast(): + // Assert this is a unicast address. + // There was a draft RFC for handling HTTP/3 over multicast QUIC, but it's expired so this seems OK to do. return &net.AddrError{ - Addr: network + "!" + address, - Err: "martian address", + Addr: ap.String(), + Err: "disallowed by policy: address is multicast", } - } - if !addr.IsPrivate() && - !addr.IsLoopback() && - !addr.IsLinkLocalUnicast() { + case addr.IsLoopback(): // OK + case addr.IsLinkLocalUnicast(): // OK + case addr.IsPrivate(): // OK + default: return &net.AddrError{ - Addr: network + "!" + address, - Err: "disallowed by policy", + Addr: ap.String(), + Err: "disallowed by policy: not loopback, link-local, or private", } } + return nil } diff --git a/internal/httputil/client_test.go b/internal/httputil/client_test.go index 287c31b2ad..5e0330d776 100644 --- a/internal/httputil/client_test.go +++ b/internal/httputil/client_test.go @@ -1,9 +1,12 @@ package httputil import ( + "context" "errors" "net" "testing" + + "github.com/google/go-cmp/cmp" ) func TestLocalOnly(t *testing.T) { @@ -12,46 +15,66 @@ func TestLocalOnly(t *testing.T) { Addr string Err *net.AddrError }{ + {Network: "tcp4", Addr: "192.168.0.1:443"}, + {Network: "tcp4", Addr: "10.0.0.1:80"}, + {Network: "tcp4", Addr: "127.0.0.1:443"}, + {Network: "tcp6", Addr: "[fe80::]:443"}, + {Network: "unix", Addr: "/run/sock"}, { - Network: "tcp4", - Addr: "192.168.0.1:443", - Err: nil, + Network: "ip6", + Addr: "::1", + Err: &net.AddrError{ + Addr: "ip6!::1", + Err: `disallowed by policy: network "ip6"`, + }, }, { Network: "tcp4", - Addr: "127.0.0.1:443", - Err: nil, + Addr: "127.256:443", + Err: &net.AddrError{ + Addr: "127.256:443", + Err: `unable to parse address: ParseAddr("127.256"): IPv4 field has value >255`, + }, }, { - Network: "tcp6", - Addr: "[fe80::]:443", - Err: nil, + Network: "tcp4", + Addr: "224.0.0.1:443", + Err: &net.AddrError{ + Addr: "224.0.0.1:443", + Err: "disallowed by policy: address is multicast", + }, }, { Network: "tcp4", Addr: "8.8.8.8:443", Err: &net.AddrError{ - Addr: "tcp4!8.8.8.8:443", - Err: "disallowed by policy", + Addr: "8.8.8.8:443", + Err: "disallowed by policy: not loopback, link-local, or private", }, }, { Network: "tcp6", Addr: "[2000::]:443", Err: &net.AddrError{ - Addr: "tcp6![2000::]:443", - Err: "disallowed by policy", + Addr: "[2000::]:443", + Err: "disallowed by policy: not loopback, link-local, or private", }, }, } + // CtlLocalOnly doesn't emit logs, don't bother with zlog. + ctx := context.Background() for _, tc := range tt { t.Logf("%s!%s", tc.Network, tc.Addr) - var nErr *net.AddrError - got := ctlLocalOnly(tc.Network, tc.Addr, nil) - if errors.As(got, &nErr) { - if tc.Err.Err != nErr.Err || tc.Err.Addr != nErr.Addr { - t.Errorf("got: %v, want: %v", got, tc.Err) - } + var got *net.AddrError + err := ctlLocalOnly(ctx, tc.Network, tc.Addr, nil) + switch { + case err == nil: + case !errors.As(err, &got): + t.Errorf("returned error not *net.AddrError, is %T", got) + continue + } + if want := tc.Err; !cmp.Equal(got, want) { + t.Error(cmp.Diff(got, want)) } } }