From 53192efc4587fc96d20ae9bd0322d15de2554c81 Mon Sep 17 00:00:00 2001 From: Andy Pan Date: Sat, 23 Mar 2024 18:54:17 +0800 Subject: [PATCH] feat: enable OnOpen for connected UDP socket Fixes #549 --- acceptor_windows.go | 2 +- client_test.go | 82 +++++++++++++++++++++++++++++++++---------- client_unix.go | 8 ++++- client_windows.go | 7 ++-- connection_unix.go | 6 +++- connection_windows.go | 6 ++++ eventloop_unix.go | 15 ++++++-- eventloop_windows.go | 15 +++++--- gnet_test.go | 7 ++-- 9 files changed, 116 insertions(+), 32 deletions(-) diff --git a/acceptor_windows.go b/acceptor_windows.go index 13d551f82..25717e59f 100644 --- a/acceptor_windows.go +++ b/acceptor_windows.go @@ -69,7 +69,7 @@ func (eng *engine) listen() (err error) { } el := eng.eventLoops.next(tc.RemoteAddr()) c := newTCPConn(tc, el) - el.ch <- c + el.ch <- &openConn{c: c} go func(c *conn, tc net.Conn, el *eventloop) { var buffer [0x10000]byte for { diff --git a/client_test.go b/client_test.go index f6b4fd6e8..0a4212f60 100644 --- a/client_test.go +++ b/client_test.go @@ -4,9 +4,11 @@ package gnet import ( + "bytes" "io" "math/rand" "net" + "sync" "sync/atomic" "testing" "time" @@ -41,6 +43,13 @@ func (ev *clientEvents) OnBoot(e Engine) Action { return None } +var pingMsg = []byte("PING\r\n") + +func (ev *clientEvents) OnOpen(Conn) (out []byte, action Action) { + out = pingMsg + return +} + func (ev *clientEvents) OnClose(Conn, error) Action { if ev.svr != nil { if atomic.AddInt32(&ev.svr.clientActive, -1) == 0 { @@ -53,7 +62,7 @@ func (ev *clientEvents) OnClose(Conn, error) Action { func (ev *clientEvents) OnTraffic(c Conn) (action Action) { handler := c.Context().(*connHandler) if handler.network == "udp" { - ev.packetLen = 1024 + ev.packetLen = datagramLen } buf, err := c.Next(-1) assert.NoError(ev.tester, err) @@ -190,19 +199,20 @@ func TestServeWithGnetClient(t *testing.T) { type testClientServer struct { *BuiltinEventEngine - client *Client - tester *testing.T - eng Engine - network string - addr string - multicore bool - async bool - nclients int - started int32 - connected int32 - clientActive int32 - disconnected int32 - workerPool *goPool.Pool + client *Client + tester *testing.T + eng Engine + network string + addr string + multicore bool + async bool + nclients int + started int32 + connected int32 + clientActive int32 + disconnected int32 + workerPool *goPool.Pool + udpReadHeader int32 } func (s *testClientServer) OnBoot(eng Engine) (action Action) { @@ -211,7 +221,7 @@ func (s *testClientServer) OnBoot(eng Engine) (action Action) { } func (s *testClientServer) OnOpen(c Conn) (out []byte, action Action) { - c.SetContext(c) + c.SetContext(&sync.Once{}) atomic.AddInt32(&s.connected, 1) require.NotNil(s.tester, c.LocalAddr(), "nil local addr") require.NotNil(s.tester, c.RemoteAddr(), "nil remote addr") @@ -223,7 +233,7 @@ func (s *testClientServer) OnClose(c Conn, err error) (action Action) { logging.Debugf("error occurred on closed, %v\n", err) } if s.network != "udp" { - require.Equal(s.tester, c.Context(), c, "invalid context") + require.IsType(s.tester, c.Context(), new(sync.Once), "invalid context") } atomic.AddInt32(&s.disconnected, 1) @@ -236,7 +246,25 @@ func (s *testClientServer) OnClose(c Conn, err error) (action Action) { return } +func (s *testClientServer) OnShutdown(Engine) { + if s.network == "udp" { + require.EqualValues(s.tester, int32(s.nclients), atomic.LoadInt32(&s.udpReadHeader)) + } +} + func (s *testClientServer) OnTraffic(c Conn) (action Action) { + readHeader := func() { + ping := make([]byte, len(pingMsg)) + n, err := io.ReadFull(c, ping) + require.NoError(s.tester, err) + require.EqualValues(s.tester, len(pingMsg), n) + require.Equal(s.tester, string(pingMsg), string(ping), "bad header") + } + v := c.Context() + if v != nil && s.network != "udp" { + v.(*sync.Once).Do(readHeader) + } + if s.async { buf := bbPool.Get() _, _ = c.WriteTo(buf) @@ -247,14 +275,30 @@ func (s *testClientServer) OnTraffic(c Conn) (action Action) { _ = c.OutboundBuffered() _, _ = c.Discard(1) } + if s.network == "udp" && bytes.Equal(buf.Bytes(), pingMsg) { + atomic.AddInt32(&s.udpReadHeader, 1) + buf.Reset() + } _ = s.workerPool.Submit( func() { - _ = c.AsyncWrite(buf.Bytes(), nil) + if buf.Len() > 0 { + err := c.AsyncWrite(buf.Bytes(), nil) + require.NoError(s.tester, err) + } }) return } + buf, _ := c.Next(-1) - _, _ = c.Write(buf) + if s.network == "udp" && bytes.Equal(buf, pingMsg) { + atomic.AddInt32(&s.udpReadHeader, 1) + buf = nil + } + if len(buf) > 0 { + n, err := c.Write(buf) + require.NoError(s.tester, err) + require.EqualValues(s.tester, len(buf), n) + } return } @@ -343,7 +387,7 @@ func startGnetClient(t *testing.T, cli *Client, network, addr string, multicore, for time.Since(start) < duration { reqData := make([]byte, streamLen) if network == "udp" { - reqData = reqData[:1024] + reqData = reqData[:datagramLen] } _, err = rand.Read(reqData) require.NoError(t, err) diff --git a/client_unix.go b/client_unix.go index fc4470b3f..64890b6c7 100644 --- a/client_unix.go +++ b/client_unix.go @@ -228,10 +228,16 @@ func (cli *Client) EnrollContext(c net.Conn, ctx interface{}) (Conn, error) { return nil, errorx.ErrUnsupportedProtocol } gc.SetContext(ctx) - err = cli.el.poller.UrgentTrigger(cli.el.register, gc) + + connOpened := make(chan struct{}) + ccb := &connWithCallback{c: gc, cb: func() { + close(connOpened) + }} + err = cli.el.poller.UrgentTrigger(cli.el.register, ccb) if err != nil { gc.Close() return nil, err } + <-connOpened return gc, nil } diff --git a/client_windows.go b/client_windows.go index b856bf7c3..07d2294b2 100644 --- a/client_windows.go +++ b/client_windows.go @@ -147,6 +147,7 @@ func (cli *Client) Enroll(nc net.Conn) (gc Conn, err error) { } func (cli *Client) EnrollContext(nc net.Conn, ctx interface{}) (gc Conn, err error) { + connOpened := make(chan struct{}) switch v := nc.(type) { case *net.TCPConn: if cli.opts.TCPNoDelay == TCPNoDelay { @@ -165,7 +166,7 @@ func (cli *Client) EnrollContext(nc net.Conn, ctx interface{}) (gc Conn, err err c := newTCPConn(nc, cli.el) c.SetContext(ctx) - cli.el.ch <- c + cli.el.ch <- &openConn{c: c, cb: func() { close(connOpened) }} go func(c *conn, tc net.Conn, el *eventloop) { var buffer [0x10000]byte for { @@ -181,7 +182,7 @@ func (cli *Client) EnrollContext(nc net.Conn, ctx interface{}) (gc Conn, err err case *net.UnixConn: c := newTCPConn(nc, cli.el) c.SetContext(ctx) - cli.el.ch <- c + cli.el.ch <- &openConn{c: c, cb: func() { close(connOpened) }} go func(c *conn, uc net.Conn, el *eventloop) { var buffer [0x10000]byte for { @@ -204,6 +205,7 @@ func (cli *Client) EnrollContext(nc net.Conn, ctx interface{}) (gc Conn, err err c := newUDPConn(cli.el, nc.LocalAddr(), nc.RemoteAddr()) c.SetContext(ctx) c.rawConn = nc + cli.el.ch <- &openConn{c: c, isDatagram: true, cb: func() { close(connOpened) }} go func(uc net.Conn, el *eventloop) { var buffer [0x10000]byte for { @@ -222,5 +224,6 @@ func (cli *Client) EnrollContext(nc net.Conn, ctx interface{}) (gc Conn, err err return nil, errorx.ErrUnsupportedProtocol } + <-connOpened return } diff --git a/connection_unix.go b/connection_unix.go index 44cd351d4..ba7cffa77 100644 --- a/connection_unix.go +++ b/connection_unix.go @@ -84,6 +84,7 @@ func newUDPConn(fd int, el *eventloop, localAddr net.Addr, sa unix.Sockaddr, con } func (c *conn) release() { + c.opened = false c.ctx = nil c.buffer = nil if addr, ok := c.localAddr.(*net.TCPAddr); ok && c.localAddr != c.loop.ln.addr && len(addr.Zone) > 0 { @@ -102,7 +103,6 @@ func (c *conn) release() { c.remoteAddr = nil c.pollAttachment.FD, c.pollAttachment.Callback = 0, nil if !c.isDatagram { - c.opened = false c.peer = nil c.inboundBuffer.Done() c.outboundBuffer.Release() @@ -110,6 +110,10 @@ func (c *conn) release() { } func (c *conn) open(buf []byte) error { + if c.isDatagram && c.peer == nil { + return unix.Send(c.fd, buf, 0) + } + n, err := unix.Write(c.fd, buf) if err != nil && err == unix.EAGAIN { _, _ = c.outboundBuffer.Write(buf) diff --git a/connection_windows.go b/connection_windows.go index c41a3f94f..61619b5bb 100644 --- a/connection_windows.go +++ b/connection_windows.go @@ -42,6 +42,12 @@ type udpConn struct { c *conn } +type openConn struct { + c *conn + cb func() + isDatagram bool +} + type conn struct { ctx interface{} // user-defined context loop *eventloop // owner event-loop diff --git a/eventloop_unix.go b/eventloop_unix.go index d434267cd..6cb1cd73c 100644 --- a/eventloop_unix.go +++ b/eventloop_unix.go @@ -61,8 +61,19 @@ func (el *eventloop) closeConns() { }) } +type connWithCallback struct { + c Conn + cb func() +} + func (el *eventloop) register(itf interface{}) error { - c := itf.(*conn) + c, ok := itf.(*conn) + if !ok { + ccb := itf.(*connWithCallback) + c = ccb.c.(*conn) + defer ccb.cb() + } + if err := el.poller.AddRead(&c.pollAttachment); err != nil { _ = unix.Close(c.fd) c.release() @@ -71,7 +82,7 @@ func (el *eventloop) register(itf interface{}) error { el.connections.addConn(c, el.idx) - if c.isDatagram { + if c.isDatagram && c.peer != nil { return nil } return el.open(c) diff --git a/eventloop_windows.go b/eventloop_windows.go index 5ae95ca3d..074318c58 100644 --- a/eventloop_windows.go +++ b/eventloop_windows.go @@ -67,7 +67,7 @@ func (el *eventloop) run() (err error) { err = v case *netErr: err = el.close(v.c, v.err) - case *conn: + case *openConn: err = el.open(v) case *tcpConn: unpackTCPConn(v) @@ -90,9 +90,16 @@ func (el *eventloop) run() (err error) { return nil } -func (el *eventloop) open(c *conn) error { - el.connections[c] = struct{}{} - el.incConn(1) +func (el *eventloop) open(oc *openConn) error { + if oc.cb != nil { + defer oc.cb() + } + + c := oc.c + if !oc.isDatagram { + el.connections[c] = struct{}{} + el.incConn(1) + } out, action := el.eventHandler.OnOpen(c) if out != nil { diff --git a/gnet_test.go b/gnet_test.go index 36bba8c73..e6150023d 100644 --- a/gnet_test.go +++ b/gnet_test.go @@ -25,7 +25,10 @@ import ( goPool "github.com/panjf2000/gnet/v2/pkg/pool/goroutine" ) -var streamLen = 1024 * 1024 +var ( + datagramLen = 1024 + streamLen = 1024 * 1024 +) func TestServe(t *testing.T) { // start an engine @@ -415,7 +418,7 @@ func startClient(t *testing.T, network, addr string, multicore, async bool) { for time.Since(start) < duration { reqData := make([]byte, streamLen) if network == "udp" { - reqData = reqData[:1024] + reqData = reqData[:datagramLen] } _, err = rand.Read(reqData) require.NoError(t, err)