diff --git a/netconn.go b/netconn.go index e397d7a2..c43d3a29 100644 --- a/netconn.go +++ b/netconn.go @@ -22,6 +22,8 @@ import ( // reading/writing goroutines are interrupted but the connection is kept alive. // // The Addr methods will return a mock net.Addr. +// +// A received StatusNormalClosure close frame will be translated to EOF when reading. func NetConn(c *Conn) net.Conn { nc := &netConn{ c: c, @@ -47,6 +49,7 @@ type netConn struct { readTimer *time.Timer readContext context.Context + eofed bool reader io.Reader } @@ -66,9 +69,18 @@ func (c *netConn) Write(p []byte) (int, error) { } func (c *netConn) Read(p []byte) (int, error) { + if c.eofed { + return 0, io.EOF + } + if c.reader == nil { typ, r, err := c.c.Reader(c.readContext) if err != nil { + var ce CloseError + if xerrors.As(err, &ce) && (ce.Code == StatusNormalClosure) { + c.eofed = true + return 0, io.EOF + } return 0, err } if typ != MessageBinary { @@ -81,6 +93,7 @@ func (c *netConn) Read(p []byte) (int, error) { n, err := c.reader.Read(p) if err == io.EOF { c.reader = nil + err = nil } return n, err } diff --git a/websocket_test.go b/websocket_test.go index 2112ff7e..1dc5283b 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -130,11 +130,13 @@ func TestHandshake(t *testing.T) { nc := websocket.NetConn(c) defer nc.Close() - nc.SetWriteDeadline(time.Now().Add(time.Second * 10)) + nc.SetWriteDeadline(time.Now().Add(time.Second * 15)) - _, err = nc.Write([]byte("hello")) - if err != nil { - return err + for i := 0; i < 3; i++ { + _, err = nc.Write([]byte("hello")) + if err != nil { + return err + } } return nil @@ -151,16 +153,39 @@ func TestHandshake(t *testing.T) { nc := websocket.NetConn(c) defer nc.Close() - nc.SetReadDeadline(time.Now().Add(time.Second * 10)) + nc.SetReadDeadline(time.Now().Add(time.Second * 15)) - p := make([]byte, len("hello")) - _, err = io.ReadFull(nc, p) - if err != nil { + read := func() error { + p := make([]byte, len("hello")) + // We do not use io.ReadFull here as it masks EOFs. + // See https://github.com/nhooyr/websocket/issues/100#issuecomment-508148024 + _, err = nc.Read(p) + if err != nil { + return err + } + + if string(p) != "hello" { + return xerrors.Errorf("unexpected payload %q received", string(p)) + } + return nil + } + + for i := 0; i < 3; i++ { + err = read() + if err != nil { + return err + } + } + + // Ensure the close frame is converted to an EOF and multiple read's after all return EOF. + err = read() + if err != io.EOF { return err } - if string(p) != "hello" { - return xerrors.Errorf("unexpected payload %q received", string(p)) + err = read() + if err != io.EOF { + return err } return nil