Skip to content

Commit

Permalink
opt: optimize Conn.Next and Conn.Peek (#654)
Browse files Browse the repository at this point in the history
Besides, add tests of partial read.
  • Loading branch information
panjf2000 authored Nov 12, 2024
1 parent bdd3fb6 commit a0c9787
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 118 deletions.
4 changes: 2 additions & 2 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -591,13 +591,13 @@ func (ev *clientEventsForWake) OnTraffic(c Conn) (action Action) {
assert.Nilf(ev.tester, buf, "expected: %v, but got: %v", nil, buf)
assert.ErrorIsf(ev.tester, err, io.ErrShortBuffer, "expected error: %v, but got: %v", io.ErrShortBuffer, err)
buf, err = c.Next(-1)
assert.Nilf(ev.tester, buf, "expected: %v, but got: %v", nil, buf)
assert.Emptyf(ev.tester, buf, "expected an empty slice, but got: %v", buf)
assert.NoErrorf(ev.tester, err, "expected: %v, but got: %v", nil, err)
buf, err = c.Peek(10)
assert.Nilf(ev.tester, buf, "expected: %v, but got: %v", nil, buf)
assert.ErrorIsf(ev.tester, err, io.ErrShortBuffer, "expected error: %v, but got: %v", io.ErrShortBuffer, err)
buf, err = c.Peek(-1)
assert.Nilf(ev.tester, buf, "expected: %v, but got: %v", nil, buf)
assert.Emptyf(ev.tester, buf, "expected an empty slice, but got: %v", buf)
assert.NoErrorf(ev.tester, err, "expected: %v, but got: %v", nil, err)
n, err = c.Discard(10)
assert.Zerof(ev.tester, n, "expected: %v, but got: %v", 0, n)
Expand Down
10 changes: 7 additions & 3 deletions connection_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,16 +318,18 @@ func (c *conn) Next(n int) (buf []byte, err error) {
} else if n <= 0 {
n = totalLen
}

if c.inboundBuffer.IsEmpty() {
buf = c.buffer[:n]
c.buffer = c.buffer[n:]
return
}

head, tail := c.inboundBuffer.Peek(n)
defer c.inboundBuffer.Discard(n) //nolint:errcheck
c.loop.cache.Reset()
c.loop.cache.Write(head)
if len(head) >= n {
if len(head) == n {
return c.loop.cache.Bytes(), err
}
c.loop.cache.Write(tail)
Expand All @@ -348,12 +350,14 @@ func (c *conn) Peek(n int) (buf []byte, err error) {
} else if n <= 0 {
n = totalLen
}

if c.inboundBuffer.IsEmpty() {
return c.buffer[:n], err
}

head, tail := c.inboundBuffer.Peek(n)
if len(head) >= n {
return head[:n], err
if len(head) == n {
return head, err
}
c.loop.cache.Reset()
c.loop.cache.Write(head)
Expand Down
164 changes: 82 additions & 82 deletions connection_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ type netErr struct {
}

type tcpConn struct {
c *conn
buf *bbPool.ByteBuffer
c *conn
b *bbPool.ByteBuffer
}

type udpConn struct {
Expand All @@ -59,35 +59,34 @@ type conn struct {
}

func packTCPConn(c *conn, buf []byte) *tcpConn {
tc := &tcpConn{c: c, buf: bbPool.Get()}
_, _ = tc.buf.Write(buf)
return tc
}

func unpackTCPConn(tc *tcpConn) {
tc.c.buffer = tc.buf
tc.buf = nil
b := bbPool.Get()
_, _ = b.Write(buf)
return &tcpConn{c: c, b: b}
}

func resetTCPConn(tc *tcpConn) {
bbPool.Put(tc.c.buffer)
tc.c.buffer = nil
func unpackTCPConn(tc *tcpConn) *conn {
if tc.c.buffer == nil { // the connection has been closed
return nil
}
_, _ = tc.c.buffer.Write(tc.b.B)
bbPool.Put(tc.b)
tc.b = nil
return tc.c
}

func packUDPConn(c *conn, buf []byte) *udpConn {
uc := &udpConn{c}
_, _ = uc.c.buffer.Write(buf)
return uc
_, _ = c.buffer.Write(buf)
return &udpConn{c}
}

func newTCPConn(nc net.Conn, el *eventloop) (c *conn) {
c = &conn{
loop: el,
rawConn: nc,
return &conn{
loop: el,
buffer: bbPool.Get(),
rawConn: nc,
localAddr: nc.LocalAddr(),
remoteAddr: nc.RemoteAddr(),
}
c.localAddr = c.rawConn.LocalAddr()
c.remoteAddr = c.rawConn.RemoteAddr()
return
}

func (c *conn) release() {
Expand Down Expand Up @@ -118,18 +117,11 @@ func (c *conn) resetBuffer() {
}

func (c *conn) Read(p []byte) (n int, err error) {
if c.buffer == nil {
if len(p) == 0 {
return 0, nil
}
return 0, io.ErrShortBuffer
}

if c.inboundBuffer.IsEmpty() {
n = copy(p, c.buffer.B)
c.buffer.B = c.buffer.B[n:]
if n == 0 && len(p) > 0 {
err = io.EOF
err = io.ErrShortBuffer
}
return
}
Expand All @@ -144,13 +136,6 @@ func (c *conn) Read(p []byte) (n int, err error) {
}

func (c *conn) Next(n int) (buf []byte, err error) {
if c.buffer == nil {
if n <= 0 {
return nil, nil
}
return nil, io.ErrShortBuffer
}

inBufferLen := c.inboundBuffer.Buffered()
if totalLen := inBufferLen + c.buffer.Len(); n > totalLen {
return nil, io.ErrShortBuffer
Expand All @@ -166,7 +151,7 @@ func (c *conn) Next(n int) (buf []byte, err error) {
defer c.inboundBuffer.Discard(n) //nolint:errcheck
c.loop.cache.Reset()
c.loop.cache.Write(head)
if len(head) >= n {
if len(head) == n {
return c.loop.cache.Bytes(), err
}
c.loop.cache.Write(tail)
Expand All @@ -181,13 +166,6 @@ func (c *conn) Next(n int) (buf []byte, err error) {
}

func (c *conn) Peek(n int) (buf []byte, err error) {
if c.buffer == nil {
if n <= 0 {
return nil, nil
}
return nil, io.ErrShortBuffer
}

inBufferLen := c.inboundBuffer.Buffered()
if totalLen := inBufferLen + c.buffer.Len(); n > totalLen {
return nil, io.ErrShortBuffer
Expand All @@ -198,8 +176,8 @@ func (c *conn) Peek(n int) (buf []byte, err error) {
return c.buffer.B[:n], err
}
head, tail := c.inboundBuffer.Peek(n)
if len(head) >= n {
return head[:n], err
if len(head) == n {
return head, err
}
c.loop.cache.Reset()
c.loop.cache.Write(head)
Expand All @@ -214,10 +192,6 @@ func (c *conn) Peek(n int) (buf []byte, err error) {
}

func (c *conn) Discard(n int) (int, error) {
if c.buffer == nil {
return 0, nil
}

inBufferLen := c.inboundBuffer.Buffered()
tempBufferLen := c.buffer.Len()
if inBufferLen+tempBufferLen < n || n <= 0 {
Expand Down Expand Up @@ -435,13 +409,24 @@ func (c *conn) SetKeepAlivePeriod(d time.Duration) error {
// func (c *conn) Gfd() gfd.GFD { return gfd.GFD{} }

func (c *conn) AsyncWrite(buf []byte, cb AsyncCallback) error {
if cb == nil {
cb = func(c Conn, err error) error { return nil }
}
_, err := c.Write(buf)
c.loop.ch <- func() error {
return cb(c, err)

callback := func() error {
if cb != nil {
_ = cb(c, err)
}
return err
}

select {
case c.loop.ch <- callback:
default:
// If the event-loop channel is full, asynchronize this operation to avoid blocking the eventloop.
go func() {
c.loop.ch <- callback
}()
}

return nil
}

Expand All @@ -460,46 +445,61 @@ func (c *conn) AsyncWritev(bs [][]byte, cb AsyncCallback) error {
}

func (c *conn) Wake(cb AsyncCallback) error {
if cb == nil {
cb = func(c Conn, err error) error { return nil }
}
c.loop.ch <- func() (err error) {
defer func() {
defer func() {
if err == nil {
err = cb(c, nil)
return
}
_ = cb(c, err)
}()
wakeFn := func() (err error) {
err = c.loop.wake(c)
if cb != nil {
_ = cb(c, err)
}
return
}

select {
case c.loop.ch <- wakeFn:
default:
// If the event-loop channel is full, asynchronize this operation to avoid blocking the eventloop.
go func() {
c.loop.ch <- wakeFn
}()
return c.loop.wake(c)
}

return nil
}

func (c *conn) Close() error {
c.loop.ch <- func() error {
err := c.loop.close(c, nil)
return err
closeFn := func() error {
return c.loop.close(c, nil)
}

select {
case c.loop.ch <- closeFn:
default:
// If the event-loop channel is full, asynchronize this operation to avoid blocking the eventloop.
go func() {
c.loop.ch <- closeFn
}()
}

return nil
}

func (c *conn) CloseWithCallback(cb AsyncCallback) error {
if cb == nil {
cb = func(c Conn, err error) error { return nil }
}
c.loop.ch <- func() (err error) {
defer func() {
if err == nil {
err = cb(c, nil)
return
}
closeFn := func() (err error) {
err = c.loop.close(c, nil)
if cb != nil {
_ = cb(c, err)
}
return
}

select {
case c.loop.ch <- closeFn:
default:
// If the event-loop channel is full, asynchronize this operation to avoid blocking the eventloop.
go func() {
c.loop.ch <- closeFn
}()
return c.loop.close(c, nil)
}

return nil
}

Expand Down
4 changes: 1 addition & 3 deletions eventloop_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,7 @@ func (el *eventloop) run() (err error) {
case *openConn:
err = el.open(v)
case *tcpConn:
unpackTCPConn(v)
err = el.read(v.c)
resetTCPConn(v)
err = el.read(unpackTCPConn(v))
case *udpConn:
err = el.readUDP(v.c)
case func() error:
Expand Down
Loading

0 comments on commit a0c9787

Please sign in to comment.