From af0fd9d45e6e56b045f8e8556aa8fe917cbc6259 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Thu, 19 Oct 2023 15:19:48 -0700 Subject: [PATCH] examples/chat: Fix race condition Tricky tricky. --- internal/examples/chat/chat.go | 39 +++++++++++++++++++++++----------- 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/internal/examples/chat/chat.go b/internal/examples/chat/chat.go index 78a5696a..8b1e30c1 100644 --- a/internal/examples/chat/chat.go +++ b/internal/examples/chat/chat.go @@ -5,6 +5,7 @@ import ( "errors" "io" "log" + "net" "net/http" "sync" "time" @@ -69,14 +70,7 @@ func (cs *chatServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { // subscribeHandler accepts the WebSocket connection and then subscribes // it to all future messages. func (cs *chatServer) subscribeHandler(w http.ResponseWriter, r *http.Request) { - c, err := websocket.Accept(w, r, nil) - if err != nil { - cs.logf("%v", err) - return - } - defer c.CloseNow() - - err = cs.subscribe(r.Context(), c) + err := cs.subscribe(r.Context(), w, r) if errors.Is(err, context.Canceled) { return } @@ -117,18 +111,39 @@ func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) { // // It uses CloseRead to keep reading from the connection to process control // messages and cancel the context if the connection drops. -func (cs *chatServer) subscribe(ctx context.Context, c *websocket.Conn) error { - ctx = c.CloseRead(ctx) - +func (cs *chatServer) subscribe(ctx context.Context, w http.ResponseWriter, r *http.Request) error { + var mu sync.Mutex + var c *websocket.Conn + var closed bool s := &subscriber{ msgs: make(chan []byte, cs.subscriberMessageBuffer), closeSlow: func() { - c.Close(websocket.StatusPolicyViolation, "connection too slow to keep up with messages") + mu.Lock() + defer mu.Unlock() + closed = true + if c != nil { + c.Close(websocket.StatusPolicyViolation, "connection too slow to keep up with messages") + } }, } cs.addSubscriber(s) defer cs.deleteSubscriber(s) + c2, err := websocket.Accept(w, r, nil) + if err != nil { + return err + } + mu.Lock() + if closed { + mu.Unlock() + return net.ErrClosed + } + c = c2 + mu.Unlock() + defer c.CloseNow() + + ctx = c.CloseRead(ctx) + for { select { case msg := <-s.msgs: