diff --git a/core/capabilities/webapi/outgoing_connector_handler.go b/core/capabilities/webapi/outgoing_connector_handler.go index 42287f07bb9..cdac2b2bb50 100644 --- a/core/capabilities/webapi/outgoing_connector_handler.go +++ b/core/capabilities/webapi/outgoing_connector_handler.go @@ -57,6 +57,11 @@ func NewOutgoingConnectorHandler(gc connector.GatewayConnector, config ServiceCo // TODO: handle retries func (c *OutgoingConnectorHandler) HandleSingleNodeRequest(ctx context.Context, messageID string, req capabilities.Request) (*api.Message, error) { lggr := logger.With(c.lggr, "messageID", messageID, "workflowID", req.WorkflowID) + + if !c.rateLimiter.AllowWorkflow(req.WorkflowID) { + return nil, errors.New("exceeded limit of gateways requests") + } + // set default timeout if not provided for all outgoing requests if req.TimeoutMs == 0 { req.TimeoutMs = defaultFetchTimeoutMs @@ -114,17 +119,13 @@ func (c *OutgoingConnectorHandler) HandleSingleNodeRequest(ctx context.Context, } } +// HandleGatewayMessage processes incoming messages from the Gateway, +// which are in response to a HandleSingleNodeRequest call. func (c *OutgoingConnectorHandler) HandleGatewayMessage(ctx context.Context, gatewayID string, msg *api.Message) { body := &msg.Body l := logger.With(c.lggr, "gatewayID", gatewayID, "method", body.Method, "messageID", msg.Body.MessageId) - var req capabilities.Request - if err := json.Unmarshal(body.Payload, &req); err != nil { - l.Errorw("failed to unmarshal req from payload", "payload", body.Payload) - return - } - - if !c.rateLimiter.Allow(body.Sender, req.WorkflowID) { + if !c.rateLimiter.Allow(body.Sender) { // error is logged here instead of warning because if a message from gateway is rate-limited, // the workflow will eventually fail with timeout as there are no retries in place yet l.Errorw("request rate-limited") diff --git a/core/capabilities/webapi/outgoing_connector_handler_test.go b/core/capabilities/webapi/outgoing_connector_handler_test.go index 2d6d63e7969..4e5ce726bdd 100644 --- a/core/capabilities/webapi/outgoing_connector_handler_test.go +++ b/core/capabilities/webapi/outgoing_connector_handler_test.go @@ -235,7 +235,7 @@ func TestHandleSingleNodeRequest(t *testing.T) { // expect the request body to contain the default timeout connector.EXPECT().SignAndSendToGateway(mock.Anything, "gateway1", expectedBody).Run(func(ctx context.Context, gatewayID string, msg *api.MessageBody) { connectorHandler.HandleGatewayMessage(ctx, "gateway1", gatewayResponse(t, msgID)) - }).Return(nil).Times(2) + }).Return(nil).Times(1) _, err = connectorHandler.HandleSingleNodeRequest(ctx, msgID, ghcapabilities.Request{ URL: testURL, @@ -247,6 +247,7 @@ func TestHandleSingleNodeRequest(t *testing.T) { URL: testURL, }) require.Error(t, err) + require.ErrorContains(t, err, "exceeded limit of gateways requests") }) } diff --git a/core/services/gateway/handlers/common/ratelimiter.go b/core/services/gateway/handlers/common/ratelimiter.go index 972a3b88696..c560f55ae95 100644 --- a/core/services/gateway/handlers/common/ratelimiter.go +++ b/core/services/gateway/handlers/common/ratelimiter.go @@ -55,10 +55,7 @@ func NewRateLimiter(config RateLimiterConfig) (*RateLimiter, error) { } // Allow checks that the sender is not rate limited. -// If a workflow ID is passed, then Allow also requires that the workflow not be rate limited. -// -// Additional IDs beyond the first are ignored. -func (rl *RateLimiter) Allow(sender string, ids ...string) bool { +func (rl *RateLimiter) Allow(sender string) bool { rl.mu.Lock() defer rl.mu.Unlock() @@ -68,20 +65,19 @@ func (rl *RateLimiter) Allow(sender string, ids ...string) bool { rl.perSender[sender] = senderLimiter } - var wfID string - if len(ids) == 1 { - wfID = ids[0] - } + return senderLimiter.Allow() && rl.global.Allow() +} - if wfID == "" { - return senderLimiter.Allow() && rl.global.Allow() - } +// AllowWorkflow checks that the workflow is not rate limited. +func (rl *RateLimiter) AllowWorkflow(ID string) bool { + rl.mu.Lock() + defer rl.mu.Unlock() - wfLimiter, ok := rl.perWorkflow[wfID] + wfLimiter, ok := rl.perWorkflow[ID] if !ok { wfLimiter = rate.NewLimiter(rate.Limit(rl.config.PerWorkflowRPS), rl.config.PerWorkflowBurst) - rl.perWorkflow[wfID] = wfLimiter + rl.perWorkflow[ID] = wfLimiter } - return wfLimiter.Allow() && senderLimiter.Allow() && rl.global.Allow() + return wfLimiter.Allow() && rl.global.Allow() } diff --git a/core/services/gateway/handlers/common/ratelimiter_test.go b/core/services/gateway/handlers/common/ratelimiter_test.go index 310a7115f4c..37ed5080852 100644 --- a/core/services/gateway/handlers/common/ratelimiter_test.go +++ b/core/services/gateway/handlers/common/ratelimiter_test.go @@ -39,9 +39,9 @@ func TestRateLimiter_PerWorkflow(t *testing.T) { } rl, err := common.NewRateLimiter(config) require.NoError(t, err) - require.True(t, rl.Allow("user1"), "workflowID1") - require.True(t, rl.Allow("user2"), "workflowID2") - require.True(t, rl.Allow("user3"), "workflowID1") - require.False(t, rl.Allow("user4"), "workflowID1") - require.False(t, rl.Allow("user5"), "workflowID3") + require.True(t, rl.AllowWorkflow("user1"), "workflowID1") + require.True(t, rl.AllowWorkflow("user2"), "workflowID2") + require.True(t, rl.AllowWorkflow("user3"), "workflowID1") + require.False(t, rl.AllowWorkflow("user4"), "workflowID1") + require.False(t, rl.AllowWorkflow("user5"), "workflowID3") }