Skip to content

Commit

Permalink
examples/chat: Fix race condition
Browse files Browse the repository at this point in the history
Tricky tricky.
  • Loading branch information
nhooyr committed Oct 19, 2023
1 parent ff3ea39 commit af0fd9d
Showing 1 changed file with 27 additions and 12 deletions.
39 changes: 27 additions & 12 deletions internal/examples/chat/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"io"
"log"
"net"
"net/http"
"sync"
"time"
Expand Down Expand Up @@ -69,14 +70,7 @@ func (cs *chatServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// subscribeHandler accepts the WebSocket connection and then subscribes
// it to all future messages.
func (cs *chatServer) subscribeHandler(w http.ResponseWriter, r *http.Request) {
c, err := websocket.Accept(w, r, nil)
if err != nil {
cs.logf("%v", err)
return
}
defer c.CloseNow()

err = cs.subscribe(r.Context(), c)
err := cs.subscribe(r.Context(), w, r)
if errors.Is(err, context.Canceled) {
return
}
Expand Down Expand Up @@ -117,18 +111,39 @@ func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) {
//
// It uses CloseRead to keep reading from the connection to process control
// messages and cancel the context if the connection drops.
func (cs *chatServer) subscribe(ctx context.Context, c *websocket.Conn) error {
ctx = c.CloseRead(ctx)

func (cs *chatServer) subscribe(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
var mu sync.Mutex
var c *websocket.Conn
var closed bool
s := &subscriber{
msgs: make(chan []byte, cs.subscriberMessageBuffer),
closeSlow: func() {
c.Close(websocket.StatusPolicyViolation, "connection too slow to keep up with messages")
mu.Lock()
defer mu.Unlock()
closed = true
if c != nil {
c.Close(websocket.StatusPolicyViolation, "connection too slow to keep up with messages")
}
},
}
cs.addSubscriber(s)
defer cs.deleteSubscriber(s)

c2, err := websocket.Accept(w, r, nil)
if err != nil {
return err
}
mu.Lock()
if closed {
mu.Unlock()
return net.ErrClosed
}
c = c2
mu.Unlock()
defer c.CloseNow()

ctx = c.CloseRead(ctx)

for {
select {
case msg := <-s.msgs:
Expand Down

0 comments on commit af0fd9d

Please sign in to comment.