diff --git a/Makefile b/Makefile index 03c96756abc..bb35ec896dd 100644 --- a/Makefile +++ b/Makefile @@ -216,7 +216,6 @@ generate-mocks: install-mock-generators mockery --name 'Storage' --dir=module/executiondatasync/tracker --case=underscore --output="module/executiondatasync/tracker/mock" --outpkg="mocktracker" mockery --name 'ScriptExecutor' --dir=module/execution --case=underscore --output="module/execution/mock" --outpkg="mock" mockery --name 'StorageSnapshot' --dir=fvm/storage/snapshot --case=underscore --output="fvm/storage/snapshot/mock" --outpkg="mock" - mockery --name 'WebsocketConnection' --dir=engine/access/rest/websockets --case=underscore --output="engine/access/rest/websockets/mock" --outpkg="mock" #temporarily make insecure/ a non-module to allow mockery to create mocks mv insecure/go.mod insecure/go2.mod diff --git a/engine/access/rest/websockets/config.go b/engine/access/rest/websockets/config.go index 8236913dd5f..1cb1a74c183 100644 --- a/engine/access/rest/websockets/config.go +++ b/engine/access/rest/websockets/config.go @@ -34,15 +34,11 @@ const ( type Config struct { MaxSubscriptionsPerConnection uint64 MaxResponsesPerSecond uint64 - SendMessageTimeout time.Duration - MaxRequestSize int64 } func NewDefaultWebsocketConfig() Config { return Config{ MaxSubscriptionsPerConnection: 1000, MaxResponsesPerSecond: 1000, - SendMessageTimeout: 10 * time.Second, - MaxRequestSize: 1024, } } diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index b999641df08..bffa57350c0 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -1,9 +1,83 @@ +// Package websockets provides a number of abstractions for managing WebSocket connections. +// It supports handling client subscriptions, sending messages, and maintaining +// the lifecycle of WebSocket connections with robust keepalive mechanisms. +// +// Overview +// +// The architecture of this package consists of three main components: +// +// 1. **Connection**: Responsible for providing a channel that allows the client +// to communicate with the server. It encapsulates WebSocket-level operations +// such as sending and receiving messages. +// 2. **Data Providers**: Standalone units responsible for fetching data from +// the blockchain (protocol). These providers act as sources of data that are +// sent to clients based on their subscriptions. +// 3. **Controller**: Acts as a mediator between the connection and data providers. +// It governs client subscriptions, handles client requests and responses, +// validates messages, and manages error handling. The controller ensures smooth +// coordination between the client and the data-fetching units. +// +// Basically, it is an N:1:1 approach: N data providers, 1 controller, 1 websocket connection. +// This allows a client to receive messages from different subscriptions over a single connection. +// +// ### Controller Details +// +// The `Controller` is the core component that coordinates the interactions between +// the client and data providers. It achieves this through three routines that run +// in parallel (writer, reader, and keepalive routine). If any of the three routines +// fails with an error, the remaining routines will be canceled using the provided +// context to ensure proper cleanup and termination. +// +// 1. **Reader Routine**: +// - Reads messages from the client WebSocket connection. +// - Parses and validates the messages. +// - Handles the messages by triggering the appropriate actions, such as subscribing +// to a topic or unsubscribing from an existing subscription. +// - Ensures proper validation of message formats and data before passing them to +// the internal handlers. +// +// 2. **Writer Routine**: +// - Listens to the `multiplexedStream`, which is a channel filled by data providers +// with messages that clients have subscribed to. +// - Writes these messages to the client WebSocket connection. +// - Ensures the outgoing messages respect the required deadlines to maintain the +// stability of the connection. +// +// 3. **Keepalive Routine**: +// - Periodically sends a WebSocket ping control message to the client to indicate +// that the controller and all its subscriptions are working as expected. +// - Ensures the connection remains clean and avoids timeout scenarios due to +// inactivity. +// - Resets the connection's read deadline whenever a pong message is received. +// +// Example +// +// Usage typically involves creating a `Controller` instance and invoking its +// `HandleConnection` method to manage a single WebSocket connection: +// +// logger := zerolog.New(os.Stdout) +// config := websockets.Config{/* configuration options */} +// conn := /* a WebsocketConnection implementation */ +// factory := /* a DataProviderFactory implementation */ +// +// controller := websockets.NewWebSocketController(logger, config, conn, factory) +// ctx := context.Background() +// controller.HandleConnection(ctx) +// +// +// Package Constants +// +// This package expects constants like `PongWait` and `WriteWait` for controlling +// the read/write deadlines. They need to be defined in your application as appropriate. + package websockets import ( "context" "encoding/json" + "errors" "fmt" + "sync" "time" "github.com/google/uuid" @@ -21,10 +95,41 @@ type Controller struct { config Config conn WebsocketConnection - communicationChannel chan interface{} // Channel for sending messages to the client. + // The `multiplexedStream` is a core channel used for communication between the + // `Controller` and Data Providers. Its lifecycle is as follows: + // + // 1. **Data Providers**: + // - Data providers write their data into this channel, which is consumed by + // the writer routine to send messages to the client. + // 2. **Reader Routine**: + // - Writes OK/error responses to the channel as a result of processing client messages. + // 3. **Writer Routine**: + // - Reads messages from this channel and forwards them to the client WebSocket connection. + // + // 4. **Channel Closing**: + // The intention to close the channel comes from the reader-from-this-channel routines (controller's routines), + // not the writer-to-this-channel routines (data providers). + // Therefore, we have to signal the data providers to stop writing, wait for them to finish write operations, + // and only after that we can close the channel. + // + // - The `Controller` is responsible for starting and managing the lifecycle of the channel. + // - If an unrecoverable error occurs in any of the three routines (reader, writer, or keepalive), + // the parent context is canceled. This triggers data providers to stop their work. + // - The `multiplexedStream` will not be closed until all data providers signal that + // they have stopped writing to it via the `dataProvidersGroup` wait group. + // + // 5. **Edge Case - Writer Routine Finished Before Providers**: + // - If the writer routine finishes before all data providers, a separate draining routine + // ensures that the `multiplexedStream` is fully drained to prevent deadlocks. + // All remaining messages in this case will be discarded. + // + // This design ensures that the channel is only closed when it is safe to do so, avoiding + // issues such as sending on a closed channel while maintaining proper cleanup. + multiplexedStream chan interface{} dataProviders *concurrentmap.Map[uuid.UUID, dp.DataProvider] dataProviderFactory dp.DataProviderFactory + dataProvidersGroup *sync.WaitGroup } func NewWebSocketController( @@ -34,12 +139,13 @@ func NewWebSocketController( dataProviderFactory dp.DataProviderFactory, ) *Controller { return &Controller{ - logger: logger.With().Str("component", "websocket-controller").Logger(), - config: config, - conn: conn, - communicationChannel: make(chan interface{}), //TODO: should it be buffered chan? - dataProviders: concurrentmap.New[uuid.UUID, dp.DataProvider](), - dataProviderFactory: dataProviderFactory, + logger: logger.With().Str("component", "websocket-controller").Logger(), + config: config, + conn: conn, + multiplexedStream: make(chan interface{}), + dataProviders: concurrentmap.New[uuid.UUID, dp.DataProvider](), + dataProviderFactory: dataProviderFactory, + dataProvidersGroup: &sync.WaitGroup{}, } } @@ -51,34 +157,29 @@ func NewWebSocketController( func (c *Controller) HandleConnection(ctx context.Context) { defer c.shutdownConnection() - // configuring the connection with appropriate read/write deadlines and handlers. err := c.configureKeepalive() if err != nil { - // TODO: add error handling here - c.logger.Error().Err(err).Msg("error configuring keepalive connection") - + c.logger.Error().Err(err).Msg("error configuring connection") return } - //TODO: spin up a response limit tracker routine - - // for track all goroutines and error handling g, gCtx := errgroup.WithContext(ctx) - g.Go(func() error { - return c.readMessages(gCtx) - }) - g.Go(func() error { return c.keepalive(gCtx) }) - g.Go(func() error { return c.writeMessages(gCtx) }) + g.Go(func() error { + return c.readMessages(gCtx) + }) if err = g.Wait(); err != nil { - //TODO: add error handling here + if errors.Is(err, websocket.ErrCloseSent) { + return + } + c.logger.Error().Err(err).Msg("error detected in one of the goroutines") } } @@ -103,6 +204,7 @@ func (c *Controller) configureKeepalive() error { if err := c.conn.SetReadDeadline(time.Now().Add(PongWait)); err != nil { return fmt.Errorf("failed to set the initial read deadline: %w", err) } + // Establish a Pong handler which sets the handler for pong messages received from the peer. c.conn.SetPongHandler(func(string) error { return c.conn.SetReadDeadline(time.Now().Add(PongWait)) @@ -111,214 +213,267 @@ func (c *Controller) configureKeepalive() error { return nil } -// writeMessages reads a messages from communication channel and passes them on to a client WebSocket connection. -// The communication channel is filled by data providers. Besides, the response limit tracker is involved in -// write message regulation -// -// No errors are expected during normal operation. All errors are considered benign. -func (c *Controller) writeMessages(ctx context.Context) error { +// keepalive sends a ping message periodically to keep the WebSocket connection alive +// and avoid timeouts. +func (c *Controller) keepalive(ctx context.Context) error { + pingTicker := time.NewTicker(PingPeriod) + defer pingTicker.Stop() + for { select { case <-ctx.Done(): return nil - case msg, ok := <-c.communicationChannel: - if !ok { - return fmt.Errorf("communication channel closed, no error occurred") - } - // TODO: handle 'response per second' limits - - // Specifies a timeout for the write operation. If the write - // isn't completed within this duration, it fails with a timeout error. - // SetWriteDeadline ensures the write operation does not block indefinitely - // if the client is slow or unresponsive. This prevents resource exhaustion - // and allows the server to gracefully handle timeouts for delayed writes. - if err := c.conn.SetWriteDeadline(time.Now().Add(WriteWait)); err != nil { - return fmt.Errorf("failed to set the write deadline: %w", err) - } - err := c.conn.WriteJSON(msg) + case <-pingTicker.C: + err := c.conn.WriteControl(websocket.PingMessage, time.Now().Add(WriteWait)) if err != nil { - return fmt.Errorf("failed to write message to connection: %w", err) + if errors.Is(err, websocket.ErrCloseSent) { + return err + } + + return fmt.Errorf("error sending ping: %w", err) } } } } -// readMessages continuously reads messages from a client WebSocket connection, -// processes each message, and handles actions based on the message type. -// -// No errors are expected during normal operation. All errors are considered benign. -func (c *Controller) readMessages(ctx context.Context) error { +// writeMessages reads a messages from multiplexed stream and passes them on to a client WebSocket connection. +// The multiplexed stream channel is filled by data providers +func (c *Controller) writeMessages(ctx context.Context) error { + defer func() { + // drain the channel as some providers may still send data to it after this routine shutdowns + // so, in order to not run into deadlock there should be at least 1 reader on the channel + go func() { + for range c.multiplexedStream { + } + }() + }() + for { select { case <-ctx.Done(): return nil - default: - msg, err := c.readMessage() - if err != nil { - if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseAbnormalClosure) { - return nil - } - return fmt.Errorf("failed to read message from client: %w", err) + case message, ok := <-c.multiplexedStream: + if !ok { + return nil } - _, validatedMsg, err := c.parseAndValidateMessage(msg) - if err != nil { - //TODO: write error to error channel - return fmt.Errorf("failed to parse and validate client message: %w", err) + if err := c.conn.SetWriteDeadline(time.Now().Add(WriteWait)); err != nil { + return fmt.Errorf("failed to set the write deadline: %w", err) } - if err := c.handleAction(ctx, validatedMsg); err != nil { - //TODO: write error to error channel - return fmt.Errorf("failed to handle message action: %w", err) + if err := c.conn.WriteJSON(message); err != nil { + return err } } } } -func (c *Controller) readMessage() (json.RawMessage, error) { - var message json.RawMessage - if err := c.conn.ReadJSON(&message); err != nil { - return nil, fmt.Errorf("error reading JSON from client: %w", err) +// readMessages continuously reads messages from a client WebSocket connection, +// validates each message, and processes it based on the message type. +func (c *Controller) readMessages(ctx context.Context) error { + for { + var message json.RawMessage + if err := c.conn.ReadJSON(&message); err != nil { + if errors.Is(err, websocket.ErrCloseSent) { + return err + } + + c.writeErrorResponse( + ctx, + err, + wrapErrorMessage(InvalidMessage, "error reading message", "", "", "")) + continue + } + + err := c.handleMessage(ctx, message) + if err != nil { + c.writeErrorResponse( + ctx, + err, + wrapErrorMessage(InvalidMessage, "error parsing message", "", "", "")) + continue + } } - return message, nil } -func (c *Controller) parseAndValidateMessage(message json.RawMessage) (models.BaseMessageRequest, interface{}, error) { +func (c *Controller) handleMessage(ctx context.Context, message json.RawMessage) error { var baseMsg models.BaseMessageRequest if err := json.Unmarshal(message, &baseMsg); err != nil { - return models.BaseMessageRequest{}, nil, fmt.Errorf("error unmarshalling base message: %w", err) + return fmt.Errorf("error unmarshalling base message: %w", err) } - var validatedMsg interface{} switch baseMsg.Action { - case "subscribe": + case models.SubscribeAction: var subscribeMsg models.SubscribeMessageRequest if err := json.Unmarshal(message, &subscribeMsg); err != nil { - return baseMsg, nil, fmt.Errorf("error unmarshalling subscribe message: %w", err) + return fmt.Errorf("error unmarshalling subscribe message: %w", err) } - //TODO: add validation logic for `topic` field - validatedMsg = subscribeMsg + c.handleSubscribe(ctx, subscribeMsg) - case "unsubscribe": + case models.UnsubscribeAction: var unsubscribeMsg models.UnsubscribeMessageRequest if err := json.Unmarshal(message, &unsubscribeMsg); err != nil { - return baseMsg, nil, fmt.Errorf("error unmarshalling unsubscribe message: %w", err) + return fmt.Errorf("error unmarshalling unsubscribe message: %w", err) } - validatedMsg = unsubscribeMsg + c.handleUnsubscribe(ctx, unsubscribeMsg) - case "list_subscriptions": + case models.ListSubscriptionsAction: var listMsg models.ListSubscriptionsMessageRequest if err := json.Unmarshal(message, &listMsg); err != nil { - return baseMsg, nil, fmt.Errorf("error unmarshalling list subscriptions message: %w", err) + return fmt.Errorf("error unmarshalling list subscriptions message: %w", err) } - validatedMsg = listMsg + c.handleListSubscriptions(ctx, listMsg) default: - return baseMsg, nil, fmt.Errorf("unknown action type: %s", baseMsg.Action) + c.logger.Debug().Str("action", baseMsg.Action).Msg("unknown action type") + return fmt.Errorf("unknown action type: %s", baseMsg.Action) } - return baseMsg, validatedMsg, nil -} - -func (c *Controller) handleAction(ctx context.Context, message interface{}) error { - switch msg := message.(type) { - case models.SubscribeMessageRequest: - c.handleSubscribe(ctx, msg) - case models.UnsubscribeMessageRequest: - c.handleUnsubscribe(ctx, msg) - case models.ListSubscriptionsMessageRequest: - c.handleListSubscriptions(ctx, msg) - default: - return fmt.Errorf("unknown message type: %T", msg) - } return nil } func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMessageRequest) { - dp, err := c.dataProviderFactory.NewDataProvider(ctx, msg.Topic, msg.Arguments, c.communicationChannel) + // register new provider + provider, err := c.dataProviderFactory.NewDataProvider(ctx, msg.Topic, msg.Arguments, c.multiplexedStream) if err != nil { - // TODO: handle error here - c.logger.Error().Err(err).Msgf("error while creating data provider for topic: %s", msg.Topic) + c.writeErrorResponse( + ctx, + err, + wrapErrorMessage(InvalidArgument, "error creating data provider", msg.ClientMessageID, models.SubscribeAction, ""), + ) + return } + c.dataProviders.Add(provider.ID(), provider) - c.dataProviders.Add(dp.ID(), dp) - - //TODO: return correct OK response to client - response := models.SubscribeMessageResponse{ + // write OK response to client + responseOk := models.SubscribeMessageResponse{ BaseMessageResponse: models.BaseMessageResponse{ - Success: true, + ClientMessageID: msg.ClientMessageID, + Success: true, + SubscriptionID: provider.ID().String(), }, - Topic: dp.Topic(), - ID: dp.ID().String(), } + c.writeResponse(ctx, responseOk) - c.communicationChannel <- response - + // run provider + c.dataProvidersGroup.Add(1) go func() { - err := dp.Run() + err = provider.Run() if err != nil { - //TODO: Log or handle the error from Run - c.logger.Error().Err(err).Msgf("error while running data provider for topic: %s", msg.Topic) + c.writeErrorResponse( + ctx, + err, + wrapErrorMessage(SubscriptionError, "subscription finished with error", "", "", ""), + ) } + + c.dataProvidersGroup.Done() + c.dataProviders.Remove(provider.ID()) }() } -func (c *Controller) handleUnsubscribe(_ context.Context, msg models.UnsubscribeMessageRequest) { - id, err := uuid.Parse(msg.ID) +func (c *Controller) handleUnsubscribe(ctx context.Context, msg models.UnsubscribeMessageRequest) { + id, err := uuid.Parse(msg.SubscriptionID) if err != nil { - c.logger.Debug().Err(err).Msg("error parsing message ID") - //TODO: return an error response to client - c.communicationChannel <- err + c.writeErrorResponse( + ctx, + err, + wrapErrorMessage(InvalidArgument, "error parsing subscription ID", msg.ClientMessageID, models.UnsubscribeAction, msg.SubscriptionID), + ) return } - dp, ok := c.dataProviders.Get(id) - if ok { - dp.Close() - c.dataProviders.Remove(id) + provider, ok := c.dataProviders.Get(id) + if !ok { + c.writeErrorResponse( + ctx, + err, + wrapErrorMessage(NotFound, "subscription not found", msg.ClientMessageID, models.UnsubscribeAction, msg.SubscriptionID), + ) + return } + + provider.Close() + c.dataProviders.Remove(id) + + responseOk := models.UnsubscribeMessageResponse{ + BaseMessageResponse: models.BaseMessageResponse{ + ClientMessageID: msg.ClientMessageID, + Success: true, + SubscriptionID: msg.SubscriptionID, + }, + } + c.writeResponse(ctx, responseOk) } func (c *Controller) handleListSubscriptions(ctx context.Context, msg models.ListSubscriptionsMessageRequest) { - //TODO: return a response to client + var subs []*models.SubscriptionEntry + err := c.dataProviders.ForEach(func(id uuid.UUID, provider dp.DataProvider) error { + subs = append(subs, &models.SubscriptionEntry{ + ID: id.String(), + Topic: provider.Topic(), + }) + return nil + }) + + if err != nil { + c.writeErrorResponse( + ctx, + err, + wrapErrorMessage(NotFound, "error listing subscriptions", msg.ClientMessageID, models.ListSubscriptionsAction, ""), + ) + return + } + + responseOk := models.ListSubscriptionsMessageResponse{ + Success: true, + ClientMessageID: msg.ClientMessageID, + Subscriptions: subs, + } + c.writeResponse(ctx, responseOk) } func (c *Controller) shutdownConnection() { - defer func() { - if err := c.conn.Close(); err != nil { - c.logger.Error().Err(err).Msg("error closing connection") - } - // TODO: safe closing communicationChannel will be included as a part of PR #6642 - }() + err := c.conn.Close() + if err != nil { + c.logger.Debug().Err(err).Msg("error closing connection") + } - err := c.dataProviders.ForEach(func(_ uuid.UUID, dp dp.DataProvider) error { - dp.Close() + err = c.dataProviders.ForEach(func(_ uuid.UUID, provider dp.DataProvider) error { + provider.Close() return nil }) if err != nil { - c.logger.Error().Err(err).Msg("error closing data provider") + c.logger.Debug().Err(err).Msg("error closing data provider") } c.dataProviders.Clear() + c.dataProvidersGroup.Wait() + close(c.multiplexedStream) } -// keepalive sends a ping message periodically to keep the WebSocket connection alive -// and avoid timeouts. -// -// No errors are expected during normal operation. All errors are considered benign. -func (c *Controller) keepalive(ctx context.Context) error { - pingTicker := time.NewTicker(PingPeriod) - defer pingTicker.Stop() +func (c *Controller) writeErrorResponse(ctx context.Context, err error, msg models.BaseMessageResponse) { + c.logger.Debug().Err(err).Msg(msg.Error.Message) + c.writeResponse(ctx, msg) +} - for { - select { - case <-ctx.Done(): - return nil - case <-pingTicker.C: - err := c.conn.WriteControl(websocket.PingMessage, time.Now().Add(WriteWait)) - if err != nil { - return fmt.Errorf("failed to write ping message: %w", err) - } - } +func (c *Controller) writeResponse(ctx context.Context, response interface{}) { + select { + case <-ctx.Done(): + return + case c.multiplexedStream <- response: + } +} + +func wrapErrorMessage(code Code, message string, msgId string, action string, subscriptionID string) models.BaseMessageResponse { + return models.BaseMessageResponse{ + ClientMessageID: msgId, + Success: false, + SubscriptionID: subscriptionID, + Error: models.ErrorMessage{ + Code: int(code), + Message: message, + Action: action, + }, } } diff --git a/engine/access/rest/websockets/controller_test.go b/engine/access/rest/websockets/controller_test.go index 58228a46b46..9707dbb8205 100644 --- a/engine/access/rest/websockets/controller_test.go +++ b/engine/access/rest/websockets/controller_test.go @@ -9,8 +9,6 @@ import ( "github.com/stretchr/testify/require" - streammock "github.com/onflow/flow-go/engine/access/state_stream/mock" - "github.com/google/uuid" "github.com/gorilla/websocket" "github.com/rs/zerolog" @@ -20,123 +18,565 @@ import ( dp "github.com/onflow/flow-go/engine/access/rest/websockets/data_providers" dpmock "github.com/onflow/flow-go/engine/access/rest/websockets/data_providers/mock" - connectionmock "github.com/onflow/flow-go/engine/access/rest/websockets/mock" + connmock "github.com/onflow/flow-go/engine/access/rest/websockets/mock" "github.com/onflow/flow-go/engine/access/rest/websockets/models" - "github.com/onflow/flow-go/engine/access/state_stream/backend" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/utils/unittest" ) -// ControllerSuite is a test suite for the WebSocket Controller. -type ControllerSuite struct { +// WsControllerSuite is a test suite for the WebSocket Controller. +type WsControllerSuite struct { suite.Suite - logger zerolog.Logger - config Config - - connection *connectionmock.WebsocketConnection - dataProviderFactory *dpmock.DataProviderFactory - - streamApi *streammock.API - streamConfig backend.Config + logger zerolog.Logger + wsConfig Config } func TestControllerSuite(t *testing.T) { - suite.Run(t, new(ControllerSuite)) + suite.Run(t, new(WsControllerSuite)) } // SetupTest initializes the test suite with required dependencies. -func (s *ControllerSuite) SetupTest() { +func (s *WsControllerSuite) SetupTest() { s.logger = unittest.Logger() - s.config = NewDefaultWebsocketConfig() - - s.connection = connectionmock.NewWebsocketConnection(s.T()) - s.dataProviderFactory = dpmock.NewDataProviderFactory(s.T()) - - s.streamApi = streammock.NewAPI(s.T()) - s.streamConfig = backend.Config{} + s.wsConfig = NewDefaultWebsocketConfig() } // TestSubscribeRequest tests the subscribe to topic flow. // We emulate a request message from a client, and a response message from a controller. -func (s *ControllerSuite) TestSubscribeRequest() { +func (s *WsControllerSuite) TestSubscribeRequest() { s.T().Run("Happy path", func(t *testing.T) { + t.Parallel() + conn, dataProviderFactory, dataProvider := newControllerMocks(t) - controller := NewWebSocketController(s.logger, s.config, conn, dataProviderFactory) + controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) + + dataProviderFactory. + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(dataProvider, nil). + Once() + + id := uuid.New() + done := make(chan struct{}) + dataProvider.On("ID").Return(id) + // data provider might finish on its own or controller will close it via Close() + dataProvider.On("Close").Return(nil).Maybe() dataProvider. On("Run"). - Run(func(args mock.Arguments) {}). + Run(func(args mock.Arguments) { + <-done + }). Return(nil). Once() - subscribeRequest := models.SubscribeMessageRequest{ - BaseMessageRequest: models.BaseMessageRequest{Action: "subscribe"}, - Topic: dp.BlocksTopic, - Arguments: nil, + request := models.SubscribeMessageRequest{ + BaseMessageRequest: models.BaseMessageRequest{ + ClientMessageID: uuid.New().String(), + Action: models.SubscribeAction, + }, + Topic: dp.BlocksTopic, + Arguments: nil, } + requestJson, err := json.Marshal(request) + require.NoError(t, err) // Simulate receiving the subscription request from the client conn. On("ReadJSON", mock.Anything). Run(func(args mock.Arguments) { - requestMsg, ok := args.Get(0).(*json.RawMessage) + msg, ok := args.Get(0).(*json.RawMessage) require.True(t, ok) - subscribeRequestMessage, err := json.Marshal(subscribeRequest) - require.NoError(t, err) - *requestMsg = subscribeRequestMessage + *msg = requestJson }). Return(nil). Once() - // Channel to signal the test flow completion - done := make(chan struct{}, 1) - - // Simulate writing a successful subscription response back to the client conn. On("WriteJSON", mock.Anything). Return(func(msg interface{}) error { + defer close(done) + response, ok := msg.(models.SubscribeMessageResponse) require.True(t, ok) require.True(t, response.Success) - close(done) // Signal that response has been sent + require.Equal(t, request.ClientMessageID, response.ClientMessageID) + require.Equal(t, id.String(), response.SubscriptionID) + return websocket.ErrCloseSent - }).Once() + }) + + s.expectCloseConnection(conn, done) + controller.HandleConnection(context.Background()) + + conn.AssertExpectations(t) + dataProviderFactory.AssertExpectations(t) + dataProvider.AssertExpectations(t) + }) + + s.T().Run("Parse and validate error", func(t *testing.T) { + t.Parallel() - // Simulate client closing connection after receiving the response + conn, dataProviderFactory, _ := newControllerMocks(t) + controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) + + type Request struct { + Action string `json:"action"` + } + + subscribeRequest := Request{ + Action: "SubscribeBlocks", + } + subscribeRequestJson, err := json.Marshal(subscribeRequest) + require.NoError(t, err) + + // Simulate receiving the subscription request from the client conn. On("ReadJSON", mock.Anything). - Return(func(interface{}) error { + Run(func(args mock.Arguments) { + msg, ok := args.Get(0).(*json.RawMessage) + require.True(t, ok) + *msg = subscribeRequestJson + }). + Return(nil). + Once() + + done := make(chan struct{}) + conn. + On("WriteJSON", mock.Anything). + Return(func(msg interface{}) error { + defer close(done) + + response, ok := msg.(models.BaseMessageResponse) + require.True(t, ok) + require.False(t, response.Success) + require.NotEmpty(t, response.Error) + require.Equal(t, int(InvalidMessage), response.Error.Code) + return websocket.ErrCloseSent + }) + + s.expectCloseConnection(conn, done) + + controller.HandleConnection(context.Background()) + + conn.AssertExpectations(t) + dataProviderFactory.AssertExpectations(t) + }) + + s.T().Run("Error creating data provider", func(t *testing.T) { + t.Parallel() + + conn, dataProviderFactory, _ := newControllerMocks(t) + controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) + + dataProviderFactory. + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(nil, fmt.Errorf("error creating data provider")). + Once() + + done := make(chan struct{}) + s.expectSubscribeRequest(t, conn) + + conn. + On("WriteJSON", mock.Anything). + Return(func(msg interface{}) error { + defer close(done) + + response, ok := msg.(models.BaseMessageResponse) + require.True(t, ok) + require.False(t, response.Success) + require.NotEmpty(t, response.Error) + require.Equal(t, int(InvalidArgument), response.Error.Code) + + return websocket.ErrCloseSent + }) + + s.expectCloseConnection(conn, done) + + controller.HandleConnection(context.Background()) + + conn.AssertExpectations(t) + dataProviderFactory.AssertExpectations(t) + }) + + s.T().Run("Provider execution error", func(t *testing.T) { + t.Parallel() + + conn, dataProviderFactory, dataProvider := newControllerMocks(t) + controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) + + dataProvider.On("ID").Return(uuid.New()) + // data provider might finish on its own or controller will close it via Close() + dataProvider.On("Close").Return(nil).Maybe() + dataProvider. + On("Run"). + Run(func(args mock.Arguments) {}). + Return(fmt.Errorf("error running data provider")). + Once() + + dataProviderFactory. + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(dataProvider, nil). + Once() + + done := make(chan struct{}) + msgID := s.expectSubscribeRequest(t, conn) + s.expectSubscribeResponse(t, conn, msgID) + + conn. + On("WriteJSON", mock.Anything). + Return(func(msg interface{}) error { + defer close(done) + + response, ok := msg.(models.BaseMessageResponse) + require.True(t, ok) + require.False(t, response.Success) + require.NotEmpty(t, response.Error) + require.Equal(t, int(SubscriptionError), response.Error.Code) + + return websocket.ErrCloseSent + }) + + s.expectCloseConnection(conn, done) + + controller.HandleConnection(context.Background()) + + conn.AssertExpectations(t) + dataProviderFactory.AssertExpectations(t) + dataProvider.AssertExpectations(t) + }) +} + +func (s *WsControllerSuite) TestUnsubscribeRequest() { + s.T().Run("Happy path", func(t *testing.T) { + t.Parallel() + + conn, dataProviderFactory, dataProvider := newControllerMocks(t) + controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) + + dataProviderFactory. + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(dataProvider, nil). + Once() + + id := uuid.New() + done := make(chan struct{}) + + dataProvider.On("ID").Return(id) + // data provider might finish on its own or controller will close it via Close() + dataProvider.On("Close").Return(nil).Maybe() + dataProvider. + On("Run"). + Run(func(args mock.Arguments) { <-done + }). + Return(nil). + Once() + + msgID := s.expectSubscribeRequest(t, conn) + s.expectSubscribeResponse(t, conn, msgID) + + request := models.UnsubscribeMessageRequest{ + BaseMessageRequest: models.BaseMessageRequest{ + ClientMessageID: uuid.New().String(), + Action: models.UnsubscribeAction, + }, + SubscriptionID: id.String(), + } + requestJson, err := json.Marshal(request) + require.NoError(t, err) + + conn. + On("ReadJSON", mock.Anything). + Run(func(args mock.Arguments) { + msg, ok := args.Get(0).(*json.RawMessage) + require.True(t, ok) + *msg = requestJson + }). + Return(nil). + Once() + + conn. + On("WriteJSON", mock.Anything). + Return(func(msg interface{}) error { + defer close(done) + + response, ok := msg.(models.UnsubscribeMessageResponse) + require.True(t, ok) + require.True(t, response.Success) + require.Empty(t, response.Error) + require.Equal(t, request.ClientMessageID, response.ClientMessageID) + require.Equal(t, request.SubscriptionID, response.SubscriptionID) + return websocket.ErrCloseSent - }).Once() + }). + Once() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + s.expectCloseConnection(conn, done) - controller.HandleConnection(ctx) + controller.HandleConnection(context.Background()) + + conn.AssertExpectations(t) + dataProviderFactory.AssertExpectations(t) + dataProvider.AssertExpectations(t) + }) + + s.T().Run("Invalid subscription uuid", func(t *testing.T) { + t.Parallel() + + conn, dataProviderFactory, dataProvider := newControllerMocks(t) + controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) + + dataProviderFactory. + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(dataProvider, nil). + Once() + + id := uuid.New() + done := make(chan struct{}) + + dataProvider.On("ID").Return(id) + // data provider might finish on its own or controller will close it via Close() + dataProvider.On("Close").Return(nil).Maybe() + dataProvider. + On("Run"). + Run(func(args mock.Arguments) { + <-done + }). + Return(nil). + Once() + + msgID := s.expectSubscribeRequest(t, conn) + s.expectSubscribeResponse(t, conn, msgID) + + request := models.UnsubscribeMessageRequest{ + BaseMessageRequest: models.BaseMessageRequest{ + ClientMessageID: uuid.New().String(), + Action: models.UnsubscribeAction, + }, + SubscriptionID: "invalid-uuid", + } + requestJson, err := json.Marshal(request) + require.NoError(t, err) + + conn. + On("ReadJSON", mock.Anything). + Run(func(args mock.Arguments) { + msg, ok := args.Get(0).(*json.RawMessage) + require.True(t, ok) + *msg = requestJson + }). + Return(nil). + Once() + + conn. + On("WriteJSON", mock.Anything). + Return(func(msg interface{}) error { + defer close(done) + + response, ok := msg.(models.BaseMessageResponse) + require.True(t, ok) + require.False(t, response.Success) + require.NotEmpty(t, response.Error) + require.Equal(t, request.ClientMessageID, response.ClientMessageID) + require.Equal(t, int(InvalidArgument), response.Error.Code) + + return websocket.ErrCloseSent + }). + Once() + + s.expectCloseConnection(conn, done) + + controller.HandleConnection(context.Background()) + + conn.AssertExpectations(t) + dataProviderFactory.AssertExpectations(t) + dataProvider.AssertExpectations(t) + }) + + s.T().Run("Unsubscribe from unknown subscription", func(t *testing.T) { + t.Parallel() + + conn, dataProviderFactory, dataProvider := newControllerMocks(t) + controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) + + dataProviderFactory. + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(dataProvider, nil). + Once() + + id := uuid.New() + done := make(chan struct{}) + + dataProvider.On("ID").Return(id) + // data provider might finish on its own or controller will close it via Close() + dataProvider.On("Close").Return(nil).Maybe() + dataProvider. + On("Run"). + Run(func(args mock.Arguments) { + <-done + }). + Return(nil). + Once() + + msgID := s.expectSubscribeRequest(t, conn) + s.expectSubscribeResponse(t, conn, msgID) + + request := models.UnsubscribeMessageRequest{ + BaseMessageRequest: models.BaseMessageRequest{ + ClientMessageID: uuid.New().String(), + Action: models.UnsubscribeAction, + }, + SubscriptionID: uuid.New().String(), + } + requestJson, err := json.Marshal(request) + require.NoError(t, err) + + conn. + On("ReadJSON", mock.Anything). + Run(func(args mock.Arguments) { + msg, ok := args.Get(0).(*json.RawMessage) + require.True(t, ok) + *msg = requestJson + }). + Return(nil). + Once() + + conn. + On("WriteJSON", mock.Anything). + Return(func(msg interface{}) error { + defer close(done) + + response, ok := msg.(models.BaseMessageResponse) + require.True(t, ok) + require.False(t, response.Success) + require.NotEmpty(t, response.Error) + + require.Equal(t, request.ClientMessageID, response.ClientMessageID) + require.Equal(t, int(NotFound), response.Error.Code) + + return websocket.ErrCloseSent + }). + Once() + + s.expectCloseConnection(conn, done) + + controller.HandleConnection(context.Background()) + + conn.AssertExpectations(t) + dataProviderFactory.AssertExpectations(t) + dataProvider.AssertExpectations(t) + }) +} + +func (s *WsControllerSuite) TestListSubscriptions() { + s.T().Run("Happy path", func(t *testing.T) { + + conn, dataProviderFactory, dataProvider := newControllerMocks(t) + controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) + + dataProviderFactory. + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(dataProvider, nil). + Once() + + done := make(chan struct{}) + + id := uuid.New() + topic := dp.BlocksTopic + dataProvider.On("ID").Return(id) + dataProvider.On("Topic").Return(topic) + // data provider might finish on its own or controller will close it via Close() + dataProvider.On("Close").Return(nil).Maybe() + dataProvider. + On("Run"). + Run(func(args mock.Arguments) { + <-done + }). + Return(nil). + Once() + + msgID := s.expectSubscribeRequest(t, conn) + s.expectSubscribeResponse(t, conn, msgID) + + request := models.ListSubscriptionsMessageRequest{ + BaseMessageRequest: models.BaseMessageRequest{ + ClientMessageID: uuid.New().String(), + Action: models.ListSubscriptionsAction, + }, + } + requestJson, err := json.Marshal(request) + require.NoError(t, err) + + conn. + On("ReadJSON", mock.Anything). + Run(func(args mock.Arguments) { + msg, ok := args.Get(0).(*json.RawMessage) + require.True(t, ok) + *msg = requestJson + }). + Return(nil). + Once() + + conn. + On("WriteJSON", mock.Anything). + Return(func(msg interface{}) error { + defer close(done) + + response, ok := msg.(models.ListSubscriptionsMessageResponse) + require.True(t, ok) + require.True(t, response.Success) + require.Empty(t, response.Error) + require.Equal(t, request.ClientMessageID, response.ClientMessageID) + require.Equal(t, 1, len(response.Subscriptions)) + require.Equal(t, id.String(), response.Subscriptions[0].ID) + require.Equal(t, topic, response.Subscriptions[0].Topic) + + return websocket.ErrCloseSent + }). + Once() + + s.expectCloseConnection(conn, done) + + controller.HandleConnection(context.Background()) + + conn.AssertExpectations(t) + dataProviderFactory.AssertExpectations(t) + dataProvider.AssertExpectations(t) }) } // TestSubscribeBlocks tests the functionality for streaming blocks to a subscriber. -func (s *ControllerSuite) TestSubscribeBlocks() { +func (s *WsControllerSuite) TestSubscribeBlocks() { s.T().Run("Stream one block", func(t *testing.T) { + t.Parallel() + conn, dataProviderFactory, dataProvider := newControllerMocks(t) - controller := NewWebSocketController(s.logger, s.config, conn, dataProviderFactory) + controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) + + dataProviderFactory. + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(dataProvider, nil). + Once() + + id := uuid.New() + dataProvider.On("ID").Return(id) + // data provider might finish on its own or controller will close it via Close() + dataProvider.On("Close").Return(nil).Maybe() // Simulate data provider write a block to the controller expectedBlock := unittest.BlockFixture() dataProvider. On("Run", mock.Anything). Run(func(args mock.Arguments) { - controller.communicationChannel <- expectedBlock + controller.multiplexedStream <- expectedBlock }). Return(nil). Once() - done := make(chan struct{}, 1) - s.expectSubscriptionRequest(conn, done) - s.expectSubscriptionResponse(conn, true) + done := make(chan struct{}) + msgID := s.expectSubscribeRequest(t, conn) + s.expectSubscribeResponse(t, conn, msgID) // Expect a valid block to be passed to WriteJSON. // If we got to this point, the controller executed all its logic properly @@ -144,24 +584,40 @@ func (s *ControllerSuite) TestSubscribeBlocks() { conn. On("WriteJSON", mock.Anything). Return(func(msg interface{}) error { + defer close(done) + block, ok := msg.(flow.Block) require.True(t, ok) actualBlock = block + require.Equal(t, expectedBlock, actualBlock) - close(done) return websocket.ErrCloseSent - }).Once() + }) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + s.expectCloseConnection(conn, done) - controller.HandleConnection(ctx) - require.Equal(t, expectedBlock, actualBlock) + controller.HandleConnection(context.Background()) + + conn.AssertExpectations(t) + dataProviderFactory.AssertExpectations(t) + dataProvider.AssertExpectations(t) }) s.T().Run("Stream many blocks", func(t *testing.T) { + t.Parallel() + conn, dataProviderFactory, dataProvider := newControllerMocks(t) - controller := NewWebSocketController(s.logger, s.config, conn, dataProviderFactory) + controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) + + dataProviderFactory. + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(dataProvider, nil). + Once() + + id := uuid.New() + dataProvider.On("ID").Return(id) + // data provider might finish on its own or controller will close it via Close() + dataProvider.On("Close").Return(nil).Maybe() // Simulate data provider writes some blocks to the controller expectedBlocks := unittest.BlockFixtures(100) @@ -169,15 +625,15 @@ func (s *ControllerSuite) TestSubscribeBlocks() { On("Run", mock.Anything). Run(func(args mock.Arguments) { for _, block := range expectedBlocks { - controller.communicationChannel <- *block + controller.multiplexedStream <- *block } }). Return(nil). Once() - done := make(chan struct{}, 1) - s.expectSubscriptionRequest(conn, done) - s.expectSubscriptionResponse(conn, true) + done := make(chan struct{}) + msgID := s.expectSubscribeRequest(t, conn) + s.expectSubscribeResponse(t, conn, msgID) i := 0 actualBlocks := make([]*flow.Block, len(expectedBlocks)) @@ -194,6 +650,7 @@ func (s *ControllerSuite) TestSubscribeBlocks() { i += 1 if i == len(expectedBlocks) { + require.Equal(t, expectedBlocks, actualBlocks) close(done) return websocket.ErrCloseSent } @@ -202,111 +659,56 @@ func (s *ControllerSuite) TestSubscribeBlocks() { }). Times(len(expectedBlocks)) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - controller.HandleConnection(ctx) - require.Equal(t, expectedBlocks, actualBlocks) - }) -} - -// newControllerMocks initializes mock WebSocket connection, data provider, and data provider factory. -// The mocked functions are expected to be called in a case when a test is expected to reach WriteJSON function. -func newControllerMocks(t *testing.T) (*connectionmock.WebsocketConnection, *dpmock.DataProviderFactory, *dpmock.DataProvider) { - conn := connectionmock.NewWebsocketConnection(t) - conn.On("Close").Return(nil).Once() - conn.On("SetReadDeadline", mock.Anything).Return(nil).Once() - conn.On("SetWriteDeadline", mock.Anything).Return(nil) - conn.On("SetPongHandler", mock.AnythingOfType("func(string) error")).Return(nil).Once() - - id := uuid.New() - topic := dp.BlocksTopic - dataProvider := dpmock.NewDataProvider(t) - dataProvider.On("ID").Return(id) - dataProvider.On("Close").Return(nil) - dataProvider.On("Topic").Return(topic) - - factory := dpmock.NewDataProviderFactory(t) - factory. - On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). - Return(dataProvider, nil). - Once() - - return conn, factory, dataProvider -} - -// expectSubscriptionRequest mocks the client's subscription request. -func (s *ControllerSuite) expectSubscriptionRequest(conn *connectionmock.WebsocketConnection, done <-chan struct{}) { - requestMessage := models.SubscribeMessageRequest{ - BaseMessageRequest: models.BaseMessageRequest{Action: "subscribe"}, - Topic: dp.BlocksTopic, - } - - // The very first message from a client is a request to subscribe to some topic - conn.On("ReadJSON", mock.Anything). - Run(func(args mock.Arguments) { - reqMsg, ok := args.Get(0).(*json.RawMessage) - require.True(s.T(), ok) - msg, err := json.Marshal(requestMessage) - require.NoError(s.T(), err) - *reqMsg = msg - }). - Return(nil). - Once() + s.expectCloseConnection(conn, done) - // In the default case, no further communication is expected from the client. - // We wait for the writer routine to signal completion, allowing us to close the connection gracefully - conn. - On("ReadJSON", mock.Anything). - Return(func(msg interface{}) error { - <-done - return websocket.ErrCloseSent - }) -} + controller.HandleConnection(context.Background()) -// expectSubscriptionResponse mocks the subscription response sent to the client. -func (s *ControllerSuite) expectSubscriptionResponse(conn *connectionmock.WebsocketConnection, success bool) { - conn.On("WriteJSON", mock.Anything). - Run(func(args mock.Arguments) { - response, ok := args.Get(0).(models.SubscribeMessageResponse) - require.True(s.T(), ok) - require.Equal(s.T(), success, response.Success) - }). - Return(nil). - Once() + conn.AssertExpectations(t) + dataProviderFactory.AssertExpectations(t) + dataProvider.AssertExpectations(t) + }) } // TestConfigureKeepaliveConnection ensures that the WebSocket connection is configured correctly. -func (s *ControllerSuite) TestConfigureKeepaliveConnection() { - controller := s.initializeController() +func (s *WsControllerSuite) TestConfigureKeepaliveConnection() { + s.T().Run("Happy path", func(t *testing.T) { + conn := connmock.NewWebsocketConnection(t) + conn.On("SetPongHandler", mock.AnythingOfType("func(string) error")).Return(nil).Once() + conn.On("SetReadDeadline", mock.Anything).Return(nil) - // Mock configureConnection to succeed - s.mockConnectionSetup() + factory := dpmock.NewDataProviderFactory(t) + controller := NewWebSocketController(s.logger, s.wsConfig, conn, factory) - // Call configureKeepalive and check for errors - err := controller.configureKeepalive() - s.Require().NoError(err, "configureKeepalive should not return an error") + err := controller.configureKeepalive() + s.Require().NoError(err, "configureKeepalive should not return an error") - // Assert expectations - s.connection.AssertExpectations(s.T()) + conn.AssertExpectations(t) + }) } -// TestControllerShutdown ensures that HandleConnection shuts down gracefully when an error occurs. -func (s *ControllerSuite) TestControllerShutdown() { - s.T().Run("keepalive routine failed", func(*testing.T) { - controller := s.initializeController() +func (s *WsControllerSuite) TestControllerShutdown() { + s.T().Run("Keepalive routine failed", func(t *testing.T) { + t.Parallel() + + conn := connmock.NewWebsocketConnection(t) + conn.On("Close").Return(nil).Once() + conn.On("SetReadDeadline", mock.Anything).Return(nil).Once() + conn.On("SetPongHandler", mock.AnythingOfType("func(string) error")).Return(nil).Once() - // Mock configureConnection to succeed - s.mockConnectionSetup() + factory := dpmock.NewDataProviderFactory(t) + controller := NewWebSocketController(s.logger, s.wsConfig, conn, factory) // Mock keepalive to return an error - done := make(chan struct{}, 1) - s.connection.On("WriteControl", websocket.PingMessage, mock.Anything).Return(func(int, time.Time) error { - close(done) - return websocket.ErrCloseSent - }).Once() + done := make(chan struct{}) + conn. + On("WriteControl", websocket.PingMessage, mock.Anything). + Return(func(int, time.Time) error { + close(done) + return websocket.ErrCloseSent + }). + Once() - s.connection. + conn. On("ReadJSON", mock.Anything). Return(func(interface{}) error { <-done @@ -314,194 +716,251 @@ func (s *ControllerSuite) TestControllerShutdown() { }). Once() - s.connection.On("Close").Return(nil).Once() + controller.HandleConnection(context.Background()) + conn.AssertExpectations(t) + }) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - controller.HandleConnection(ctx) + s.T().Run("Read routine failed", func(t *testing.T) { + t.Parallel() - // Ensure all expectations are met - s.connection.AssertExpectations(s.T()) - }) + conn := connmock.NewWebsocketConnection(t) + conn.On("Close").Return(nil).Once() + conn.On("SetReadDeadline", mock.Anything).Return(nil).Once() + conn.On("SetPongHandler", mock.AnythingOfType("func(string) error")).Return(nil).Once() - s.T().Run("read routine failed", func(*testing.T) { - controller := s.initializeController() - // Mock configureConnection to succeed - s.mockConnectionSetup() + factory := dpmock.NewDataProviderFactory(t) + controller := NewWebSocketController(s.logger, s.wsConfig, conn, factory) - s.connection. + conn. On("ReadJSON", mock.Anything). Return(func(_ interface{}) error { - return assert.AnError + return websocket.ErrCloseSent //TODO: this should be assert.AnError and test should be rewritten }). Once() - s.connection.On("Close").Return(nil).Once() - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - controller.HandleConnection(ctx) - - // Ensure all expectations are met - s.connection.AssertExpectations(s.T()) + controller.HandleConnection(context.Background()) + conn.AssertExpectations(t) }) - s.T().Run("write routine failed", func(*testing.T) { - controller := s.initializeController() + s.T().Run("Write routine failed", func(t *testing.T) { + t.Parallel() - // Mock configureConnection to succeed - s.mockConnectionSetup() - blocksDataProvider := s.mockBlockDataProviderSetup(uuid.New()) + conn, dataProviderFactory, dataProvider := newControllerMocks(t) + controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) - done := make(chan struct{}, 1) - requestMessage := models.SubscribeMessageRequest{ - BaseMessageRequest: models.BaseMessageRequest{Action: "subscribe"}, - Topic: dp.BlocksTopic, - Arguments: nil, - } - msg, err := json.Marshal(requestMessage) - s.Require().NoError(err) + dataProviderFactory. + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(dataProvider, nil). + Once() - // Mocks `ReadJSON(v interface{}) error` which accepts an uninitialize interface that - // receives the contents of the read message. This logic mocks that behavior, setting - // the target with the value `msg` - s.connection. - On("ReadJSON", mock.Anything). + id := uuid.New() + dataProvider.On("ID").Return(id) + // data provider might finish on its own or controller will close it via Close() + dataProvider.On("Close").Return(nil).Maybe() + + dataProvider. + On("Run", mock.Anything). Run(func(args mock.Arguments) { - reqMsg, ok := args.Get(0).(*json.RawMessage) - s.Require().True(ok) - *reqMsg = msg + controller.multiplexedStream <- unittest.BlockFixture() }). Return(nil). Once() - s.connection. - On("ReadJSON", mock.Anything). - Return(func(interface{}) error { - <-done - return websocket.ErrCloseSent - }) + done := make(chan struct{}) + msgID := s.expectSubscribeRequest(t, conn) + s.expectSubscribeResponse(t, conn, msgID) - s.connection.On("SetWriteDeadline", mock.Anything).Return(nil).Once() - s.connection. + conn. On("WriteJSON", mock.Anything). Return(func(msg interface{}) error { close(done) return assert.AnError }) - s.connection.On("Close").Return(nil).Once() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - controller.HandleConnection(ctx) + s.expectCloseConnection(conn, done) + + controller.HandleConnection(context.Background()) // Ensure all expectations are met - s.connection.AssertExpectations(s.T()) - s.dataProviderFactory.AssertExpectations(s.T()) - blocksDataProvider.AssertExpectations(s.T()) + conn.AssertExpectations(t) + dataProviderFactory.AssertExpectations(t) + dataProvider.AssertExpectations(t) }) - s.T().Run("context closed", func(*testing.T) { - controller := s.initializeController() + s.T().Run("Context cancelled", func(t *testing.T) { + t.Parallel() - // Mock configureConnection to succeed - s.mockConnectionSetup() + conn := connmock.NewWebsocketConnection(t) + conn.On("Close").Return(nil).Once() + conn.On("SetReadDeadline", mock.Anything).Return(nil).Once() + conn.On("SetPongHandler", mock.AnythingOfType("func(string) error")).Return(nil).Once() - s.connection.On("Close").Return(nil).Once() + factory := dpmock.NewDataProviderFactory(t) + controller := NewWebSocketController(s.logger, s.wsConfig, conn, factory) ctx, cancel := context.WithCancel(context.Background()) - cancel() + conn.On("ReadJSON", mock.Anything).Return(func(_ interface{}) error { + <-ctx.Done() + return websocket.ErrCloseSent + }).Once() + cancel() controller.HandleConnection(ctx) - // Ensure all expectations are met - s.connection.AssertExpectations(s.T()) + conn.AssertExpectations(t) }) } -// TestKeepaliveHappyCase tests the behavior of the keepalive function. -func (s *ControllerSuite) TestKeepaliveHappyCase() { - // Create a context for the test - ctx := context.Background() +func (s *WsControllerSuite) TestKeepaliveRoutine() { + s.T().Run("Successfully pings connection n times", func(t *testing.T) { + conn := connmock.NewWebsocketConnection(t) + conn.On("Close").Return(nil).Once() + conn.On("SetPongHandler", mock.AnythingOfType("func(string) error")).Return(nil).Once() + conn.On("SetReadDeadline", mock.Anything).Return(nil) - controller := s.initializeController() - s.connection.On("WriteControl", websocket.PingMessage, mock.Anything).Return(nil) + done := make(chan struct{}) + i := 0 + expectedCalls := 2 + conn. + On("WriteControl", websocket.PingMessage, mock.Anything). + Return(func(int, time.Time) error { + if i == expectedCalls { + close(done) + return websocket.ErrCloseSent + } - // Start the keepalive process in a separate goroutine - go func() { - err := controller.keepalive(ctx) - s.Require().NoError(err) - }() + i += 1 + return nil + }). + Times(expectedCalls + 1) - // Use Eventually to wait for some ping messages - expectedCalls := 3 // expected 3 ping messages for 30 seconds - s.Require().Eventually(func() bool { - return len(s.connection.Calls) == expectedCalls - }, time.Duration(expectedCalls)*PongWait, 1*time.Second, "not all ping messages were sent") + conn.On("ReadJSON", mock.Anything).Return(func(_ interface{}) error { + <-done + return websocket.ErrCloseSent + }) - s.connection.On("Close").Return(nil).Once() - controller.shutdownConnection() + factory := dpmock.NewDataProviderFactory(t) + controller := NewWebSocketController(s.logger, s.wsConfig, conn, factory) + controller.HandleConnection(context.Background()) - // Assert that the ping was sent - s.connection.AssertExpectations(s.T()) -} + conn.AssertExpectations(t) + }) -// TestKeepaliveError tests the behavior of the keepalive function when there is an error in writing the ping. -func (s *ControllerSuite) TestKeepaliveError() { - controller := s.initializeController() + s.T().Run("Error on write to connection", func(t *testing.T) { + conn := connmock.NewWebsocketConnection(t) + conn. + On("WriteControl", websocket.PingMessage, mock.Anything). + Return(websocket.ErrCloseSent). //TODO: change to assert.AnError and rewrite test + Once() - // Setup the mock connection with an error - s.connection.On("WriteControl", websocket.PingMessage, mock.Anything).Return(assert.AnError).Once() + factory := dpmock.NewDataProviderFactory(t) + controller := NewWebSocketController(s.logger, s.wsConfig, conn, factory) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + err := controller.keepalive(ctx) + s.Require().Error(err) + s.Require().ErrorIs(websocket.ErrCloseSent, err) - expectedError := fmt.Errorf("failed to write ping message: %w", assert.AnError) - // Start the keepalive process - err := controller.keepalive(ctx) - s.Require().Error(err) - s.Require().Equal(expectedError, err) + conn.AssertExpectations(t) + }) + + s.T().Run("Context cancelled", func(t *testing.T) { + conn := connmock.NewWebsocketConnection(t) + factory := dpmock.NewDataProviderFactory(t) + controller := NewWebSocketController(s.logger, s.wsConfig, conn, factory) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Immediately cancel the context + + // Start the keepalive process with the context canceled + err := controller.keepalive(ctx) + s.Require().NoError(err) - // Assert expectations - s.connection.AssertExpectations(s.T()) + conn.AssertExpectations(t) // Should not invoke WriteMessage after context cancellation + }) } -// TestKeepaliveContextCancel tests the behavior of keepalive when the context is canceled before a ping is sent and -// no ping message is sent after the context is canceled. -func (s *ControllerSuite) TestKeepaliveContextCancel() { - controller := s.initializeController() +// newControllerMocks initializes mock WebSocket connection, data provider, and data provider factory +func newControllerMocks(t *testing.T) (*connmock.WebsocketConnection, *dpmock.DataProviderFactory, *dpmock.DataProvider) { + conn := connmock.NewWebsocketConnection(t) + conn.On("Close").Return(nil).Once() + conn.On("SetPongHandler", mock.AnythingOfType("func(string) error")).Return(nil).Once() + conn.On("SetReadDeadline", mock.Anything).Return(nil) + conn.On("SetWriteDeadline", mock.Anything).Return(nil) - ctx, cancel := context.WithCancel(context.Background()) - cancel() // Immediately cancel the context + dataProvider := dpmock.NewDataProvider(t) + factory := dpmock.NewDataProviderFactory(t) - // Start the keepalive process with the context canceled - err := controller.keepalive(ctx) - s.Require().NoError(err) + return conn, factory, dataProvider +} + +// expectSubscribeRequest mocks the client's subscription request. +func (s *WsControllerSuite) expectSubscribeRequest(t *testing.T, conn *connmock.WebsocketConnection) string { + request := models.SubscribeMessageRequest{ + BaseMessageRequest: models.BaseMessageRequest{ + ClientMessageID: uuid.New().String(), + Action: models.SubscribeAction, + }, + Topic: dp.BlocksTopic, + } + requestJson, err := json.Marshal(request) + require.NoError(t, err) + + // The very first message from a client is a request to subscribe to some topic + conn. + On("ReadJSON", mock.Anything). + Run(func(args mock.Arguments) { + msg, ok := args.Get(0).(*json.RawMessage) + require.True(t, ok) + *msg = requestJson + }). + Return(nil). + Once() - // Assert expectations - s.connection.AssertExpectations(s.T()) // Should not invoke WriteMessage after context cancellation + return request.ClientMessageID } -// initializeController initializes the WebSocket controller. -func (s *ControllerSuite) initializeController() *Controller { - return NewWebSocketController(s.logger, s.config, s.connection, s.dataProviderFactory) +// expectSubscribeResponse mocks the subscription response sent to the client. +func (s *WsControllerSuite) expectSubscribeResponse(t *testing.T, conn *connmock.WebsocketConnection, msgId string) { + conn. + On("WriteJSON", mock.Anything). + Run(func(args mock.Arguments) { + response, ok := args.Get(0).(models.SubscribeMessageResponse) + require.True(t, ok) + require.Equal(t, msgId, response.ClientMessageID) + require.Equal(t, true, response.Success) + }). + Return(nil). + Once() } -// mockDataProviderSetup is a helper which mocks a blocks data provider setup. -func (s *ControllerSuite) mockBlockDataProviderSetup(id uuid.UUID) *dpmock.DataProvider { - dataProvider := dpmock.NewDataProvider(s.T()) - dataProvider.On("ID").Return(id).Twice() - dataProvider.On("Close").Return(nil).Once() - dataProvider.On("Topic").Return(dp.BlocksTopic).Once() - s.dataProviderFactory.On("NewDataProvider", mock.Anything, dp.BlocksTopic, mock.Anything, mock.Anything). - Return(dataProvider, nil).Once() - dataProvider.On("Run").Return(nil).Once() - - return dataProvider +func (s *WsControllerSuite) expectCloseConnection(conn *connmock.WebsocketConnection, done <-chan struct{}) { + // In the default case, no further communication is expected from the client. + // We wait for the writer routine to signal completion, allowing us to close the connection gracefully + conn. + On("ReadJSON", mock.Anything). + Return(func(msg interface{}) error { + <-done + return websocket.ErrCloseSent + }). + Once() + + s.expectKeepaliveRoutineShutdown(conn, done) } -// mockConnectionSetup is a helper which mocks connection setup for SetReadDeadline and SetPongHandler. -func (s *ControllerSuite) mockConnectionSetup() { - s.connection.On("SetReadDeadline", mock.Anything).Return(nil).Once() - s.connection.On("SetPongHandler", mock.AnythingOfType("func(string) error")).Return(nil).Once() +func (s *WsControllerSuite) expectKeepaliveRoutineShutdown(conn *connmock.WebsocketConnection, done <-chan struct{}) { + // We use Maybe() because a test may finish faster than keepalive routine trigger WriteControl + conn. + On("WriteControl", websocket.PingMessage, mock.Anything). + Return(func(int, time.Time) error { + select { + case <-done: + return websocket.ErrCloseSent + default: + return nil + } + }). + Maybe() } diff --git a/engine/access/rest/websockets/data_providers/base_provider.go b/engine/access/rest/websockets/data_providers/base_provider.go index cf1ee1313d9..0ee040cd4ac 100644 --- a/engine/access/rest/websockets/data_providers/base_provider.go +++ b/engine/access/rest/websockets/data_providers/base_provider.go @@ -46,7 +46,6 @@ func (b *baseDataProvider) Topic() string { // Close terminates the data provider. // // No errors are expected during normal operations. -func (b *baseDataProvider) Close() error { +func (b *baseDataProvider) Close() { b.cancel() - return nil } diff --git a/engine/access/rest/websockets/data_providers/data_provider.go b/engine/access/rest/websockets/data_providers/data_provider.go index 08dc497808b..ab48ebeb9f9 100644 --- a/engine/access/rest/websockets/data_providers/data_provider.go +++ b/engine/access/rest/websockets/data_providers/data_provider.go @@ -14,7 +14,7 @@ type DataProvider interface { // Close terminates the data provider. // // No errors are expected during normal operations. - Close() error + Close() // Run starts processing the subscription and handles responses. // // The separation of the data provider's creation and its Run() method diff --git a/engine/access/rest/websockets/data_providers/factory.go b/engine/access/rest/websockets/data_providers/factory.go index 26aade4e090..30d3ce01a48 100644 --- a/engine/access/rest/websockets/data_providers/factory.go +++ b/engine/access/rest/websockets/data_providers/factory.go @@ -31,7 +31,7 @@ type DataProviderFactory interface { // and configuration parameters. // // No errors are expected during normal operations. - NewDataProvider(ctx context.Context, topic string, arguments models.Arguments, ch chan<- interface{}) (DataProvider, error) + NewDataProvider(ctx context.Context, topic string, args models.Arguments, ch chan<- interface{}) (DataProvider, error) } var _ DataProviderFactory = (*DataProviderFactoryImpl)(nil) diff --git a/engine/access/rest/websockets/data_providers/mock/data_provider.go b/engine/access/rest/websockets/data_providers/mock/data_provider.go index 3fe8bc5d15b..48debb23ae3 100644 --- a/engine/access/rest/websockets/data_providers/mock/data_provider.go +++ b/engine/access/rest/websockets/data_providers/mock/data_provider.go @@ -13,21 +13,8 @@ type DataProvider struct { } // Close provides a mock function with given fields: -func (_m *DataProvider) Close() error { - ret := _m.Called() - - if len(ret) == 0 { - panic("no return value specified for Close") - } - - var r0 error - if rf, ok := ret.Get(0).(func() error); ok { - r0 = rf() - } else { - r0 = ret.Error(0) - } - - return r0 +func (_m *DataProvider) Close() { + _m.Called() } // ID provides a mock function with given fields: diff --git a/engine/access/rest/websockets/data_providers/mock/data_provider_factory.go b/engine/access/rest/websockets/data_providers/mock/data_provider_factory.go index c2e46e58d1d..af49cb4e687 100644 --- a/engine/access/rest/websockets/data_providers/mock/data_provider_factory.go +++ b/engine/access/rest/websockets/data_providers/mock/data_provider_factory.go @@ -16,9 +16,9 @@ type DataProviderFactory struct { mock.Mock } -// NewDataProvider provides a mock function with given fields: ctx, topic, arguments, ch -func (_m *DataProviderFactory) NewDataProvider(ctx context.Context, topic string, arguments models.Arguments, ch chan<- interface{}) (data_providers.DataProvider, error) { - ret := _m.Called(ctx, topic, arguments, ch) +// NewDataProvider provides a mock function with given fields: ctx, topic, args, ch +func (_m *DataProviderFactory) NewDataProvider(ctx context.Context, topic string, args models.Arguments, ch chan<- interface{}) (data_providers.DataProvider, error) { + ret := _m.Called(ctx, topic, args, ch) if len(ret) == 0 { panic("no return value specified for NewDataProvider") @@ -27,10 +27,10 @@ func (_m *DataProviderFactory) NewDataProvider(ctx context.Context, topic string var r0 data_providers.DataProvider var r1 error if rf, ok := ret.Get(0).(func(context.Context, string, models.Arguments, chan<- interface{}) (data_providers.DataProvider, error)); ok { - return rf(ctx, topic, arguments, ch) + return rf(ctx, topic, args, ch) } if rf, ok := ret.Get(0).(func(context.Context, string, models.Arguments, chan<- interface{}) data_providers.DataProvider); ok { - r0 = rf(ctx, topic, arguments, ch) + r0 = rf(ctx, topic, args, ch) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(data_providers.DataProvider) @@ -38,7 +38,7 @@ func (_m *DataProviderFactory) NewDataProvider(ctx context.Context, topic string } if rf, ok := ret.Get(1).(func(context.Context, string, models.Arguments, chan<- interface{}) error); ok { - r1 = rf(ctx, topic, arguments, ch) + r1 = rf(ctx, topic, args, ch) } else { r1 = ret.Error(1) } diff --git a/engine/access/rest/websockets/error_codes.go b/engine/access/rest/websockets/error_codes.go new file mode 100644 index 00000000000..fd206bed0b3 --- /dev/null +++ b/engine/access/rest/websockets/error_codes.go @@ -0,0 +1,10 @@ +package websockets + +type Code int + +const ( + InvalidMessage Code = iota + InvalidArgument + NotFound + SubscriptionError +) diff --git a/engine/access/rest/websockets/models/base_message.go b/engine/access/rest/websockets/models/base_message.go index f56d62fda8f..cdcd72eb1ed 100644 --- a/engine/access/rest/websockets/models/base_message.go +++ b/engine/access/rest/websockets/models/base_message.go @@ -1,13 +1,21 @@ package models +const ( + SubscribeAction = "subscribe" + UnsubscribeAction = "unsubscribe" + ListSubscriptionsAction = "list_subscription" +) + // BaseMessageRequest represents a base structure for incoming messages. type BaseMessageRequest struct { - Action string `json:"action"` // Action type of the request + Action string `json:"action"` // subscribe, unsubscribe or list_subscriptions + ClientMessageID string `json:"message_id"` // ClientMessageID is a uuid generated by client to identify request/response uniquely } // BaseMessageResponse represents a base structure for outgoing messages. type BaseMessageResponse struct { - Action string `json:"action,omitempty"` // Action type of the response - Success bool `json:"success"` // Indicates success or failure - ErrorMessage string `json:"error_message,omitempty"` // Error message, if any + SubscriptionID string `json:"subscription_id"` + ClientMessageID string `json:"message_id,omitempty"` // ClientMessageID may be empty in case we send msg by ourselves (e.g. error occurred) + Success bool `json:"success"` + Error ErrorMessage `json:"error,omitempty"` } diff --git a/engine/access/rest/websockets/models/error_message.go b/engine/access/rest/websockets/models/error_message.go new file mode 100644 index 00000000000..d5c0670926f --- /dev/null +++ b/engine/access/rest/websockets/models/error_message.go @@ -0,0 +1,7 @@ +package models + +type ErrorMessage struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action,omitempty"` +} diff --git a/engine/access/rest/websockets/models/list_subscriptions.go b/engine/access/rest/websockets/models/list_subscriptions.go index 26174869585..4893a34b09d 100644 --- a/engine/access/rest/websockets/models/list_subscriptions.go +++ b/engine/access/rest/websockets/models/list_subscriptions.go @@ -8,6 +8,8 @@ type ListSubscriptionsMessageRequest struct { // ListSubscriptionsMessageResponse is the structure used to respond to list_subscriptions requests. // It contains a list of active subscriptions for the current WebSocket connection. type ListSubscriptionsMessageResponse struct { - BaseMessageResponse - Subscriptions []*SubscriptionEntry `json:"subscriptions,omitempty"` + ClientMessageID string `json:"message_id"` + Success bool `json:"success"` + Error ErrorMessage `json:"error,omitempty"` + Subscriptions []*SubscriptionEntry `json:"subscriptions,omitempty"` } diff --git a/engine/access/rest/websockets/models/subscribe.go b/engine/access/rest/websockets/models/subscribe_message.go similarity index 80% rename from engine/access/rest/websockets/models/subscribe.go rename to engine/access/rest/websockets/models/subscribe_message.go index 03b37aee5f1..532e4c6a987 100644 --- a/engine/access/rest/websockets/models/subscribe.go +++ b/engine/access/rest/websockets/models/subscribe_message.go @@ -12,6 +12,4 @@ type SubscribeMessageRequest struct { // SubscribeMessageResponse represents the response to a subscription request. type SubscribeMessageResponse struct { BaseMessageResponse - Topic string `json:"topic"` // Topic of the subscription - ID string `json:"id"` // Unique subscription ID } diff --git a/engine/access/rest/websockets/models/unsubscribe.go b/engine/access/rest/websockets/models/unsubscribe_message.go similarity index 75% rename from engine/access/rest/websockets/models/unsubscribe.go rename to engine/access/rest/websockets/models/unsubscribe_message.go index 2024bb922e0..1402189a601 100644 --- a/engine/access/rest/websockets/models/unsubscribe.go +++ b/engine/access/rest/websockets/models/unsubscribe_message.go @@ -3,11 +3,10 @@ package models // UnsubscribeMessageRequest represents a request to unsubscribe from a topic. type UnsubscribeMessageRequest struct { BaseMessageRequest - ID string `json:"id"` // Unique subscription ID + SubscriptionID string `json:"id"` } // UnsubscribeMessageResponse represents the response to an unsubscription request. type UnsubscribeMessageResponse struct { BaseMessageResponse - ID string `json:"id"` // Unique subscription ID }