Skip to content

Commit

Permalink
(refactor): outgoing fetch rate limit by workflow only applies to mak…
Browse files Browse the repository at this point in the history
…ing request
  • Loading branch information
justinkaseman committed Mar 7, 2025
1 parent 6a81de5 commit ce9c1de
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 21 deletions.
9 changes: 8 additions & 1 deletion core/capabilities/webapi/outgoing_connector_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ func NewOutgoingConnectorHandler(gc connector.GatewayConnector, config ServiceCo
// TODO: handle retries
func (c *OutgoingConnectorHandler) HandleSingleNodeRequest(ctx context.Context, messageID string, req capabilities.Request) (*api.Message, error) {
lggr := logger.With(c.lggr, "messageID", messageID, "workflowID", req.WorkflowID)

if !c.rateLimiter.AllowWorkflow(req.WorkflowID) {
return nil, errors.New("exceeded limit of gateways requests")
}

// set default timeout if not provided for all outgoing requests
if req.TimeoutMs == 0 {
req.TimeoutMs = defaultFetchTimeoutMs
Expand Down Expand Up @@ -114,6 +119,8 @@ func (c *OutgoingConnectorHandler) HandleSingleNodeRequest(ctx context.Context,
}
}

// HandleGatewayMessage processes incoming messages from the Gateway,
// which are in response to a HandleSingleNodeRequest call.
func (c *OutgoingConnectorHandler) HandleGatewayMessage(ctx context.Context, gatewayID string, msg *api.Message) {
body := &msg.Body
l := logger.With(c.lggr, "gatewayID", gatewayID, "method", body.Method, "messageID", msg.Body.MessageId)
Expand All @@ -124,7 +131,7 @@ func (c *OutgoingConnectorHandler) HandleGatewayMessage(ctx context.Context, gat
return
}

if !c.rateLimiter.Allow(body.Sender, req.WorkflowID) {
if !c.rateLimiter.Allow(body.Sender) {
// error is logged here instead of warning because if a message from gateway is rate-limited,
// the workflow will eventually fail with timeout as there are no retries in place yet
l.Errorw("request rate-limited")
Expand Down
3 changes: 2 additions & 1 deletion core/capabilities/webapi/outgoing_connector_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ func TestHandleSingleNodeRequest(t *testing.T) {
// expect the request body to contain the default timeout
connector.EXPECT().SignAndSendToGateway(mock.Anything, "gateway1", expectedBody).Run(func(ctx context.Context, gatewayID string, msg *api.MessageBody) {
connectorHandler.HandleGatewayMessage(ctx, "gateway1", gatewayResponse(t, msgID))
}).Return(nil).Times(2)
}).Return(nil).Times(1)

_, err = connectorHandler.HandleSingleNodeRequest(ctx, msgID, ghcapabilities.Request{
URL: testURL,
Expand All @@ -247,6 +247,7 @@ func TestHandleSingleNodeRequest(t *testing.T) {
URL: testURL,
})
require.Error(t, err)
require.ErrorContains(t, err, "exceeded limit of gateways requests")
})

}
Expand Down
24 changes: 10 additions & 14 deletions core/services/gateway/handlers/common/ratelimiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,7 @@ func NewRateLimiter(config RateLimiterConfig) (*RateLimiter, error) {
}

// Allow checks that the sender is not rate limited.
// If a workflow ID is passed, then Allow also requires that the workflow not be rate limited.
//
// Additional IDs beyond the first are ignored.
func (rl *RateLimiter) Allow(sender string, ids ...string) bool {
func (rl *RateLimiter) Allow(sender string) bool {
rl.mu.Lock()
defer rl.mu.Unlock()

Expand All @@ -68,20 +65,19 @@ func (rl *RateLimiter) Allow(sender string, ids ...string) bool {
rl.perSender[sender] = senderLimiter
}

var wfID string
if len(ids) == 1 {
wfID = ids[0]
}
return senderLimiter.Allow() && rl.global.Allow()
}

if wfID == "" {
return senderLimiter.Allow() && rl.global.Allow()
}
// AllowWorkflow checks that the workflow is not rate limited.
func (rl *RateLimiter) AllowWorkflow(ID string) bool {

Check failure on line 72 in core/services/gateway/handlers/common/ratelimiter.go

View workflow job for this annotation

GitHub Actions / GolangCI Lint (.)

captLocal: `ID' should not be capitalized (gocritic)
rl.mu.Lock()
defer rl.mu.Unlock()

wfLimiter, ok := rl.perWorkflow[wfID]
wfLimiter, ok := rl.perWorkflow[ID]
if !ok {
wfLimiter = rate.NewLimiter(rate.Limit(rl.config.PerWorkflowRPS), rl.config.PerWorkflowBurst)
rl.perWorkflow[wfID] = wfLimiter
rl.perWorkflow[ID] = wfLimiter
}

return wfLimiter.Allow() && senderLimiter.Allow() && rl.global.Allow()
return wfLimiter.Allow() && rl.global.Allow()
}
10 changes: 5 additions & 5 deletions core/services/gateway/handlers/common/ratelimiter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ func TestRateLimiter_PerWorkflow(t *testing.T) {
}
rl, err := common.NewRateLimiter(config)
require.NoError(t, err)
require.True(t, rl.Allow("user1"), "workflowID1")
require.True(t, rl.Allow("user2"), "workflowID2")
require.True(t, rl.Allow("user3"), "workflowID1")
require.False(t, rl.Allow("user4"), "workflowID1")
require.False(t, rl.Allow("user5"), "workflowID3")
require.True(t, rl.AllowWorkflow("user1"), "workflowID1")
require.True(t, rl.AllowWorkflow("user2"), "workflowID2")
require.True(t, rl.AllowWorkflow("user3"), "workflowID1")
require.False(t, rl.AllowWorkflow("user4"), "workflowID1")
require.False(t, rl.AllowWorkflow("user5"), "workflowID3")
}

0 comments on commit ce9c1de

Please sign in to comment.