Skip to content

Commit

Permalink
all: imp code
Browse files Browse the repository at this point in the history
  • Loading branch information
EugeneOne1 committed Sep 19, 2023
1 parent 9429f58 commit 4b3e59b
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 68 deletions.
16 changes: 4 additions & 12 deletions internal/bootstrap/hostsresolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package bootstrap
import (
"context"
"fmt"
"io"
"io/fs"
"net/netip"

Expand All @@ -21,16 +20,11 @@ type HostsResolver struct {
}

// NewHostsResolver is the resolver based on system hosts files.
func NewHostsResolver(rs ...io.Reader) (hr *HostsResolver, err error) {
hosts, err := netutil.NewHosts(rs...)
if err != nil {
return nil, fmt.Errorf("parsing hosts: %w", err)
}

func NewHostsResolver(hosts *netutil.Hosts) (hr *HostsResolver) {
hr = &HostsResolver{}
_, hr.addrs = hosts.Mappings()

return hr, nil
return hr
}

// NewDefaultHostsResolver returns a resolver based on system hosts files
Expand All @@ -47,14 +41,12 @@ func NewDefaultHostsResolver(rootFSys fs.FS) (hr *HostsResolver, err error) {
for _, name := range paths {
err = parseHostsFile(rootFSys, hosts, name)
if err != nil {
// Don't wrap the error since it's already informative enough as is.
return nil, err
}
}

hr = &HostsResolver{}
_, hr.addrs = hosts.Mappings()

return hr, nil
return NewHostsResolver(hosts), nil
}

// parseHostsFile reads a single hosts file from fsys and parses it into hosts.
Expand Down
18 changes: 4 additions & 14 deletions internal/bootstrap/hostsresolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ import (
"testing"

"github.com/AdguardTeam/dnsproxy/internal/bootstrap"
"github.com/AdguardTeam/dnsproxy/internal/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/golibs/testutil/fakeio"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand All @@ -24,9 +24,11 @@ func TestHostsResolver_LookupNetIP(t *testing.T) {
v6Addr = netip.MustParseAddr("::1")
)

hr, err := bootstrap.NewHostsResolver(strings.NewReader(hostsData))
hosts, err := netutil.NewHosts(strings.NewReader(hostsData))
require.NoError(t, err)

hr := bootstrap.NewHostsResolver(hosts)

testCases := []struct {
name string
host string
Expand Down Expand Up @@ -94,15 +96,3 @@ func TestHostsResolver_LookupNetIP(t *testing.T) {
testutil.AssertErrorMsg(t, `unsupported network "ip5"`, err)
})
}

func TestNewHostsResolver_error(t *testing.T) {
r := &fakeio.Reader{
OnRead: func(_ []byte) (n int, err error) {
return 0, assert.AnError
},
}

hr, err := bootstrap.NewHostsResolver(r)
assert.ErrorIs(t, err, assert.AnError)
assert.Nil(t, hr)
}
67 changes: 33 additions & 34 deletions internal/netutil/hosts.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ type orderedSet[K string | netip.Addr] struct {
vals []K
}

// add adds val to os if it's not already there.
func (os *orderedSet[K]) add(key, val K) {
if _, ok := os.set[key]; !ok {
os.set[key] = unit{}
os.vals = append(os.vals, val)
}
}

// Convenience aliases for [orderedSet].
type (
namesSet = orderedSet[string]
Expand All @@ -37,13 +45,13 @@ type (
//
// TODO(e.burkov): Think of storing only slices.
type Hosts struct {
// addrs maps each address to its hosts in original case and in original
// names maps each address to its names in original case and in original
// adding order without duplicates.
addrs map[netip.Addr]*namesSet
names map[netip.Addr]*namesSet

// names maps each host to its addresses in original adding order without
// addrs maps each host to its addresses in original adding order without
// duplicates.
names map[string]*addrsSet
addrs map[string]*addrsSet
}

// type check
Expand All @@ -53,8 +61,8 @@ var _ hostsfile.HandleSet = (*Hosts)(nil)
// optional, the error is only returned in case of parsing error.
func NewHosts(readers ...io.Reader) (h *Hosts, err error) {
h = &Hosts{
addrs: map[netip.Addr]*namesSet{},
names: map[string]*addrsSet{},
names: map[netip.Addr]*namesSet{},
addrs: map[string]*addrsSet{},
}

for i, r := range readers {
Expand All @@ -71,34 +79,25 @@ var _ hostsfile.HandleSet = (*Hosts)(nil)

// Add implements the [hostsfile.Set] interface for *Hosts.
func (h *Hosts) Add(rec *hostsfile.Record) {
knownNames := h.addrs[rec.Addr]
if knownNames == nil {
knownNames = &namesSet{set: set[string]{}}
h.addrs[rec.Addr] = knownNames
names := h.names[rec.Addr]
if names == nil {
names = &namesSet{set: set[string]{}}
h.names[rec.Addr] = names
}

for _, name := range rec.Names {
lowered := strings.ToLower(name)
names.add(lowered, name)

if _, ok := knownNames.set[lowered]; !ok {
knownNames.set[lowered] = unit{}
knownNames.vals = append(knownNames.vals, name)
}

knownAddrs := h.names[lowered]
if knownAddrs == nil {
h.names[lowered] = &addrsSet{
vals: []netip.Addr{rec.Addr},
set: set[netip.Addr]{rec.Addr: {}},
addrs := h.addrs[lowered]
if addrs == nil {
addrs = &addrsSet{
vals: []netip.Addr{},
set: set[netip.Addr]{},
}

continue
}

if _, ok := knownAddrs.set[rec.Addr]; !ok {
knownAddrs.set[rec.Addr] = unit{}
knownAddrs.vals = append(knownAddrs.vals, rec.Addr)
h.addrs[lowered] = addrs
}
addrs.add(rec.Addr, rec.Addr)
}
}

Expand All @@ -111,7 +110,7 @@ func (h *Hosts) HandleInvalid(srcName string, _ []byte, err error) {
return
}

if err = errors.Unwrap(lineErr); errors.Is(err, hostsfile.ErrEmptyLine) {
if errors.Is(err, hostsfile.ErrEmptyLine) {
// Ignore empty lines and comments.
return
}
Expand All @@ -122,7 +121,7 @@ func (h *Hosts) HandleInvalid(srcName string, _ []byte, err error) {
// ByAddr returns each host for addr in original case, in original adding order
// without duplicates. It returns nil if h doesn't contain the addr.
func (h *Hosts) ByAddr(addr netip.Addr) (hosts []string) {
if hostsSet, ok := h.addrs[addr]; ok {
if hostsSet, ok := h.names[addr]; ok {
return hostsSet.vals
}

Expand All @@ -132,7 +131,7 @@ func (h *Hosts) ByAddr(addr netip.Addr) (hosts []string) {
// ByName returns each address for host in original adding order without
// duplicates. It returns nil if h doesn't contain the host.
func (h *Hosts) ByName(host string) (addrs []netip.Addr) {
if addrsSet, ok := h.names[strings.ToLower(host)]; ok {
if addrsSet, ok := h.addrs[strings.ToLower(host)]; ok {
return addrsSet.vals
}

Expand All @@ -141,14 +140,14 @@ func (h *Hosts) ByName(host string) (addrs []netip.Addr) {

// Mappings returns a deep clone of the internal mappings.
func (h *Hosts) Mappings() (names map[netip.Addr][]string, addrs map[string][]netip.Addr) {
names = make(map[netip.Addr][]string, len(h.addrs))
addrs = make(map[string][]netip.Addr, len(h.names))
names = make(map[netip.Addr][]string, len(h.names))
addrs = make(map[string][]netip.Addr, len(h.addrs))

for addr, namesSet := range h.addrs {
for addr, namesSet := range h.names {
names[addr] = slices.Clone(namesSet.vals)
}

for name, addrsSet := range h.names {
for name, addrsSet := range h.addrs {
addrs[name] = slices.Clone(addrsSet.vals)
}

Expand Down
6 changes: 6 additions & 0 deletions internal/netutil/hosts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ func TestHosts(t *testing.T) {
}
)

t.Run("Mappings", func(t *testing.T) {
names, addrs := h.Mappings()
assert.Equal(t, wantAddrs, names)
assert.Equal(t, wantHosts, addrs)
})

t.Run("ByAddr", func(t *testing.T) {
t.Parallel()

Expand Down
13 changes: 5 additions & 8 deletions upstream/upstreamresolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,11 @@ import (
// Resolver is an alias for [bootstrap.Resolver] to avoid the import cycle.
type Resolver = bootstrap.Resolver

// NewUpstreamResolver creates a Resolver. resolverAddress format is the same
// as in the [AddressToUpstream], except that it also shouldn't need a
// bootstrap, i.e. have an IP address in hostname, or be a DNSCrypt.
// resolverAddress must not be empty, use another [Resolver] instead, e.g.
// [net.Resolver].
//
// sorting the resolved addresses is the caller's responsibility. Otherwise, it
// creates an [Upstream] using opts.
// NewUpstreamResolver creates an upstream that can be used as [Resolver].
// resolverAddress format is the same as in the [AddressToUpstream], except that
// it also shouldn't need a bootstrap, i.e. have an IP address in hostname, or
// be a DNSCrypt. resolverAddress must not be empty, use another [Resolver]
// instead, e.g. [net.Resolver].
func NewUpstreamResolver(resolverAddress string, opts *Options) (r Resolver, err error) {
upsOpts := &Options{}

Expand Down

0 comments on commit 4b3e59b

Please sign in to comment.