diff --git a/internal/bootstrap/hostsresolver.go b/internal/bootstrap/hostsresolver.go index cd9883b73..13a2b527a 100644 --- a/internal/bootstrap/hostsresolver.go +++ b/internal/bootstrap/hostsresolver.go @@ -3,7 +3,6 @@ package bootstrap import ( "context" "fmt" - "io" "io/fs" "net/netip" @@ -21,16 +20,11 @@ type HostsResolver struct { } // NewHostsResolver is the resolver based on system hosts files. -func NewHostsResolver(rs ...io.Reader) (hr *HostsResolver, err error) { - hosts, err := netutil.NewHosts(rs...) - if err != nil { - return nil, fmt.Errorf("parsing hosts: %w", err) - } - +func NewHostsResolver(hosts *netutil.Hosts) (hr *HostsResolver) { hr = &HostsResolver{} _, hr.addrs = hosts.Mappings() - return hr, nil + return hr } // NewDefaultHostsResolver returns a resolver based on system hosts files @@ -47,14 +41,12 @@ func NewDefaultHostsResolver(rootFSys fs.FS) (hr *HostsResolver, err error) { for _, name := range paths { err = parseHostsFile(rootFSys, hosts, name) if err != nil { + // Don't wrap the error since it's already informative enough as is. return nil, err } } - hr = &HostsResolver{} - _, hr.addrs = hosts.Mappings() - - return hr, nil + return NewHostsResolver(hosts), nil } // parseHostsFile reads a single hosts file from fsys and parses it into hosts. diff --git a/internal/bootstrap/hostsresolver_test.go b/internal/bootstrap/hostsresolver_test.go index 9b35cd2be..a50a3ca53 100644 --- a/internal/bootstrap/hostsresolver_test.go +++ b/internal/bootstrap/hostsresolver_test.go @@ -7,8 +7,8 @@ import ( "testing" "github.com/AdguardTeam/dnsproxy/internal/bootstrap" + "github.com/AdguardTeam/dnsproxy/internal/netutil" "github.com/AdguardTeam/golibs/testutil" - "github.com/AdguardTeam/golibs/testutil/fakeio" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -24,9 +24,11 @@ func TestHostsResolver_LookupNetIP(t *testing.T) { v6Addr = netip.MustParseAddr("::1") ) - hr, err := bootstrap.NewHostsResolver(strings.NewReader(hostsData)) + hosts, err := netutil.NewHosts(strings.NewReader(hostsData)) require.NoError(t, err) + hr := bootstrap.NewHostsResolver(hosts) + testCases := []struct { name string host string @@ -94,15 +96,3 @@ func TestHostsResolver_LookupNetIP(t *testing.T) { testutil.AssertErrorMsg(t, `unsupported network "ip5"`, err) }) } - -func TestNewHostsResolver_error(t *testing.T) { - r := &fakeio.Reader{ - OnRead: func(_ []byte) (n int, err error) { - return 0, assert.AnError - }, - } - - hr, err := bootstrap.NewHostsResolver(r) - assert.ErrorIs(t, err, assert.AnError) - assert.Nil(t, hr) -} diff --git a/internal/netutil/hosts.go b/internal/netutil/hosts.go index 402b35c26..2c8c2eb78 100644 --- a/internal/netutil/hosts.go +++ b/internal/netutil/hosts.go @@ -25,6 +25,14 @@ type orderedSet[K string | netip.Addr] struct { vals []K } +// add adds val to os if it's not already there. +func (os *orderedSet[K]) add(key, val K) { + if _, ok := os.set[key]; !ok { + os.set[key] = unit{} + os.vals = append(os.vals, val) + } +} + // Convenience aliases for [orderedSet]. type ( namesSet = orderedSet[string] @@ -37,13 +45,13 @@ type ( // // TODO(e.burkov): Think of storing only slices. type Hosts struct { - // addrs maps each address to its hosts in original case and in original + // names maps each address to its names in original case and in original // adding order without duplicates. - addrs map[netip.Addr]*namesSet + names map[netip.Addr]*namesSet - // names maps each host to its addresses in original adding order without + // addrs maps each host to its addresses in original adding order without // duplicates. - names map[string]*addrsSet + addrs map[string]*addrsSet } // type check @@ -53,8 +61,8 @@ var _ hostsfile.HandleSet = (*Hosts)(nil) // optional, the error is only returned in case of parsing error. func NewHosts(readers ...io.Reader) (h *Hosts, err error) { h = &Hosts{ - addrs: map[netip.Addr]*namesSet{}, - names: map[string]*addrsSet{}, + names: map[netip.Addr]*namesSet{}, + addrs: map[string]*addrsSet{}, } for i, r := range readers { @@ -71,34 +79,25 @@ var _ hostsfile.HandleSet = (*Hosts)(nil) // Add implements the [hostsfile.Set] interface for *Hosts. func (h *Hosts) Add(rec *hostsfile.Record) { - knownNames := h.addrs[rec.Addr] - if knownNames == nil { - knownNames = &namesSet{set: set[string]{}} - h.addrs[rec.Addr] = knownNames + names := h.names[rec.Addr] + if names == nil { + names = &namesSet{set: set[string]{}} + h.names[rec.Addr] = names } for _, name := range rec.Names { lowered := strings.ToLower(name) + names.add(lowered, name) - if _, ok := knownNames.set[lowered]; !ok { - knownNames.set[lowered] = unit{} - knownNames.vals = append(knownNames.vals, name) - } - - knownAddrs := h.names[lowered] - if knownAddrs == nil { - h.names[lowered] = &addrsSet{ - vals: []netip.Addr{rec.Addr}, - set: set[netip.Addr]{rec.Addr: {}}, + addrs := h.addrs[lowered] + if addrs == nil { + addrs = &addrsSet{ + vals: []netip.Addr{}, + set: set[netip.Addr]{}, } - - continue - } - - if _, ok := knownAddrs.set[rec.Addr]; !ok { - knownAddrs.set[rec.Addr] = unit{} - knownAddrs.vals = append(knownAddrs.vals, rec.Addr) + h.addrs[lowered] = addrs } + addrs.add(rec.Addr, rec.Addr) } } @@ -111,7 +110,7 @@ func (h *Hosts) HandleInvalid(srcName string, _ []byte, err error) { return } - if err = errors.Unwrap(lineErr); errors.Is(err, hostsfile.ErrEmptyLine) { + if errors.Is(err, hostsfile.ErrEmptyLine) { // Ignore empty lines and comments. return } @@ -122,7 +121,7 @@ func (h *Hosts) HandleInvalid(srcName string, _ []byte, err error) { // ByAddr returns each host for addr in original case, in original adding order // without duplicates. It returns nil if h doesn't contain the addr. func (h *Hosts) ByAddr(addr netip.Addr) (hosts []string) { - if hostsSet, ok := h.addrs[addr]; ok { + if hostsSet, ok := h.names[addr]; ok { return hostsSet.vals } @@ -132,7 +131,7 @@ func (h *Hosts) ByAddr(addr netip.Addr) (hosts []string) { // ByName returns each address for host in original adding order without // duplicates. It returns nil if h doesn't contain the host. func (h *Hosts) ByName(host string) (addrs []netip.Addr) { - if addrsSet, ok := h.names[strings.ToLower(host)]; ok { + if addrsSet, ok := h.addrs[strings.ToLower(host)]; ok { return addrsSet.vals } @@ -141,14 +140,14 @@ func (h *Hosts) ByName(host string) (addrs []netip.Addr) { // Mappings returns a deep clone of the internal mappings. func (h *Hosts) Mappings() (names map[netip.Addr][]string, addrs map[string][]netip.Addr) { - names = make(map[netip.Addr][]string, len(h.addrs)) - addrs = make(map[string][]netip.Addr, len(h.names)) + names = make(map[netip.Addr][]string, len(h.names)) + addrs = make(map[string][]netip.Addr, len(h.addrs)) - for addr, namesSet := range h.addrs { + for addr, namesSet := range h.names { names[addr] = slices.Clone(namesSet.vals) } - for name, addrsSet := range h.names { + for name, addrsSet := range h.addrs { addrs[name] = slices.Clone(addrsSet.vals) } diff --git a/internal/netutil/hosts_test.go b/internal/netutil/hosts_test.go index fa677f957..6f1694ced 100644 --- a/internal/netutil/hosts_test.go +++ b/internal/netutil/hosts_test.go @@ -61,6 +61,12 @@ func TestHosts(t *testing.T) { } ) + t.Run("Mappings", func(t *testing.T) { + names, addrs := h.Mappings() + assert.Equal(t, wantAddrs, names) + assert.Equal(t, wantHosts, addrs) + }) + t.Run("ByAddr", func(t *testing.T) { t.Parallel() diff --git a/upstream/upstreamresolver.go b/upstream/upstreamresolver.go index 220c8c3fe..3d607fb7d 100644 --- a/upstream/upstreamresolver.go +++ b/upstream/upstreamresolver.go @@ -16,14 +16,11 @@ import ( // Resolver is an alias for [bootstrap.Resolver] to avoid the import cycle. type Resolver = bootstrap.Resolver -// NewUpstreamResolver creates a Resolver. resolverAddress format is the same -// as in the [AddressToUpstream], except that it also shouldn't need a -// bootstrap, i.e. have an IP address in hostname, or be a DNSCrypt. -// resolverAddress must not be empty, use another [Resolver] instead, e.g. -// [net.Resolver]. -// -// sorting the resolved addresses is the caller's responsibility. Otherwise, it -// creates an [Upstream] using opts. +// NewUpstreamResolver creates an upstream that can be used as [Resolver]. +// resolverAddress format is the same as in the [AddressToUpstream], except that +// it also shouldn't need a bootstrap, i.e. have an IP address in hostname, or +// be a DNSCrypt. resolverAddress must not be empty, use another [Resolver] +// instead, e.g. [net.Resolver]. func NewUpstreamResolver(resolverAddress string, opts *Options) (r Resolver, err error) { upsOpts := &Options{}