From 2c17c96933a764b9a6205c35c8d32a28da2d1fe3 Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Wed, 4 Dec 2024 18:20:08 +0200 Subject: [PATCH 01/26] change error handling in reader and writer routines --- engine/access/rest/websockets/controller.go | 123 +++++++++----------- 1 file changed, 55 insertions(+), 68 deletions(-) diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index 07f76c93fba..fcbff1b6299 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -21,7 +21,9 @@ type Controller struct { config Config conn WebsocketConnection - communicationChannel chan interface{} // Channel for sending messages to the client. + // data channel which data providers write messages to. + // writer routine reads from this channel and writes messages to connection + multiplexedStream chan interface{} dataProviders *concurrentmap.Map[uuid.UUID, dp.DataProvider] dataProviderFactory dp.DataProviderFactory @@ -34,12 +36,12 @@ func NewWebSocketController( dataProviderFactory dp.DataProviderFactory, ) *Controller { return &Controller{ - logger: logger.With().Str("component", "websocket-controller").Logger(), - config: config, - conn: conn, - communicationChannel: make(chan interface{}), //TODO: should it be buffered chan? - dataProviders: concurrentmap.New[uuid.UUID, dp.DataProvider](), - dataProviderFactory: dataProviderFactory, + logger: logger.With().Str("component", "websocket-controller").Logger(), + config: config, + conn: conn, + multiplexedStream: make(chan interface{}), + dataProviders: concurrentmap.New[uuid.UUID, dp.DataProvider](), + dataProviderFactory: dataProviderFactory, } } @@ -49,37 +51,28 @@ func NewWebSocketController( // Parameters: // - ctx: The context for controlling cancellation and timeouts. func (c *Controller) HandleConnection(ctx context.Context) { + defer c.shutdownConnection() - // configuring the connection with appropriate read/write deadlines and handlers. err := c.configureKeepalive() if err != nil { - // TODO: add error handling here c.logger.Error().Err(err).Msg("error configuring connection") - c.shutdownConnection() return } - //TODO: spin up a response limit tracker routine - - // for track all goroutines and error handling g, gCtx := errgroup.WithContext(ctx) g.Go(func() error { - return c.readMessagesFromClient(gCtx) + return c.readMessages(gCtx) }) - g.Go(func() error { return c.keepalive(gCtx) }) - g.Go(func() error { - return c.writeMessagesToClient(gCtx) + return c.writeMessages(gCtx) }) if err = g.Wait(); err != nil { - //TODO: add error handling here c.logger.Error().Err(err).Msg("error detected in one of the goroutines") - c.shutdownConnection() } } @@ -103,6 +96,7 @@ func (c *Controller) configureKeepalive() error { if err := c.conn.SetReadDeadline(time.Now().Add(PongWait)); err != nil { return fmt.Errorf("failed to set the initial read deadline: %w", err) } + // Establish a Pong handler which sets the handler for pong messages received from the peer. c.conn.SetPongHandler(func(string) error { return c.conn.SetReadDeadline(time.Now().Add(PongWait)) @@ -111,73 +105,63 @@ func (c *Controller) configureKeepalive() error { return nil } -// writeMessagesToClient reads a messages from communication channel and passes them on to a client WebSocket connection. +// writeMessages reads a messages from communication channel and passes them on to a client WebSocket connection. // The communication channel is filled by data providers. Besides, the response limit tracker is involved in // write message regulation // // Expected errors during normal operation: // - context.Canceled if the client disconnected -func (c *Controller) writeMessagesToClient(ctx context.Context) error { +func (c *Controller) writeMessages(ctx context.Context) error { for { select { case <-ctx.Done(): return ctx.Err() - case msg, ok := <-c.communicationChannel: + case msg, ok := <-c.multiplexedStream: if !ok { - err := fmt.Errorf("communication channel closed, no error occurred") - return err + return nil } - // TODO: handle 'response per second' limits - // Specifies a timeout for the write operation. If the write - // isn't completed within this duration, it fails with a timeout error. - // SetWriteDeadline ensures the write operation does not block indefinitely - // if the client is slow or unresponsive. This prevents resource exhaustion - // and allows the server to gracefully handle timeouts for delayed writes. if err := c.conn.SetWriteDeadline(time.Now().Add(WriteWait)); err != nil { - c.logger.Error().Err(err).Msg("failed to set the write deadline") - return err + return fmt.Errorf("failed to set the write deadline: %w", err) } + err := c.conn.WriteJSON(msg) if err != nil { - c.logger.Error().Err(err).Msg("error writing to connection") - return err + if IsCloseError(err) { + return nil + } + c.logger.Error().Err(err).Msg("failed to write msg to connection") } } } } -// readMessagesFromClient continuously reads messages from a client WebSocket connection, +// readMessages continuously reads messages from a client WebSocket connection, // processes each message, and handles actions based on the message type. // // Expected errors during normal operation: // - context.Canceled if the client disconnected -func (c *Controller) readMessagesFromClient(ctx context.Context) error { +func (c *Controller) readMessages(ctx context.Context) error { for { - select { - case <-ctx.Done(): - c.logger.Info().Msg("context canceled, stopping read message loop") - return ctx.Err() - default: - msg, err := c.readMessage() - if err != nil { - if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseAbnormalClosure) { - return nil - } - c.logger.Warn().Err(err).Msg("error reading message from client") - return fmt.Errorf("failed to read message from client: %w", err) + msg, err := c.readMessage() + if err != nil { + if IsCloseError(err) { + return nil } - baseMsg, validatedMsg, err := c.parseAndValidateMessage(msg) - if err != nil { - c.logger.Debug().Err(err).Msg("error parsing and validating client message") - return fmt.Errorf("failed to parse and validate client message: %w", err) - } + c.logger.Error().Err(err).Msg("error reading message") + continue + } - if err := c.handleAction(ctx, validatedMsg); err != nil { - c.logger.Warn().Err(err).Str("action", baseMsg.Action).Msg("error handling action") - return fmt.Errorf("failed to handle message action: %w", err) - } + validatedMsg, err := c.parseAndValidateMessage(msg) + if err != nil { + c.logger.Error().Err(err).Msg("failed to parse message") + continue + } + + if err := c.handleAction(ctx, validatedMsg); err != nil { + c.logger.Error().Err(err).Msg("failed to handle action") + continue } } } @@ -190,10 +174,10 @@ func (c *Controller) readMessage() (json.RawMessage, error) { return message, nil } -func (c *Controller) parseAndValidateMessage(message json.RawMessage) (models.BaseMessageRequest, interface{}, error) { +func (c *Controller) parseAndValidateMessage(message json.RawMessage) (interface{}, error) { var baseMsg models.BaseMessageRequest if err := json.Unmarshal(message, &baseMsg); err != nil { - return models.BaseMessageRequest{}, nil, fmt.Errorf("error unmarshalling base message: %w", err) + return nil, fmt.Errorf("error unmarshalling base message: %w", err) } var validatedMsg interface{} @@ -201,7 +185,7 @@ func (c *Controller) parseAndValidateMessage(message json.RawMessage) (models.Ba case "subscribe": var subscribeMsg models.SubscribeMessageRequest if err := json.Unmarshal(message, &subscribeMsg); err != nil { - return baseMsg, nil, fmt.Errorf("error unmarshalling subscribe message: %w", err) + return nil, fmt.Errorf("error unmarshalling subscribe message: %w", err) } //TODO: add validation logic for `topic` field validatedMsg = subscribeMsg @@ -209,23 +193,23 @@ func (c *Controller) parseAndValidateMessage(message json.RawMessage) (models.Ba case "unsubscribe": var unsubscribeMsg models.UnsubscribeMessageRequest if err := json.Unmarshal(message, &unsubscribeMsg); err != nil { - return baseMsg, nil, fmt.Errorf("error unmarshalling unsubscribe message: %w", err) + return nil, fmt.Errorf("error unmarshalling unsubscribe message: %w", err) } validatedMsg = unsubscribeMsg case "list_subscriptions": var listMsg models.ListSubscriptionsMessageRequest if err := json.Unmarshal(message, &listMsg); err != nil { - return baseMsg, nil, fmt.Errorf("error unmarshalling list subscriptions message: %w", err) + return nil, fmt.Errorf("error unmarshalling list subscriptions message: %w", err) } validatedMsg = listMsg default: c.logger.Debug().Str("action", baseMsg.Action).Msg("unknown action type") - return baseMsg, nil, fmt.Errorf("unknown action type: %s", baseMsg.Action) + return nil, fmt.Errorf("unknown action type: %s", baseMsg.Action) } - return baseMsg, validatedMsg, nil + return validatedMsg, nil } func (c *Controller) handleAction(ctx context.Context, message interface{}) error { @@ -243,7 +227,7 @@ func (c *Controller) handleAction(ctx context.Context, message interface{}) erro } func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMessageRequest) { - dp, err := c.dataProviderFactory.NewDataProvider(ctx, msg.Topic, msg.Arguments, c.communicationChannel) + dp, err := c.dataProviderFactory.NewDataProvider(ctx, msg.Topic, msg.Arguments, c.multiplexedStream) if err != nil { // TODO: handle error here c.logger.Error().Err(err).Msgf("error while creating data provider for topic: %s", msg.Topic) @@ -252,7 +236,7 @@ func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMe c.dataProviders.Add(dp.ID(), dp) //TODO: return OK response to client - c.communicationChannel <- msg + c.multiplexedStream <- msg go func() { err := dp.Run() @@ -268,7 +252,7 @@ func (c *Controller) handleUnsubscribe(_ context.Context, msg models.Unsubscribe if err != nil { c.logger.Debug().Err(err).Msg("error parsing message ID") //TODO: return an error response to client - c.communicationChannel <- err + c.multiplexedStream <- err return } @@ -288,7 +272,6 @@ func (c *Controller) shutdownConnection() { if err := c.conn.Close(); err != nil { c.logger.Error().Err(err).Msg("error closing connection") } - // TODO: safe closing communicationChannel will be included as a part of PR #6642 }() err := c.dataProviders.ForEach(func(_ uuid.UUID, dp dp.DataProvider) error { @@ -326,3 +309,7 @@ func (c *Controller) keepalive(ctx context.Context) error { } } } + +func IsCloseError(err error) bool { + return websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseAbnormalClosure, websocket.CloseGoingAway) +} From 665cdb08c675bf543019976feae82d299057e643 Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Wed, 11 Dec 2024 11:46:21 +0200 Subject: [PATCH 02/26] Handle different data flows in controller --- engine/access/rest/websockets/controller.go | 193 +++++++++++++----- .../access/rest/websockets/controller_test.go | 81 +++++--- .../rest/websockets/data_providers/factory.go | 2 +- .../rest/websockets/models/base_message.go | 2 + .../rest/websockets/models/subscribe.go | 3 +- 5 files changed, 196 insertions(+), 85 deletions(-) diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index fcbff1b6299..249ccc880cd 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -3,7 +3,9 @@ package websockets import ( "context" "encoding/json" + "errors" "fmt" + "sync" "time" "github.com/google/uuid" @@ -27,6 +29,7 @@ type Controller struct { dataProviders *concurrentmap.Map[uuid.UUID, dp.DataProvider] dataProviderFactory dp.DataProviderFactory + dataProvidersGroup *sync.WaitGroup } func NewWebSocketController( @@ -42,6 +45,7 @@ func NewWebSocketController( multiplexedStream: make(chan interface{}), dataProviders: concurrentmap.New[uuid.UUID, dp.DataProvider](), dataProviderFactory: dataProviderFactory, + dataProvidersGroup: &sync.WaitGroup{}, } } @@ -72,6 +76,10 @@ func (c *Controller) HandleConnection(ctx context.Context) { }) if err = g.Wait(); err != nil { + if errors.Is(err, websocket.ErrCloseSent) { + return + } + c.logger.Error().Err(err).Msg("error detected in one of the goroutines") } } @@ -116,21 +124,17 @@ func (c *Controller) writeMessages(ctx context.Context) error { select { case <-ctx.Done(): return ctx.Err() - case msg, ok := <-c.multiplexedStream: + case message, ok := <-c.multiplexedStream: if !ok { - return nil + return fmt.Errorf("multiplexed stream closed") } if err := c.conn.SetWriteDeadline(time.Now().Add(WriteWait)); err != nil { return fmt.Errorf("failed to set the write deadline: %w", err) } - err := c.conn.WriteJSON(msg) - if err != nil { - if IsCloseError(err) { - return nil - } - c.logger.Error().Err(err).Msg("failed to write msg to connection") + if err := c.conn.WriteJSON(message); err != nil { + return err } } } @@ -143,37 +147,32 @@ func (c *Controller) writeMessages(ctx context.Context) error { // - context.Canceled if the client disconnected func (c *Controller) readMessages(ctx context.Context) error { for { - msg, err := c.readMessage() - if err != nil { - if IsCloseError(err) { - return nil + var message json.RawMessage + if err := c.conn.ReadJSON(&message); err != nil { + if errors.Is(err, websocket.ErrCloseSent) { + return err } + c.writeBaseErrorResponse(ctx, err, "") c.logger.Error().Err(err).Msg("error reading message") continue } - validatedMsg, err := c.parseAndValidateMessage(msg) + validatedMsg, err := c.parseAndValidateMessage(message) if err != nil { + c.writeBaseErrorResponse(ctx, err, "") c.logger.Error().Err(err).Msg("failed to parse message") continue } - if err := c.handleAction(ctx, validatedMsg); err != nil { + if err = c.handleAction(ctx, validatedMsg); err != nil { + c.writeBaseErrorResponse(ctx, err, "") c.logger.Error().Err(err).Msg("failed to handle action") continue } } } -func (c *Controller) readMessage() (json.RawMessage, error) { - var message json.RawMessage - if err := c.conn.ReadJSON(&message); err != nil { - return nil, fmt.Errorf("error reading JSON from client: %w", err) - } - return message, nil -} - func (c *Controller) parseAndValidateMessage(message json.RawMessage) (interface{}, error) { var baseMsg models.BaseMessageRequest if err := json.Unmarshal(message, &baseMsg); err != nil { @@ -187,7 +186,6 @@ func (c *Controller) parseAndValidateMessage(message json.RawMessage) (interface if err := json.Unmarshal(message, &subscribeMsg); err != nil { return nil, fmt.Errorf("error unmarshalling subscribe message: %w", err) } - //TODO: add validation logic for `topic` field validatedMsg = subscribeMsg case "unsubscribe": @@ -219,7 +217,7 @@ func (c *Controller) handleAction(ctx context.Context, message interface{}) erro case models.UnsubscribeMessageRequest: c.handleUnsubscribe(ctx, msg) case models.ListSubscriptionsMessageRequest: - c.handleListSubscriptions(ctx, msg) + c.handleListSubscriptions(ctx) default: return fmt.Errorf("unknown message type: %T", msg) } @@ -227,62 +225,112 @@ func (c *Controller) handleAction(ctx context.Context, message interface{}) erro } func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMessageRequest) { - dp, err := c.dataProviderFactory.NewDataProvider(ctx, msg.Topic, msg.Arguments, c.multiplexedStream) + // register new provider + provider, err := c.dataProviderFactory.NewDataProvider(ctx, msg.Topic, msg.Arguments, c.multiplexedStream) if err != nil { - // TODO: handle error here - c.logger.Error().Err(err).Msgf("error while creating data provider for topic: %s", msg.Topic) + c.writeBaseErrorResponse(ctx, err, "subscribe") + c.logger.Error().Err(err).Msg("error creating data provider") + return } - c.dataProviders.Add(dp.ID(), dp) - - //TODO: return OK response to client - c.multiplexedStream <- msg + c.dataProviders.Add(provider.ID(), provider) + c.writeSubscribeOkResponse(ctx, provider.ID()) + // run provider + c.dataProvidersGroup.Add(1) go func() { - err := dp.Run() + err = provider.Run() if err != nil { - //TODO: Log or handle the error from Run + c.writeBaseErrorResponse(ctx, err, "") c.logger.Error().Err(err).Msgf("error while running data provider for topic: %s", msg.Topic) } + + c.dataProvidersGroup.Done() + c.dataProviders.Remove(provider.ID()) }() } -func (c *Controller) handleUnsubscribe(_ context.Context, msg models.UnsubscribeMessageRequest) { +func (c *Controller) handleUnsubscribe(ctx context.Context, msg models.UnsubscribeMessageRequest) { id, err := uuid.Parse(msg.ID) if err != nil { + c.writeBaseErrorResponse(ctx, err, "unsubscribe") c.logger.Debug().Err(err).Msg("error parsing message ID") - //TODO: return an error response to client - c.multiplexedStream <- err return } - dp, ok := c.dataProviders.Get(id) - if ok { - dp.Close() - c.dataProviders.Remove(id) + provider, ok := c.dataProviders.Get(id) + if !ok { + c.writeBaseErrorResponse(ctx, err, "unsubscribe") + c.logger.Debug().Err(err).Msg("no active subscription with such ID found") + return } + + err = provider.Close() + if err != nil { + c.writeBaseErrorResponse(ctx, err, "unsubscribe") + return + } + + c.dataProviders.Remove(id) + c.writeUnsubscribeOkResponse(ctx, id) } -func (c *Controller) handleListSubscriptions(ctx context.Context, msg models.ListSubscriptionsMessageRequest) { - //TODO: return a response to client +func (c *Controller) handleListSubscriptions(ctx context.Context) { + var subs []*models.SubscriptionEntry + + err := c.dataProviders.ForEach(func(id uuid.UUID, provider dp.DataProvider) error { + subs = append(subs, &models.SubscriptionEntry{ + ID: id.String(), + Topic: provider.Topic(), + }) + return nil + }) + + if err != nil { + c.writeBaseErrorResponse(ctx, err, "list_subscriptions") + c.logger.Debug().Err(err).Msg("error listing subscriptions") + return + } + + resp := models.ListSubscriptionsMessageResponse{ + BaseMessageResponse: models.BaseMessageResponse{ + Success: true, + }, + Subscriptions: subs, + } + c.writeResponse(ctx, resp) } func (c *Controller) shutdownConnection() { - defer func() { - if err := c.conn.Close(); err != nil { - c.logger.Error().Err(err).Msg("error closing connection") + err := c.conn.Close() + if err != nil { + c.logger.Error().Err(err).Msg("error closing connection") + } + + err = c.dataProviders.ForEach(func(_ uuid.UUID, dp dp.DataProvider) error { + //TODO: why did i think it's a good idea to return error in Close()? it's messy now + err = dp.Close() + if err != nil { + c.logger.Error().Err(err).Msg("error closing data provider") } - }() - err := c.dataProviders.ForEach(func(_ uuid.UUID, dp dp.DataProvider) error { - dp.Close() return nil }) + if err != nil { c.logger.Error().Err(err).Msg("error closing data provider") } c.dataProviders.Clear() + + // drain the channel as some providers may still send data to it during shutdown + go func() { + for range c.multiplexedStream { + } + }() + + c.dataProvidersGroup.Wait() + close(c.multiplexedStream) } // keepalive sends a ping message periodically to keep the WebSocket connection alive @@ -301,15 +349,56 @@ func (c *Controller) keepalive(ctx context.Context) error { case <-pingTicker.C: err := c.conn.WriteControl(websocket.PingMessage, time.Now().Add(WriteWait)) if err != nil { - // Log error and exit the loop on failure - c.logger.Debug().Err(err).Msg("failed to send ping") + if errors.Is(err, websocket.ErrCloseSent) { + return err + } + c.writeBaseErrorResponse(ctx, err, "") + c.logger.Debug().Err(err).Msg("failed to send ping") return fmt.Errorf("failed to write ping message: %w", err) } } } } -func IsCloseError(err error) bool { - return websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseAbnormalClosure, websocket.CloseGoingAway) +func (c *Controller) writeBaseErrorResponse(ctx context.Context, err error, action string) { + request := models.BaseMessageResponse{ + Action: action, + Success: false, + ErrorMessage: err.Error(), + } + + c.writeResponse(ctx, request) +} + +func (c *Controller) writeSubscribeOkResponse(ctx context.Context, id uuid.UUID) { + request := models.SubscribeMessageResponse{ + BaseMessageResponse: models.BaseMessageResponse{ + Action: "subscribe", + Success: true, + }, + ID: id.String(), + } + + c.writeResponse(ctx, request) +} + +func (c *Controller) writeUnsubscribeOkResponse(ctx context.Context, id uuid.UUID) { + request := models.UnsubscribeMessageResponse{ + BaseMessageResponse: models.BaseMessageResponse{ + Action: "unsubscribe", + Success: true, + }, + ID: id.String(), + } + + c.writeResponse(ctx, request) +} + +func (c *Controller) writeResponse(ctx context.Context, response interface{}) { + select { + case <-ctx.Done(): + return + case c.multiplexedStream <- response: + } } diff --git a/engine/access/rest/websockets/controller_test.go b/engine/access/rest/websockets/controller_test.go index 4e36aacc3bc..0b54b42c36c 100644 --- a/engine/access/rest/websockets/controller_test.go +++ b/engine/access/rest/websockets/controller_test.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "testing" + "time" "github.com/google/uuid" "github.com/gorilla/websocket" @@ -46,7 +47,7 @@ func TestWsControllerSuite(t *testing.T) { func (s *WsControllerSuite) TestSubscribeRequest() { s.T().Run("Happy path", func(t *testing.T) { conn, dataProviderFactory, dataProvider := newControllerMocks(t) - controller := NewWebSocketController(s.logger, s.wsConfig, dataProviderFactory, conn) + controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) dataProvider. On("Run"). @@ -59,16 +60,16 @@ func (s *WsControllerSuite) TestSubscribeRequest() { Topic: "blocks", Arguments: nil, } + subscribeRequestJson, err := json.Marshal(subscribeRequest) + require.NoError(t, err) // Simulate receiving the subscription request from the client conn. On("ReadJSON", mock.Anything). Run(func(args mock.Arguments) { - requestMsg, ok := args.Get(0).(*json.RawMessage) + msg, ok := args.Get(0).(*json.RawMessage) require.True(t, ok) - subscribeRequestMessage, err := json.Marshal(subscribeRequest) - require.NoError(t, err) - *requestMsg = subscribeRequestMessage + *msg = subscribeRequestJson }). Return(nil). Once() @@ -97,27 +98,31 @@ func (s *WsControllerSuite) TestSubscribeRequest() { controller.HandleConnection(context.Background()) }) + + s.T().Run("Parse request message error", func(t *testing.T) { + + }) } // TestSubscribeBlocks tests the functionality for streaming blocks to a subscriber. func (s *WsControllerSuite) TestSubscribeBlocks() { s.T().Run("Stream one block", func(t *testing.T) { conn, dataProviderFactory, dataProvider := newControllerMocks(t) - controller := NewWebSocketController(s.logger, s.wsConfig, dataProviderFactory, conn) + controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) // Simulate data provider write a block to the controller expectedBlock := unittest.BlockFixture() dataProvider. On("Run", mock.Anything). Run(func(args mock.Arguments) { - controller.communicationChannel <- expectedBlock + controller.multiplexedStream <- expectedBlock }). Return(nil). Once() done := make(chan struct{}, 1) - s.expectSubscriptionRequest(conn, done) - s.expectSubscriptionResponse(conn, true) + s.expectSubscribeRequest(conn, done) + s.expectSubscribeResponse(conn, true) // Expect a valid block to be passed to WriteJSON. // If we got to this point, the controller executed all its logic properly @@ -139,7 +144,7 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { s.T().Run("Stream many blocks", func(t *testing.T) { conn, dataProviderFactory, dataProvider := newControllerMocks(t) - controller := NewWebSocketController(s.logger, s.wsConfig, dataProviderFactory, conn) + controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) // Simulate data provider writes some blocks to the controller expectedBlocks := unittest.BlockFixtures(100) @@ -147,15 +152,15 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { On("Run", mock.Anything). Run(func(args mock.Arguments) { for _, block := range expectedBlocks { - controller.communicationChannel <- *block + controller.multiplexedStream <- *block } }). Return(nil). Once() done := make(chan struct{}, 1) - s.expectSubscriptionRequest(conn, done) - s.expectSubscriptionResponse(conn, true) + s.expectSubscribeRequest(conn, done) + s.expectSubscribeResponse(conn, true) i := 0 actualBlocks := make([]*flow.Block, len(expectedBlocks)) @@ -189,14 +194,15 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { // The mocked functions are expected to be called in a case when a test is expected to reach WriteJSON function. func newControllerMocks(t *testing.T) (*connmock.WebsocketConnection, *dpmock.DataProviderFactory, *dpmock.DataProvider) { conn := connmock.NewWebsocketConnection(t) - conn.On("Close").Return(nil).Once() + conn.On("Close").Return(nil) + conn.On("SetPongHandler", mock.AnythingOfType("func(string) error")).Return(nil).Once() + conn.On("SetReadDeadline", mock.Anything).Return(nil) + conn.On("SetWriteDeadline", mock.Anything).Return(nil) id := uuid.New() - topic := "blocks" dataProvider := dpmock.NewDataProvider(t) dataProvider.On("ID").Return(id) - dataProvider.On("Close").Return(nil) - dataProvider.On("Topic").Return(topic) + //dataProvider.On("Close").Return(nil). factory := dpmock.NewDataProviderFactory(t) factory. @@ -207,21 +213,22 @@ func newControllerMocks(t *testing.T) (*connmock.WebsocketConnection, *dpmock.Da return conn, factory, dataProvider } -// expectSubscriptionRequest mocks the client's subscription request. -func (s *WsControllerSuite) expectSubscriptionRequest(conn *connmock.WebsocketConnection, done <-chan struct{}) { - requestMessage := models.SubscribeMessageRequest{ +// expectSubscribeRequest mocks the client's subscription request. +func (s *WsControllerSuite) expectSubscribeRequest(conn *connmock.WebsocketConnection, done <-chan struct{}) { + subscribeRequest := models.SubscribeMessageRequest{ BaseMessageRequest: models.BaseMessageRequest{Action: "subscribe"}, Topic: "blocks", } + subscribeRequestJson, err := json.Marshal(subscribeRequest) + require.NoError(s.T(), err) - // The very first message from a client is a request to subscribe to some topic - conn.On("ReadJSON", mock.Anything). + // The very first message from a client is a subscribeRequest to subscribe to some topic + conn. + On("ReadJSON", mock.Anything). Run(func(args mock.Arguments) { - reqMsg, ok := args.Get(0).(*json.RawMessage) + msg, ok := args.Get(0).(*json.RawMessage) require.True(s.T(), ok) - msg, err := json.Marshal(requestMessage) - require.NoError(s.T(), err) - *reqMsg = msg + *msg = subscribeRequestJson }). Return(nil). Once() @@ -231,14 +238,16 @@ func (s *WsControllerSuite) expectSubscriptionRequest(conn *connmock.WebsocketCo conn. On("ReadJSON", mock.Anything). Return(func(msg interface{}) error { - <-done + for range done { + } return websocket.ErrCloseSent }) } -// expectSubscriptionResponse mocks the subscription response sent to the client. -func (s *WsControllerSuite) expectSubscriptionResponse(conn *connmock.WebsocketConnection, success bool) { - conn.On("WriteJSON", mock.Anything). +// expectSubscribeResponse mocks the subscription response sent to the client. +func (s *WsControllerSuite) expectSubscribeResponse(conn *connmock.WebsocketConnection, success bool) { + conn. + On("WriteJSON", mock.Anything). Run(func(args mock.Arguments) { response, ok := args.Get(0).(models.SubscribeMessageResponse) require.True(s.T(), ok) @@ -247,3 +256,15 @@ func (s *WsControllerSuite) expectSubscriptionResponse(conn *connmock.WebsocketC Return(nil). Once() } + +func (s *WsControllerSuite) expectKeepaliveClose(conn *connmock.WebsocketConnection, done <-chan struct{}) { + // first ping will be sent in 9 seconds, so there's no point mocking it + conn. + On("WriteControl", websocket.PingMessage, mock.Anything). + Return(func(int, time.Time) error { + for range done { + } + return websocket.ErrCloseSent + }). + Once() +} diff --git a/engine/access/rest/websockets/data_providers/factory.go b/engine/access/rest/websockets/data_providers/factory.go index 72f4a6b7633..f46abc50ed6 100644 --- a/engine/access/rest/websockets/data_providers/factory.go +++ b/engine/access/rest/websockets/data_providers/factory.go @@ -30,7 +30,7 @@ type DataProviderFactory interface { // and configuration parameters. // // No errors are expected during normal operations. - NewDataProvider(ctx context.Context, topic string, arguments models.Arguments, ch chan<- interface{}) (DataProvider, error) + NewDataProvider(ctx context.Context, topic string, args models.Arguments, ch chan<- interface{}) (DataProvider, error) } var _ DataProviderFactory = (*DataProviderFactoryImpl)(nil) diff --git a/engine/access/rest/websockets/models/base_message.go b/engine/access/rest/websockets/models/base_message.go index f56d62fda8f..38a030358bd 100644 --- a/engine/access/rest/websockets/models/base_message.go +++ b/engine/access/rest/websockets/models/base_message.go @@ -11,3 +11,5 @@ type BaseMessageResponse struct { Success bool `json:"success"` // Indicates success or failure ErrorMessage string `json:"error_message,omitempty"` // Error message, if any } + +// TODO: add Action enum? subscribe, unsubscribe, list diff --git a/engine/access/rest/websockets/models/subscribe.go b/engine/access/rest/websockets/models/subscribe.go index 95ad17e3708..1b4a28470d2 100644 --- a/engine/access/rest/websockets/models/subscribe.go +++ b/engine/access/rest/websockets/models/subscribe.go @@ -12,6 +12,5 @@ type SubscribeMessageRequest struct { // SubscribeMessageResponse represents the response to a subscription request. type SubscribeMessageResponse struct { BaseMessageResponse - Topic string `json:"topic"` // Topic of the subscription - ID string `json:"id"` // Unique subscription ID + ID string `json:"id"` // Unique subscription ID } From b62f8db11cc89cd08713e70dc8f4386844de3baf Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Wed, 11 Dec 2024 16:22:11 +0200 Subject: [PATCH 03/26] cover all actions with tests --- engine/access/rest/websockets/controller.go | 7 +- .../access/rest/websockets/controller_test.go | 422 +++++++++++++++++- 2 files changed, 405 insertions(+), 24 deletions(-) diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index 249ccc880cd..fb88106d86b 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -18,6 +18,10 @@ import ( "github.com/onflow/flow-go/utils/concurrentmap" ) +var ( + ErrUnmarshalMessage = errors.New("failed to unmarshal message") +) + type Controller struct { logger zerolog.Logger config Config @@ -260,7 +264,7 @@ func (c *Controller) handleUnsubscribe(ctx context.Context, msg models.Unsubscri provider, ok := c.dataProviders.Get(id) if !ok { - c.writeBaseErrorResponse(ctx, err, "unsubscribe") + c.writeBaseErrorResponse(ctx, fmt.Errorf("could not find data provider with such id"), "unsubscribe") c.logger.Debug().Err(err).Msg("no active subscription with such ID found") return } @@ -295,6 +299,7 @@ func (c *Controller) handleListSubscriptions(ctx context.Context) { resp := models.ListSubscriptionsMessageResponse{ BaseMessageResponse: models.BaseMessageResponse{ Success: true, + Action: "list_subscriptions", }, Subscriptions: subs, } diff --git a/engine/access/rest/websockets/controller_test.go b/engine/access/rest/websockets/controller_test.go index 0b54b42c36c..002d7e28fb8 100644 --- a/engine/access/rest/websockets/controller_test.go +++ b/engine/access/rest/websockets/controller_test.go @@ -3,6 +3,7 @@ package websockets import ( "context" "encoding/json" + "fmt" "testing" "time" @@ -49,18 +50,25 @@ func (s *WsControllerSuite) TestSubscribeRequest() { conn, dataProviderFactory, dataProvider := newControllerMocks(t) controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) + dataProviderFactory. + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(dataProvider, nil). + Once() + + id := uuid.New() + dataProvider.On("ID").Return(id) dataProvider. On("Run"). Run(func(args mock.Arguments) {}). Return(nil). Once() - subscribeRequest := models.SubscribeMessageRequest{ + request := models.SubscribeMessageRequest{ BaseMessageRequest: models.BaseMessageRequest{Action: "subscribe"}, Topic: "blocks", Arguments: nil, } - subscribeRequestJson, err := json.Marshal(subscribeRequest) + requestJson, err := json.Marshal(request) require.NoError(t, err) // Simulate receiving the subscription request from the client @@ -69,38 +77,391 @@ func (s *WsControllerSuite) TestSubscribeRequest() { Run(func(args mock.Arguments) { msg, ok := args.Get(0).(*json.RawMessage) require.True(t, ok) - *msg = subscribeRequestJson + *msg = requestJson }). Return(nil). Once() - // Channel to signal the test flow completion done := make(chan struct{}, 1) + s.expectCloseConnection(conn, done) - // Simulate writing a successful subscription response back to the client conn. On("WriteJSON", mock.Anything). Return(func(msg interface{}) error { response, ok := msg.(models.SubscribeMessageResponse) require.True(t, ok) + require.Equal(t, request.Action, response.Action) require.True(t, response.Success) + require.Equal(t, id.String(), response.ID) + close(done) // Signal that response has been sent return websocket.ErrCloseSent }) - // Simulate client closing connection after receiving the response + controller.HandleConnection(context.Background()) + }) + + s.T().Run("Parse and validate error", func(t *testing.T) { + conn, dataProviderFactory, _ := newControllerMocks(t) + controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) + + type Request struct { + Action string `json:"action"` + } + + subscribeRequest := Request{ + Action: "SubscribeBlocks", + } + subscribeRequestJson, err := json.Marshal(subscribeRequest) + require.NoError(t, err) + + // Simulate receiving the subscription request from the client conn. On("ReadJSON", mock.Anything). - Return(func(interface{}) error { - <-done + Run(func(args mock.Arguments) { + msg, ok := args.Get(0).(*json.RawMessage) + require.True(t, ok) + *msg = subscribeRequestJson + }). + Return(nil). + Once() + + done := make(chan struct{}, 1) + s.expectCloseConnection(conn, done) + + conn. + On("WriteJSON", mock.Anything). + Return(func(msg interface{}) error { + response, ok := msg.(models.BaseMessageResponse) + require.True(t, ok) + require.Empty(t, response.Action) + require.False(t, response.Success) + require.NotNil(t, response.ErrorMessage) //TODO: add kinds of errors for different data flows + + s.T().Log(response.ErrorMessage) + + close(done) // Signal that response has been sent + return websocket.ErrCloseSent + }) + + controller.HandleConnection(context.Background()) + }) + + s.T().Run("Error creating data provider", func(t *testing.T) { + conn, dataProviderFactory, _ := newControllerMocks(t) + controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) + + dataProviderFactory. + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(nil, fmt.Errorf("error creating data provider")). + Once() + + done := make(chan struct{}, 1) + s.expectSubscribeRequest(conn) + s.expectCloseConnection(conn, done) + + conn. + On("WriteJSON", mock.Anything). + Return(func(msg interface{}) error { + response, ok := msg.(models.BaseMessageResponse) + require.True(t, ok) + require.Equal(t, "subscribe", response.Action) + require.False(t, response.Success) + require.Equal(t, response.ErrorMessage, "error creating data provider") + + s.T().Log(response.ErrorMessage) + + close(done) // Signal that response has been sent return websocket.ErrCloseSent }) controller.HandleConnection(context.Background()) }) - s.T().Run("Parse request message error", func(t *testing.T) { - + s.T().Run("Run error", func(t *testing.T) { + conn, dataProviderFactory, dataProvider := newControllerMocks(t) + controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) + + dataProvider. + On("ID"). + Return(uuid.New()) + + dataProvider. + On("Run"). + Run(func(args mock.Arguments) {}). + Return(fmt.Errorf("error running data provider")). + Once() + + dataProviderFactory. + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(dataProvider, nil). + Once() + + done := make(chan struct{}, 1) + s.expectSubscribeRequest(conn) + s.expectSubscribeResponse(conn, true) + s.expectCloseConnection(conn, done) + + conn. + On("WriteJSON", mock.Anything). + Return(func(msg interface{}) error { + response, ok := msg.(models.BaseMessageResponse) + require.True(t, ok) + require.Equal(t, "", response.Action) + require.False(t, response.Success) + require.NotNil(t, response.ErrorMessage) //TODO: add kinds of errors for different data flows + + s.T().Log(response.ErrorMessage) + + close(done) // Signal that response has been sent + return websocket.ErrCloseSent + }) + + controller.HandleConnection(context.Background()) + }) +} + +func (s *WsControllerSuite) TestUnsubscribeRequest() { + s.T().Run("Happy path", func(t *testing.T) { + conn, dataProviderFactory, dataProvider := newControllerMocks(t) + controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) + + dataProviderFactory. + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(dataProvider, nil). + Once() + + id := uuid.New() + dataProvider.On("ID").Return(id) + dataProvider.On("Close").Return(nil) + dataProvider. + On("Run"). + Run(func(args mock.Arguments) {}). + Return(nil). + Once() + + s.expectSubscribeRequest(conn) + s.expectSubscribeResponse(conn, true) + + request := models.UnsubscribeMessageRequest{ + BaseMessageRequest: models.BaseMessageRequest{Action: "unsubscribe"}, + ID: id.String(), + } + requestJson, err := json.Marshal(request) + require.NoError(s.T(), err) + + conn. + On("ReadJSON", mock.Anything). + Run(func(args mock.Arguments) { + msg, ok := args.Get(0).(*json.RawMessage) + require.True(t, ok) + *msg = requestJson + }). + Return(nil). + Once() + + done := make(chan struct{}, 1) + s.expectCloseConnection(conn, done) + + conn. + On("WriteJSON", mock.Anything). + Return(func(msg interface{}) error { + response, ok := msg.(models.UnsubscribeMessageResponse) + require.True(t, ok) + require.Equal(t, request.Action, response.Action) + require.True(t, response.Success) + require.Empty(t, response.ErrorMessage) + require.Equal(t, request.ID, response.ID) + + close(done) + return websocket.ErrCloseSent + }). + Once() + + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second)) + defer cancel() + controller.HandleConnection(ctx) + }) + + s.T().Run("Invalid subscription uuid", func(t *testing.T) { + conn, dataProviderFactory, dataProvider := newControllerMocks(t) + controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) + + dataProviderFactory. + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(dataProvider, nil). + Once() + + id := uuid.New() + dataProvider.On("ID").Return(id) + dataProvider. + On("Run"). + Run(func(args mock.Arguments) {}). + Return(nil). + Once() + + s.expectSubscribeRequest(conn) + s.expectSubscribeResponse(conn, true) + + request := models.UnsubscribeMessageRequest{ + BaseMessageRequest: models.BaseMessageRequest{Action: "unsubscribe"}, + ID: "invalid-uuid", + } + requestJson, err := json.Marshal(request) + require.NoError(s.T(), err) + + conn. + On("ReadJSON", mock.Anything). + Run(func(args mock.Arguments) { + msg, ok := args.Get(0).(*json.RawMessage) + require.True(t, ok) + *msg = requestJson + }). + Return(nil). + Once() + + done := make(chan struct{}, 1) + s.expectCloseConnection(conn, done) + + conn. + On("WriteJSON", mock.Anything). + Return(func(msg interface{}) error { + response, ok := msg.(models.BaseMessageResponse) + require.True(t, ok) + require.Equal(t, request.Action, response.Action) + require.False(t, response.Success) + require.NotEmpty(t, response.ErrorMessage) + + s.T().Log(response.ErrorMessage) + + close(done) + return websocket.ErrCloseSent + }). + Once() + + controller.HandleConnection(context.Background()) + }) + + s.T().Run("Unsubscribe from unknown subscription", func(t *testing.T) { + conn, dataProviderFactory, dataProvider := newControllerMocks(t) + controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) + + dataProviderFactory. + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(dataProvider, nil). + Once() + + id := uuid.New() + dataProvider.On("ID").Return(id) + dataProvider. + On("Run"). + Run(func(args mock.Arguments) {}). + Return(nil). + Once() + + s.expectSubscribeRequest(conn) + s.expectSubscribeResponse(conn, true) + + request := models.UnsubscribeMessageRequest{ + BaseMessageRequest: models.BaseMessageRequest{Action: "unsubscribe"}, + ID: uuid.New().String(), + } + requestJson, err := json.Marshal(request) + require.NoError(s.T(), err) + + conn. + On("ReadJSON", mock.Anything). + Run(func(args mock.Arguments) { + msg, ok := args.Get(0).(*json.RawMessage) + require.True(t, ok) + *msg = requestJson + }). + Return(nil). + Once() + + done := make(chan struct{}, 1) + s.expectCloseConnection(conn, done) + + conn. + On("WriteJSON", mock.Anything). + Return(func(msg interface{}) error { + response, ok := msg.(models.BaseMessageResponse) + require.True(t, ok) + require.Equal(t, request.Action, response.Action) + require.False(t, response.Success) + require.NotEmpty(t, response.ErrorMessage) + + s.T().Log(response.ErrorMessage) + + close(done) + return websocket.ErrCloseSent + }). + Once() + + controller.HandleConnection(context.Background()) + }) +} + +func (s *WsControllerSuite) TestListSubscriptions() { + s.T().Run("Happy path", func(t *testing.T) { + conn, dataProviderFactory, dataProvider := newControllerMocks(t) + controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) + + dataProviderFactory. + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(dataProvider, nil). + Once() + + id := uuid.New() + topic := "blocks" + dataProvider.On("ID").Return(id) + dataProvider.On("Topic").Return(topic) + dataProvider. + On("Run"). + Run(func(args mock.Arguments) {}). + Return(nil). + Once() + + s.expectSubscribeRequest(conn) + s.expectSubscribeResponse(conn, true) + + request := models.ListSubscriptionsMessageRequest{ + BaseMessageRequest: models.BaseMessageRequest{Action: "list_subscriptions"}, + } + requestJson, err := json.Marshal(request) + require.NoError(s.T(), err) + + conn. + On("ReadJSON", mock.Anything). + Run(func(args mock.Arguments) { + msg, ok := args.Get(0).(*json.RawMessage) + require.True(t, ok) + *msg = requestJson + }). + Return(nil). + Once() + + done := make(chan struct{}, 1) + s.expectCloseConnection(conn, done) + + conn. + On("WriteJSON", mock.Anything). + Return(func(msg interface{}) error { + response, ok := msg.(models.ListSubscriptionsMessageResponse) + require.True(t, ok) + require.Equal(t, 1, len(response.Subscriptions)) + require.Equal(t, id.String(), response.Subscriptions[0].ID) + require.Equal(t, topic, response.Subscriptions[0].Topic) + require.Equal(t, response.Action, "list_subscriptions") + require.True(t, response.Success) + require.Empty(t, response.ErrorMessage) + + close(done) + return websocket.ErrCloseSent + }). + Once() + + controller.HandleConnection(context.Background()) }) } @@ -110,6 +471,14 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { conn, dataProviderFactory, dataProvider := newControllerMocks(t) controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) + dataProviderFactory. + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(dataProvider, nil). + Once() + + id := uuid.New() + dataProvider.On("ID").Return(id) + // Simulate data provider write a block to the controller expectedBlock := unittest.BlockFixture() dataProvider. @@ -121,8 +490,9 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { Once() done := make(chan struct{}, 1) - s.expectSubscribeRequest(conn, done) + s.expectSubscribeRequest(conn) s.expectSubscribeResponse(conn, true) + s.expectCloseConnection(conn, done) // Expect a valid block to be passed to WriteJSON. // If we got to this point, the controller executed all its logic properly @@ -146,6 +516,15 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { conn, dataProviderFactory, dataProvider := newControllerMocks(t) controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) + dataProviderFactory. + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(dataProvider, nil). + Once() + + id := uuid.New() + dataProvider.On("ID").Return(id) + dataProvider.On("Close").Return(nil).Maybe() + // Simulate data provider writes some blocks to the controller expectedBlocks := unittest.BlockFixtures(100) dataProvider. @@ -159,8 +538,9 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { Once() done := make(chan struct{}, 1) - s.expectSubscribeRequest(conn, done) + s.expectSubscribeRequest(conn) s.expectSubscribeResponse(conn, true) + s.expectCloseConnection(conn, done) i := 0 actualBlocks := make([]*flow.Block, len(expectedBlocks)) @@ -199,22 +579,14 @@ func newControllerMocks(t *testing.T) (*connmock.WebsocketConnection, *dpmock.Da conn.On("SetReadDeadline", mock.Anything).Return(nil) conn.On("SetWriteDeadline", mock.Anything).Return(nil) - id := uuid.New() dataProvider := dpmock.NewDataProvider(t) - dataProvider.On("ID").Return(id) - //dataProvider.On("Close").Return(nil). - factory := dpmock.NewDataProviderFactory(t) - factory. - On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). - Return(dataProvider, nil). - Once() return conn, factory, dataProvider } // expectSubscribeRequest mocks the client's subscription request. -func (s *WsControllerSuite) expectSubscribeRequest(conn *connmock.WebsocketConnection, done <-chan struct{}) { +func (s *WsControllerSuite) expectSubscribeRequest(conn *connmock.WebsocketConnection) { subscribeRequest := models.SubscribeMessageRequest{ BaseMessageRequest: models.BaseMessageRequest{Action: "subscribe"}, Topic: "blocks", @@ -232,7 +604,9 @@ func (s *WsControllerSuite) expectSubscribeRequest(conn *connmock.WebsocketConne }). Return(nil). Once() +} +func (s *WsControllerSuite) expectCloseConnection(conn *connmock.WebsocketConnection, done <-chan struct{}) { // In the default case, no further communication is expected from the client. // We wait for the writer routine to signal completion, allowing us to close the connection gracefully conn. @@ -241,7 +615,8 @@ func (s *WsControllerSuite) expectSubscribeRequest(conn *connmock.WebsocketConne for range done { } return websocket.ErrCloseSent - }) + }). + Once() } // expectSubscribeResponse mocks the subscription response sent to the client. @@ -251,6 +626,7 @@ func (s *WsControllerSuite) expectSubscribeResponse(conn *connmock.WebsocketConn Run(func(args mock.Arguments) { response, ok := args.Get(0).(models.SubscribeMessageResponse) require.True(s.T(), ok) + require.Equal(s.T(), "subscribe", response.Action) require.Equal(s.T(), success, response.Success) }). Return(nil). @@ -266,5 +642,5 @@ func (s *WsControllerSuite) expectKeepaliveClose(conn *connmock.WebsocketConnect } return websocket.ErrCloseSent }). - Once() + Maybe() } From 4c573c13b7a62344aca554e1126a1bae2aaa2991 Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Thu, 12 Dec 2024 16:02:52 +0200 Subject: [PATCH 04/26] fix tests shutdown --- .../access/rest/websockets/controller_test.go | 84 ++++++++++++------- .../mock/data_provider_factory.go | 12 +-- 2 files changed, 58 insertions(+), 38 deletions(-) diff --git a/engine/access/rest/websockets/controller_test.go b/engine/access/rest/websockets/controller_test.go index 002d7e28fb8..c83469d61b9 100644 --- a/engine/access/rest/websockets/controller_test.go +++ b/engine/access/rest/websockets/controller_test.go @@ -56,10 +56,16 @@ func (s *WsControllerSuite) TestSubscribeRequest() { Once() id := uuid.New() + done := make(chan struct{}, 1) + dataProvider.On("ID").Return(id) + dataProvider.On("Close").Return(nil).Maybe() dataProvider. On("Run"). - Run(func(args mock.Arguments) {}). + Run(func(args mock.Arguments) { + for range done { + } + }). Return(nil). Once() @@ -82,7 +88,6 @@ func (s *WsControllerSuite) TestSubscribeRequest() { Return(nil). Once() - done := make(chan struct{}, 1) s.expectCloseConnection(conn, done) conn. @@ -182,10 +187,8 @@ func (s *WsControllerSuite) TestSubscribeRequest() { conn, dataProviderFactory, dataProvider := newControllerMocks(t) controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) - dataProvider. - On("ID"). - Return(uuid.New()) - + dataProvider.On("ID").Return(uuid.New()) + dataProvider.On("Close").Return(nil).Maybe() dataProvider. On("Run"). Run(func(args mock.Arguments) {}). @@ -199,7 +202,7 @@ func (s *WsControllerSuite) TestSubscribeRequest() { done := make(chan struct{}, 1) s.expectSubscribeRequest(conn) - s.expectSubscribeResponse(conn, true) + s.expectSubscribeResponse(conn) s.expectCloseConnection(conn, done) conn. @@ -232,16 +235,21 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { Once() id := uuid.New() + done := make(chan struct{}, 1) + dataProvider.On("ID").Return(id) - dataProvider.On("Close").Return(nil) + dataProvider.On("Close").Return(nil).Maybe() dataProvider. On("Run"). - Run(func(args mock.Arguments) {}). + Run(func(args mock.Arguments) { + for range done { + } + }). Return(nil). Once() s.expectSubscribeRequest(conn) - s.expectSubscribeResponse(conn, true) + s.expectSubscribeResponse(conn) request := models.UnsubscribeMessageRequest{ BaseMessageRequest: models.BaseMessageRequest{Action: "unsubscribe"}, @@ -260,9 +268,6 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { Return(nil). Once() - done := make(chan struct{}, 1) - s.expectCloseConnection(conn, done) - conn. On("WriteJSON", mock.Anything). Return(func(msg interface{}) error { @@ -278,6 +283,7 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { }). Once() + s.expectCloseConnection(conn, done) ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second)) defer cancel() controller.HandleConnection(ctx) @@ -293,15 +299,21 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { Once() id := uuid.New() + done := make(chan struct{}, 1) + dataProvider.On("ID").Return(id) + dataProvider.On("Close").Return(nil).Maybe() dataProvider. On("Run"). - Run(func(args mock.Arguments) {}). + Run(func(args mock.Arguments) { + for range done { + } + }). Return(nil). Once() s.expectSubscribeRequest(conn) - s.expectSubscribeResponse(conn, true) + s.expectSubscribeResponse(conn) request := models.UnsubscribeMessageRequest{ BaseMessageRequest: models.BaseMessageRequest{Action: "unsubscribe"}, @@ -320,9 +332,6 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { Return(nil). Once() - done := make(chan struct{}, 1) - s.expectCloseConnection(conn, done) - conn. On("WriteJSON", mock.Anything). Return(func(msg interface{}) error { @@ -339,6 +348,7 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { }). Once() + s.expectCloseConnection(conn, done) controller.HandleConnection(context.Background()) }) @@ -352,15 +362,21 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { Once() id := uuid.New() + done := make(chan struct{}, 1) + dataProvider.On("ID").Return(id) + dataProvider.On("Close").Return(nil).Maybe() dataProvider. On("Run"). - Run(func(args mock.Arguments) {}). + Run(func(args mock.Arguments) { + for range done { + } + }). Return(nil). Once() s.expectSubscribeRequest(conn) - s.expectSubscribeResponse(conn, true) + s.expectSubscribeResponse(conn) request := models.UnsubscribeMessageRequest{ BaseMessageRequest: models.BaseMessageRequest{Action: "unsubscribe"}, @@ -379,9 +395,6 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { Return(nil). Once() - done := make(chan struct{}, 1) - s.expectCloseConnection(conn, done) - conn. On("WriteJSON", mock.Anything). Return(func(msg interface{}) error { @@ -398,6 +411,7 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { }). Once() + s.expectCloseConnection(conn, done) controller.HandleConnection(context.Background()) }) } @@ -412,18 +426,24 @@ func (s *WsControllerSuite) TestListSubscriptions() { Return(dataProvider, nil). Once() + done := make(chan struct{}, 1) + id := uuid.New() topic := "blocks" dataProvider.On("ID").Return(id) dataProvider.On("Topic").Return(topic) + dataProvider.On("Close").Return(nil).Maybe() dataProvider. On("Run"). - Run(func(args mock.Arguments) {}). + Run(func(args mock.Arguments) { + for range done { + } + }). Return(nil). Once() s.expectSubscribeRequest(conn) - s.expectSubscribeResponse(conn, true) + s.expectSubscribeResponse(conn) request := models.ListSubscriptionsMessageRequest{ BaseMessageRequest: models.BaseMessageRequest{Action: "list_subscriptions"}, @@ -441,9 +461,6 @@ func (s *WsControllerSuite) TestListSubscriptions() { Return(nil). Once() - done := make(chan struct{}, 1) - s.expectCloseConnection(conn, done) - conn. On("WriteJSON", mock.Anything). Return(func(msg interface{}) error { @@ -461,6 +478,7 @@ func (s *WsControllerSuite) TestListSubscriptions() { }). Once() + s.expectCloseConnection(conn, done) controller.HandleConnection(context.Background()) }) } @@ -478,6 +496,8 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { id := uuid.New() dataProvider.On("ID").Return(id) + // data provider might finish by its own or controller will close it via Close() + dataProvider.On("Close").Return(nil).Maybe() // Simulate data provider write a block to the controller expectedBlock := unittest.BlockFixture() @@ -491,7 +511,7 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { done := make(chan struct{}, 1) s.expectSubscribeRequest(conn) - s.expectSubscribeResponse(conn, true) + s.expectSubscribeResponse(conn) s.expectCloseConnection(conn, done) // Expect a valid block to be passed to WriteJSON. @@ -539,7 +559,7 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { done := make(chan struct{}, 1) s.expectSubscribeRequest(conn) - s.expectSubscribeResponse(conn, true) + s.expectSubscribeResponse(conn) s.expectCloseConnection(conn, done) i := 0 @@ -620,14 +640,14 @@ func (s *WsControllerSuite) expectCloseConnection(conn *connmock.WebsocketConnec } // expectSubscribeResponse mocks the subscription response sent to the client. -func (s *WsControllerSuite) expectSubscribeResponse(conn *connmock.WebsocketConnection, success bool) { +func (s *WsControllerSuite) expectSubscribeResponse(conn *connmock.WebsocketConnection) { conn. On("WriteJSON", mock.Anything). Run(func(args mock.Arguments) { response, ok := args.Get(0).(models.SubscribeMessageResponse) require.True(s.T(), ok) require.Equal(s.T(), "subscribe", response.Action) - require.Equal(s.T(), success, response.Success) + require.Equal(s.T(), true, response.Success) }). Return(nil). Once() diff --git a/engine/access/rest/websockets/data_providers/mock/data_provider_factory.go b/engine/access/rest/websockets/data_providers/mock/data_provider_factory.go index c2e46e58d1d..af49cb4e687 100644 --- a/engine/access/rest/websockets/data_providers/mock/data_provider_factory.go +++ b/engine/access/rest/websockets/data_providers/mock/data_provider_factory.go @@ -16,9 +16,9 @@ type DataProviderFactory struct { mock.Mock } -// NewDataProvider provides a mock function with given fields: ctx, topic, arguments, ch -func (_m *DataProviderFactory) NewDataProvider(ctx context.Context, topic string, arguments models.Arguments, ch chan<- interface{}) (data_providers.DataProvider, error) { - ret := _m.Called(ctx, topic, arguments, ch) +// NewDataProvider provides a mock function with given fields: ctx, topic, args, ch +func (_m *DataProviderFactory) NewDataProvider(ctx context.Context, topic string, args models.Arguments, ch chan<- interface{}) (data_providers.DataProvider, error) { + ret := _m.Called(ctx, topic, args, ch) if len(ret) == 0 { panic("no return value specified for NewDataProvider") @@ -27,10 +27,10 @@ func (_m *DataProviderFactory) NewDataProvider(ctx context.Context, topic string var r0 data_providers.DataProvider var r1 error if rf, ok := ret.Get(0).(func(context.Context, string, models.Arguments, chan<- interface{}) (data_providers.DataProvider, error)); ok { - return rf(ctx, topic, arguments, ch) + return rf(ctx, topic, args, ch) } if rf, ok := ret.Get(0).(func(context.Context, string, models.Arguments, chan<- interface{}) data_providers.DataProvider); ok { - r0 = rf(ctx, topic, arguments, ch) + r0 = rf(ctx, topic, args, ch) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(data_providers.DataProvider) @@ -38,7 +38,7 @@ func (_m *DataProviderFactory) NewDataProvider(ctx context.Context, topic string } if rf, ok := ret.Get(1).(func(context.Context, string, models.Arguments, chan<- interface{}) error); ok { - r1 = rf(ctx, topic, arguments, ch) + r1 = rf(ctx, topic, args, ch) } else { r1 = ret.Error(1) } From 78ca825c3f3fc700411a1310b68d81d3b4781737 Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Thu, 12 Dec 2024 16:14:00 +0200 Subject: [PATCH 05/26] remove old mockery cmds --- Makefile | 3 --- 1 file changed, 3 deletions(-) diff --git a/Makefile b/Makefile index 664c97992cc..bb35ec896dd 100644 --- a/Makefile +++ b/Makefile @@ -216,9 +216,6 @@ generate-mocks: install-mock-generators mockery --name 'Storage' --dir=module/executiondatasync/tracker --case=underscore --output="module/executiondatasync/tracker/mock" --outpkg="mocktracker" mockery --name 'ScriptExecutor' --dir=module/execution --case=underscore --output="module/execution/mock" --outpkg="mock" mockery --name 'StorageSnapshot' --dir=fvm/storage/snapshot --case=underscore --output="fvm/storage/snapshot/mock" --outpkg="mock" - mockery --name 'DataProvider' --dir=engine/access/rest/websockets/data_provider --case=underscore --output="engine/access/rest/websockets/data_provider/mock" --outpkg="mock" - mockery --name 'Factory' --dir=engine/access/rest/websockets/data_provider --case=underscore --output="engine/access/rest/websockets/data_provider/mock" --outpkg="mock" - mockery --name 'WebsocketConnection' --dir=engine/access/rest/websockets --case=underscore --output="engine/access/rest/websockets/mock" --outpkg="mock" #temporarily make insecure/ a non-module to allow mockery to create mocks mv insecure/go.mod insecure/go2.mod From 09bf64dc5f5b416450fb18ceab056c237f4c2ec3 Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Thu, 12 Dec 2024 18:29:06 +0200 Subject: [PATCH 06/26] refactor places where we return error --- engine/access/rest/websockets/controller.go | 164 ++++++++------ .../access/rest/websockets/controller_test.go | 200 +++++++++++------- engine/access/rest/websockets/error_codes.go | 14 ++ .../rest/websockets/models/base_message.go | 15 +- engine/access/rest/websockets/models/error.go | 8 + .../rest/websockets/models/unsubscribe.go | 4 +- 6 files changed, 252 insertions(+), 153 deletions(-) create mode 100644 engine/access/rest/websockets/error_codes.go create mode 100644 engine/access/rest/websockets/models/error.go diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index fb88106d86b..854df073710 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -157,21 +157,27 @@ func (c *Controller) readMessages(ctx context.Context) error { return err } - c.writeBaseErrorResponse(ctx, err, "") - c.logger.Error().Err(err).Msg("error reading message") + c.writeErrorResponse( + ctx, + err, + wrapErrorMessage(ConnectionRead, "error reading from conn", "", "", "")) continue } validatedMsg, err := c.parseAndValidateMessage(message) if err != nil { - c.writeBaseErrorResponse(ctx, err, "") - c.logger.Error().Err(err).Msg("failed to parse message") + c.writeErrorResponse( + ctx, + err, + wrapErrorMessage(InvalidMessage, "error parsing message", "", "", "")) continue } if err = c.handleAction(ctx, validatedMsg); err != nil { - c.writeBaseErrorResponse(ctx, err, "") - c.logger.Error().Err(err).Msg("failed to handle action") + c.writeErrorResponse( + ctx, + err, + wrapErrorMessage(InvalidMessage, "error handling action", "", "", "")) continue } } @@ -185,21 +191,21 @@ func (c *Controller) parseAndValidateMessage(message json.RawMessage) (interface var validatedMsg interface{} switch baseMsg.Action { - case "subscribe": + case models.SubscribeAction: var subscribeMsg models.SubscribeMessageRequest if err := json.Unmarshal(message, &subscribeMsg); err != nil { return nil, fmt.Errorf("error unmarshalling subscribe message: %w", err) } validatedMsg = subscribeMsg - case "unsubscribe": + case models.UnsubscribeAction: var unsubscribeMsg models.UnsubscribeMessageRequest if err := json.Unmarshal(message, &unsubscribeMsg); err != nil { return nil, fmt.Errorf("error unmarshalling unsubscribe message: %w", err) } validatedMsg = unsubscribeMsg - case "list_subscriptions": + case models.ListSubscriptionsAction: var listMsg models.ListSubscriptionsMessageRequest if err := json.Unmarshal(message, &listMsg); err != nil { return nil, fmt.Errorf("error unmarshalling list subscriptions message: %w", err) @@ -221,7 +227,7 @@ func (c *Controller) handleAction(ctx context.Context, message interface{}) erro case models.UnsubscribeMessageRequest: c.handleUnsubscribe(ctx, msg) case models.ListSubscriptionsMessageRequest: - c.handleListSubscriptions(ctx) + c.handleListSubscriptions(ctx, msg) default: return fmt.Errorf("unknown message type: %T", msg) } @@ -232,21 +238,35 @@ func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMe // register new provider provider, err := c.dataProviderFactory.NewDataProvider(ctx, msg.Topic, msg.Arguments, c.multiplexedStream) if err != nil { - c.writeBaseErrorResponse(ctx, err, "subscribe") - c.logger.Error().Err(err).Msg("error creating data provider") + c.writeErrorResponse( + ctx, + err, + wrapErrorMessage(InvalidArgument, "error creating data provider", msg.MessageID, models.SubscribeAction, ""), + ) return } - c.dataProviders.Add(provider.ID(), provider) - c.writeSubscribeOkResponse(ctx, provider.ID()) + + // write OK response to client + responseOk := models.SubscribeMessageResponse{ + BaseMessageResponse: models.BaseMessageResponse{ + MessageID: msg.MessageID, + Success: true, + }, + ID: provider.ID().String(), + } + c.writeResponse(ctx, responseOk) // run provider c.dataProvidersGroup.Add(1) go func() { err = provider.Run() if err != nil { - c.writeBaseErrorResponse(ctx, err, "") - c.logger.Error().Err(err).Msgf("error while running data provider for topic: %s", msg.Topic) + c.writeErrorResponse( + ctx, + err, + wrapErrorMessage(RunError, "data provider finished with error", msg.MessageID, "", ""), + ) } c.dataProvidersGroup.Done() @@ -255,33 +275,50 @@ func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMe } func (c *Controller) handleUnsubscribe(ctx context.Context, msg models.UnsubscribeMessageRequest) { - id, err := uuid.Parse(msg.ID) + id, err := uuid.Parse(msg.SubscriptionID) if err != nil { - c.writeBaseErrorResponse(ctx, err, "unsubscribe") - c.logger.Debug().Err(err).Msg("error parsing message ID") + c.writeErrorResponse( + ctx, + err, + wrapErrorMessage(InvalidArgument, "error parsing message ID", msg.MessageID, models.UnsubscribeAction, msg.SubscriptionID), + ) return } provider, ok := c.dataProviders.Get(id) if !ok { - c.writeBaseErrorResponse(ctx, fmt.Errorf("could not find data provider with such id"), "unsubscribe") - c.logger.Debug().Err(err).Msg("no active subscription with such ID found") + c.writeErrorResponse( + ctx, + err, + wrapErrorMessage(NotFound, "provider not found", msg.MessageID, models.UnsubscribeAction, msg.SubscriptionID), + ) return } err = provider.Close() if err != nil { - c.writeBaseErrorResponse(ctx, err, "unsubscribe") + c.writeErrorResponse( + ctx, + err, + wrapErrorMessage(InternalError, "provider close error", msg.MessageID, models.UnsubscribeAction, msg.SubscriptionID), + ) return } c.dataProviders.Remove(id) - c.writeUnsubscribeOkResponse(ctx, id) + + responseOk := models.UnsubscribeMessageResponse{ + BaseMessageResponse: models.BaseMessageResponse{ + MessageID: msg.MessageID, + Success: true, + }, + SubscriptionID: msg.SubscriptionID, + } + c.writeResponse(ctx, responseOk) } -func (c *Controller) handleListSubscriptions(ctx context.Context) { +func (c *Controller) handleListSubscriptions(ctx context.Context, msg models.ListSubscriptionsMessageRequest) { var subs []*models.SubscriptionEntry - err := c.dataProviders.ForEach(func(id uuid.UUID, provider dp.DataProvider) error { subs = append(subs, &models.SubscriptionEntry{ ID: id.String(), @@ -291,39 +328,42 @@ func (c *Controller) handleListSubscriptions(ctx context.Context) { }) if err != nil { - c.writeBaseErrorResponse(ctx, err, "list_subscriptions") - c.logger.Debug().Err(err).Msg("error listing subscriptions") + c.writeErrorResponse( + ctx, + err, + wrapErrorMessage(NotFound, "error looking for subscription", msg.MessageID, models.ListSubscriptionsAction, ""), + ) return } - resp := models.ListSubscriptionsMessageResponse{ + responseOk := models.ListSubscriptionsMessageResponse{ BaseMessageResponse: models.BaseMessageResponse{ - Success: true, - Action: "list_subscriptions", + Success: true, + MessageID: msg.MessageID, }, Subscriptions: subs, } - c.writeResponse(ctx, resp) + c.writeResponse(ctx, responseOk) } func (c *Controller) shutdownConnection() { err := c.conn.Close() if err != nil { - c.logger.Error().Err(err).Msg("error closing connection") + c.logger.Debug().Err(err).Msg("error closing connection") } err = c.dataProviders.ForEach(func(_ uuid.UUID, dp dp.DataProvider) error { //TODO: why did i think it's a good idea to return error in Close()? it's messy now err = dp.Close() if err != nil { - c.logger.Error().Err(err).Msg("error closing data provider") + c.logger.Debug().Err(err).Msg("error closing data provider") } return nil }) if err != nil { - c.logger.Error().Err(err).Msg("error closing data provider") + c.logger.Debug().Err(err).Msg("error closing data provider") } c.dataProviders.Clear() @@ -358,46 +398,19 @@ func (c *Controller) keepalive(ctx context.Context) error { return err } - c.writeBaseErrorResponse(ctx, err, "") - c.logger.Debug().Err(err).Msg("failed to send ping") - return fmt.Errorf("failed to write ping message: %w", err) + c.writeErrorResponse( + ctx, + err, + wrapErrorMessage(ConnectionWrite, "error sending ping", "", "", "")) + return fmt.Errorf("error sending ping: %w", err) } } } } -func (c *Controller) writeBaseErrorResponse(ctx context.Context, err error, action string) { - request := models.BaseMessageResponse{ - Action: action, - Success: false, - ErrorMessage: err.Error(), - } - - c.writeResponse(ctx, request) -} - -func (c *Controller) writeSubscribeOkResponse(ctx context.Context, id uuid.UUID) { - request := models.SubscribeMessageResponse{ - BaseMessageResponse: models.BaseMessageResponse{ - Action: "subscribe", - Success: true, - }, - ID: id.String(), - } - - c.writeResponse(ctx, request) -} - -func (c *Controller) writeUnsubscribeOkResponse(ctx context.Context, id uuid.UUID) { - request := models.UnsubscribeMessageResponse{ - BaseMessageResponse: models.BaseMessageResponse{ - Action: "unsubscribe", - Success: true, - }, - ID: id.String(), - } - - c.writeResponse(ctx, request) +func (c *Controller) writeErrorResponse(ctx context.Context, err error, msg models.BaseMessageResponse) { + c.logger.Debug().Err(err).Msg(msg.Error.Message) + c.writeResponse(ctx, msg) } func (c *Controller) writeResponse(ctx context.Context, response interface{}) { @@ -407,3 +420,16 @@ func (c *Controller) writeResponse(ctx context.Context, response interface{}) { case c.multiplexedStream <- response: } } + +func wrapErrorMessage(code Code, message string, msgId string, action string, subscriptionID string) models.BaseMessageResponse { + return models.BaseMessageResponse{ + MessageID: msgId, + Success: false, + Error: models.ErrorMessage{ + Code: int(code), + Message: message, + Action: action, + SubscriptionID: subscriptionID, + }, + } +} diff --git a/engine/access/rest/websockets/controller_test.go b/engine/access/rest/websockets/controller_test.go index c83469d61b9..7e27428ac28 100644 --- a/engine/access/rest/websockets/controller_test.go +++ b/engine/access/rest/websockets/controller_test.go @@ -70,9 +70,12 @@ func (s *WsControllerSuite) TestSubscribeRequest() { Once() request := models.SubscribeMessageRequest{ - BaseMessageRequest: models.BaseMessageRequest{Action: "subscribe"}, - Topic: "blocks", - Arguments: nil, + BaseMessageRequest: models.BaseMessageRequest{ + MessageID: uuid.New().String(), + Action: models.SubscribeAction, + }, + Topic: "blocks", + Arguments: nil, } requestJson, err := json.Marshal(request) require.NoError(t, err) @@ -93,17 +96,20 @@ func (s *WsControllerSuite) TestSubscribeRequest() { conn. On("WriteJSON", mock.Anything). Return(func(msg interface{}) error { + defer close(done) // Signal that response has been sent + response, ok := msg.(models.SubscribeMessageResponse) require.True(t, ok) - require.Equal(t, request.Action, response.Action) require.True(t, response.Success) + require.Equal(t, request.MessageID, response.MessageID) require.Equal(t, id.String(), response.ID) - close(done) // Signal that response has been sent return websocket.ErrCloseSent }) - controller.HandleConnection(context.Background()) + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second)) + defer cancel() + controller.HandleConnection(ctx) }) s.T().Run("Parse and validate error", func(t *testing.T) { @@ -137,19 +143,20 @@ func (s *WsControllerSuite) TestSubscribeRequest() { conn. On("WriteJSON", mock.Anything). Return(func(msg interface{}) error { + defer close(done) // Signal that response has been sent + response, ok := msg.(models.BaseMessageResponse) require.True(t, ok) - require.Empty(t, response.Action) require.False(t, response.Success) - require.NotNil(t, response.ErrorMessage) //TODO: add kinds of errors for different data flows - - s.T().Log(response.ErrorMessage) + require.NotEmpty(t, response.Error) + require.Equal(t, int(InvalidMessage), response.Error.Code) - close(done) // Signal that response has been sent return websocket.ErrCloseSent }) - controller.HandleConnection(context.Background()) + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second)) + defer cancel() + controller.HandleConnection(ctx) }) s.T().Run("Error creating data provider", func(t *testing.T) { @@ -168,19 +175,20 @@ func (s *WsControllerSuite) TestSubscribeRequest() { conn. On("WriteJSON", mock.Anything). Return(func(msg interface{}) error { + defer close(done) // Signal that response has been sent + response, ok := msg.(models.BaseMessageResponse) require.True(t, ok) - require.Equal(t, "subscribe", response.Action) require.False(t, response.Success) - require.Equal(t, response.ErrorMessage, "error creating data provider") - - s.T().Log(response.ErrorMessage) + require.NotEmpty(t, response.Error) + require.Equal(t, int(InvalidArgument), response.Error.Code) - close(done) // Signal that response has been sent return websocket.ErrCloseSent }) - controller.HandleConnection(context.Background()) + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second)) + defer cancel() + controller.HandleConnection(ctx) }) s.T().Run("Run error", func(t *testing.T) { @@ -201,26 +209,27 @@ func (s *WsControllerSuite) TestSubscribeRequest() { Once() done := make(chan struct{}, 1) - s.expectSubscribeRequest(conn) - s.expectSubscribeResponse(conn) + msgID := s.expectSubscribeRequest(conn) + s.expectSubscribeResponse(conn, msgID) s.expectCloseConnection(conn, done) conn. On("WriteJSON", mock.Anything). Return(func(msg interface{}) error { + defer close(done) // Signal that response has been sent + response, ok := msg.(models.BaseMessageResponse) require.True(t, ok) - require.Equal(t, "", response.Action) require.False(t, response.Success) - require.NotNil(t, response.ErrorMessage) //TODO: add kinds of errors for different data flows + require.NotEmpty(t, response.Error) + require.Equal(t, int(RunError), response.Error.Code) - s.T().Log(response.ErrorMessage) - - close(done) // Signal that response has been sent return websocket.ErrCloseSent }) - controller.HandleConnection(context.Background()) + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second)) + defer cancel() + controller.HandleConnection(ctx) }) } @@ -248,12 +257,15 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { Return(nil). Once() - s.expectSubscribeRequest(conn) - s.expectSubscribeResponse(conn) + msgID := s.expectSubscribeRequest(conn) + s.expectSubscribeResponse(conn, msgID) request := models.UnsubscribeMessageRequest{ - BaseMessageRequest: models.BaseMessageRequest{Action: "unsubscribe"}, - ID: id.String(), + BaseMessageRequest: models.BaseMessageRequest{ + MessageID: uuid.New().String(), + Action: models.UnsubscribeAction, + }, + SubscriptionID: id.String(), } requestJson, err := json.Marshal(request) require.NoError(s.T(), err) @@ -271,19 +283,21 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { conn. On("WriteJSON", mock.Anything). Return(func(msg interface{}) error { + defer close(done) + response, ok := msg.(models.UnsubscribeMessageResponse) require.True(t, ok) - require.Equal(t, request.Action, response.Action) require.True(t, response.Success) - require.Empty(t, response.ErrorMessage) - require.Equal(t, request.ID, response.ID) + require.Empty(t, response.Error) + require.Equal(t, request.MessageID, response.MessageID) + require.Equal(t, request.SubscriptionID, response.SubscriptionID) - close(done) return websocket.ErrCloseSent }). Once() s.expectCloseConnection(conn, done) + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second)) defer cancel() controller.HandleConnection(ctx) @@ -312,12 +326,15 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { Return(nil). Once() - s.expectSubscribeRequest(conn) - s.expectSubscribeResponse(conn) + msgID := s.expectSubscribeRequest(conn) + s.expectSubscribeResponse(conn, msgID) request := models.UnsubscribeMessageRequest{ - BaseMessageRequest: models.BaseMessageRequest{Action: "unsubscribe"}, - ID: "invalid-uuid", + BaseMessageRequest: models.BaseMessageRequest{ + MessageID: uuid.New().String(), + Action: models.UnsubscribeAction, + }, + SubscriptionID: "invalid-uuid", } requestJson, err := json.Marshal(request) require.NoError(s.T(), err) @@ -335,21 +352,24 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { conn. On("WriteJSON", mock.Anything). Return(func(msg interface{}) error { + defer close(done) + response, ok := msg.(models.BaseMessageResponse) require.True(t, ok) - require.Equal(t, request.Action, response.Action) require.False(t, response.Success) - require.NotEmpty(t, response.ErrorMessage) + require.NotEmpty(t, response.Error) + require.Equal(t, request.MessageID, response.MessageID) + require.Equal(t, int(InvalidArgument), response.Error.Code) - s.T().Log(response.ErrorMessage) - - close(done) return websocket.ErrCloseSent }). Once() s.expectCloseConnection(conn, done) - controller.HandleConnection(context.Background()) + + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second)) + defer cancel() + controller.HandleConnection(ctx) }) s.T().Run("Unsubscribe from unknown subscription", func(t *testing.T) { @@ -375,12 +395,15 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { Return(nil). Once() - s.expectSubscribeRequest(conn) - s.expectSubscribeResponse(conn) + msgID := s.expectSubscribeRequest(conn) + s.expectSubscribeResponse(conn, msgID) request := models.UnsubscribeMessageRequest{ - BaseMessageRequest: models.BaseMessageRequest{Action: "unsubscribe"}, - ID: uuid.New().String(), + BaseMessageRequest: models.BaseMessageRequest{ + MessageID: uuid.New().String(), + Action: models.UnsubscribeAction, + }, + SubscriptionID: uuid.New().String(), } requestJson, err := json.Marshal(request) require.NoError(s.T(), err) @@ -398,21 +421,25 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { conn. On("WriteJSON", mock.Anything). Return(func(msg interface{}) error { + defer close(done) + response, ok := msg.(models.BaseMessageResponse) require.True(t, ok) - require.Equal(t, request.Action, response.Action) require.False(t, response.Success) - require.NotEmpty(t, response.ErrorMessage) + require.NotEmpty(t, response.Error) - s.T().Log(response.ErrorMessage) + require.Equal(t, request.MessageID, response.MessageID) + require.Equal(t, int(NotFound), response.Error.Code) - close(done) return websocket.ErrCloseSent }). Once() s.expectCloseConnection(conn, done) - controller.HandleConnection(context.Background()) + + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second)) + defer cancel() + controller.HandleConnection(ctx) }) } @@ -442,11 +469,14 @@ func (s *WsControllerSuite) TestListSubscriptions() { Return(nil). Once() - s.expectSubscribeRequest(conn) - s.expectSubscribeResponse(conn) + msgID := s.expectSubscribeRequest(conn) + s.expectSubscribeResponse(conn, msgID) request := models.ListSubscriptionsMessageRequest{ - BaseMessageRequest: models.BaseMessageRequest{Action: "list_subscriptions"}, + BaseMessageRequest: models.BaseMessageRequest{ + MessageID: uuid.New().String(), + Action: models.ListSubscriptionsAction, + }, } requestJson, err := json.Marshal(request) require.NoError(s.T(), err) @@ -464,22 +494,26 @@ func (s *WsControllerSuite) TestListSubscriptions() { conn. On("WriteJSON", mock.Anything). Return(func(msg interface{}) error { + defer close(done) + response, ok := msg.(models.ListSubscriptionsMessageResponse) require.True(t, ok) + require.True(t, response.Success) + require.Empty(t, response.Error) + require.Equal(t, request.MessageID, response.MessageID) require.Equal(t, 1, len(response.Subscriptions)) require.Equal(t, id.String(), response.Subscriptions[0].ID) require.Equal(t, topic, response.Subscriptions[0].Topic) - require.Equal(t, response.Action, "list_subscriptions") - require.True(t, response.Success) - require.Empty(t, response.ErrorMessage) - close(done) return websocket.ErrCloseSent }). Once() s.expectCloseConnection(conn, done) - controller.HandleConnection(context.Background()) + + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second)) + defer cancel() + controller.HandleConnection(ctx) }) } @@ -510,8 +544,8 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { Once() done := make(chan struct{}, 1) - s.expectSubscribeRequest(conn) - s.expectSubscribeResponse(conn) + msgID := s.expectSubscribeRequest(conn) + s.expectSubscribeResponse(conn, msgID) s.expectCloseConnection(conn, done) // Expect a valid block to be passed to WriteJSON. @@ -520,15 +554,19 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { conn. On("WriteJSON", mock.Anything). Return(func(msg interface{}) error { + defer close(done) + block, ok := msg.(flow.Block) require.True(t, ok) actualBlock = block - close(done) return websocket.ErrCloseSent }) - controller.HandleConnection(context.Background()) + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second)) + defer cancel() + controller.HandleConnection(ctx) + require.Equal(t, expectedBlock, actualBlock) }) @@ -558,8 +596,8 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { Once() done := make(chan struct{}, 1) - s.expectSubscribeRequest(conn) - s.expectSubscribeResponse(conn) + msgID := s.expectSubscribeRequest(conn) + s.expectSubscribeResponse(conn, msgID) s.expectCloseConnection(conn, done) i := 0 @@ -585,7 +623,10 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { }). Times(len(expectedBlocks)) - controller.HandleConnection(context.Background()) + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second)) + defer cancel() + controller.HandleConnection(ctx) + require.Equal(t, expectedBlocks, actualBlocks) }) } @@ -606,24 +647,29 @@ func newControllerMocks(t *testing.T) (*connmock.WebsocketConnection, *dpmock.Da } // expectSubscribeRequest mocks the client's subscription request. -func (s *WsControllerSuite) expectSubscribeRequest(conn *connmock.WebsocketConnection) { - subscribeRequest := models.SubscribeMessageRequest{ - BaseMessageRequest: models.BaseMessageRequest{Action: "subscribe"}, - Topic: "blocks", +func (s *WsControllerSuite) expectSubscribeRequest(conn *connmock.WebsocketConnection) string { + request := models.SubscribeMessageRequest{ + BaseMessageRequest: models.BaseMessageRequest{ + MessageID: uuid.New().String(), + Action: models.SubscribeAction, + }, + Topic: "blocks", } - subscribeRequestJson, err := json.Marshal(subscribeRequest) + requestJson, err := json.Marshal(request) require.NoError(s.T(), err) - // The very first message from a client is a subscribeRequest to subscribe to some topic + // The very first message from a client is a request to subscribe to some topic conn. On("ReadJSON", mock.Anything). Run(func(args mock.Arguments) { msg, ok := args.Get(0).(*json.RawMessage) require.True(s.T(), ok) - *msg = subscribeRequestJson + *msg = requestJson }). Return(nil). Once() + + return request.MessageID } func (s *WsControllerSuite) expectCloseConnection(conn *connmock.WebsocketConnection, done <-chan struct{}) { @@ -640,13 +686,13 @@ func (s *WsControllerSuite) expectCloseConnection(conn *connmock.WebsocketConnec } // expectSubscribeResponse mocks the subscription response sent to the client. -func (s *WsControllerSuite) expectSubscribeResponse(conn *connmock.WebsocketConnection) { +func (s *WsControllerSuite) expectSubscribeResponse(conn *connmock.WebsocketConnection, msgId string) { conn. On("WriteJSON", mock.Anything). Run(func(args mock.Arguments) { response, ok := args.Get(0).(models.SubscribeMessageResponse) require.True(s.T(), ok) - require.Equal(s.T(), "subscribe", response.Action) + require.Equal(s.T(), msgId, response.MessageID) require.Equal(s.T(), true, response.Success) }). Return(nil). diff --git a/engine/access/rest/websockets/error_codes.go b/engine/access/rest/websockets/error_codes.go new file mode 100644 index 00000000000..fa4d1521a3b --- /dev/null +++ b/engine/access/rest/websockets/error_codes.go @@ -0,0 +1,14 @@ +package websockets + +type Code int + +const ( + Ok Code = iota + ConnectionRead + ConnectionWrite + InvalidMessage + NotFound + InvalidArgument + RunError + InternalError +) diff --git a/engine/access/rest/websockets/models/base_message.go b/engine/access/rest/websockets/models/base_message.go index 38a030358bd..88b15ded6a3 100644 --- a/engine/access/rest/websockets/models/base_message.go +++ b/engine/access/rest/websockets/models/base_message.go @@ -2,14 +2,19 @@ package models // BaseMessageRequest represents a base structure for incoming messages. type BaseMessageRequest struct { - Action string `json:"action"` // Action type of the request + Action string `json:"action"` // subscribe, unsubscribe or list_subscriptions + MessageID string `json:"message_id"` // MessageID is a uuid generated by client to identify request/response uniquely } // BaseMessageResponse represents a base structure for outgoing messages. type BaseMessageResponse struct { - Action string `json:"action,omitempty"` // Action type of the response - Success bool `json:"success"` // Indicates success or failure - ErrorMessage string `json:"error_message,omitempty"` // Error message, if any + MessageID string `json:"message_id,omitempty"` // MessageID may be empty in case we send msg by ourselves (e.g. error occurred) + Success bool `json:"success"` + Error ErrorMessage `json:"error,omitempty"` } -// TODO: add Action enum? subscribe, unsubscribe, list +const ( + SubscribeAction = "subscribe" + UnsubscribeAction = "unsubscribe" + ListSubscriptionsAction = "list_subscription" +) diff --git a/engine/access/rest/websockets/models/error.go b/engine/access/rest/websockets/models/error.go new file mode 100644 index 00000000000..a49e90594d5 --- /dev/null +++ b/engine/access/rest/websockets/models/error.go @@ -0,0 +1,8 @@ +package models + +type ErrorMessage struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action,omitempty"` + SubscriptionID string `json:"subscription_id,omitempty"` +} diff --git a/engine/access/rest/websockets/models/unsubscribe.go b/engine/access/rest/websockets/models/unsubscribe.go index 2024bb922e0..b0b8b8f8e0d 100644 --- a/engine/access/rest/websockets/models/unsubscribe.go +++ b/engine/access/rest/websockets/models/unsubscribe.go @@ -3,11 +3,11 @@ package models // UnsubscribeMessageRequest represents a request to unsubscribe from a topic. type UnsubscribeMessageRequest struct { BaseMessageRequest - ID string `json:"id"` // Unique subscription ID + SubscriptionID string `json:"id"` } // UnsubscribeMessageResponse represents the response to an unsubscription request. type UnsubscribeMessageResponse struct { BaseMessageResponse - ID string `json:"id"` // Unique subscription ID + SubscriptionID string `json:"id"` } From 849b5a82ff15758ff332489c0f4d066030575cc5 Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Fri, 13 Dec 2024 13:01:25 +0200 Subject: [PATCH 07/26] remove comments --- engine/access/rest/websockets/controller_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/engine/access/rest/websockets/controller_test.go b/engine/access/rest/websockets/controller_test.go index 7e27428ac28..60cea94a3a8 100644 --- a/engine/access/rest/websockets/controller_test.go +++ b/engine/access/rest/websockets/controller_test.go @@ -96,7 +96,7 @@ func (s *WsControllerSuite) TestSubscribeRequest() { conn. On("WriteJSON", mock.Anything). Return(func(msg interface{}) error { - defer close(done) // Signal that response has been sent + defer close(done) response, ok := msg.(models.SubscribeMessageResponse) require.True(t, ok) @@ -143,7 +143,7 @@ func (s *WsControllerSuite) TestSubscribeRequest() { conn. On("WriteJSON", mock.Anything). Return(func(msg interface{}) error { - defer close(done) // Signal that response has been sent + defer close(done) response, ok := msg.(models.BaseMessageResponse) require.True(t, ok) @@ -175,7 +175,7 @@ func (s *WsControllerSuite) TestSubscribeRequest() { conn. On("WriteJSON", mock.Anything). Return(func(msg interface{}) error { - defer close(done) // Signal that response has been sent + defer close(done) response, ok := msg.(models.BaseMessageResponse) require.True(t, ok) @@ -216,7 +216,7 @@ func (s *WsControllerSuite) TestSubscribeRequest() { conn. On("WriteJSON", mock.Anything). Return(func(msg interface{}) error { - defer close(done) // Signal that response has been sent + defer close(done) response, ok := msg.(models.BaseMessageResponse) require.True(t, ok) From 29200425ba679cf72388d54cc66b7b92fcf79313 Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Fri, 13 Dec 2024 18:00:08 +0200 Subject: [PATCH 08/26] remove log msg --- engine/access/rest/websockets/controller_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/engine/access/rest/websockets/controller_test.go b/engine/access/rest/websockets/controller_test.go index 737c78bb706..13766346e32 100644 --- a/engine/access/rest/websockets/controller_test.go +++ b/engine/access/rest/websockets/controller_test.go @@ -795,7 +795,6 @@ func (s *WsControllerSuite) TestKeepaliveHappyCase() { conn. On("WriteControl", websocket.PingMessage, mock.Anything). Return(func(int, time.Time) error { - s.T().Log("---WRITE CONTROL") if i == expectedCalls { close(done) return websocket.ErrCloseSent From e9795fbd5358c4bf7fb8ad49c7c3d29efbc4a5e7 Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Fri, 13 Dec 2024 18:15:06 +0200 Subject: [PATCH 09/26] remove unnecessary func --- engine/access/rest/websockets/controller.go | 36 +++++++-------------- 1 file changed, 11 insertions(+), 25 deletions(-) diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index a88f9e52e2c..4d86563cc25 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -18,10 +18,6 @@ import ( "github.com/onflow/flow-go/utils/concurrentmap" ) -var ( - ErrUnmarshalMessage = errors.New("failed to unmarshal message") -) - type Controller struct { logger zerolog.Logger config Config @@ -164,7 +160,7 @@ func (c *Controller) readMessages(ctx context.Context) error { continue } - validatedMsg, err := c.parseAndValidateMessage(message) + err := c.parseAndValidateMessage(ctx, message) if err != nil { c.writeErrorResponse( ctx, @@ -172,21 +168,13 @@ func (c *Controller) readMessages(ctx context.Context) error { wrapErrorMessage(InvalidMessage, "error parsing message", "", "", "")) continue } - - if err = c.handleAction(ctx, validatedMsg); err != nil { - c.writeErrorResponse( - ctx, - err, - wrapErrorMessage(InvalidMessage, "error handling action", "", "", "")) - continue - } } } -func (c *Controller) parseAndValidateMessage(message json.RawMessage) (interface{}, error) { +func (c *Controller) parseAndValidateMessage(ctx context.Context, message json.RawMessage) error { var baseMsg models.BaseMessageRequest if err := json.Unmarshal(message, &baseMsg); err != nil { - return nil, fmt.Errorf("error unmarshalling base message: %w", err) + return fmt.Errorf("error unmarshalling base message: %w", err) } var validatedMsg interface{} @@ -194,33 +182,34 @@ func (c *Controller) parseAndValidateMessage(message json.RawMessage) (interface case models.SubscribeAction: var subscribeMsg models.SubscribeMessageRequest if err := json.Unmarshal(message, &subscribeMsg); err != nil { - return nil, fmt.Errorf("error unmarshalling subscribe message: %w", err) + return fmt.Errorf("error unmarshalling subscribe message: %w", err) } validatedMsg = subscribeMsg case models.UnsubscribeAction: var unsubscribeMsg models.UnsubscribeMessageRequest if err := json.Unmarshal(message, &unsubscribeMsg); err != nil { - return nil, fmt.Errorf("error unmarshalling unsubscribe message: %w", err) + return fmt.Errorf("error unmarshalling unsubscribe message: %w", err) } validatedMsg = unsubscribeMsg case models.ListSubscriptionsAction: var listMsg models.ListSubscriptionsMessageRequest if err := json.Unmarshal(message, &listMsg); err != nil { - return nil, fmt.Errorf("error unmarshalling list subscriptions message: %w", err) + return fmt.Errorf("error unmarshalling list subscriptions message: %w", err) } validatedMsg = listMsg default: c.logger.Debug().Str("action", baseMsg.Action).Msg("unknown action type") - return nil, fmt.Errorf("unknown action type: %s", baseMsg.Action) + return fmt.Errorf("unknown action type: %s", baseMsg.Action) } - return validatedMsg, nil + c.handleAction(ctx, validatedMsg) + return nil } -func (c *Controller) handleAction(ctx context.Context, message interface{}) error { +func (c *Controller) handleAction(ctx context.Context, message interface{}) { switch msg := message.(type) { case models.SubscribeMessageRequest: c.handleSubscribe(ctx, msg) @@ -228,10 +217,7 @@ func (c *Controller) handleAction(ctx context.Context, message interface{}) erro c.handleUnsubscribe(ctx, msg) case models.ListSubscriptionsMessageRequest: c.handleListSubscriptions(ctx, msg) - default: - return fmt.Errorf("unknown message type: %T", msg) } - return nil } func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMessageRequest) { @@ -265,7 +251,7 @@ func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMe c.writeErrorResponse( ctx, err, - wrapErrorMessage(RunError, "data provider finished with error", msg.MessageID, "", ""), + wrapErrorMessage(RunError, "data provider finished with error", "", "", ""), ) } From c00a12cc724fb89a8498430104f622911c065b0a Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Fri, 13 Dec 2024 19:45:16 +0200 Subject: [PATCH 10/26] fix parallel running and turn it on --- .../access/rest/websockets/controller_test.go | 296 +++++++++--------- 1 file changed, 148 insertions(+), 148 deletions(-) diff --git a/engine/access/rest/websockets/controller_test.go b/engine/access/rest/websockets/controller_test.go index 13766346e32..947102db14c 100644 --- a/engine/access/rest/websockets/controller_test.go +++ b/engine/access/rest/websockets/controller_test.go @@ -41,7 +41,7 @@ func TestWsControllerSuite(t *testing.T) { // TestSubscribeRequest tests the subscribe to topic flow. // We emulate a request message from a client, and a response message from a controller. func (s *WsControllerSuite) TestSubscribeRequest() { - //s.T().Parallel() + s.T().Parallel() s.T().Run("Happy path", func(t *testing.T) { t.Parallel() @@ -191,7 +191,7 @@ func (s *WsControllerSuite) TestSubscribeRequest() { controller.HandleConnection(context.Background()) }) - s.T().Run("Run error", func(t *testing.T) { + s.T().Run("Provider execution error", func(t *testing.T) { t.Parallel() conn, dataProviderFactory, dataProvider := newControllerMocks(t) @@ -235,7 +235,7 @@ func (s *WsControllerSuite) TestSubscribeRequest() { } func (s *WsControllerSuite) TestUnsubscribeRequest() { - //s.T().Parallel() + s.T().Parallel() s.T().Run("Happy path", func(t *testing.T) { t.Parallel() @@ -273,7 +273,7 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { SubscriptionID: id.String(), } requestJson, err := json.Marshal(request) - require.NoError(s.T(), err) + require.NoError(t, err) conn. On("ReadJSON", mock.Anything). @@ -343,7 +343,7 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { SubscriptionID: "invalid-uuid", } requestJson, err := json.Marshal(request) - require.NoError(s.T(), err) + require.NoError(t, err) conn. On("ReadJSON", mock.Anything). @@ -413,7 +413,7 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { SubscriptionID: uuid.New().String(), } requestJson, err := json.Marshal(request) - require.NoError(s.T(), err) + require.NoError(t, err) conn. On("ReadJSON", mock.Anything). @@ -450,81 +450,84 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { } func (s *WsControllerSuite) TestListSubscriptions() { - //s.T().Parallel() + s.T().Parallel() - conn, dataProviderFactory, dataProvider := newControllerMocks(s.T()) - controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) + s.T().Run("Happy path", func(t *testing.T) { - dataProviderFactory. - On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). - Return(dataProvider, nil). - Once() + conn, dataProviderFactory, dataProvider := newControllerMocks(t) + controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) - done := make(chan struct{}, 1) + dataProviderFactory. + On("NewDataProvider", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(dataProvider, nil). + Once() - id := uuid.New() - topic := "blocks" - dataProvider.On("ID").Return(id) - dataProvider.On("Topic").Return(topic) - dataProvider.On("Close").Return(nil).Maybe() - dataProvider. - On("Run"). - Run(func(args mock.Arguments) { - for range done { - } - }). - Return(nil). - Once() + done := make(chan struct{}, 1) - msgID := s.expectSubscribeRequest(conn) - s.expectSubscribeResponse(conn, msgID) + id := uuid.New() + topic := "blocks" + dataProvider.On("ID").Return(id) + dataProvider.On("Topic").Return(topic) + dataProvider.On("Close").Return(nil).Maybe() + dataProvider. + On("Run"). + Run(func(args mock.Arguments) { + for range done { + } + }). + Return(nil). + Once() - request := models.ListSubscriptionsMessageRequest{ - BaseMessageRequest: models.BaseMessageRequest{ - MessageID: uuid.New().String(), - Action: models.ListSubscriptionsAction, - }, - } - requestJson, err := json.Marshal(request) - require.NoError(s.T(), err) + msgID := s.expectSubscribeRequest(conn) + s.expectSubscribeResponse(conn, msgID) - conn. - On("ReadJSON", mock.Anything). - Run(func(args mock.Arguments) { - msg, ok := args.Get(0).(*json.RawMessage) - require.True(s.T(), ok) - *msg = requestJson - }). - Return(nil). - Once() + request := models.ListSubscriptionsMessageRequest{ + BaseMessageRequest: models.BaseMessageRequest{ + MessageID: uuid.New().String(), + Action: models.ListSubscriptionsAction, + }, + } + requestJson, err := json.Marshal(request) + require.NoError(t, err) - conn. - On("WriteJSON", mock.Anything). - Return(func(msg interface{}) error { - defer close(done) + conn. + On("ReadJSON", mock.Anything). + Run(func(args mock.Arguments) { + msg, ok := args.Get(0).(*json.RawMessage) + require.True(t, ok) + *msg = requestJson + }). + Return(nil). + Once() - response, ok := msg.(models.ListSubscriptionsMessageResponse) - require.True(s.T(), ok) - require.True(s.T(), response.Success) - require.Empty(s.T(), response.Error) - require.Equal(s.T(), request.MessageID, response.MessageID) - require.Equal(s.T(), 1, len(response.Subscriptions)) - require.Equal(s.T(), id.String(), response.Subscriptions[0].ID) - require.Equal(s.T(), topic, response.Subscriptions[0].Topic) + conn. + On("WriteJSON", mock.Anything). + Return(func(msg interface{}) error { + defer close(done) - return websocket.ErrCloseSent - }). - Once() + response, ok := msg.(models.ListSubscriptionsMessageResponse) + require.True(t, ok) + require.True(t, response.Success) + require.Empty(t, response.Error) + require.Equal(t, request.MessageID, response.MessageID) + require.Equal(t, 1, len(response.Subscriptions)) + require.Equal(t, id.String(), response.Subscriptions[0].ID) + require.Equal(t, topic, response.Subscriptions[0].Topic) - s.expectCloseConnection(conn, done) - s.expectKeepaliveClose(conn, done) + return websocket.ErrCloseSent + }). + Once() + + s.expectCloseConnection(conn, done) + s.expectKeepaliveClose(conn, done) - controller.HandleConnection(context.Background()) + controller.HandleConnection(context.Background()) + }) } // TestSubscribeBlocks tests the functionality for streaming blocks to a subscriber. func (s *WsControllerSuite) TestSubscribeBlocks() { - //s.T().Parallel() + s.T().Parallel() s.T().Run("Stream one block", func(t *testing.T) { t.Parallel() @@ -640,25 +643,27 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { // TestConfigureKeepaliveConnection ensures that the WebSocket connection is configured correctly. func (s *WsControllerSuite) TestConfigureKeepaliveConnection() { - //s.T().Parallel() + s.T().Parallel() - conn := connmock.NewWebsocketConnection(s.T()) - conn.On("SetPongHandler", mock.AnythingOfType("func(string) error")).Return(nil).Once() - conn.On("SetReadDeadline", mock.Anything).Return(nil) + s.T().Run("Happy path", func(t *testing.T) { + conn := connmock.NewWebsocketConnection(t) + conn.On("SetPongHandler", mock.AnythingOfType("func(string) error")).Return(nil).Once() + conn.On("SetReadDeadline", mock.Anything).Return(nil) - factory := dpmock.NewDataProviderFactory(s.T()) - controller := NewWebSocketController(s.logger, s.wsConfig, conn, factory) + factory := dpmock.NewDataProviderFactory(t) + controller := NewWebSocketController(s.logger, s.wsConfig, conn, factory) - err := controller.configureKeepalive() - s.Require().NoError(err, "configureKeepalive should not return an error") + err := controller.configureKeepalive() + s.Require().NoError(err, "configureKeepalive should not return an error") - conn.AssertExpectations(s.T()) + conn.AssertExpectations(t) + }) } func (s *WsControllerSuite) TestControllerShutdown() { - //s.T().Parallel() + s.T().Parallel() - s.T().Run("keepalive routine failed", func(t *testing.T) { + s.T().Run("Keepalive routine failed", func(t *testing.T) { t.Parallel() conn := connmock.NewWebsocketConnection(t) @@ -689,10 +694,10 @@ func (s *WsControllerSuite) TestControllerShutdown() { Once() controller.HandleConnection(context.Background()) - conn.AssertExpectations(s.T()) + conn.AssertExpectations(t) }) - s.T().Run("read routine failed", func(t *testing.T) { + s.T().Run("Read routine failed", func(t *testing.T) { t.Parallel() conn := connmock.NewWebsocketConnection(t) @@ -711,10 +716,10 @@ func (s *WsControllerSuite) TestControllerShutdown() { Once() controller.HandleConnection(context.Background()) - conn.AssertExpectations(s.T()) + conn.AssertExpectations(t) }) - s.T().Run("write routine failed", func(t *testing.T) { + s.T().Run("Write routine failed", func(t *testing.T) { t.Parallel() conn, dataProviderFactory, dataProvider := newControllerMocks(t) @@ -752,12 +757,12 @@ func (s *WsControllerSuite) TestControllerShutdown() { controller.HandleConnection(context.Background()) // Ensure all expectations are met - conn.AssertExpectations(s.T()) - dataProviderFactory.AssertExpectations(s.T()) - dataProvider.AssertExpectations(s.T()) + conn.AssertExpectations(t) + dataProviderFactory.AssertExpectations(t) + dataProvider.AssertExpectations(t) }) - s.T().Run("context closed", func(t *testing.T) { + s.T().Run("Context cancelled", func(t *testing.T) { t.Parallel() conn := connmock.NewWebsocketConnection(t) @@ -777,88 +782,83 @@ func (s *WsControllerSuite) TestControllerShutdown() { cancel() controller.HandleConnection(ctx) - conn.AssertExpectations(s.T()) + conn.AssertExpectations(t) }) } -func (s *WsControllerSuite) TestKeepaliveHappyCase() { - //s.T().Parallel() - - conn := connmock.NewWebsocketConnection(s.T()) - conn.On("Close").Return(nil).Once() - conn.On("SetPongHandler", mock.AnythingOfType("func(string) error")).Return(nil).Once() - conn.On("SetReadDeadline", mock.Anything).Return(nil) - - done := make(chan struct{}, 1) - i := 0 - expectedCalls := 2 - conn. - On("WriteControl", websocket.PingMessage, mock.Anything). - Return(func(int, time.Time) error { - if i == expectedCalls { - close(done) - return websocket.ErrCloseSent - } +func (s *WsControllerSuite) TestKeepaliveRoutine() { + s.T().Parallel() - i += 1 - return nil - }). - Times(expectedCalls + 1) + s.T().Run("Successfully pings connection n times", func(t *testing.T) { + conn := connmock.NewWebsocketConnection(t) + conn.On("Close").Return(nil).Once() + conn.On("SetPongHandler", mock.AnythingOfType("func(string) error")).Return(nil).Once() + conn.On("SetReadDeadline", mock.Anything).Return(nil) - conn.On("ReadJSON", mock.Anything).Return(func(_ interface{}) error { - for range done { - } - return websocket.ErrCloseSent - }) + done := make(chan struct{}, 1) + i := 0 + expectedCalls := 2 + conn. + On("WriteControl", websocket.PingMessage, mock.Anything). + Return(func(int, time.Time) error { + if i == expectedCalls { + close(done) + return websocket.ErrCloseSent + } - factory := dpmock.NewDataProviderFactory(s.T()) - controller := NewWebSocketController(s.logger, s.wsConfig, conn, factory) - controller.HandleConnection(context.Background()) + i += 1 + return nil + }). + Times(expectedCalls + 1) - conn.AssertExpectations(s.T()) -} + conn.On("ReadJSON", mock.Anything).Return(func(_ interface{}) error { + for range done { + } + return websocket.ErrCloseSent + }) -// TestKeepaliveError tests the behavior of the keepalive function when there is an error in writing the ping. -func (s *WsControllerSuite) TestKeepaliveError() { - //s.T().Parallel() + factory := dpmock.NewDataProviderFactory(t) + controller := NewWebSocketController(s.logger, s.wsConfig, conn, factory) + controller.HandleConnection(context.Background()) - conn := connmock.NewWebsocketConnection(s.T()) - conn. - On("WriteControl", websocket.PingMessage, mock.Anything). - Return(websocket.ErrCloseSent). //TODO: change to assert.AnError and rewrite test - Maybe() + conn.AssertExpectations(t) + }) - factory := dpmock.NewDataProviderFactory(s.T()) - controller := NewWebSocketController(s.logger, s.wsConfig, conn, factory) + s.T().Run("Error on write to connection", func(t *testing.T) { + conn := connmock.NewWebsocketConnection(t) + conn. + On("WriteControl", websocket.PingMessage, mock.Anything). + Return(websocket.ErrCloseSent). //TODO: change to assert.AnError and rewrite test + Maybe() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + factory := dpmock.NewDataProviderFactory(t) + controller := NewWebSocketController(s.logger, s.wsConfig, conn, factory) - err := controller.keepalive(ctx) - s.Require().Error(err) - s.Require().ErrorIs(websocket.ErrCloseSent, err) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - conn.AssertExpectations(s.T()) -} + err := controller.keepalive(ctx) + s.Require().Error(err) + s.Require().ErrorIs(websocket.ErrCloseSent, err) -// TestKeepaliveContextCancel tests the behavior of keepalive when the context is canceled before a ping is sent and -// no ping message is sent after the context is canceled. -func (s *WsControllerSuite) TestKeepaliveContextCancel() { - //s.T().Parallel() + conn.AssertExpectations(t) + }) - conn := connmock.NewWebsocketConnection(s.T()) - factory := dpmock.NewDataProviderFactory(s.T()) - controller := NewWebSocketController(s.logger, s.wsConfig, conn, factory) + s.T().Run("Context cancelled", func(t *testing.T) { + conn := connmock.NewWebsocketConnection(t) + factory := dpmock.NewDataProviderFactory(t) + controller := NewWebSocketController(s.logger, s.wsConfig, conn, factory) - ctx, cancel := context.WithCancel(context.Background()) - cancel() // Immediately cancel the context + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Immediately cancel the context - // Start the keepalive process with the context canceled - err := controller.keepalive(ctx) - s.Require().Error(err) - s.Require().ErrorIs(context.Canceled, err) //TODO: should be nil + // Start the keepalive process with the context canceled + err := controller.keepalive(ctx) + s.Require().Error(err) + s.Require().ErrorIs(context.Canceled, err) //TODO: should be nil - conn.AssertExpectations(s.T()) // Should not invoke WriteMessage after context cancellation + conn.AssertExpectations(t) // Should not invoke WriteMessage after context cancellation + }) } // newControllerMocks initializes mock WebSocket connection, data provider, and data provider factory. From b41f176a2474229d5601bfd9a3e81379e4df884c Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Fri, 13 Dec 2024 20:03:31 +0200 Subject: [PATCH 11/26] use once instead of maybe where needed --- .../access/rest/websockets/controller_test.go | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/engine/access/rest/websockets/controller_test.go b/engine/access/rest/websockets/controller_test.go index 947102db14c..13ee1803ade 100644 --- a/engine/access/rest/websockets/controller_test.go +++ b/engine/access/rest/websockets/controller_test.go @@ -58,6 +58,7 @@ func (s *WsControllerSuite) TestSubscribeRequest() { done := make(chan struct{}, 1) dataProvider.On("ID").Return(id) + // data provider might finish on its own or controller will close it via Close() dataProvider.On("Close").Return(nil).Maybe() dataProvider. On("Run"). @@ -198,6 +199,7 @@ func (s *WsControllerSuite) TestSubscribeRequest() { controller := NewWebSocketController(s.logger, s.wsConfig, conn, dataProviderFactory) dataProvider.On("ID").Return(uuid.New()) + // data provider might finish on its own or controller will close it via Close() dataProvider.On("Close").Return(nil).Maybe() dataProvider. On("Run"). @@ -252,6 +254,7 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { done := make(chan struct{}, 1) dataProvider.On("ID").Return(id) + // data provider might finish on its own or controller will close it via Close() dataProvider.On("Close").Return(nil).Maybe() dataProvider. On("Run"). @@ -322,6 +325,7 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { done := make(chan struct{}, 1) dataProvider.On("ID").Return(id) + // data provider might finish on its own or controller will close it via Close() dataProvider.On("Close").Return(nil).Maybe() dataProvider. On("Run"). @@ -392,6 +396,7 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { done := make(chan struct{}, 1) dataProvider.On("ID").Return(id) + // data provider might finish on its own or controller will close it via Close() dataProvider.On("Close").Return(nil).Maybe() dataProvider. On("Run"). @@ -468,6 +473,7 @@ func (s *WsControllerSuite) TestListSubscriptions() { topic := "blocks" dataProvider.On("ID").Return(id) dataProvider.On("Topic").Return(topic) + // data provider might finish on its own or controller will close it via Close() dataProvider.On("Close").Return(nil).Maybe() dataProvider. On("Run"). @@ -542,7 +548,7 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { id := uuid.New() dataProvider.On("ID").Return(id) - // data provider might finish by its own or controller will close it via Close() + // data provider might finish on its own or controller will close it via Close() dataProvider.On("Close").Return(nil).Maybe() // Simulate data provider write a block to the controller @@ -593,6 +599,7 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { id := uuid.New() dataProvider.On("ID").Return(id) + // data provider might finish on its own or controller will close it via Close() dataProvider.On("Close").Return(nil).Maybe() // Simulate data provider writes some blocks to the controller @@ -667,7 +674,7 @@ func (s *WsControllerSuite) TestControllerShutdown() { t.Parallel() conn := connmock.NewWebsocketConnection(t) - conn.On("Close").Return(nil).Maybe() + conn.On("Close").Return(nil).Once() conn.On("SetReadDeadline", mock.Anything).Return(nil).Once() conn.On("SetPongHandler", mock.AnythingOfType("func(string) error")).Return(nil).Once() @@ -682,7 +689,7 @@ func (s *WsControllerSuite) TestControllerShutdown() { close(done) return websocket.ErrCloseSent }). - Maybe() + Once() conn. On("ReadJSON", mock.Anything). @@ -701,7 +708,7 @@ func (s *WsControllerSuite) TestControllerShutdown() { t.Parallel() conn := connmock.NewWebsocketConnection(t) - conn.On("Close").Return(nil).Maybe() + conn.On("Close").Return(nil).Once() conn.On("SetReadDeadline", mock.Anything).Return(nil).Once() conn.On("SetPongHandler", mock.AnythingOfType("func(string) error")).Return(nil).Once() @@ -732,6 +739,7 @@ func (s *WsControllerSuite) TestControllerShutdown() { id := uuid.New() dataProvider.On("ID").Return(id) + // data provider might finish on its own or controller will close it via Close() dataProvider.On("Close").Return(nil).Maybe() dataProvider. @@ -766,7 +774,7 @@ func (s *WsControllerSuite) TestControllerShutdown() { t.Parallel() conn := connmock.NewWebsocketConnection(t) - conn.On("Close").Return(nil).Maybe() + conn.On("Close").Return(nil).Once() conn.On("SetReadDeadline", mock.Anything).Return(nil).Once() conn.On("SetPongHandler", mock.AnythingOfType("func(string) error")).Return(nil).Once() @@ -829,7 +837,7 @@ func (s *WsControllerSuite) TestKeepaliveRoutine() { conn. On("WriteControl", websocket.PingMessage, mock.Anything). Return(websocket.ErrCloseSent). //TODO: change to assert.AnError and rewrite test - Maybe() + Once() factory := dpmock.NewDataProviderFactory(t) controller := NewWebSocketController(s.logger, s.wsConfig, conn, factory) @@ -930,6 +938,7 @@ func (s *WsControllerSuite) expectSubscribeResponse(conn *connmock.WebsocketConn } func (s *WsControllerSuite) expectKeepaliveClose(conn *connmock.WebsocketConnection, done <-chan struct{}) { + // We use maybe as a test may finish faster than keepalive routine trigger WriteControl conn. On("WriteControl", websocket.PingMessage, mock.Anything). Return(func(int, time.Time) error { From ee1cf77df9351a2ee5bda8a598d6c52f33c8c468 Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Fri, 13 Dec 2024 20:04:42 +0200 Subject: [PATCH 12/26] refactor msg --- engine/access/rest/websockets/controller_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/engine/access/rest/websockets/controller_test.go b/engine/access/rest/websockets/controller_test.go index 13ee1803ade..4303ce3b662 100644 --- a/engine/access/rest/websockets/controller_test.go +++ b/engine/access/rest/websockets/controller_test.go @@ -938,7 +938,7 @@ func (s *WsControllerSuite) expectSubscribeResponse(conn *connmock.WebsocketConn } func (s *WsControllerSuite) expectKeepaliveClose(conn *connmock.WebsocketConnection, done <-chan struct{}) { - // We use maybe as a test may finish faster than keepalive routine trigger WriteControl + // We use Maybe() because a test may finish faster than keepalive routine trigger WriteControl conn. On("WriteControl", websocket.PingMessage, mock.Anything). Return(func(int, time.Time) error { From 502c74fed98e02d6f9ca4170855e73d4a02dffef Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Sat, 14 Dec 2024 01:13:32 +0200 Subject: [PATCH 13/26] add assert expectations for mocks in tests --- .../access/rest/websockets/controller_test.go | 100 ++++++++++++------ 1 file changed, 67 insertions(+), 33 deletions(-) diff --git a/engine/access/rest/websockets/controller_test.go b/engine/access/rest/websockets/controller_test.go index 4303ce3b662..69c4aef488c 100644 --- a/engine/access/rest/websockets/controller_test.go +++ b/engine/access/rest/websockets/controller_test.go @@ -91,8 +91,6 @@ func (s *WsControllerSuite) TestSubscribeRequest() { Return(nil). Once() - s.expectCloseConnection(conn, done) - conn. On("WriteJSON", mock.Anything). Return(func(msg interface{}) error { @@ -107,8 +105,12 @@ func (s *WsControllerSuite) TestSubscribeRequest() { return websocket.ErrCloseSent }) - s.expectKeepaliveClose(conn, done) + s.expectCloseConnection(conn, done) controller.HandleConnection(context.Background()) + + conn.AssertExpectations(t) + dataProviderFactory.AssertExpectations(t) + dataProvider.AssertExpectations(t) }) s.T().Run("Parse and validate error", func(t *testing.T) { @@ -139,8 +141,6 @@ func (s *WsControllerSuite) TestSubscribeRequest() { Once() done := make(chan struct{}, 1) - s.expectCloseConnection(conn, done) - conn. On("WriteJSON", mock.Anything). Return(func(msg interface{}) error { @@ -155,8 +155,12 @@ func (s *WsControllerSuite) TestSubscribeRequest() { return websocket.ErrCloseSent }) - s.expectKeepaliveClose(conn, done) + s.expectCloseConnection(conn, done) + controller.HandleConnection(context.Background()) + + conn.AssertExpectations(t) + dataProviderFactory.AssertExpectations(t) }) s.T().Run("Error creating data provider", func(t *testing.T) { @@ -172,7 +176,6 @@ func (s *WsControllerSuite) TestSubscribeRequest() { done := make(chan struct{}, 1) s.expectSubscribeRequest(conn) - s.expectCloseConnection(conn, done) conn. On("WriteJSON", mock.Anything). @@ -188,8 +191,12 @@ func (s *WsControllerSuite) TestSubscribeRequest() { return websocket.ErrCloseSent }) - s.expectKeepaliveClose(conn, done) + s.expectCloseConnection(conn, done) + controller.HandleConnection(context.Background()) + + conn.AssertExpectations(t) + dataProviderFactory.AssertExpectations(t) }) s.T().Run("Provider execution error", func(t *testing.T) { @@ -215,7 +222,6 @@ func (s *WsControllerSuite) TestSubscribeRequest() { done := make(chan struct{}, 1) msgID := s.expectSubscribeRequest(conn) s.expectSubscribeResponse(conn, msgID) - s.expectCloseConnection(conn, done) conn. On("WriteJSON", mock.Anything). @@ -231,8 +237,13 @@ func (s *WsControllerSuite) TestSubscribeRequest() { return websocket.ErrCloseSent }) - s.expectKeepaliveClose(conn, done) + s.expectCloseConnection(conn, done) + controller.HandleConnection(context.Background()) + + conn.AssertExpectations(t) + dataProviderFactory.AssertExpectations(t) + dataProvider.AssertExpectations(t) }) } @@ -305,9 +316,12 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { Once() s.expectCloseConnection(conn, done) - s.expectKeepaliveClose(conn, done) controller.HandleConnection(context.Background()) + + conn.AssertExpectations(t) + dataProviderFactory.AssertExpectations(t) + dataProvider.AssertExpectations(t) }) s.T().Run("Invalid subscription uuid", func(t *testing.T) { @@ -376,9 +390,12 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { Once() s.expectCloseConnection(conn, done) - s.expectKeepaliveClose(conn, done) controller.HandleConnection(context.Background()) + + conn.AssertExpectations(t) + dataProviderFactory.AssertExpectations(t) + dataProvider.AssertExpectations(t) }) s.T().Run("Unsubscribe from unknown subscription", func(t *testing.T) { @@ -448,9 +465,12 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { Once() s.expectCloseConnection(conn, done) - s.expectKeepaliveClose(conn, done) controller.HandleConnection(context.Background()) + + conn.AssertExpectations(t) + dataProviderFactory.AssertExpectations(t) + dataProvider.AssertExpectations(t) }) } @@ -525,9 +545,12 @@ func (s *WsControllerSuite) TestListSubscriptions() { Once() s.expectCloseConnection(conn, done) - s.expectKeepaliveClose(conn, done) controller.HandleConnection(context.Background()) + + conn.AssertExpectations(t) + dataProviderFactory.AssertExpectations(t) + dataProvider.AssertExpectations(t) }) } @@ -564,7 +587,6 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { done := make(chan struct{}, 1) msgID := s.expectSubscribeRequest(conn) s.expectSubscribeResponse(conn, msgID) - s.expectCloseConnection(conn, done) // Expect a valid block to be passed to WriteJSON. // If we got to this point, the controller executed all its logic properly @@ -582,8 +604,13 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { return websocket.ErrCloseSent }) - s.expectKeepaliveClose(conn, done) + s.expectCloseConnection(conn, done) + controller.HandleConnection(context.Background()) + + conn.AssertExpectations(t) + dataProviderFactory.AssertExpectations(t) + dataProvider.AssertExpectations(t) }) s.T().Run("Stream many blocks", func(t *testing.T) { @@ -617,7 +644,6 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { done := make(chan struct{}, 1) msgID := s.expectSubscribeRequest(conn) s.expectSubscribeResponse(conn, msgID) - s.expectCloseConnection(conn, done) i := 0 actualBlocks := make([]*flow.Block, len(expectedBlocks)) @@ -643,8 +669,13 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { }). Times(len(expectedBlocks)) - s.expectKeepaliveClose(conn, done) + s.expectCloseConnection(conn, done) + controller.HandleConnection(context.Background()) + + conn.AssertExpectations(t) + dataProviderFactory.AssertExpectations(t) + dataProvider.AssertExpectations(t) }) } @@ -753,7 +784,6 @@ func (s *WsControllerSuite) TestControllerShutdown() { done := make(chan struct{}, 1) msgID := s.expectSubscribeRequest(conn) s.expectSubscribeResponse(conn, msgID) - s.expectCloseConnection(conn, done) conn. On("WriteJSON", mock.Anything). @@ -762,6 +792,8 @@ func (s *WsControllerSuite) TestControllerShutdown() { return assert.AnError }) + s.expectCloseConnection(conn, done) + controller.HandleConnection(context.Background()) // Ensure all expectations are met @@ -910,19 +942,6 @@ func (s *WsControllerSuite) expectSubscribeRequest(conn *connmock.WebsocketConne return request.MessageID } -func (s *WsControllerSuite) expectCloseConnection(conn *connmock.WebsocketConnection, done <-chan struct{}) { - // In the default case, no further communication is expected from the client. - // We wait for the writer routine to signal completion, allowing us to close the connection gracefully - conn. - On("ReadJSON", mock.Anything). - Return(func(msg interface{}) error { - for range done { - } - return websocket.ErrCloseSent - }). - Once() -} - // expectSubscribeResponse mocks the subscription response sent to the client. func (s *WsControllerSuite) expectSubscribeResponse(conn *connmock.WebsocketConnection, msgId string) { conn. @@ -937,7 +956,22 @@ func (s *WsControllerSuite) expectSubscribeResponse(conn *connmock.WebsocketConn Once() } -func (s *WsControllerSuite) expectKeepaliveClose(conn *connmock.WebsocketConnection, done <-chan struct{}) { +func (s *WsControllerSuite) expectCloseConnection(conn *connmock.WebsocketConnection, done <-chan struct{}) { + // In the default case, no further communication is expected from the client. + // We wait for the writer routine to signal completion, allowing us to close the connection gracefully + conn. + On("ReadJSON", mock.Anything). + Return(func(msg interface{}) error { + for range done { + } + return websocket.ErrCloseSent + }). + Once() + + s.expectKeepaliveRoutineShutdown(conn, done) +} + +func (s *WsControllerSuite) expectKeepaliveRoutineShutdown(conn *connmock.WebsocketConnection, done <-chan struct{}) { // We use Maybe() because a test may finish faster than keepalive routine trigger WriteControl conn. On("WriteControl", websocket.PingMessage, mock.Anything). From 6c7116855fc929ff4eb0dcc5810550757bf9ff5f Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Sat, 14 Dec 2024 01:19:24 +0200 Subject: [PATCH 14/26] shuffle code --- engine/access/rest/websockets/controller.go | 58 ++++++++++----------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index 4d86563cc25..b6e748463bc 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -65,15 +65,15 @@ func (c *Controller) HandleConnection(ctx context.Context) { g, gCtx := errgroup.WithContext(ctx) - g.Go(func() error { - return c.readMessages(gCtx) - }) g.Go(func() error { return c.keepalive(gCtx) }) g.Go(func() error { return c.writeMessages(gCtx) }) + g.Go(func() error { + return c.readMessages(gCtx) + }) if err = g.Wait(); err != nil { if errors.Is(err, websocket.ErrCloseSent) { @@ -113,6 +113,32 @@ func (c *Controller) configureKeepalive() error { return nil } +// keepalive sends a ping message periodically to keep the WebSocket connection alive +// and avoid timeouts. +// +// Expected errors during normal operation: +// - context.Canceled if the client disconnected +func (c *Controller) keepalive(ctx context.Context) error { + pingTicker := time.NewTicker(PingPeriod) + defer pingTicker.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-pingTicker.C: + err := c.conn.WriteControl(websocket.PingMessage, time.Now().Add(WriteWait)) + if err != nil { + if errors.Is(err, websocket.ErrCloseSent) { + return err + } + + return fmt.Errorf("error sending ping: %w", err) + } + } + } +} + // writeMessages reads a messages from communication channel and passes them on to a client WebSocket connection. // The communication channel is filled by data providers. Besides, the response limit tracker is involved in // write message regulation @@ -364,32 +390,6 @@ func (c *Controller) shutdownConnection() { close(c.multiplexedStream) } -// keepalive sends a ping message periodically to keep the WebSocket connection alive -// and avoid timeouts. -// -// Expected errors during normal operation: -// - context.Canceled if the client disconnected -func (c *Controller) keepalive(ctx context.Context) error { - pingTicker := time.NewTicker(PingPeriod) - defer pingTicker.Stop() - - for { - select { - case <-ctx.Done(): - return ctx.Err() - case <-pingTicker.C: - err := c.conn.WriteControl(websocket.PingMessage, time.Now().Add(WriteWait)) - if err != nil { - if errors.Is(err, websocket.ErrCloseSent) { - return err - } - - return fmt.Errorf("error sending ping: %w", err) - } - } - } -} - func (c *Controller) writeErrorResponse(ctx context.Context, err error, msg models.BaseMessageResponse) { c.logger.Debug().Err(err).Msg(msg.Error.Message) c.writeResponse(ctx, msg) From ac529ad402f5b199f47f13c6cef125ce7a1bd1cf Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Sun, 15 Dec 2024 23:42:19 +0200 Subject: [PATCH 15/26] move stream draining to write routine to not run into deadlock. fix parallel tests --- engine/access/rest/websockets/controller.go | 17 ++--- .../access/rest/websockets/controller_test.go | 64 ++++++++++--------- 2 files changed, 42 insertions(+), 39 deletions(-) diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index b6e748463bc..351d3d01b82 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -146,6 +146,15 @@ func (c *Controller) keepalive(ctx context.Context) error { // Expected errors during normal operation: // - context.Canceled if the client disconnected func (c *Controller) writeMessages(ctx context.Context) error { + defer func() { + // drain the channel as some providers may still send data to it after this routine shutdowns + // so, in order to not run into deadlock there should be at least 1 reader on the channel + go func() { + for range c.multiplexedStream { + } + }() + }() + for { select { case <-ctx.Done(): @@ -373,19 +382,11 @@ func (c *Controller) shutdownConnection() { return nil }) - if err != nil { c.logger.Debug().Err(err).Msg("error closing data provider") } c.dataProviders.Clear() - - // drain the channel as some providers may still send data to it during shutdown - go func() { - for range c.multiplexedStream { - } - }() - c.dataProvidersGroup.Wait() close(c.multiplexedStream) } diff --git a/engine/access/rest/websockets/controller_test.go b/engine/access/rest/websockets/controller_test.go index 69c4aef488c..64382e4607a 100644 --- a/engine/access/rest/websockets/controller_test.go +++ b/engine/access/rest/websockets/controller_test.go @@ -41,7 +41,9 @@ func TestWsControllerSuite(t *testing.T) { // TestSubscribeRequest tests the subscribe to topic flow. // We emulate a request message from a client, and a response message from a controller. func (s *WsControllerSuite) TestSubscribeRequest() { - s.T().Parallel() + // It still fails when run with -race even though we don't share any state + // (I tried changing logger & config to be unique in each test) + //s.T().Parallel() s.T().Run("Happy path", func(t *testing.T) { t.Parallel() @@ -175,7 +177,7 @@ func (s *WsControllerSuite) TestSubscribeRequest() { Once() done := make(chan struct{}, 1) - s.expectSubscribeRequest(conn) + s.expectSubscribeRequest(t, conn) conn. On("WriteJSON", mock.Anything). @@ -220,8 +222,8 @@ func (s *WsControllerSuite) TestSubscribeRequest() { Once() done := make(chan struct{}, 1) - msgID := s.expectSubscribeRequest(conn) - s.expectSubscribeResponse(conn, msgID) + msgID := s.expectSubscribeRequest(t, conn) + s.expectSubscribeResponse(t, conn, msgID) conn. On("WriteJSON", mock.Anything). @@ -248,7 +250,7 @@ func (s *WsControllerSuite) TestSubscribeRequest() { } func (s *WsControllerSuite) TestUnsubscribeRequest() { - s.T().Parallel() + //s.T().Parallel() s.T().Run("Happy path", func(t *testing.T) { t.Parallel() @@ -276,8 +278,8 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { Return(nil). Once() - msgID := s.expectSubscribeRequest(conn) - s.expectSubscribeResponse(conn, msgID) + msgID := s.expectSubscribeRequest(t, conn) + s.expectSubscribeResponse(t, conn, msgID) request := models.UnsubscribeMessageRequest{ BaseMessageRequest: models.BaseMessageRequest{ @@ -350,8 +352,8 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { Return(nil). Once() - msgID := s.expectSubscribeRequest(conn) - s.expectSubscribeResponse(conn, msgID) + msgID := s.expectSubscribeRequest(t, conn) + s.expectSubscribeResponse(t, conn, msgID) request := models.UnsubscribeMessageRequest{ BaseMessageRequest: models.BaseMessageRequest{ @@ -424,8 +426,8 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { Return(nil). Once() - msgID := s.expectSubscribeRequest(conn) - s.expectSubscribeResponse(conn, msgID) + msgID := s.expectSubscribeRequest(t, conn) + s.expectSubscribeResponse(t, conn, msgID) request := models.UnsubscribeMessageRequest{ BaseMessageRequest: models.BaseMessageRequest{ @@ -475,7 +477,7 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { } func (s *WsControllerSuite) TestListSubscriptions() { - s.T().Parallel() + //s.T().Parallel() s.T().Run("Happy path", func(t *testing.T) { @@ -504,8 +506,8 @@ func (s *WsControllerSuite) TestListSubscriptions() { Return(nil). Once() - msgID := s.expectSubscribeRequest(conn) - s.expectSubscribeResponse(conn, msgID) + msgID := s.expectSubscribeRequest(t, conn) + s.expectSubscribeResponse(t, conn, msgID) request := models.ListSubscriptionsMessageRequest{ BaseMessageRequest: models.BaseMessageRequest{ @@ -556,7 +558,7 @@ func (s *WsControllerSuite) TestListSubscriptions() { // TestSubscribeBlocks tests the functionality for streaming blocks to a subscriber. func (s *WsControllerSuite) TestSubscribeBlocks() { - s.T().Parallel() + //s.T().Parallel() s.T().Run("Stream one block", func(t *testing.T) { t.Parallel() @@ -585,8 +587,8 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { Once() done := make(chan struct{}, 1) - msgID := s.expectSubscribeRequest(conn) - s.expectSubscribeResponse(conn, msgID) + msgID := s.expectSubscribeRequest(t, conn) + s.expectSubscribeResponse(t, conn, msgID) // Expect a valid block to be passed to WriteJSON. // If we got to this point, the controller executed all its logic properly @@ -642,8 +644,8 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { Once() done := make(chan struct{}, 1) - msgID := s.expectSubscribeRequest(conn) - s.expectSubscribeResponse(conn, msgID) + msgID := s.expectSubscribeRequest(t, conn) + s.expectSubscribeResponse(t, conn, msgID) i := 0 actualBlocks := make([]*flow.Block, len(expectedBlocks)) @@ -681,7 +683,7 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { // TestConfigureKeepaliveConnection ensures that the WebSocket connection is configured correctly. func (s *WsControllerSuite) TestConfigureKeepaliveConnection() { - s.T().Parallel() + //s.T().Parallel() s.T().Run("Happy path", func(t *testing.T) { conn := connmock.NewWebsocketConnection(t) @@ -699,7 +701,7 @@ func (s *WsControllerSuite) TestConfigureKeepaliveConnection() { } func (s *WsControllerSuite) TestControllerShutdown() { - s.T().Parallel() + //s.T().Parallel() s.T().Run("Keepalive routine failed", func(t *testing.T) { t.Parallel() @@ -782,8 +784,8 @@ func (s *WsControllerSuite) TestControllerShutdown() { Once() done := make(chan struct{}, 1) - msgID := s.expectSubscribeRequest(conn) - s.expectSubscribeResponse(conn, msgID) + msgID := s.expectSubscribeRequest(t, conn) + s.expectSubscribeResponse(t, conn, msgID) conn. On("WriteJSON", mock.Anything). @@ -827,7 +829,7 @@ func (s *WsControllerSuite) TestControllerShutdown() { } func (s *WsControllerSuite) TestKeepaliveRoutine() { - s.T().Parallel() + //s.T().Parallel() s.T().Run("Successfully pings connection n times", func(t *testing.T) { conn := connmock.NewWebsocketConnection(t) @@ -917,7 +919,7 @@ func newControllerMocks(t *testing.T) (*connmock.WebsocketConnection, *dpmock.Da } // expectSubscribeRequest mocks the client's subscription request. -func (s *WsControllerSuite) expectSubscribeRequest(conn *connmock.WebsocketConnection) string { +func (s *WsControllerSuite) expectSubscribeRequest(t *testing.T, conn *connmock.WebsocketConnection) string { request := models.SubscribeMessageRequest{ BaseMessageRequest: models.BaseMessageRequest{ MessageID: uuid.New().String(), @@ -926,14 +928,14 @@ func (s *WsControllerSuite) expectSubscribeRequest(conn *connmock.WebsocketConne Topic: "blocks", } requestJson, err := json.Marshal(request) - require.NoError(s.T(), err) + require.NoError(t, err) // The very first message from a client is a request to subscribe to some topic conn. On("ReadJSON", mock.Anything). Run(func(args mock.Arguments) { msg, ok := args.Get(0).(*json.RawMessage) - require.True(s.T(), ok) + require.True(t, ok) *msg = requestJson }). Return(nil). @@ -943,14 +945,14 @@ func (s *WsControllerSuite) expectSubscribeRequest(conn *connmock.WebsocketConne } // expectSubscribeResponse mocks the subscription response sent to the client. -func (s *WsControllerSuite) expectSubscribeResponse(conn *connmock.WebsocketConnection, msgId string) { +func (s *WsControllerSuite) expectSubscribeResponse(t *testing.T, conn *connmock.WebsocketConnection, msgId string) { conn. On("WriteJSON", mock.Anything). Run(func(args mock.Arguments) { response, ok := args.Get(0).(models.SubscribeMessageResponse) - require.True(s.T(), ok) - require.Equal(s.T(), msgId, response.MessageID) - require.Equal(s.T(), true, response.Success) + require.True(t, ok) + require.Equal(t, msgId, response.MessageID) + require.Equal(t, true, response.Success) }). Return(nil). Once() From 911e6cd299210e571113c00775819f6c971556f7 Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Mon, 16 Dec 2024 12:18:13 +0200 Subject: [PATCH 16/26] Updated return statements to latest changes --- engine/access/rest/websockets/controller.go | 4 ++-- engine/access/rest/websockets/controller_test.go | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index 351d3d01b82..f4cc380d98b 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -125,7 +125,7 @@ func (c *Controller) keepalive(ctx context.Context) error { for { select { case <-ctx.Done(): - return ctx.Err() + return nil case <-pingTicker.C: err := c.conn.WriteControl(websocket.PingMessage, time.Now().Add(WriteWait)) if err != nil { @@ -158,7 +158,7 @@ func (c *Controller) writeMessages(ctx context.Context) error { for { select { case <-ctx.Done(): - return ctx.Err() + return nil case message, ok := <-c.multiplexedStream: if !ok { return fmt.Errorf("multiplexed stream closed") diff --git a/engine/access/rest/websockets/controller_test.go b/engine/access/rest/websockets/controller_test.go index 64382e4607a..8dc16603604 100644 --- a/engine/access/rest/websockets/controller_test.go +++ b/engine/access/rest/websockets/controller_test.go @@ -896,8 +896,7 @@ func (s *WsControllerSuite) TestKeepaliveRoutine() { // Start the keepalive process with the context canceled err := controller.keepalive(ctx) - s.Require().Error(err) - s.Require().ErrorIs(context.Canceled, err) //TODO: should be nil + s.Require().NoError(err) conn.AssertExpectations(t) // Should not invoke WriteMessage after context cancellation }) From 2dd1ed2489870ca598e5931af6fe8ef55591c761 Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Thu, 19 Dec 2024 16:18:09 +0200 Subject: [PATCH 17/26] fix comments --- engine/access/rest/websockets/controller_test.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/engine/access/rest/websockets/controller_test.go b/engine/access/rest/websockets/controller_test.go index 21cd5552708..6812de26a7f 100644 --- a/engine/access/rest/websockets/controller_test.go +++ b/engine/access/rest/websockets/controller_test.go @@ -18,6 +18,7 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" + dp "github.com/onflow/flow-go/engine/access/rest/websockets/data_providers" dpmock "github.com/onflow/flow-go/engine/access/rest/websockets/data_providers/mock" connmock "github.com/onflow/flow-go/engine/access/rest/websockets/mock" "github.com/onflow/flow-go/engine/access/rest/websockets/models" @@ -94,7 +95,7 @@ func (s *WsControllerSuite) TestSubscribeRequest() { MessageID: uuid.New().String(), Action: models.SubscribeAction, }, - Topic: "blocks", + Topic: dp.BlocksTopic, Arguments: nil, } requestJson, err := json.Marshal(request) @@ -510,7 +511,7 @@ func (s *WsControllerSuite) TestListSubscriptions() { done := make(chan struct{}, 1) id := uuid.New() - topic := "blocks" + topic := dp.BlocksTopic dataProvider.On("ID").Return(id) dataProvider.On("Topic").Return(topic) // data provider might finish on its own or controller will close it via Close() @@ -942,7 +943,7 @@ func (s *WsControllerSuite) expectSubscribeRequest(t *testing.T, conn *connmock. MessageID: uuid.New().String(), Action: models.SubscribeAction, }, - Topic: "blocks", + Topic: dp.BlocksTopic, } requestJson, err := json.Marshal(request) require.NoError(t, err) From f56fa5a4e509e2bc7b0a87aa4b46131dbcbf7c23 Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Thu, 19 Dec 2024 17:06:34 +0200 Subject: [PATCH 18/26] fix comments --- engine/access/rest/websockets/controller.go | 69 ++++------- .../access/rest/websockets/controller_test.go | 108 +++++++----------- .../data_providers/base_provider.go | 3 +- .../data_providers/data_provider.go | 2 +- .../data_providers/mock/data_provider.go | 17 +-- engine/access/rest/websockets/error_codes.go | 10 +- .../rest/websockets/models/base_message.go | 10 +- 7 files changed, 75 insertions(+), 144 deletions(-) diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index f4cc380d98b..47480bc4617 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -176,7 +176,7 @@ func (c *Controller) writeMessages(ctx context.Context) error { } // readMessages continuously reads messages from a client WebSocket connection, -// processes each message, and handles actions based on the message type. +// validates each message, and processes it based on the message type. // // Expected errors during normal operation: // - context.Canceled if the client disconnected @@ -191,7 +191,7 @@ func (c *Controller) readMessages(ctx context.Context) error { c.writeErrorResponse( ctx, err, - wrapErrorMessage(ConnectionRead, "error reading from conn", "", "", "")) + wrapErrorMessage(InvalidMessage, "error reading message", "", "", "")) continue } @@ -212,49 +212,36 @@ func (c *Controller) parseAndValidateMessage(ctx context.Context, message json.R return fmt.Errorf("error unmarshalling base message: %w", err) } - var validatedMsg interface{} switch baseMsg.Action { case models.SubscribeAction: var subscribeMsg models.SubscribeMessageRequest if err := json.Unmarshal(message, &subscribeMsg); err != nil { return fmt.Errorf("error unmarshalling subscribe message: %w", err) } - validatedMsg = subscribeMsg + c.handleSubscribe(ctx, subscribeMsg) case models.UnsubscribeAction: var unsubscribeMsg models.UnsubscribeMessageRequest if err := json.Unmarshal(message, &unsubscribeMsg); err != nil { return fmt.Errorf("error unmarshalling unsubscribe message: %w", err) } - validatedMsg = unsubscribeMsg + c.handleUnsubscribe(ctx, unsubscribeMsg) case models.ListSubscriptionsAction: var listMsg models.ListSubscriptionsMessageRequest if err := json.Unmarshal(message, &listMsg); err != nil { return fmt.Errorf("error unmarshalling list subscriptions message: %w", err) } - validatedMsg = listMsg + c.handleListSubscriptions(ctx, listMsg) default: c.logger.Debug().Str("action", baseMsg.Action).Msg("unknown action type") return fmt.Errorf("unknown action type: %s", baseMsg.Action) } - c.handleAction(ctx, validatedMsg) return nil } -func (c *Controller) handleAction(ctx context.Context, message interface{}) { - switch msg := message.(type) { - case models.SubscribeMessageRequest: - c.handleSubscribe(ctx, msg) - case models.UnsubscribeMessageRequest: - c.handleUnsubscribe(ctx, msg) - case models.ListSubscriptionsMessageRequest: - c.handleListSubscriptions(ctx, msg) - } -} - func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMessageRequest) { // register new provider provider, err := c.dataProviderFactory.NewDataProvider(ctx, msg.Topic, msg.Arguments, c.multiplexedStream) @@ -262,7 +249,7 @@ func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMe c.writeErrorResponse( ctx, err, - wrapErrorMessage(InvalidArgument, "error creating data provider", msg.MessageID, models.SubscribeAction, ""), + wrapErrorMessage(InvalidArgument, "error creating data provider", msg.ClientMessageID, models.SubscribeAction, ""), ) return } @@ -271,8 +258,8 @@ func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMe // write OK response to client responseOk := models.SubscribeMessageResponse{ BaseMessageResponse: models.BaseMessageResponse{ - MessageID: msg.MessageID, - Success: true, + ClientMessageID: msg.ClientMessageID, + Success: true, }, ID: provider.ID().String(), } @@ -286,7 +273,7 @@ func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMe c.writeErrorResponse( ctx, err, - wrapErrorMessage(RunError, "data provider finished with error", "", "", ""), + wrapErrorMessage(SubscriptionError, "subscription finished with error", "", "", ""), ) } @@ -301,7 +288,7 @@ func (c *Controller) handleUnsubscribe(ctx context.Context, msg models.Unsubscri c.writeErrorResponse( ctx, err, - wrapErrorMessage(InvalidArgument, "error parsing message ID", msg.MessageID, models.UnsubscribeAction, msg.SubscriptionID), + wrapErrorMessage(InvalidArgument, "error parsing message ID", msg.ClientMessageID, models.UnsubscribeAction, msg.SubscriptionID), ) return } @@ -311,27 +298,18 @@ func (c *Controller) handleUnsubscribe(ctx context.Context, msg models.Unsubscri c.writeErrorResponse( ctx, err, - wrapErrorMessage(NotFound, "provider not found", msg.MessageID, models.UnsubscribeAction, msg.SubscriptionID), - ) - return - } - - err = provider.Close() - if err != nil { - c.writeErrorResponse( - ctx, - err, - wrapErrorMessage(InternalError, "provider close error", msg.MessageID, models.UnsubscribeAction, msg.SubscriptionID), + wrapErrorMessage(NotFound, "subscription not found", msg.ClientMessageID, models.UnsubscribeAction, msg.SubscriptionID), ) return } + provider.Close() c.dataProviders.Remove(id) responseOk := models.UnsubscribeMessageResponse{ BaseMessageResponse: models.BaseMessageResponse{ - MessageID: msg.MessageID, - Success: true, + ClientMessageID: msg.ClientMessageID, + Success: true, }, SubscriptionID: msg.SubscriptionID, } @@ -352,15 +330,15 @@ func (c *Controller) handleListSubscriptions(ctx context.Context, msg models.Lis c.writeErrorResponse( ctx, err, - wrapErrorMessage(NotFound, "error looking for subscription", msg.MessageID, models.ListSubscriptionsAction, ""), + wrapErrorMessage(NotFound, "error listing subscriptions", msg.ClientMessageID, models.ListSubscriptionsAction, ""), ) return } responseOk := models.ListSubscriptionsMessageResponse{ BaseMessageResponse: models.BaseMessageResponse{ - Success: true, - MessageID: msg.MessageID, + Success: true, + ClientMessageID: msg.ClientMessageID, }, Subscriptions: subs, } @@ -373,13 +351,8 @@ func (c *Controller) shutdownConnection() { c.logger.Debug().Err(err).Msg("error closing connection") } - err = c.dataProviders.ForEach(func(_ uuid.UUID, dp dp.DataProvider) error { - //TODO: why did i think it's a good idea to return error in Close()? it's messy now - err = dp.Close() - if err != nil { - c.logger.Debug().Err(err).Msg("error closing data provider") - } - + err = c.dataProviders.ForEach(func(_ uuid.UUID, provider dp.DataProvider) error { + provider.Close() return nil }) if err != nil { @@ -406,8 +379,8 @@ func (c *Controller) writeResponse(ctx context.Context, response interface{}) { func wrapErrorMessage(code Code, message string, msgId string, action string, subscriptionID string) models.BaseMessageResponse { return models.BaseMessageResponse{ - MessageID: msgId, - Success: false, + ClientMessageID: msgId, + Success: false, Error: models.ErrorMessage{ Code: int(code), Message: message, diff --git a/engine/access/rest/websockets/controller_test.go b/engine/access/rest/websockets/controller_test.go index 6812de26a7f..da1f35bd062 100644 --- a/engine/access/rest/websockets/controller_test.go +++ b/engine/access/rest/websockets/controller_test.go @@ -9,8 +9,6 @@ import ( "github.com/stretchr/testify/require" - streammock "github.com/onflow/flow-go/engine/access/state_stream/mock" - "github.com/google/uuid" "github.com/gorilla/websocket" "github.com/rs/zerolog" @@ -22,7 +20,6 @@ import ( dpmock "github.com/onflow/flow-go/engine/access/rest/websockets/data_providers/mock" connmock "github.com/onflow/flow-go/engine/access/rest/websockets/mock" "github.com/onflow/flow-go/engine/access/rest/websockets/models" - "github.com/onflow/flow-go/engine/access/state_stream/backend" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/utils/unittest" ) @@ -33,12 +30,6 @@ type WsControllerSuite struct { logger zerolog.Logger wsConfig Config - - connection *connmock.WebsocketConnection - dataProviderFactory *dpmock.DataProviderFactory - - streamApi *streammock.API - streamConfig backend.Config } func TestControllerSuite(t *testing.T) { @@ -49,12 +40,6 @@ func TestControllerSuite(t *testing.T) { func (s *WsControllerSuite) SetupTest() { s.logger = unittest.Logger() s.wsConfig = NewDefaultWebsocketConfig() - - s.connection = connmock.NewWebsocketConnection(s.T()) - s.dataProviderFactory = dpmock.NewDataProviderFactory(s.T()) - - s.streamApi = streammock.NewAPI(s.T()) - s.streamConfig = backend.Config{} } // TestSubscribeRequest tests the subscribe to topic flow. @@ -76,7 +61,7 @@ func (s *WsControllerSuite) TestSubscribeRequest() { Once() id := uuid.New() - done := make(chan struct{}, 1) + done := make(chan struct{}) dataProvider.On("ID").Return(id) // data provider might finish on its own or controller will close it via Close() @@ -84,16 +69,15 @@ func (s *WsControllerSuite) TestSubscribeRequest() { dataProvider. On("Run"). Run(func(args mock.Arguments) { - for range done { - } + <-done }). Return(nil). Once() request := models.SubscribeMessageRequest{ BaseMessageRequest: models.BaseMessageRequest{ - MessageID: uuid.New().String(), - Action: models.SubscribeAction, + ClientMessageID: uuid.New().String(), + Action: models.SubscribeAction, }, Topic: dp.BlocksTopic, Arguments: nil, @@ -120,7 +104,7 @@ func (s *WsControllerSuite) TestSubscribeRequest() { response, ok := msg.(models.SubscribeMessageResponse) require.True(t, ok) require.True(t, response.Success) - require.Equal(t, request.MessageID, response.MessageID) + require.Equal(t, request.ClientMessageID, response.ClientMessageID) require.Equal(t, id.String(), response.ID) return websocket.ErrCloseSent @@ -161,7 +145,7 @@ func (s *WsControllerSuite) TestSubscribeRequest() { Return(nil). Once() - done := make(chan struct{}, 1) + done := make(chan struct{}) conn. On("WriteJSON", mock.Anything). Return(func(msg interface{}) error { @@ -195,7 +179,7 @@ func (s *WsControllerSuite) TestSubscribeRequest() { Return(nil, fmt.Errorf("error creating data provider")). Once() - done := make(chan struct{}, 1) + done := make(chan struct{}) s.expectSubscribeRequest(t, conn) conn. @@ -240,7 +224,7 @@ func (s *WsControllerSuite) TestSubscribeRequest() { Return(dataProvider, nil). Once() - done := make(chan struct{}, 1) + done := make(chan struct{}) msgID := s.expectSubscribeRequest(t, conn) s.expectSubscribeResponse(t, conn, msgID) @@ -253,7 +237,7 @@ func (s *WsControllerSuite) TestSubscribeRequest() { require.True(t, ok) require.False(t, response.Success) require.NotEmpty(t, response.Error) - require.Equal(t, int(RunError), response.Error.Code) + require.Equal(t, int(SubscriptionError), response.Error.Code) return websocket.ErrCloseSent }) @@ -283,7 +267,7 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { Once() id := uuid.New() - done := make(chan struct{}, 1) + done := make(chan struct{}) dataProvider.On("ID").Return(id) // data provider might finish on its own or controller will close it via Close() @@ -291,8 +275,7 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { dataProvider. On("Run"). Run(func(args mock.Arguments) { - for range done { - } + <-done }). Return(nil). Once() @@ -302,8 +285,8 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { request := models.UnsubscribeMessageRequest{ BaseMessageRequest: models.BaseMessageRequest{ - MessageID: uuid.New().String(), - Action: models.UnsubscribeAction, + ClientMessageID: uuid.New().String(), + Action: models.UnsubscribeAction, }, SubscriptionID: id.String(), } @@ -329,7 +312,7 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { require.True(t, ok) require.True(t, response.Success) require.Empty(t, response.Error) - require.Equal(t, request.MessageID, response.MessageID) + require.Equal(t, request.ClientMessageID, response.ClientMessageID) require.Equal(t, request.SubscriptionID, response.SubscriptionID) return websocket.ErrCloseSent @@ -357,7 +340,7 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { Once() id := uuid.New() - done := make(chan struct{}, 1) + done := make(chan struct{}) dataProvider.On("ID").Return(id) // data provider might finish on its own or controller will close it via Close() @@ -365,8 +348,7 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { dataProvider. On("Run"). Run(func(args mock.Arguments) { - for range done { - } + <-done }). Return(nil). Once() @@ -376,8 +358,8 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { request := models.UnsubscribeMessageRequest{ BaseMessageRequest: models.BaseMessageRequest{ - MessageID: uuid.New().String(), - Action: models.UnsubscribeAction, + ClientMessageID: uuid.New().String(), + Action: models.UnsubscribeAction, }, SubscriptionID: "invalid-uuid", } @@ -403,7 +385,7 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { require.True(t, ok) require.False(t, response.Success) require.NotEmpty(t, response.Error) - require.Equal(t, request.MessageID, response.MessageID) + require.Equal(t, request.ClientMessageID, response.ClientMessageID) require.Equal(t, int(InvalidArgument), response.Error.Code) return websocket.ErrCloseSent @@ -431,7 +413,7 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { Once() id := uuid.New() - done := make(chan struct{}, 1) + done := make(chan struct{}) dataProvider.On("ID").Return(id) // data provider might finish on its own or controller will close it via Close() @@ -439,8 +421,7 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { dataProvider. On("Run"). Run(func(args mock.Arguments) { - for range done { - } + <-done }). Return(nil). Once() @@ -450,8 +431,8 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { request := models.UnsubscribeMessageRequest{ BaseMessageRequest: models.BaseMessageRequest{ - MessageID: uuid.New().String(), - Action: models.UnsubscribeAction, + ClientMessageID: uuid.New().String(), + Action: models.UnsubscribeAction, }, SubscriptionID: uuid.New().String(), } @@ -478,7 +459,7 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { require.False(t, response.Success) require.NotEmpty(t, response.Error) - require.Equal(t, request.MessageID, response.MessageID) + require.Equal(t, request.ClientMessageID, response.ClientMessageID) require.Equal(t, int(NotFound), response.Error.Code) return websocket.ErrCloseSent @@ -508,7 +489,7 @@ func (s *WsControllerSuite) TestListSubscriptions() { Return(dataProvider, nil). Once() - done := make(chan struct{}, 1) + done := make(chan struct{}) id := uuid.New() topic := dp.BlocksTopic @@ -519,8 +500,7 @@ func (s *WsControllerSuite) TestListSubscriptions() { dataProvider. On("Run"). Run(func(args mock.Arguments) { - for range done { - } + <-done }). Return(nil). Once() @@ -530,8 +510,8 @@ func (s *WsControllerSuite) TestListSubscriptions() { request := models.ListSubscriptionsMessageRequest{ BaseMessageRequest: models.BaseMessageRequest{ - MessageID: uuid.New().String(), - Action: models.ListSubscriptionsAction, + ClientMessageID: uuid.New().String(), + Action: models.ListSubscriptionsAction, }, } requestJson, err := json.Marshal(request) @@ -556,7 +536,7 @@ func (s *WsControllerSuite) TestListSubscriptions() { require.True(t, ok) require.True(t, response.Success) require.Empty(t, response.Error) - require.Equal(t, request.MessageID, response.MessageID) + require.Equal(t, request.ClientMessageID, response.ClientMessageID) require.Equal(t, 1, len(response.Subscriptions)) require.Equal(t, id.String(), response.Subscriptions[0].ID) require.Equal(t, topic, response.Subscriptions[0].Topic) @@ -605,7 +585,7 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { Return(nil). Once() - done := make(chan struct{}, 1) + done := make(chan struct{}) msgID := s.expectSubscribeRequest(t, conn) s.expectSubscribeResponse(t, conn, msgID) @@ -662,7 +642,7 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { Return(nil). Once() - done := make(chan struct{}, 1) + done := make(chan struct{}) msgID := s.expectSubscribeRequest(t, conn) s.expectSubscribeResponse(t, conn, msgID) @@ -734,7 +714,7 @@ func (s *WsControllerSuite) TestControllerShutdown() { controller := NewWebSocketController(s.logger, s.wsConfig, conn, factory) // Mock keepalive to return an error - done := make(chan struct{}, 1) + done := make(chan struct{}) conn. On("WriteControl", websocket.PingMessage, mock.Anything). Return(func(int, time.Time) error { @@ -746,8 +726,7 @@ func (s *WsControllerSuite) TestControllerShutdown() { conn. On("ReadJSON", mock.Anything). Return(func(interface{}) error { - for range done { - } + <-done return websocket.ErrCloseSent }). Once() @@ -802,7 +781,7 @@ func (s *WsControllerSuite) TestControllerShutdown() { Return(nil). Once() - done := make(chan struct{}, 1) + done := make(chan struct{}) msgID := s.expectSubscribeRequest(t, conn) s.expectSubscribeResponse(t, conn, msgID) @@ -856,7 +835,7 @@ func (s *WsControllerSuite) TestKeepaliveRoutine() { conn.On("SetPongHandler", mock.AnythingOfType("func(string) error")).Return(nil).Once() conn.On("SetReadDeadline", mock.Anything).Return(nil) - done := make(chan struct{}, 1) + done := make(chan struct{}) i := 0 expectedCalls := 2 conn. @@ -873,8 +852,7 @@ func (s *WsControllerSuite) TestKeepaliveRoutine() { Times(expectedCalls + 1) conn.On("ReadJSON", mock.Anything).Return(func(_ interface{}) error { - for range done { - } + <-done return websocket.ErrCloseSent }) @@ -921,8 +899,7 @@ func (s *WsControllerSuite) TestKeepaliveRoutine() { }) } -// newControllerMocks initializes mock WebSocket connection, data provider, and data provider factory. -// The mocked functions are expected to be called in a case when a test is expected to reach WriteJSON function. +// newControllerMocks initializes mock WebSocket connection, data provider, and data provider factory func newControllerMocks(t *testing.T) (*connmock.WebsocketConnection, *dpmock.DataProviderFactory, *dpmock.DataProvider) { conn := connmock.NewWebsocketConnection(t) conn.On("Close").Return(nil).Once() @@ -940,8 +917,8 @@ func newControllerMocks(t *testing.T) (*connmock.WebsocketConnection, *dpmock.Da func (s *WsControllerSuite) expectSubscribeRequest(t *testing.T, conn *connmock.WebsocketConnection) string { request := models.SubscribeMessageRequest{ BaseMessageRequest: models.BaseMessageRequest{ - MessageID: uuid.New().String(), - Action: models.SubscribeAction, + ClientMessageID: uuid.New().String(), + Action: models.SubscribeAction, }, Topic: dp.BlocksTopic, } @@ -959,7 +936,7 @@ func (s *WsControllerSuite) expectSubscribeRequest(t *testing.T, conn *connmock. Return(nil). Once() - return request.MessageID + return request.ClientMessageID } // expectSubscribeResponse mocks the subscription response sent to the client. @@ -969,7 +946,7 @@ func (s *WsControllerSuite) expectSubscribeResponse(t *testing.T, conn *connmock Run(func(args mock.Arguments) { response, ok := args.Get(0).(models.SubscribeMessageResponse) require.True(t, ok) - require.Equal(t, msgId, response.MessageID) + require.Equal(t, msgId, response.ClientMessageID) require.Equal(t, true, response.Success) }). Return(nil). @@ -982,8 +959,7 @@ func (s *WsControllerSuite) expectCloseConnection(conn *connmock.WebsocketConnec conn. On("ReadJSON", mock.Anything). Return(func(msg interface{}) error { - for range done { - } + <-done return websocket.ErrCloseSent }). Once() diff --git a/engine/access/rest/websockets/data_providers/base_provider.go b/engine/access/rest/websockets/data_providers/base_provider.go index cf1ee1313d9..0ee040cd4ac 100644 --- a/engine/access/rest/websockets/data_providers/base_provider.go +++ b/engine/access/rest/websockets/data_providers/base_provider.go @@ -46,7 +46,6 @@ func (b *baseDataProvider) Topic() string { // Close terminates the data provider. // // No errors are expected during normal operations. -func (b *baseDataProvider) Close() error { +func (b *baseDataProvider) Close() { b.cancel() - return nil } diff --git a/engine/access/rest/websockets/data_providers/data_provider.go b/engine/access/rest/websockets/data_providers/data_provider.go index 08dc497808b..ab48ebeb9f9 100644 --- a/engine/access/rest/websockets/data_providers/data_provider.go +++ b/engine/access/rest/websockets/data_providers/data_provider.go @@ -14,7 +14,7 @@ type DataProvider interface { // Close terminates the data provider. // // No errors are expected during normal operations. - Close() error + Close() // Run starts processing the subscription and handles responses. // // The separation of the data provider's creation and its Run() method diff --git a/engine/access/rest/websockets/data_providers/mock/data_provider.go b/engine/access/rest/websockets/data_providers/mock/data_provider.go index 3fe8bc5d15b..48debb23ae3 100644 --- a/engine/access/rest/websockets/data_providers/mock/data_provider.go +++ b/engine/access/rest/websockets/data_providers/mock/data_provider.go @@ -13,21 +13,8 @@ type DataProvider struct { } // Close provides a mock function with given fields: -func (_m *DataProvider) Close() error { - ret := _m.Called() - - if len(ret) == 0 { - panic("no return value specified for Close") - } - - var r0 error - if rf, ok := ret.Get(0).(func() error); ok { - r0 = rf() - } else { - r0 = ret.Error(0) - } - - return r0 +func (_m *DataProvider) Close() { + _m.Called() } // ID provides a mock function with given fields: diff --git a/engine/access/rest/websockets/error_codes.go b/engine/access/rest/websockets/error_codes.go index fa4d1521a3b..fd206bed0b3 100644 --- a/engine/access/rest/websockets/error_codes.go +++ b/engine/access/rest/websockets/error_codes.go @@ -3,12 +3,8 @@ package websockets type Code int const ( - Ok Code = iota - ConnectionRead - ConnectionWrite - InvalidMessage - NotFound + InvalidMessage Code = iota InvalidArgument - RunError - InternalError + NotFound + SubscriptionError ) diff --git a/engine/access/rest/websockets/models/base_message.go b/engine/access/rest/websockets/models/base_message.go index 88b15ded6a3..6118e2aec89 100644 --- a/engine/access/rest/websockets/models/base_message.go +++ b/engine/access/rest/websockets/models/base_message.go @@ -2,15 +2,15 @@ package models // BaseMessageRequest represents a base structure for incoming messages. type BaseMessageRequest struct { - Action string `json:"action"` // subscribe, unsubscribe or list_subscriptions - MessageID string `json:"message_id"` // MessageID is a uuid generated by client to identify request/response uniquely + Action string `json:"action"` // subscribe, unsubscribe or list_subscriptions + ClientMessageID string `json:"message_id"` // ClientMessageID is a uuid generated by client to identify request/response uniquely } // BaseMessageResponse represents a base structure for outgoing messages. type BaseMessageResponse struct { - MessageID string `json:"message_id,omitempty"` // MessageID may be empty in case we send msg by ourselves (e.g. error occurred) - Success bool `json:"success"` - Error ErrorMessage `json:"error,omitempty"` + ClientMessageID string `json:"message_id,omitempty"` // ClientMessageID may be empty in case we send msg by ourselves (e.g. error occurred) + Success bool `json:"success"` + Error ErrorMessage `json:"error,omitempty"` } const ( From 02ad4882faa47231986e0be9b4d74a23a1c8a1e6 Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Thu, 19 Dec 2024 17:18:19 +0200 Subject: [PATCH 19/26] change error messsages --- engine/access/rest/websockets/controller.go | 2 +- engine/access/rest/websockets/controller_test.go | 12 +++++------- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index 47480bc4617..9eb4c43f433 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -288,7 +288,7 @@ func (c *Controller) handleUnsubscribe(ctx context.Context, msg models.Unsubscri c.writeErrorResponse( ctx, err, - wrapErrorMessage(InvalidArgument, "error parsing message ID", msg.ClientMessageID, models.UnsubscribeAction, msg.SubscriptionID), + wrapErrorMessage(InvalidArgument, "error parsing subscription ID", msg.ClientMessageID, models.UnsubscribeAction, msg.SubscriptionID), ) return } diff --git a/engine/access/rest/websockets/controller_test.go b/engine/access/rest/websockets/controller_test.go index da1f35bd062..9c6fc8d3283 100644 --- a/engine/access/rest/websockets/controller_test.go +++ b/engine/access/rest/websockets/controller_test.go @@ -972,13 +972,11 @@ func (s *WsControllerSuite) expectKeepaliveRoutineShutdown(conn *connmock.Websoc conn. On("WriteControl", websocket.PingMessage, mock.Anything). Return(func(int, time.Time) error { - for { - select { - case <-done: - return websocket.ErrCloseSent - default: - return nil - } + select { + case <-done: + return websocket.ErrCloseSent + default: + return nil } }). Maybe() From 72de9e5dd47ec8011e819769026cdbe2bbbb9e45 Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Thu, 19 Dec 2024 17:29:30 +0200 Subject: [PATCH 20/26] change request-response models --- engine/access/rest/websockets/controller.go | 20 +++++++++---------- .../access/rest/websockets/controller_test.go | 18 +---------------- .../rest/websockets/models/base_message.go | 13 ++++++------ .../models/{error.go => error_message.go} | 0 .../websockets/models/list_subscriptions.go | 6 ++++-- .../{subscribe.go => subscribe_message.go} | 1 - ...{unsubscribe.go => unsubscribe_message.go} | 1 - 7 files changed, 21 insertions(+), 38 deletions(-) rename engine/access/rest/websockets/models/{error.go => error_message.go} (100%) rename engine/access/rest/websockets/models/{subscribe.go => subscribe_message.go} (90%) rename engine/access/rest/websockets/models/{unsubscribe.go => unsubscribe_message.go} (90%) diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index 9eb4c43f433..22ced8e7ae5 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -260,8 +260,8 @@ func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMe BaseMessageResponse: models.BaseMessageResponse{ ClientMessageID: msg.ClientMessageID, Success: true, + SubscriptionID: provider.ID().String(), }, - ID: provider.ID().String(), } c.writeResponse(ctx, responseOk) @@ -310,8 +310,8 @@ func (c *Controller) handleUnsubscribe(ctx context.Context, msg models.Unsubscri BaseMessageResponse: models.BaseMessageResponse{ ClientMessageID: msg.ClientMessageID, Success: true, + SubscriptionID: msg.SubscriptionID, }, - SubscriptionID: msg.SubscriptionID, } c.writeResponse(ctx, responseOk) } @@ -336,11 +336,9 @@ func (c *Controller) handleListSubscriptions(ctx context.Context, msg models.Lis } responseOk := models.ListSubscriptionsMessageResponse{ - BaseMessageResponse: models.BaseMessageResponse{ - Success: true, - ClientMessageID: msg.ClientMessageID, - }, - Subscriptions: subs, + Success: true, + ClientMessageID: msg.ClientMessageID, + Subscriptions: subs, } c.writeResponse(ctx, responseOk) } @@ -381,11 +379,11 @@ func wrapErrorMessage(code Code, message string, msgId string, action string, su return models.BaseMessageResponse{ ClientMessageID: msgId, Success: false, + SubscriptionID: subscriptionID, Error: models.ErrorMessage{ - Code: int(code), - Message: message, - Action: action, - SubscriptionID: subscriptionID, + Code: int(code), + Message: message, + Action: action, }, } } diff --git a/engine/access/rest/websockets/controller_test.go b/engine/access/rest/websockets/controller_test.go index 9c6fc8d3283..de7bcc7b613 100644 --- a/engine/access/rest/websockets/controller_test.go +++ b/engine/access/rest/websockets/controller_test.go @@ -45,10 +45,6 @@ func (s *WsControllerSuite) SetupTest() { // TestSubscribeRequest tests the subscribe to topic flow. // We emulate a request message from a client, and a response message from a controller. func (s *WsControllerSuite) TestSubscribeRequest() { - // It still fails when run with -race even though we don't share any state - // (I tried changing logger & config to be unique in each test) - //s.T().Parallel() - s.T().Run("Happy path", func(t *testing.T) { t.Parallel() @@ -105,7 +101,7 @@ func (s *WsControllerSuite) TestSubscribeRequest() { require.True(t, ok) require.True(t, response.Success) require.Equal(t, request.ClientMessageID, response.ClientMessageID) - require.Equal(t, id.String(), response.ID) + require.Equal(t, id.String(), response.SubscriptionID) return websocket.ErrCloseSent }) @@ -253,8 +249,6 @@ func (s *WsControllerSuite) TestSubscribeRequest() { } func (s *WsControllerSuite) TestUnsubscribeRequest() { - //s.T().Parallel() - s.T().Run("Happy path", func(t *testing.T) { t.Parallel() @@ -477,8 +471,6 @@ func (s *WsControllerSuite) TestUnsubscribeRequest() { } func (s *WsControllerSuite) TestListSubscriptions() { - //s.T().Parallel() - s.T().Run("Happy path", func(t *testing.T) { conn, dataProviderFactory, dataProvider := newControllerMocks(t) @@ -557,8 +549,6 @@ func (s *WsControllerSuite) TestListSubscriptions() { // TestSubscribeBlocks tests the functionality for streaming blocks to a subscriber. func (s *WsControllerSuite) TestSubscribeBlocks() { - //s.T().Parallel() - s.T().Run("Stream one block", func(t *testing.T) { t.Parallel() @@ -682,8 +672,6 @@ func (s *WsControllerSuite) TestSubscribeBlocks() { // TestConfigureKeepaliveConnection ensures that the WebSocket connection is configured correctly. func (s *WsControllerSuite) TestConfigureKeepaliveConnection() { - //s.T().Parallel() - s.T().Run("Happy path", func(t *testing.T) { conn := connmock.NewWebsocketConnection(t) conn.On("SetPongHandler", mock.AnythingOfType("func(string) error")).Return(nil).Once() @@ -700,8 +688,6 @@ func (s *WsControllerSuite) TestConfigureKeepaliveConnection() { } func (s *WsControllerSuite) TestControllerShutdown() { - //s.T().Parallel() - s.T().Run("Keepalive routine failed", func(t *testing.T) { t.Parallel() @@ -827,8 +813,6 @@ func (s *WsControllerSuite) TestControllerShutdown() { } func (s *WsControllerSuite) TestKeepaliveRoutine() { - //s.T().Parallel() - s.T().Run("Successfully pings connection n times", func(t *testing.T) { conn := connmock.NewWebsocketConnection(t) conn.On("Close").Return(nil).Once() diff --git a/engine/access/rest/websockets/models/base_message.go b/engine/access/rest/websockets/models/base_message.go index 6118e2aec89..cdcd72eb1ed 100644 --- a/engine/access/rest/websockets/models/base_message.go +++ b/engine/access/rest/websockets/models/base_message.go @@ -1,5 +1,11 @@ package models +const ( + SubscribeAction = "subscribe" + UnsubscribeAction = "unsubscribe" + ListSubscriptionsAction = "list_subscription" +) + // BaseMessageRequest represents a base structure for incoming messages. type BaseMessageRequest struct { Action string `json:"action"` // subscribe, unsubscribe or list_subscriptions @@ -8,13 +14,8 @@ type BaseMessageRequest struct { // BaseMessageResponse represents a base structure for outgoing messages. type BaseMessageResponse struct { + SubscriptionID string `json:"subscription_id"` ClientMessageID string `json:"message_id,omitempty"` // ClientMessageID may be empty in case we send msg by ourselves (e.g. error occurred) Success bool `json:"success"` Error ErrorMessage `json:"error,omitempty"` } - -const ( - SubscribeAction = "subscribe" - UnsubscribeAction = "unsubscribe" - ListSubscriptionsAction = "list_subscription" -) diff --git a/engine/access/rest/websockets/models/error.go b/engine/access/rest/websockets/models/error_message.go similarity index 100% rename from engine/access/rest/websockets/models/error.go rename to engine/access/rest/websockets/models/error_message.go diff --git a/engine/access/rest/websockets/models/list_subscriptions.go b/engine/access/rest/websockets/models/list_subscriptions.go index 26174869585..4893a34b09d 100644 --- a/engine/access/rest/websockets/models/list_subscriptions.go +++ b/engine/access/rest/websockets/models/list_subscriptions.go @@ -8,6 +8,8 @@ type ListSubscriptionsMessageRequest struct { // ListSubscriptionsMessageResponse is the structure used to respond to list_subscriptions requests. // It contains a list of active subscriptions for the current WebSocket connection. type ListSubscriptionsMessageResponse struct { - BaseMessageResponse - Subscriptions []*SubscriptionEntry `json:"subscriptions,omitempty"` + ClientMessageID string `json:"message_id"` + Success bool `json:"success"` + Error ErrorMessage `json:"error,omitempty"` + Subscriptions []*SubscriptionEntry `json:"subscriptions,omitempty"` } diff --git a/engine/access/rest/websockets/models/subscribe.go b/engine/access/rest/websockets/models/subscribe_message.go similarity index 90% rename from engine/access/rest/websockets/models/subscribe.go rename to engine/access/rest/websockets/models/subscribe_message.go index 1b4a28470d2..f26fd8fca54 100644 --- a/engine/access/rest/websockets/models/subscribe.go +++ b/engine/access/rest/websockets/models/subscribe_message.go @@ -12,5 +12,4 @@ type SubscribeMessageRequest struct { // SubscribeMessageResponse represents the response to a subscription request. type SubscribeMessageResponse struct { BaseMessageResponse - ID string `json:"id"` // Unique subscription ID } diff --git a/engine/access/rest/websockets/models/unsubscribe.go b/engine/access/rest/websockets/models/unsubscribe_message.go similarity index 90% rename from engine/access/rest/websockets/models/unsubscribe.go rename to engine/access/rest/websockets/models/unsubscribe_message.go index b0b8b8f8e0d..1402189a601 100644 --- a/engine/access/rest/websockets/models/unsubscribe.go +++ b/engine/access/rest/websockets/models/unsubscribe_message.go @@ -9,5 +9,4 @@ type UnsubscribeMessageRequest struct { // UnsubscribeMessageResponse represents the response to an unsubscription request. type UnsubscribeMessageResponse struct { BaseMessageResponse - SubscriptionID string `json:"id"` } From 780b0c28b25c61182cd5ef3475bbcabd81493dbc Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Fri, 27 Dec 2024 14:20:37 +0200 Subject: [PATCH 21/26] remove old comments --- engine/access/rest/websockets/controller.go | 15 +++------------ engine/access/rest/websockets/controller_test.go | 1 - .../rest/websockets/models/error_message.go | 7 +++---- 3 files changed, 6 insertions(+), 17 deletions(-) diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index 22ced8e7ae5..92ea2c2ff48 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -115,9 +115,6 @@ func (c *Controller) configureKeepalive() error { // keepalive sends a ping message periodically to keep the WebSocket connection alive // and avoid timeouts. -// -// Expected errors during normal operation: -// - context.Canceled if the client disconnected func (c *Controller) keepalive(ctx context.Context) error { pingTicker := time.NewTicker(PingPeriod) defer pingTicker.Stop() @@ -142,9 +139,6 @@ func (c *Controller) keepalive(ctx context.Context) error { // writeMessages reads a messages from communication channel and passes them on to a client WebSocket connection. // The communication channel is filled by data providers. Besides, the response limit tracker is involved in // write message regulation -// -// Expected errors during normal operation: -// - context.Canceled if the client disconnected func (c *Controller) writeMessages(ctx context.Context) error { defer func() { // drain the channel as some providers may still send data to it after this routine shutdowns @@ -161,7 +155,7 @@ func (c *Controller) writeMessages(ctx context.Context) error { return nil case message, ok := <-c.multiplexedStream: if !ok { - return fmt.Errorf("multiplexed stream closed") + return nil } if err := c.conn.SetWriteDeadline(time.Now().Add(WriteWait)); err != nil { @@ -177,9 +171,6 @@ func (c *Controller) writeMessages(ctx context.Context) error { // readMessages continuously reads messages from a client WebSocket connection, // validates each message, and processes it based on the message type. -// -// Expected errors during normal operation: -// - context.Canceled if the client disconnected func (c *Controller) readMessages(ctx context.Context) error { for { var message json.RawMessage @@ -195,7 +186,7 @@ func (c *Controller) readMessages(ctx context.Context) error { continue } - err := c.parseAndValidateMessage(ctx, message) + err := c.handleMessage(ctx, message) if err != nil { c.writeErrorResponse( ctx, @@ -206,7 +197,7 @@ func (c *Controller) readMessages(ctx context.Context) error { } } -func (c *Controller) parseAndValidateMessage(ctx context.Context, message json.RawMessage) error { +func (c *Controller) handleMessage(ctx context.Context, message json.RawMessage) error { var baseMsg models.BaseMessageRequest if err := json.Unmarshal(message, &baseMsg); err != nil { return fmt.Errorf("error unmarshalling base message: %w", err) diff --git a/engine/access/rest/websockets/controller_test.go b/engine/access/rest/websockets/controller_test.go index de7bcc7b613..9707dbb8205 100644 --- a/engine/access/rest/websockets/controller_test.go +++ b/engine/access/rest/websockets/controller_test.go @@ -152,7 +152,6 @@ func (s *WsControllerSuite) TestSubscribeRequest() { require.False(t, response.Success) require.NotEmpty(t, response.Error) require.Equal(t, int(InvalidMessage), response.Error.Code) - return websocket.ErrCloseSent }) diff --git a/engine/access/rest/websockets/models/error_message.go b/engine/access/rest/websockets/models/error_message.go index a49e90594d5..d5c0670926f 100644 --- a/engine/access/rest/websockets/models/error_message.go +++ b/engine/access/rest/websockets/models/error_message.go @@ -1,8 +1,7 @@ package models type ErrorMessage struct { - Code int `json:"code"` - Message string `json:"message"` - Action string `json:"action,omitempty"` - SubscriptionID string `json:"subscription_id,omitempty"` + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action,omitempty"` } From 2c808a2d1479d147f86c48e2cf10c89350a53ecb Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Fri, 27 Dec 2024 14:40:22 +0200 Subject: [PATCH 22/26] Add godoc for controller and websockets package --- engine/access/rest/websockets/controller.go | 74 ++++++++++++++++++++- 1 file changed, 71 insertions(+), 3 deletions(-) diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index 92ea2c2ff48..72b73496c72 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -1,3 +1,72 @@ +// Package websockets provides a number of abstractions for managing WebSocket connections. +// It supports handling client subscriptions, sending messages, and maintaining +// the lifecycle of WebSocket connections with robust keepalive mechanisms. +// +// Overview +// +// The architecture of this package consists of three main components: +// +// 1. **Connection**: Responsible for providing a channel that allows the client +// to communicate with the server. It encapsulates WebSocket-level operations +// such as sending and receiving messages. +// 2. **Data Providers**: Standalone units responsible for fetching data from +// the blockchain (protocol). These providers act as sources of data that are +// sent to clients based on their subscriptions. +// 3. **Controller**: Acts as a mediator between the connection and data providers. +// It governs client subscriptions, handles client requests and responses, +// validates messages, and manages error handling. The controller ensures smooth +// coordination between the client and the data-fetching units. +// +// ### Controller Details +// +// The `Controller` is the core component that coordinates the interactions between +// the client and data providers. It achieves this through three routines that run +// in parallel (writer, reader, and keepalive routine). If any of the three routines +// fails with an error, the remaining routines will be canceled using the provided +// context to ensure proper cleanup and termination. +// +// 1. **Reader Routine**: +// - Reads messages from the client WebSocket connection. +// - Parses and validates the messages. +// - Handles the messages by triggering the appropriate actions, such as subscribing +// to a topic or unsubscribing from an existing subscription. +// - Ensures proper validation of message formats and data before passing them to +// the internal handlers. +// +// 2. **Writer Routine**: +// - Listens to the `multiplexedStream`, which is a channel filled by data providers +// with messages that clients have subscribed to. +// - Writes these messages to the client WebSocket connection. +// - Ensures the outgoing messages respect the required deadlines to maintain the +// stability of the connection. +// +// 3. **Keepalive Routine**: +// - Periodically sends a WebSocket ping control message to the client to indicate +// that the controller and all its subscriptions are working as expected. +// - Ensures the connection remains clean and avoids timeout scenarios due to +// inactivity. +// - Resets the connection's read deadline whenever a pong message is received. +// +// Example +// +// Usage typically involves creating a `Controller` instance and invoking its +// `HandleConnection` method to manage a single WebSocket connection: +// +// logger := zerolog.New(os.Stdout) +// config := websockets.Config{/* configuration options */} +// conn := /* a WebsocketConnection implementation */ +// factory := /* a DataProviderFactory implementation */ +// +// controller := websockets.NewWebSocketController(logger, config, conn, factory) +// ctx := context.Background() +// controller.HandleConnection(ctx) +// +// +// Package Constants +// +// This package expects constants like `PongWait` and `WriteWait` for controlling +// the read/write deadlines. They need to be defined in your application as appropriate. + package websockets import ( @@ -136,9 +205,8 @@ func (c *Controller) keepalive(ctx context.Context) error { } } -// writeMessages reads a messages from communication channel and passes them on to a client WebSocket connection. -// The communication channel is filled by data providers. Besides, the response limit tracker is involved in -// write message regulation +// writeMessages reads a messages from multiplexed stream and passes them on to a client WebSocket connection. +// The multiplexed stream channel is filled by data providers func (c *Controller) writeMessages(ctx context.Context) error { defer func() { // drain the channel as some providers may still send data to it after this routine shutdowns From 963a6e7390abf7ede74917852e1a046905b428c9 Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Fri, 27 Dec 2024 15:13:42 +0200 Subject: [PATCH 23/26] Remove unused values from config --- engine/access/rest/websockets/config.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/engine/access/rest/websockets/config.go b/engine/access/rest/websockets/config.go index 8236913dd5f..1cb1a74c183 100644 --- a/engine/access/rest/websockets/config.go +++ b/engine/access/rest/websockets/config.go @@ -34,15 +34,11 @@ const ( type Config struct { MaxSubscriptionsPerConnection uint64 MaxResponsesPerSecond uint64 - SendMessageTimeout time.Duration - MaxRequestSize int64 } func NewDefaultWebsocketConfig() Config { return Config{ MaxSubscriptionsPerConnection: 1000, MaxResponsesPerSecond: 1000, - SendMessageTimeout: 10 * time.Second, - MaxRequestSize: 1024, } } From 51a567f74e83840c78764eaf9483d1c3f5da6911 Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Fri, 27 Dec 2024 15:32:56 +0200 Subject: [PATCH 24/26] add comment about multiplexed stream lifecycly --- engine/access/rest/websockets/controller.go | 27 +++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index 72b73496c72..33ded68878b 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -92,8 +92,31 @@ type Controller struct { config Config conn WebsocketConnection - // data channel which data providers write messages to. - // writer routine reads from this channel and writes messages to connection + // The `multiplexedStream` is a core channel used for communication between the + // `Controller` and Data Providers. Its lifecycle is as follows: + // + // 1. **Data Providers**: + // - Data providers write their data into this channel, which is consumed by + // the writer routine to send messages to the client. + // 2. **Reader Routine**: + // - Writes OK/error responses to the channel as a result of processing client messages. + // 3. **Writer Routine**: + // - Reads messages from this channel and forwards them to the client WebSocket connection. + // + // 4. **Channel Closing**: + // - The `Controller` is responsible for starting and managing the lifecycle of the channel. + // - If an unrecoverable error occurs in any of the three routines (reader, writer, or keepalive), + // the parent context is canceled. This triggers data providers to stop their work. + // - The `multiplexedStream` will not be closed until all data providers signal that + // they have stopped writing to it via the `dataProvidersGroup` wait group. + // + // 5. **Edge Case - Writer Routine Finished Before Providers**: + // - If the writer routine finishes before all data providers, a separate draining routine + // ensures that the `multiplexedStream` is fully drained to prevent deadlocks. + // All remaining messages in this case will be discarded. + // + // This design ensures that the channel is only closed when it is safe to do so, avoiding + // issues such as sending on a closed channel while maintaining proper cleanup. multiplexedStream chan interface{} dataProviders *concurrentmap.Map[uuid.UUID, dp.DataProvider] From af0938953a3c1f2fab8d53b59b55e08de9f590a7 Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Fri, 27 Dec 2024 15:38:27 +0200 Subject: [PATCH 25/26] improve package godoc --- engine/access/rest/websockets/controller.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index 33ded68878b..324af8d47bd 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -17,6 +17,9 @@ // validates messages, and manages error handling. The controller ensures smooth // coordination between the client and the data-fetching units. // +// Basically, it is an N:1:1 approach: N data providers, 1 controller, 1 websocket connection. +// This allows a client to receive messages from different subscriptions over a single connection. +// // ### Controller Details // // The `Controller` is the core component that coordinates the interactions between From 52210b997e981f099370c5bb17b72a6a9d71bab1 Mon Sep 17 00:00:00 2001 From: Illia Malachyn Date: Fri, 27 Dec 2024 15:47:15 +0200 Subject: [PATCH 26/26] add more explanation on how multiplexed stream is closed --- engine/access/rest/websockets/controller.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/engine/access/rest/websockets/controller.go b/engine/access/rest/websockets/controller.go index 324af8d47bd..bffa57350c0 100644 --- a/engine/access/rest/websockets/controller.go +++ b/engine/access/rest/websockets/controller.go @@ -107,6 +107,11 @@ type Controller struct { // - Reads messages from this channel and forwards them to the client WebSocket connection. // // 4. **Channel Closing**: + // The intention to close the channel comes from the reader-from-this-channel routines (controller's routines), + // not the writer-to-this-channel routines (data providers). + // Therefore, we have to signal the data providers to stop writing, wait for them to finish write operations, + // and only after that we can close the channel. + // // - The `Controller` is responsible for starting and managing the lifecycle of the channel. // - If an unrecoverable error occurs in any of the three routines (reader, writer, or keepalive), // the parent context is canceled. This triggers data providers to stop their work.