Skip to content

Commit

Permalink
test: add test for SO_BINDTODEVICE with TCP (#652)
Browse files Browse the repository at this point in the history
  • Loading branch information
panjf2000 authored Nov 9, 2024
1 parent 872b71e commit bdd3fb6
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 38 deletions.
2 changes: 1 addition & 1 deletion client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ func startGnetClient(t *testing.T, cli *Client, network, addr string, multicore,
}
if netDial {
var netConn net.Conn
netConn, err = NetDial(network, addr)
netConn, err = stdDial(network, addr)
require.NoError(t, err)
c, err = cli.EnrollContext(netConn, handler)
} else {
Expand Down
120 changes: 84 additions & 36 deletions os_unix_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"net"
"regexp"
"runtime"
"strings"
"sync"
"sync/atomic"
"testing"
Expand All @@ -27,7 +28,7 @@ import (

var (
SysClose = unix.Close
NetDial = net.Dial
stdDial = net.Dial
)

// NOTE: TestServeMulticast can fail with "write: no buffer space available" on Wi-Fi interface.
Expand Down Expand Up @@ -244,20 +245,31 @@ func getInterfaceIP(ifname string, ipv4 bool) (net.IP, error) {
return nil, errors.New("no valid IP address found")
}

type testBindToDeviceServer struct {
type testBindToDeviceServer[T interface{ *net.TCPAddr | *net.UDPAddr }] struct {
BuiltinEventEngine
tester *testing.T
data []byte
packets atomic.Int32
expectedPackets int32
network string
loopBackIP net.IP
eth0IP net.IP
broadcastIP net.IP
zone string
loopBackAddr T
eth0Addr T
broadcastAddr T
}

func netDial[T *net.TCPAddr | *net.UDPAddr](network string, a T) (net.Conn, error) {
addr := any(a)
switch v := addr.(type) {
case *net.TCPAddr:
return net.DialTCP(network, nil, v)
case *net.UDPAddr:
return net.DialUDP(network, nil, v)
default:
return nil, errors.New("unsupported address type")
}
}

func (s *testBindToDeviceServer) OnTraffic(c Conn) (action Action) {
func (s *testBindToDeviceServer[T]) OnTraffic(c Conn) (action Action) {
b, err := c.Next(-1)
assert.NoError(s.tester, err)
assert.EqualValues(s.tester, s.data, b)
Expand All @@ -267,30 +279,34 @@ func (s *testBindToDeviceServer) OnTraffic(c Conn) (action Action) {
return
}

func (s *testBindToDeviceServer) OnShutdown(_ Engine) {
func (s *testBindToDeviceServer[T]) OnShutdown(_ Engine) {
assert.EqualValues(s.tester, s.expectedPackets, s.packets.Load())
}

func (s *testBindToDeviceServer) OnTick() (delay time.Duration, action Action) {
func (s *testBindToDeviceServer[T]) OnTick() (delay time.Duration, action Action) {
// Send a packet to the loopback interface, it should never make its way to the server
// because we've bound the server to eth0.
lp, err := findLoopbackInterface()
assert.NoError(s.tester, err)
c, err := net.DialUDP(s.network, nil, &net.UDPAddr{IP: s.loopBackIP, Port: 9999, Zone: lp.Name})
assert.NoError(s.tester, err)
defer c.Close()
_, err = c.Write(s.data)
assert.NoError(s.tester, err)
c, err := netDial(s.network, s.loopBackAddr)
if strings.HasPrefix(s.network, "tcp") {
assert.ErrorContains(s.tester, err, "connection refused")
} else {
assert.NoError(s.tester, err)
defer c.Close()
_, err = c.Write(s.data)
assert.NoError(s.tester, err)
}

// Send a packet to the broadcast address, it should reach the server.
c6, err := net.DialUDP(s.network, nil, &net.UDPAddr{IP: s.broadcastIP, Port: 9999, Zone: s.zone})
assert.NoError(s.tester, err)
defer c6.Close()
_, err = c6.Write(s.data)
assert.NoError(s.tester, err)
if s.broadcastAddr != nil {
// Send a packet to the broadcast address, it should reach the server.
c6, err := netDial(s.network, s.broadcastAddr)
assert.NoError(s.tester, err)
defer c6.Close()
_, err = c6.Write(s.data)
assert.NoError(s.tester, err)
}

// Send a packet to the eth0 interface, it should reach the server.
c4, err := net.DialUDP(s.network, nil, &net.UDPAddr{IP: s.eth0IP, Port: 9999, Zone: s.zone})
c4, err := netDial(s.network, s.eth0Addr)
assert.NoError(s.tester, err)
defer c4.Close()
_, err = c4.Write(s.data)
Expand All @@ -305,28 +321,44 @@ func (s *testBindToDeviceServer) OnTick() (delay time.Duration, action Action) {

func TestBindToDevice(t *testing.T) {
if runtime.GOOS != "linux" {
err := Run(&testBindToDeviceServer{}, "udp://:9999", WithBindToDevice("eth0"))
err := Run(&testBindToDeviceServer[*net.UDPAddr]{}, "tcp://:9999", WithBindToDevice("eth0"))
assert.ErrorIs(t, err, errorx.ErrUnsupportedOp)
return
}

lp, err := findLoopbackInterface()
assert.NoError(t, err)
dev, err := detectLinuxEthernetInterfaceName()
assert.NoErrorf(t, err, "no testable Ethernet interface found")
t.Logf("detected Ethernet interface: %s", dev)
data := []byte("hello")
t.Run("IPv4", func(t *testing.T) {
t.Run("UDP", func(t *testing.T) {
ip, err := getInterfaceIP(dev, true)
ip, err := getInterfaceIP(dev, true)
assert.NoError(t, err)
t.Run("TCP", func(t *testing.T) {
ts := &testBindToDeviceServer[*net.TCPAddr]{
tester: t,
data: data,
expectedPackets: 1,
network: "tcp",
loopBackAddr: &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 9999, Zone: ""},
eth0Addr: &net.TCPAddr{IP: ip, Port: 9999, Zone: ""},
}
require.NoError(t, err)
err = Run(ts, "tcp://0.0.0.0:9999",
WithTicker(true),
WithBindToDevice(dev))
assert.NoError(t, err)
ts := &testBindToDeviceServer{
})
t.Run("UDP", func(t *testing.T) {
ts := &testBindToDeviceServer[*net.UDPAddr]{
tester: t,
data: data,
expectedPackets: 2,
network: "udp",
loopBackIP: net.IPv4(127, 0, 0, 1),
eth0IP: ip,
broadcastIP: net.IPv4bcast,
zone: dev,
loopBackAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 9999, Zone: ""},
eth0Addr: &net.UDPAddr{IP: ip, Port: 9999, Zone: ""},
broadcastAddr: &net.UDPAddr{IP: net.IPv4bcast, Port: 9999, Zone: ""},
}
require.NoError(t, err)
err = Run(ts, "udp://0.0.0.0:9999",
Expand All @@ -336,18 +368,34 @@ func TestBindToDevice(t *testing.T) {
})
})
t.Run("IPv6", func(t *testing.T) {
t.Run("TCP", func(t *testing.T) {
ip, err := getInterfaceIP(dev, false)
assert.NoError(t, err)
ts := &testBindToDeviceServer[*net.TCPAddr]{
tester: t,
data: data,
expectedPackets: 1,
network: "tcp6",
loopBackAddr: &net.TCPAddr{IP: net.IPv6loopback, Port: 9999, Zone: lp.Name},
eth0Addr: &net.TCPAddr{IP: ip, Port: 9999, Zone: dev},
}
require.NoError(t, err)
err = Run(ts, "tcp6://[::]:9999",
WithTicker(true),
WithBindToDevice(dev))
assert.NoError(t, err)
})
t.Run("UDP", func(t *testing.T) {
ip, err := getInterfaceIP(dev, false)
assert.NoError(t, err)
ts := &testBindToDeviceServer{
ts := &testBindToDeviceServer[*net.UDPAddr]{
tester: t,
data: data,
expectedPackets: 2,
network: "udp6",
loopBackIP: net.IPv6loopback,
eth0IP: ip,
broadcastIP: net.IPv6linklocalallnodes,
zone: dev,
loopBackAddr: &net.UDPAddr{IP: net.IPv6loopback, Port: 9999, Zone: lp.Name},
eth0Addr: &net.UDPAddr{IP: ip, Port: 9999, Zone: dev},
broadcastAddr: &net.UDPAddr{IP: net.IPv6linklocalallnodes, Port: 9999, Zone: dev},
}
require.NoError(t, err)
err = Run(ts, "udp6://[::]:9999",
Expand Down
2 changes: 1 addition & 1 deletion os_windows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ func SysClose(fd int) error {
return syscall.CloseHandle(syscall.Handle(fd))
}

func NetDial(network, addr string) (net.Conn, error) {
func stdDial(network, addr string) (net.Conn, error) {
if network == "unix" {
laddr, _ := net.ResolveUnixAddr(network, unixAddr(addr))
raddr, _ := net.ResolveUnixAddr(network, addr)
Expand Down

0 comments on commit bdd3fb6

Please sign in to comment.