From a0c9787b61f2da9abf5dd7c131b416d2d593b85b Mon Sep 17 00:00:00 2001 From: Andy Pan Date: Tue, 12 Nov 2024 23:35:30 +0800 Subject: [PATCH] opt: optimize Conn.Next and Conn.Peek (#654) Besides, add tests of partial read. --- client_test.go | 4 +- connection_unix.go | 10 ++- connection_windows.go | 164 +++++++++++++++++++++--------------------- eventloop_windows.go | 4 +- gnet_test.go | 78 ++++++++++++-------- 5 files changed, 142 insertions(+), 118 deletions(-) diff --git a/client_test.go b/client_test.go index 47a454a99..e51b4fcd0 100644 --- a/client_test.go +++ b/client_test.go @@ -591,13 +591,13 @@ func (ev *clientEventsForWake) OnTraffic(c Conn) (action Action) { assert.Nilf(ev.tester, buf, "expected: %v, but got: %v", nil, buf) assert.ErrorIsf(ev.tester, err, io.ErrShortBuffer, "expected error: %v, but got: %v", io.ErrShortBuffer, err) buf, err = c.Next(-1) - assert.Nilf(ev.tester, buf, "expected: %v, but got: %v", nil, buf) + assert.Emptyf(ev.tester, buf, "expected an empty slice, but got: %v", buf) assert.NoErrorf(ev.tester, err, "expected: %v, but got: %v", nil, err) buf, err = c.Peek(10) assert.Nilf(ev.tester, buf, "expected: %v, but got: %v", nil, buf) assert.ErrorIsf(ev.tester, err, io.ErrShortBuffer, "expected error: %v, but got: %v", io.ErrShortBuffer, err) buf, err = c.Peek(-1) - assert.Nilf(ev.tester, buf, "expected: %v, but got: %v", nil, buf) + assert.Emptyf(ev.tester, buf, "expected an empty slice, but got: %v", buf) assert.NoErrorf(ev.tester, err, "expected: %v, but got: %v", nil, err) n, err = c.Discard(10) assert.Zerof(ev.tester, n, "expected: %v, but got: %v", 0, n) diff --git a/connection_unix.go b/connection_unix.go index 3ccac4ab1..d399c9361 100644 --- a/connection_unix.go +++ b/connection_unix.go @@ -318,16 +318,18 @@ func (c *conn) Next(n int) (buf []byte, err error) { } else if n <= 0 { n = totalLen } + if c.inboundBuffer.IsEmpty() { buf = c.buffer[:n] c.buffer = c.buffer[n:] return } + head, tail := c.inboundBuffer.Peek(n) defer c.inboundBuffer.Discard(n) //nolint:errcheck c.loop.cache.Reset() c.loop.cache.Write(head) - if len(head) >= n { + if len(head) == n { return c.loop.cache.Bytes(), err } c.loop.cache.Write(tail) @@ -348,12 +350,14 @@ func (c *conn) Peek(n int) (buf []byte, err error) { } else if n <= 0 { n = totalLen } + if c.inboundBuffer.IsEmpty() { return c.buffer[:n], err } + head, tail := c.inboundBuffer.Peek(n) - if len(head) >= n { - return head[:n], err + if len(head) == n { + return head, err } c.loop.cache.Reset() c.loop.cache.Write(head) diff --git a/connection_windows.go b/connection_windows.go index 745028434..de8e2b939 100644 --- a/connection_windows.go +++ b/connection_windows.go @@ -34,8 +34,8 @@ type netErr struct { } type tcpConn struct { - c *conn - buf *bbPool.ByteBuffer + c *conn + b *bbPool.ByteBuffer } type udpConn struct { @@ -59,35 +59,34 @@ type conn struct { } func packTCPConn(c *conn, buf []byte) *tcpConn { - tc := &tcpConn{c: c, buf: bbPool.Get()} - _, _ = tc.buf.Write(buf) - return tc -} - -func unpackTCPConn(tc *tcpConn) { - tc.c.buffer = tc.buf - tc.buf = nil + b := bbPool.Get() + _, _ = b.Write(buf) + return &tcpConn{c: c, b: b} } -func resetTCPConn(tc *tcpConn) { - bbPool.Put(tc.c.buffer) - tc.c.buffer = nil +func unpackTCPConn(tc *tcpConn) *conn { + if tc.c.buffer == nil { // the connection has been closed + return nil + } + _, _ = tc.c.buffer.Write(tc.b.B) + bbPool.Put(tc.b) + tc.b = nil + return tc.c } func packUDPConn(c *conn, buf []byte) *udpConn { - uc := &udpConn{c} - _, _ = uc.c.buffer.Write(buf) - return uc + _, _ = c.buffer.Write(buf) + return &udpConn{c} } func newTCPConn(nc net.Conn, el *eventloop) (c *conn) { - c = &conn{ - loop: el, - rawConn: nc, + return &conn{ + loop: el, + buffer: bbPool.Get(), + rawConn: nc, + localAddr: nc.LocalAddr(), + remoteAddr: nc.RemoteAddr(), } - c.localAddr = c.rawConn.LocalAddr() - c.remoteAddr = c.rawConn.RemoteAddr() - return } func (c *conn) release() { @@ -118,18 +117,11 @@ func (c *conn) resetBuffer() { } func (c *conn) Read(p []byte) (n int, err error) { - if c.buffer == nil { - if len(p) == 0 { - return 0, nil - } - return 0, io.ErrShortBuffer - } - if c.inboundBuffer.IsEmpty() { n = copy(p, c.buffer.B) c.buffer.B = c.buffer.B[n:] if n == 0 && len(p) > 0 { - err = io.EOF + err = io.ErrShortBuffer } return } @@ -144,13 +136,6 @@ func (c *conn) Read(p []byte) (n int, err error) { } func (c *conn) Next(n int) (buf []byte, err error) { - if c.buffer == nil { - if n <= 0 { - return nil, nil - } - return nil, io.ErrShortBuffer - } - inBufferLen := c.inboundBuffer.Buffered() if totalLen := inBufferLen + c.buffer.Len(); n > totalLen { return nil, io.ErrShortBuffer @@ -166,7 +151,7 @@ func (c *conn) Next(n int) (buf []byte, err error) { defer c.inboundBuffer.Discard(n) //nolint:errcheck c.loop.cache.Reset() c.loop.cache.Write(head) - if len(head) >= n { + if len(head) == n { return c.loop.cache.Bytes(), err } c.loop.cache.Write(tail) @@ -181,13 +166,6 @@ func (c *conn) Next(n int) (buf []byte, err error) { } func (c *conn) Peek(n int) (buf []byte, err error) { - if c.buffer == nil { - if n <= 0 { - return nil, nil - } - return nil, io.ErrShortBuffer - } - inBufferLen := c.inboundBuffer.Buffered() if totalLen := inBufferLen + c.buffer.Len(); n > totalLen { return nil, io.ErrShortBuffer @@ -198,8 +176,8 @@ func (c *conn) Peek(n int) (buf []byte, err error) { return c.buffer.B[:n], err } head, tail := c.inboundBuffer.Peek(n) - if len(head) >= n { - return head[:n], err + if len(head) == n { + return head, err } c.loop.cache.Reset() c.loop.cache.Write(head) @@ -214,10 +192,6 @@ func (c *conn) Peek(n int) (buf []byte, err error) { } func (c *conn) Discard(n int) (int, error) { - if c.buffer == nil { - return 0, nil - } - inBufferLen := c.inboundBuffer.Buffered() tempBufferLen := c.buffer.Len() if inBufferLen+tempBufferLen < n || n <= 0 { @@ -435,13 +409,24 @@ func (c *conn) SetKeepAlivePeriod(d time.Duration) error { // func (c *conn) Gfd() gfd.GFD { return gfd.GFD{} } func (c *conn) AsyncWrite(buf []byte, cb AsyncCallback) error { - if cb == nil { - cb = func(c Conn, err error) error { return nil } - } _, err := c.Write(buf) - c.loop.ch <- func() error { - return cb(c, err) + + callback := func() error { + if cb != nil { + _ = cb(c, err) + } + return err + } + + select { + case c.loop.ch <- callback: + default: + // If the event-loop channel is full, asynchronize this operation to avoid blocking the eventloop. + go func() { + c.loop.ch <- callback + }() } + return nil } @@ -460,46 +445,61 @@ func (c *conn) AsyncWritev(bs [][]byte, cb AsyncCallback) error { } func (c *conn) Wake(cb AsyncCallback) error { - if cb == nil { - cb = func(c Conn, err error) error { return nil } - } - c.loop.ch <- func() (err error) { - defer func() { - defer func() { - if err == nil { - err = cb(c, nil) - return - } - _ = cb(c, err) - }() + wakeFn := func() (err error) { + err = c.loop.wake(c) + if cb != nil { + _ = cb(c, err) + } + return + } + + select { + case c.loop.ch <- wakeFn: + default: + // If the event-loop channel is full, asynchronize this operation to avoid blocking the eventloop. + go func() { + c.loop.ch <- wakeFn }() - return c.loop.wake(c) } + return nil } func (c *conn) Close() error { - c.loop.ch <- func() error { - err := c.loop.close(c, nil) - return err + closeFn := func() error { + return c.loop.close(c, nil) } + + select { + case c.loop.ch <- closeFn: + default: + // If the event-loop channel is full, asynchronize this operation to avoid blocking the eventloop. + go func() { + c.loop.ch <- closeFn + }() + } + return nil } func (c *conn) CloseWithCallback(cb AsyncCallback) error { - if cb == nil { - cb = func(c Conn, err error) error { return nil } - } - c.loop.ch <- func() (err error) { - defer func() { - if err == nil { - err = cb(c, nil) - return - } + closeFn := func() (err error) { + err = c.loop.close(c, nil) + if cb != nil { _ = cb(c, err) + } + return + } + + select { + case c.loop.ch <- closeFn: + default: + // If the event-loop channel is full, asynchronize this operation to avoid blocking the eventloop. + go func() { + c.loop.ch <- closeFn }() - return c.loop.close(c, nil) } + return nil } diff --git a/eventloop_windows.go b/eventloop_windows.go index 906d8924f..6f6f72254 100644 --- a/eventloop_windows.go +++ b/eventloop_windows.go @@ -71,9 +71,7 @@ func (el *eventloop) run() (err error) { case *openConn: err = el.open(v) case *tcpConn: - unpackTCPConn(v) - err = el.read(v.c) - resetTCPConn(v) + err = el.read(unpackTCPConn(v)) case *udpConn: err = el.readUDP(v.c) case func() error: diff --git a/gnet_test.go b/gnet_test.go index 201c6b9de..b6755e038 100644 --- a/gnet_test.go +++ b/gnet_test.go @@ -8,6 +8,7 @@ import ( "encoding/binary" "errors" "io" + "math" "math/rand" "net" "path/filepath" @@ -1542,7 +1543,8 @@ type simServer struct { multicore bool nclients int packetSize int - packetBatch int + batchWrite int + batchRead int started int32 connected int32 disconnected int32 @@ -1579,7 +1581,7 @@ func (s *simServer) OnClose(_ Conn, err error) (action Action) { func (s *simServer) OnTraffic(c Conn) (action Action) { codec := c.Context().(*testCodec) var packets [][]byte - for { + for i := 0; i < s.batchRead; i++ { data, err := codec.Decode(c) if errors.Is(err, errIncompletePacket) { break @@ -1596,6 +1598,10 @@ func (s *simServer) OnTraffic(c Conn) (action Action) { } else if n == 1 { _, _ = c.Write(packets[0]) } + if len(packets) == s.batchRead && c.InboundBuffered() > 0 { + err := c.Wake(nil) // wake up the connection manually to avoid missing the leftover data + assert.NoError(s.tester, err) + } return } @@ -1603,7 +1609,7 @@ func (s *simServer) OnTick() (delay time.Duration, action Action) { if atomic.CompareAndSwapInt32(&s.started, 0, 1) { for i := 0; i < s.nclients; i++ { go func() { - runSimClient(s.tester, s.network, s.addr, s.packetSize, s.packetBatch) + runSimClient(s.tester, s.network, s.addr, s.packetSize, s.batchWrite) }() } } @@ -1651,11 +1657,14 @@ func (codec testCodec) Encode(buf []byte) ([]byte, error) { return data, nil } -func (codec *testCodec) Decode(c Conn) ([]byte, error) { +func (codec testCodec) Decode(c Conn) ([]byte, error) { bodyOffset := magicNumberSize + bodySize - buf, _ := c.Peek(bodyOffset) - if len(buf) < bodyOffset { - return nil, errIncompletePacket + buf, err := c.Peek(bodyOffset) + if err != nil { + if errors.Is(err, io.ErrShortBuffer) { + err = errIncompletePacket + } + return nil, err } if !bytes.Equal(magicNumberBytes, buf[:magicNumberSize]) { @@ -1664,13 +1673,18 @@ func (codec *testCodec) Decode(c Conn) ([]byte, error) { bodyLen := binary.BigEndian.Uint32(buf[magicNumberSize:bodyOffset]) msgLen := bodyOffset + int(bodyLen) - if c.InboundBuffered() < msgLen { - return nil, errIncompletePacket + buf, err = c.Peek(msgLen) + if err != nil { + if errors.Is(err, io.ErrShortBuffer) { + err = errIncompletePacket + } + return nil, err } - buf, _ = c.Peek(msgLen) + body := make([]byte, bodyLen) + copy(body, buf[bodyOffset:msgLen]) _, _ = c.Discard(msgLen) - return buf[bodyOffset:msgLen], nil + return body, nil } func (codec testCodec) Unpack(buf []byte) ([]byte, error) { @@ -1693,41 +1707,48 @@ func (codec testCodec) Unpack(buf []byte) ([]byte, error) { } func TestSimServer(t *testing.T) { + t.Run("packet-size=64,batch=200", func(t *testing.T) { + runSimServer(t, ":7200", true, 10, 64, 200, -1) + }) t.Run("packet-size=128,batch=100", func(t *testing.T) { - runSimServer(t, ":7200", false, 10, 128, 100) + runSimServer(t, ":7201", false, 10, 128, 100, 10) }) t.Run("packet-size=256,batch=50", func(t *testing.T) { - runSimServer(t, ":7201", true, 10, 256, 50) + runSimServer(t, ":7202", true, 10, 256, 50, -1) }) t.Run("packet-size=512,batch=30", func(t *testing.T) { - runSimServer(t, ":7202", false, 10, 512, 30) + runSimServer(t, ":7203", false, 10, 512, 30, 3) }) t.Run("packet-size=1024,batch=20", func(t *testing.T) { - runSimServer(t, ":7203", true, 10, 1024, 20) + runSimServer(t, ":7204", true, 10, 1024, 20, -1) }) t.Run("packet-size=64*1024,batch=10", func(t *testing.T) { - runSimServer(t, ":7204", false, 10, 64*1024, 10) + runSimServer(t, ":7205", false, 10, 64*1024, 10, 1) }) t.Run("packet-size=128*1024,batch=5", func(t *testing.T) { - runSimServer(t, ":7205", true, 10, 128*1024, 5) + runSimServer(t, ":7206", true, 10, 128*1024, 5, -1) }) t.Run("packet-size=512*1024,batch=3", func(t *testing.T) { - runSimServer(t, ":7206", false, 10, 512*1024, 3) + runSimServer(t, ":7207", false, 10, 512*1024, 3, 1) }) t.Run("packet-size=1024*1024,batch=2", func(t *testing.T) { - runSimServer(t, ":7207", true, 10, 1024*1024, 2) + runSimServer(t, ":7208", true, 10, 1024*1024, 2, -1) }) } -func runSimServer(t *testing.T, addr string, et bool, nclients, packetSize, packetBatch int) { +func runSimServer(t *testing.T, addr string, et bool, nclients, packetSize, batchWrite, batchRead int) { ts := &simServer{ - tester: t, - network: "tcp", - addr: addr, - multicore: true, - nclients: nclients, - packetSize: packetSize, - packetBatch: packetBatch, + tester: t, + network: "tcp", + addr: addr, + multicore: true, + nclients: nclients, + packetSize: packetSize, + batchWrite: batchWrite, + batchRead: batchRead, + } + if batchRead < 0 { + ts.batchRead = math.MaxInt32 // unlimited read batch } err := Run(ts, ts.network+"://"+ts.addr, @@ -1789,6 +1810,7 @@ func batchSendAndRecv(t *testing.T, c net.Conn, rd *bufio.Reader, packetSize, ba for i, req := range requests { rsp, err := codec.Unpack(respPacket[i*packetLen:]) require.NoError(t, err) - require.Equalf(t, req, rsp, "request and response mismatch, packet size: %d, batch: %d", packetSize, batch) + require.Equalf(t, req, rsp, "request and response mismatch, packet size: %d, batch: %d, round: %d", + packetSize, batch, i) } }