diff --git a/internal/socket/socket.go b/internal/socket/socket.go index a6d9a9448..9ddce28a0 100644 --- a/internal/socket/socket.go +++ b/internal/socket/socket.go @@ -27,24 +27,33 @@ import ( ) // Option is used for setting an option on socket. -type Option struct { - SetSockOpt func(int, int) error - Opt int +type Option[T int | string] struct { + SetSockOpt func(int, T) error + Opt T +} + +func execSockOpts[T int | string](fd int, opts []Option[T]) error { + for _, opt := range opts { + if err := opt.SetSockOpt(fd, opt.Opt); err != nil { + return err + } + } + return nil } // TCPSocket calls the internal tcpSocket. -func TCPSocket(proto, addr string, passive bool, sockOpts ...Option) (int, net.Addr, error) { - return tcpSocket(proto, addr, passive, sockOpts...) +func TCPSocket(proto, addr string, passive bool, sockOptInts []Option[int], sockOptStrs []Option[string]) (int, net.Addr, error) { + return tcpSocket(proto, addr, passive, sockOptInts, sockOptStrs) } // UDPSocket calls the internal udpSocket. -func UDPSocket(proto, addr string, connect bool, sockOpts ...Option) (int, net.Addr, error) { - return udpSocket(proto, addr, connect, sockOpts...) +func UDPSocket(proto, addr string, connect bool, sockOptInts []Option[int], sockOptStrs []Option[string]) (int, net.Addr, error) { + return udpSocket(proto, addr, connect, sockOptInts, sockOptStrs) } // UnixSocket calls the internal udsSocket. -func UnixSocket(proto, addr string, passive bool, sockOpts ...Option) (int, net.Addr, error) { - return udsSocket(proto, addr, passive, sockOpts...) +func UnixSocket(proto, addr string, passive bool, sockOptInts []Option[int], sockOptStrs []Option[string]) (int, net.Addr, error) { + return udsSocket(proto, addr, passive, sockOptInts, sockOptStrs) } // Accept accepts the next incoming socket along with setting diff --git a/internal/socket/sockopts_bsd.go b/internal/socket/sockopts_bsd.go new file mode 100644 index 000000000..92b1b79ec --- /dev/null +++ b/internal/socket/sockopts_bsd.go @@ -0,0 +1,26 @@ +// Copyright (c) 2024 The Gnet Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build dragonfly || freebsd || netbsd || openbsd +// +build dragonfly freebsd netbsd openbsd + +package socket + +import errorx "github.com/panjf2000/gnet/v2/pkg/errors" + +// SetBindToDevice is not implemented on *BSD because there is +// no equivalent of Linux's SO_BINDTODEVICE. +func SetBindToDevice(_ int, _ string) error { + return errorx.ErrUnsupportedOp +} diff --git a/internal/socket/sockopts_darwin.go b/internal/socket/sockopts_darwin.go index 338857cc5..5a2f69217 100644 --- a/internal/socket/sockopts_darwin.go +++ b/internal/socket/sockopts_darwin.go @@ -19,6 +19,8 @@ import ( "os" "golang.org/x/sys/unix" + + errorx "github.com/panjf2000/gnet/v2/pkg/errors" ) // SetKeepAlivePeriod sets whether the operating system should send @@ -52,3 +54,9 @@ func SetKeepAlivePeriod(fd, secs int) error { return os.NewSyscallError("setsockopt", unix.SetsockoptInt(fd, unix.IPPROTO_TCP, unix.TCP_KEEPCNT, 5)) } + +// SetBindToDevice is not implemented on macOS because there is +// no equivalent of Linux's SO_BINDTODEVICE. +func SetBindToDevice(_ int, _ string) error { + return errorx.ErrUnsupportedOp +} diff --git a/internal/socket/sockopts_linux.go b/internal/socket/sockopts_linux.go new file mode 100644 index 000000000..f110f16b8 --- /dev/null +++ b/internal/socket/sockopts_linux.go @@ -0,0 +1,30 @@ +// Copyright (c) 2024 The Gnet Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package socket + +import ( + "os" + + "golang.org/x/sys/unix" +) + +// SetBindToDevice binds the socket to a specific network interface. +// +// SO_BINDTODEVICE on Linux works in both directions: only process packets +// received from the particular interface along with sending them through +// that interface, instead of following the default route. +func SetBindToDevice(fd int, ifname string) error { + return os.NewSyscallError("setsockopt", unix.BindToDevice(fd, ifname)) +} diff --git a/internal/socket/sockopts_openbsd.go b/internal/socket/sockopts_openbsd.go index 8670f2af0..47820e6a0 100644 --- a/internal/socket/sockopts_openbsd.go +++ b/internal/socket/sockopts_openbsd.go @@ -14,11 +14,11 @@ package socket -import "golang.org/x/sys/unix" +import errorx "github.com/panjf2000/gnet/v2/pkg/errors" // SetKeepAlivePeriod sets whether the operating system should send // keep-alive messages on the connection and sets period between TCP keep-alive probes. func SetKeepAlivePeriod(_, _ int) error { // OpenBSD has no user-settable per-socket TCP keepalive options. - return unix.ENOPROTOOPT + return errorx.ErrUnsupportedOp } diff --git a/internal/socket/sockopts_posix.go b/internal/socket/sockopts_posix.go index 61fc8d31f..5da7a7206 100644 --- a/internal/socket/sockopts_posix.go +++ b/internal/socket/sockopts_posix.go @@ -39,13 +39,13 @@ func SetNoDelay(fd, noDelay int) error { // SetRecvBuffer sets the size of the operating system's // receive buffer associated with the connection. func SetRecvBuffer(fd, size int) error { - return unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_RCVBUF, size) + return os.NewSyscallError("setsockopt", unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_RCVBUF, size)) } // SetSendBuffer sets the size of the operating system's // transmit buffer associated with the connection. func SetSendBuffer(fd, size int) error { - return unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_SNDBUF, size) + return os.NewSyscallError("setsockopt", unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_SNDBUF, size)) } // SetReuseAddr enables SO_REUSEADDR option on socket. @@ -55,7 +55,7 @@ func SetReuseAddr(fd, reuseAddr int) error { // SetIPv6Only restricts a IPv6 socket to only process IPv6 requests or both IPv4 and IPv6 requests. func SetIPv6Only(fd, ipv6only int) error { - return unix.SetsockoptInt(fd, unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, ipv6only) + return os.NewSyscallError("setsockopt", unix.SetsockoptInt(fd, unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, ipv6only)) } // SetLinger sets the behavior of Close on a connection which still @@ -79,7 +79,7 @@ func SetLinger(fd, sec int) error { l.Onoff = 0 l.Linger = 0 } - return unix.SetsockoptLinger(fd, syscall.SOL_SOCKET, syscall.SO_LINGER, &l) + return os.NewSyscallError("setsockopt", unix.SetsockoptLinger(fd, syscall.SOL_SOCKET, syscall.SO_LINGER, &l)) } // SetMulticastMembership returns with a socket option function based on the IP diff --git a/internal/socket/tcp_socket.go b/internal/socket/tcp_socket.go index 21d4af32e..e086469f2 100644 --- a/internal/socket/tcp_socket.go +++ b/internal/socket/tcp_socket.go @@ -83,7 +83,7 @@ func determineTCPProto(proto string, addr *net.TCPAddr) (string, error) { // tcpSocket creates an endpoint for communication and returns a file descriptor that refers to that endpoint. // Argument `reusePort` indicates whether the SO_REUSEPORT flag will be assigned. -func tcpSocket(proto, addr string, passive bool, sockOpts ...Option) (fd int, netAddr net.Addr, err error) { +func tcpSocket(proto, addr string, passive bool, sockOptInts []Option[int], sockOptStrs []Option[string]) (fd int, netAddr net.Addr, err error) { var ( family int ipv6only bool @@ -114,10 +114,11 @@ func tcpSocket(proto, addr string, passive bool, sockOpts ...Option) (fd int, ne } } - for _, sockOpt := range sockOpts { - if err = sockOpt.SetSockOpt(fd, sockOpt.Opt); err != nil { - return - } + if err = execSockOpts(fd, sockOptInts); err != nil { + return + } + if err = execSockOpts(fd, sockOptStrs); err != nil { + return } if passive { diff --git a/internal/socket/udp_socket.go b/internal/socket/udp_socket.go index 6205c986b..0e524457b 100644 --- a/internal/socket/udp_socket.go +++ b/internal/socket/udp_socket.go @@ -81,7 +81,7 @@ func determineUDPProto(proto string, addr *net.UDPAddr) (string, error) { // udpSocket creates an endpoint for communication and returns a file descriptor that refers to that endpoint. // Argument `reusePort` indicates whether the SO_REUSEPORT flag will be assigned. -func udpSocket(proto, addr string, connect bool, sockOpts ...Option) (fd int, netAddr net.Addr, err error) { +func udpSocket(proto, addr string, connect bool, sockOptInts []Option[int], sockOptStrs []Option[string]) (fd int, netAddr net.Addr, err error) { var ( family int ipv6only bool @@ -117,10 +117,11 @@ func udpSocket(proto, addr string, connect bool, sockOpts ...Option) (fd int, ne return } - for _, sockOpt := range sockOpts { - if err = sockOpt.SetSockOpt(fd, sockOpt.Opt); err != nil { - return - } + if err = execSockOpts(fd, sockOptInts); err != nil { + return + } + if err = execSockOpts(fd, sockOptStrs); err != nil { + return } if connect { diff --git a/internal/socket/unix_socket.go b/internal/socket/unix_socket.go index 688672d09..87cc889b4 100644 --- a/internal/socket/unix_socket.go +++ b/internal/socket/unix_socket.go @@ -45,7 +45,7 @@ func GetUnixSockAddr(proto, addr string) (sa unix.Sockaddr, family int, unixAddr // udsSocket creates an endpoint for communication and returns a file descriptor that refers to that endpoint. // Argument `reusePort` indicates whether the SO_REUSEPORT flag will be assigned. -func udsSocket(proto, addr string, passive bool, sockOpts ...Option) (fd int, netAddr net.Addr, err error) { +func udsSocket(proto, addr string, passive bool, sockOptInts []Option[int], sockOptStrs []Option[string]) (fd int, netAddr net.Addr, err error) { var ( family int sa unix.Sockaddr @@ -70,10 +70,11 @@ func udsSocket(proto, addr string, passive bool, sockOpts ...Option) (fd int, ne } }() - for _, sockOpt := range sockOpts { - if err = sockOpt.SetSockOpt(fd, sockOpt.Opt); err != nil { - return - } + if err = execSockOpts(fd, sockOptInts); err != nil { + return + } + if err = execSockOpts(fd, sockOptStrs); err != nil { + return } if passive { diff --git a/listener_unix.go b/listener_unix.go index 18fde857a..2e2711df8 100644 --- a/listener_unix.go +++ b/listener_unix.go @@ -36,7 +36,8 @@ type listener struct { fd int addr net.Addr address, network string - sockOpts []socket.Option + sockOptInts []socket.Option[int] + sockOptStrs []socket.Option[string] pollAttachment *netpoll.PollAttachment // listener attachment for poller } @@ -52,14 +53,14 @@ func (ln *listener) dup() (int, error) { func (ln *listener) normalize() (err error) { switch ln.network { case "tcp", "tcp4", "tcp6": - ln.fd, ln.addr, err = socket.TCPSocket(ln.network, ln.address, true, ln.sockOpts...) + ln.fd, ln.addr, err = socket.TCPSocket(ln.network, ln.address, true, ln.sockOptInts, ln.sockOptStrs) ln.network = "tcp" case "udp", "udp4", "udp6": - ln.fd, ln.addr, err = socket.UDPSocket(ln.network, ln.address, false, ln.sockOpts...) + ln.fd, ln.addr, err = socket.UDPSocket(ln.network, ln.address, false, ln.sockOptInts, ln.sockOptStrs) ln.network = "udp" case "unix": _ = os.RemoveAll(ln.address) - ln.fd, ln.addr, err = socket.UnixSocket(ln.network, ln.address, true, ln.sockOpts...) + ln.fd, ln.addr, err = socket.UnixSocket(ln.network, ln.address, true, ln.sockOptInts, ln.sockOptStrs) default: err = errors.ErrUnsupportedProtocol } @@ -79,37 +80,44 @@ func (ln *listener) close() { } func initListener(network, addr string, options *Options) (l *listener, err error) { - var sockOpts []socket.Option + var ( + sockOptInts []socket.Option[int] + sockOptStrs []socket.Option[string] + ) if options.ReusePort || strings.HasPrefix(network, "udp") { - sockOpt := socket.Option{SetSockOpt: socket.SetReuseport, Opt: 1} - sockOpts = append(sockOpts, sockOpt) + sockOpt := socket.Option[int]{SetSockOpt: socket.SetReuseport, Opt: 1} + sockOptInts = append(sockOptInts, sockOpt) } if options.ReuseAddr { - sockOpt := socket.Option{SetSockOpt: socket.SetReuseAddr, Opt: 1} - sockOpts = append(sockOpts, sockOpt) + sockOpt := socket.Option[int]{SetSockOpt: socket.SetReuseAddr, Opt: 1} + sockOptInts = append(sockOptInts, sockOpt) } if options.TCPNoDelay == TCPNoDelay && strings.HasPrefix(network, "tcp") { - sockOpt := socket.Option{SetSockOpt: socket.SetNoDelay, Opt: 1} - sockOpts = append(sockOpts, sockOpt) + sockOpt := socket.Option[int]{SetSockOpt: socket.SetNoDelay, Opt: 1} + sockOptInts = append(sockOptInts, sockOpt) } if options.SocketRecvBuffer > 0 { - sockOpt := socket.Option{SetSockOpt: socket.SetRecvBuffer, Opt: options.SocketRecvBuffer} - sockOpts = append(sockOpts, sockOpt) + sockOpt := socket.Option[int]{SetSockOpt: socket.SetRecvBuffer, Opt: options.SocketRecvBuffer} + sockOptInts = append(sockOptInts, sockOpt) } if options.SocketSendBuffer > 0 { - sockOpt := socket.Option{SetSockOpt: socket.SetSendBuffer, Opt: options.SocketSendBuffer} - sockOpts = append(sockOpts, sockOpt) + sockOpt := socket.Option[int]{SetSockOpt: socket.SetSendBuffer, Opt: options.SocketSendBuffer} + sockOptInts = append(sockOptInts, sockOpt) } if strings.HasPrefix(network, "udp") { udpAddr, err := net.ResolveUDPAddr(network, addr) if err == nil && udpAddr.IP.IsMulticast() { if sockoptFn := socket.SetMulticastMembership(network, udpAddr); sockoptFn != nil { - sockOpt := socket.Option{SetSockOpt: sockoptFn, Opt: options.MulticastInterfaceIndex} - sockOpts = append(sockOpts, sockOpt) + sockOpt := socket.Option[int]{SetSockOpt: sockoptFn, Opt: options.MulticastInterfaceIndex} + sockOptInts = append(sockOptInts, sockOpt) } } } - l = &listener{network: network, address: addr, sockOpts: sockOpts} + if options.BindToDevice != "" { + sockOpt := socket.Option[string]{SetSockOpt: socket.SetBindToDevice, Opt: options.BindToDevice} + sockOptStrs = append(sockOptStrs, sockOpt) + } + l = &listener{network: network, address: addr, sockOptInts: sockOptInts, sockOptStrs: sockOptStrs} err = l.normalize() return } diff --git a/options.go b/options.go index 545532443..86d5883af 100644 --- a/options.go +++ b/options.go @@ -68,6 +68,12 @@ type Options struct { // MulticastInterfaceIndex is the index of the interface name where the multicast UDP addresses will be bound to. MulticastInterfaceIndex int + // BindToDevice is the name of the interface to which the listening socket will be bound. + // + // It is only available on Linux at the moment, an error will therefore be returned when + // setting this option on non-linux platforms. + BindToDevice string + // ============================= Options for both server-side and client-side ============================= // ReadBufferCap is the maximum number of bytes that can be read from the remote when the readable event comes. @@ -95,7 +101,7 @@ type Options struct { // Ticker indicates whether the ticker has been set up. Ticker bool - // TCPKeepAlive enable the TCP keep-alive mechanism (SO_KEEPALIVE) and set its value + // TCPKeepAlive enables the TCP keep-alive mechanism (SO_KEEPALIVE) and set its value // on TCP_KEEPIDLE, 1/5 of its value on TCP_KEEPINTVL, and 5 on TCP_KEEPCNT. TCPKeepAlive time.Duration @@ -270,6 +276,16 @@ func WithMulticastInterfaceIndex(idx int) Option { } } +// WithBindToDevice sets the name of the interface to which the listening socket will be bound. +// +// It is only available on Linux at the moment, an error will therefore be returned when +// setting this option on non-linux platforms. +func WithBindToDevice(iface string) Option { + return func(opts *Options) { + opts.BindToDevice = iface + } +} + // WithEdgeTriggeredIO enables the edge-triggered I/O for the underlying epoll/kqueue event-loop. func WithEdgeTriggeredIO(et bool) Option { return func(opts *Options) { diff --git a/os_unix_test.go b/os_unix_test.go index 696ec1514..02b2d1107 100644 --- a/os_unix_test.go +++ b/os_unix_test.go @@ -10,6 +10,7 @@ import ( "fmt" "math/rand" "net" + "runtime" "sync" "sync/atomic" "testing" @@ -19,6 +20,7 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/sys/unix" + errorx "github.com/panjf2000/gnet/v2/pkg/errors" "github.com/panjf2000/gnet/v2/pkg/logging" ) @@ -27,7 +29,7 @@ var ( NetDial = net.Dial ) -// NOTE: TestServeMulticast can fail with "write: no buffer space available" on wifi interface. +// NOTE: TestServeMulticast can fail with "write: no buffer space available" on Wi-Fi interface. func TestServeMulticast(t *testing.T) { t.Run("IPv4", func(t *testing.T) { // 224.0.0.169 is an unassigned address from the Local Network Control Block @@ -191,6 +193,145 @@ func TestMulticastBindIPv6(t *testing.T) { assert.NoError(t, err) } +func getInterfaceIP(ifname string, ipv4 bool) (net.IP, error) { + iface, err := net.InterfaceByName(ifname) + if err != nil { + return nil, err + } + // Get all unicast addresses for this interface + addrs, err := iface.Addrs() + if err != nil { + return nil, err + } + // Loop through the addresses and find the first IPv4 address + for _, addr := range addrs { + var ip net.IP + switch v := addr.(type) { + case *net.IPNet: + ip = v.IP + case *net.IPAddr: + ip = v.IP + } + // Check if the IP is IPv4. + if ip != nil && (ip.To4() != nil) == ipv4 { + return ip, nil + } + } + return nil, errors.New("no valid IP address found") +} + +type testBindToDeviceServer struct { + BuiltinEventEngine + tester *testing.T + data []byte + packets atomic.Int32 + expectedPackets int32 + network string + loopBackIP net.IP + eth0IP net.IP + broadcastIP net.IP + zone string +} + +func (s *testBindToDeviceServer) OnTraffic(c Conn) (action Action) { + b, err := c.Next(-1) + assert.NoError(s.tester, err) + assert.EqualValues(s.tester, s.data, b) + _, err = c.Write(b) + assert.NoError(s.tester, err) + s.packets.Add(1) + return +} + +func (s *testBindToDeviceServer) OnShutdown(_ Engine) { + assert.EqualValues(s.tester, s.packets.Load(), s.expectedPackets) +} + +func (s *testBindToDeviceServer) OnTick() (delay time.Duration, action Action) { + // Send a packet to the loopback interface, it should never make its way to the server + // because we've bound the server to eth0. + lp, err := findLoopbackInterface() + assert.NoError(s.tester, err) + c, err := net.DialUDP(s.network, nil, &net.UDPAddr{IP: s.loopBackIP, Port: 9999, Zone: lp.Name}) + assert.NoError(s.tester, err) + defer c.Close() + _, err = c.Write(s.data) + assert.NoError(s.tester, err) + + // Send a packet to the eth0 interface, it should reach the server. + c4, err := net.DialUDP(s.network, nil, &net.UDPAddr{IP: s.eth0IP, Port: 9999, Zone: s.zone}) + assert.NoError(s.tester, err) + defer c4.Close() + _, err = c4.Write(s.data) + assert.NoError(s.tester, err) + buf := make([]byte, len(s.data)) + _, err = c4.Read(buf) + assert.NoError(s.tester, err) + assert.EqualValues(s.tester, s.data, buf, len(s.data), len(buf)) + + // Send a packet to the broadcast address, it should reach the server. + c6, err := net.DialUDP(s.network, nil, &net.UDPAddr{IP: s.broadcastIP, Port: 9999, Zone: s.zone}) + assert.NoError(s.tester, err) + defer c6.Close() + _, err = c6.Write(s.data) + assert.NoError(s.tester, err) + + return time.Second, Shutdown +} + +func TestBindToDevice(t *testing.T) { + if runtime.GOOS != "linux" { + err := Run(&testBindToDeviceServer{}, "udp://:9999", WithBindToDevice("eth0")) + assert.ErrorIs(t, err, errorx.ErrUnsupportedOp) + t.Skip("skipping the subsequent tests on non-linux OS") + } + + dev := "eth0" + data := []byte("hello") + t.Run("IPv4", func(t *testing.T) { + t.Run("UDP", func(t *testing.T) { + ip, err := getInterfaceIP(dev, true) + assert.NoError(t, err) + ts := &testBindToDeviceServer{ + tester: t, + data: data, + expectedPackets: 2, + network: "udp", + loopBackIP: net.IPv4(127, 0, 0, 1), + eth0IP: ip, + broadcastIP: net.IPv4bcast, + zone: dev, + } + require.NoError(t, err) + err = Run(ts, "udp://0.0.0.0:9999", + WithTicker(true), + WithBindToDevice(dev)) + assert.NoError(t, err) + }) + }) + t.Run("IPv6", func(t *testing.T) { + t.Run("UDP", func(t *testing.T) { + ip, err := getInterfaceIP(dev, false) + assert.NoError(t, err) + ts := &testBindToDeviceServer{ + tester: t, + data: data, + expectedPackets: 2, + network: "udp6", + loopBackIP: net.IPv6loopback, + eth0IP: ip, + broadcastIP: net.IPv6linklocalallnodes, + zone: dev, + } + require.NoError(t, err) + err = Run(ts, "udp6://[::]:9999", + WithTicker(true), + WithBindToDevice(dev)) + assert.NoError(t, err) + }) + }) +} + /* func TestEngineAsyncWrite(t *testing.T) { t.Run("tcp", func(t *testing.T) {