Skip to content

Commit

Permalink
Fix NetConn read bug
Browse files Browse the repository at this point in the history
  • Loading branch information
nhooyr committed Jul 3, 2019
1 parent 97f63d0 commit 5024792
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 10 deletions.
13 changes: 13 additions & 0 deletions netconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import (
// reading/writing goroutines are interrupted but the connection is kept alive.
//
// The Addr methods will return a mock net.Addr.
//
// A received StatusNormalClosure close frame will be translated to EOF when reading.
func NetConn(c *Conn) net.Conn {
nc := &netConn{
c: c,
Expand All @@ -47,6 +49,7 @@ type netConn struct {

readTimer *time.Timer
readContext context.Context
eofed bool

reader io.Reader
}
Expand All @@ -66,9 +69,18 @@ func (c *netConn) Write(p []byte) (int, error) {
}

func (c *netConn) Read(p []byte) (int, error) {
if c.eofed {
return 0, io.EOF
}

if c.reader == nil {
typ, r, err := c.c.Reader(c.readContext)
if err != nil {
var ce CloseError
if xerrors.As(err, &ce) && (ce.Code == StatusNormalClosure) {
c.eofed = true
return 0, io.EOF
}
return 0, err
}
if typ != MessageBinary {
Expand All @@ -81,6 +93,7 @@ func (c *netConn) Read(p []byte) (int, error) {
n, err := c.reader.Read(p)
if err == io.EOF {
c.reader = nil
err = nil
}
return n, err
}
Expand Down
45 changes: 35 additions & 10 deletions websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,13 @@ func TestHandshake(t *testing.T) {
nc := websocket.NetConn(c)
defer nc.Close()

nc.SetWriteDeadline(time.Now().Add(time.Second * 10))
nc.SetWriteDeadline(time.Now().Add(time.Second * 15))

_, err = nc.Write([]byte("hello"))
if err != nil {
return err
for i := 0; i < 3; i++ {
_, err = nc.Write([]byte("hello"))
if err != nil {
return err
}
}

return nil
Expand All @@ -151,16 +153,39 @@ func TestHandshake(t *testing.T) {
nc := websocket.NetConn(c)
defer nc.Close()

nc.SetReadDeadline(time.Now().Add(time.Second * 10))
nc.SetReadDeadline(time.Now().Add(time.Second * 15))

p := make([]byte, len("hello"))
_, err = io.ReadFull(nc, p)
if err != nil {
read := func() error {
p := make([]byte, len("hello"))
// We do not use io.ReadFull here as it masks EOFs.
// See https://github.com/nhooyr/websocket/issues/100#issuecomment-508148024
_, err = nc.Read(p)
if err != nil {
return err
}

if string(p) != "hello" {
return xerrors.Errorf("unexpected payload %q received", string(p))
}
return nil
}

for i := 0; i < 3; i++ {
err = read()
if err != nil {
return err
}
}

// Ensure the close frame is converted to an EOF and multiple read's after all return EOF.
err = read()
if err != io.EOF {
return err
}

if string(p) != "hello" {
return xerrors.Errorf("unexpected payload %q received", string(p))
err = read()
if err != io.EOF {
return err
}

return nil
Expand Down

0 comments on commit 5024792

Please sign in to comment.