Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cappl 588/rate limit per workflow #16672

Open
wants to merge 9 commits into
base: develop
Choose a base branch
from
11 changes: 6 additions & 5 deletions core/capabilities/compute/compute.go
Original file line number Diff line number Diff line change
Expand Up @@ -351,11 +351,12 @@ func (f *outgoingConnectorFetcherFactory) NewFetcher(log logger.Logger, emitter
}

resp, err := f.outgoingConnectorHandler.HandleSingleNodeRequest(ctx, messageID, ghcapabilities.Request{
URL: req.Url,
Method: req.Method,
Headers: headersReq,
Body: req.Body,
TimeoutMs: req.TimeoutMs,
URL: req.Url,
Method: req.Method,
Headers: headersReq,
Body: req.Body,
TimeoutMs: req.TimeoutMs,
WorkflowID: req.Metadata.WorkflowId,
})
if err != nil {
return nil, err
Expand Down
21 changes: 16 additions & 5 deletions core/capabilities/webapi/outgoing_connector_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ func NewOutgoingConnectorHandler(gc connector.GatewayConnector, config ServiceCo
// HandleSingleNodeRequest sends a request to first available gateway node and blocks until response is received
// TODO: handle retries
func (c *OutgoingConnectorHandler) HandleSingleNodeRequest(ctx context.Context, messageID string, req capabilities.Request) (*api.Message, error) {
lggr := logger.With(c.lggr, "messageID", messageID, "workflowID", req.WorkflowID)

if !c.rateLimiter.AllowWorkflow(req.WorkflowID) {
return nil, errors.New("exceeded limit of gateways requests")
}

// set default timeout if not provided for all outgoing requests
if req.TimeoutMs == 0 {
req.TimeoutMs = defaultFetchTimeoutMs
Expand All @@ -78,8 +84,7 @@ func (c *OutgoingConnectorHandler) HandleSingleNodeRequest(ctx context.Context,
}
defer c.responses.cleanup(messageID)

l := logger.With(c.lggr, "messageID", messageID)
l.Debugw("sending request to gateway")
lggr.Debugw("sending request to gateway")

body := &api.MessageBody{
MessageId: messageID,
Expand All @@ -93,7 +98,9 @@ func (c *OutgoingConnectorHandler) HandleSingleNodeRequest(ctx context.Context,
return nil, fmt.Errorf("failed to select gateway: %w", err)
}

l.Infow("selected gateway, awaiting connection", "gatewayID", selectedGateway)
lggr = logger.With(lggr, "gatewayID", selectedGateway)

lggr.Infow("selected gateway, awaiting connection")

if err := c.gc.AwaitConnection(ctx, selectedGateway); err != nil {
return nil, errors.Wrap(err, "await connection canceled")
Expand All @@ -105,22 +112,26 @@ func (c *OutgoingConnectorHandler) HandleSingleNodeRequest(ctx context.Context,

select {
case resp := <-ch:
l.Debugw("received response from gateway", "gatewayID", selectedGateway)
lggr.Debugw("received response from gateway")
return resp, nil
case <-ctx.Done():
return nil, ctx.Err()
}
}

// HandleGatewayMessage processes incoming messages from the Gateway,
// which are in response to a HandleSingleNodeRequest call.
func (c *OutgoingConnectorHandler) HandleGatewayMessage(ctx context.Context, gatewayID string, msg *api.Message) {
body := &msg.Body
l := logger.With(c.lggr, "gatewayID", gatewayID, "method", body.Method, "messageID", msg.Body.MessageId)

if !c.rateLimiter.Allow(body.Sender) {
// error is logged here instead of warning because if a message from gateway is rate-limited,
// the workflow will eventually fail with timeout as there are no retries in place yet
c.lggr.Errorw("request rate-limited")
l.Errorw("request rate-limited")
return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically out of scope, but could we return an error here back to the caller? I think this results in nicer UX by telling users they were rate limited rather than returning a timeout.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. Changing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}

l.Debugw("handling gateway request")
switch body.Method {
case capabilities.MethodWebAPITarget, capabilities.MethodComputeAction, capabilities.MethodWorkflowSyncer:
Expand Down
78 changes: 70 additions & 8 deletions core/capabilities/webapi/outgoing_connector_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func TestHandleSingleNodeRequest(t *testing.T) {
ctx := tests.Context(t)
msgID := "msgID"
testURL := "http://localhost:8080"
connector, connectorHandler := newFunction(
connector, connectorHandler := newFunctionWithDefaultConfig(
t,
func(gc *gcmocks.GatewayConnector) {
gc.EXPECT().DonID().Return("donID")
Expand Down Expand Up @@ -63,7 +63,7 @@ func TestHandleSingleNodeRequest(t *testing.T) {
ctx := tests.Context(t)
msgID := "msgID"
testURL := "http://localhost:8080"
connector, connectorHandler := newFunction(
connector, connectorHandler := newFunctionWithDefaultConfig(
t,
func(gc *gcmocks.GatewayConnector) {
gc.EXPECT().DonID().Return("donID")
Expand Down Expand Up @@ -113,7 +113,7 @@ func TestHandleSingleNodeRequest(t *testing.T) {
ctx := tests.Context(t)
msgID := "msgID"
testURL := "http://localhost:8080"
connector, connectorHandler := newFunction(
connector, connectorHandler := newFunctionWithDefaultConfig(
t,
func(gc *gcmocks.GatewayConnector) {
gc.EXPECT().DonID().Return("donID")
Expand Down Expand Up @@ -155,7 +155,7 @@ func TestHandleSingleNodeRequest(t *testing.T) {
ctx := tests.Context(t)
msgID := "msgID"
testURL := "http://localhost:8080"
connector, connectorHandler := newFunction(
connector, connectorHandler := newFunctionWithDefaultConfig(
t,
func(gc *gcmocks.GatewayConnector) {
gc.EXPECT().DonID().Return("donID")
Expand Down Expand Up @@ -192,11 +192,67 @@ func TestHandleSingleNodeRequest(t *testing.T) {
assert.False(t, found)
assert.ErrorIs(t, err, context.DeadlineExceeded)
})

t.Run("rate limits outgoing traffic by workflow", func(t *testing.T) {
ctx := tests.Context(t)
msgID := "msgID"
testURL := "http://localhost:8080"
var config = ServiceConfig{
RateLimiter: common.RateLimiterConfig{
GlobalRPS: 100.0,
GlobalBurst: 100,
PerSenderRPS: 100.0,
PerSenderBurst: 100,
PerWorkflowRPS: 1.0,
PerWorkflowBurst: 1,
},
}
connector, connectorHandler := newFunction(
t,
func(gc *gcmocks.GatewayConnector) {
gc.EXPECT().DonID().Return("donID")
gc.EXPECT().AwaitConnection(matches.AnyContext, "gateway1").Return(nil)
gc.EXPECT().GatewayIDs().Return([]string{"gateway1"})
},
config,
)

// build the expected body with the default timeout
req := ghcapabilities.Request{
URL: testURL,
TimeoutMs: defaultFetchTimeoutMs,
}
payload, err := json.Marshal(req)
require.NoError(t, err)

expectedBody := &api.MessageBody{
MessageId: msgID,
DonId: connector.DonID(),
Method: ghcapabilities.MethodComputeAction,
Payload: payload,
}

// expect the request body to contain the default timeout
connector.EXPECT().SignAndSendToGateway(mock.Anything, "gateway1", expectedBody).Run(func(ctx context.Context, gatewayID string, msg *api.MessageBody) {
connectorHandler.HandleGatewayMessage(ctx, "gateway1", gatewayResponse(t, msgID))
}).Return(nil).Times(1)

_, err = connectorHandler.HandleSingleNodeRequest(ctx, msgID, ghcapabilities.Request{
URL: testURL,
})
require.NoError(t, err)

// Second request should error
_, err = connectorHandler.HandleSingleNodeRequest(ctx, msgID, ghcapabilities.Request{
URL: testURL,
})
require.Error(t, err)
require.ErrorContains(t, err, "exceeded limit of gateways requests")
})

}

func newFunction(t *testing.T, mockFn func(*gcmocks.GatewayConnector)) (*gcmocks.GatewayConnector, *OutgoingConnectorHandler) {
log := logger.TestLogger(t)
connector := gcmocks.NewGatewayConnector(t)
func newFunctionWithDefaultConfig(t *testing.T, mockFn func(*gcmocks.GatewayConnector)) (*gcmocks.GatewayConnector, *OutgoingConnectorHandler) {
var defaultConfig = ServiceConfig{
RateLimiter: common.RateLimiterConfig{
GlobalRPS: 100.0,
Expand All @@ -205,10 +261,16 @@ func newFunction(t *testing.T, mockFn func(*gcmocks.GatewayConnector)) (*gcmocks
PerSenderBurst: 100,
},
}
return newFunction(t, mockFn, defaultConfig)
}

func newFunction(t *testing.T, mockFn func(*gcmocks.GatewayConnector), serviceConfig ServiceConfig) (*gcmocks.GatewayConnector, *OutgoingConnectorHandler) {
log := logger.TestLogger(t)
connector := gcmocks.NewGatewayConnector(t)

mockFn(connector)

connectorHandler, err := NewOutgoingConnectorHandler(connector, defaultConfig, ghcapabilities.MethodComputeAction, log)
connectorHandler, err := NewOutgoingConnectorHandler(connector, serviceConfig, ghcapabilities.MethodComputeAction, log)
require.NoError(t, err)
return connector, connectorHandler
}
Expand Down
19 changes: 12 additions & 7 deletions core/capabilities/webapi/target/target.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,11 @@ func defaultIfNil[T any](value *T, defaultValue T) T {
return defaultValue
}

func getPayload(input webapicap.TargetPayload, cfg webapicap.TargetConfig) (ghcapabilities.Request, error) {
func getPayload(input webapicap.TargetPayload, cfg webapicap.TargetConfig, req capabilities.CapabilityRequest) (ghcapabilities.Request, error) {
if err := validation.ValidateWorkflowOrExecutionID(req.Metadata.WorkflowID); err != nil {
return ghcapabilities.Request{}, fmt.Errorf("workflow ID is invalid: %w", err)
}

method := defaultIfNil(input.Method, DefaultHTTPMethod)
body := defaultIfNil(input.Body, "")
timeoutMs := defaultIfNil(cfg.TimeoutMs, DefaultTimeoutMs)
Expand All @@ -102,11 +106,12 @@ func getPayload(input webapicap.TargetPayload, cfg webapicap.TargetConfig) (ghca
}

return ghcapabilities.Request{
URL: input.Url,
Method: method,
Headers: input.Headers,
Body: []byte(body),
TimeoutMs: timeoutMs,
URL: input.Url,
Method: method,
Headers: input.Headers,
Body: []byte(body),
TimeoutMs: timeoutMs,
WorkflowID: req.Metadata.WorkflowID,
}, nil
}

Expand All @@ -130,7 +135,7 @@ func (c *Capability) Execute(ctx context.Context, req capabilities.CapabilityReq
return capabilities.CapabilityResponse{}, err
}

payload, err := getPayload(input, workflowCfg)
payload, err := getPayload(input, workflowCfg, req)
if err != nil {
return capabilities.CapabilityResponse{}, err
}
Expand Down
1 change: 1 addition & 0 deletions core/services/gateway/handlers/capabilities/webapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ type Request struct {

// Maximum number of bytes to read from the response body. If the gateway max response size is smaller than this value, the gateway max response size will be used.
MaxResponseBytes uint32 `json:"maxBytes,omitempty"`
WorkflowID string
}

type Response struct {
Expand Down
58 changes: 46 additions & 12 deletions core/services/gateway/handlers/common/ratelimiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,27 @@ import (
"golang.org/x/time/rate"
)

const (
defaultWorkflowRPS = 5.0
defaultWorkflowBurst = 50
)

// Wrapper around Go's rate.Limiter that supports both global and a per-sender rate limiting.
type RateLimiter struct {
global *rate.Limiter
perSender map[string]*rate.Limiter
config RateLimiterConfig
mu sync.Mutex
global *rate.Limiter
perSender map[string]*rate.Limiter
perWorkflow map[string]*rate.Limiter
config RateLimiterConfig
mu sync.Mutex
}

type RateLimiterConfig struct {
GlobalRPS float64 `json:"globalRPS"`
GlobalBurst int `json:"globalBurst"`
PerSenderRPS float64 `json:"perSenderRPS"`
PerSenderBurst int `json:"perSenderBurst"`
GlobalRPS float64 `json:"globalRPS"`
GlobalBurst int `json:"globalBurst"`
PerSenderRPS float64 `json:"perSenderRPS"`
PerSenderBurst int `json:"perSenderBurst"`
PerWorkflowRPS float64 `json:"perWorkflowRPS"`
PerWorkflowBurst int `json:"perWorkflowBurst"`
}

func NewRateLimiter(config RateLimiterConfig) (*RateLimiter, error) {
Expand All @@ -29,21 +37,47 @@ func NewRateLimiter(config RateLimiterConfig) (*RateLimiter, error) {
if config.GlobalBurst <= 0 || config.PerSenderBurst <= 0 {
return nil, errors.New("burst values must be positive")
}

if config.PerWorkflowBurst <= 0 {
config.PerWorkflowBurst = defaultWorkflowBurst
}

if config.PerWorkflowRPS <= 0.0 {
config.PerWorkflowRPS = defaultWorkflowRPS
}

return &RateLimiter{
global: rate.NewLimiter(rate.Limit(config.GlobalRPS), config.GlobalBurst),
perSender: make(map[string]*rate.Limiter),
config: config,
global: rate.NewLimiter(rate.Limit(config.GlobalRPS), config.GlobalBurst),
perSender: make(map[string]*rate.Limiter),
perWorkflow: make(map[string]*rate.Limiter),
config: config,
}, nil
}

// Allow checks that the sender is not rate limited.
func (rl *RateLimiter) Allow(sender string) bool {
rl.mu.Lock()
defer rl.mu.Unlock()

senderLimiter, ok := rl.perSender[sender]
if !ok {
senderLimiter = rate.NewLimiter(rate.Limit(rl.config.PerSenderRPS), rl.config.PerSenderBurst)
rl.perSender[sender] = senderLimiter
}
rl.mu.Unlock()

return senderLimiter.Allow() && rl.global.Allow()
}

// AllowWorkflow checks that the workflow is not rate limited.
func (rl *RateLimiter) AllowWorkflow(id string) bool {
rl.mu.Lock()
defer rl.mu.Unlock()

wfLimiter, ok := rl.perWorkflow[id]
if !ok {
wfLimiter = rate.NewLimiter(rate.Limit(rl.config.PerWorkflowRPS), rl.config.PerWorkflowBurst)
rl.perWorkflow[id] = wfLimiter
}

return wfLimiter.Allow() && rl.global.Allow()
}
22 changes: 21 additions & 1 deletion core/services/gateway/handlers/common/ratelimiter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/common"
)

func TestRateLimiter_Simple(t *testing.T) {
func TestRateLimiter_PerSender(t *testing.T) {
t.Parallel()

config := common.RateLimiterConfig{
Expand All @@ -25,3 +25,23 @@ func TestRateLimiter_Simple(t *testing.T) {
require.False(t, rl.Allow("user1"))
require.False(t, rl.Allow("user3"))
}

func TestRateLimiter_PerWorkflow(t *testing.T) {
t.Parallel()

config := common.RateLimiterConfig{
GlobalRPS: 3.0,
GlobalBurst: 3,
PerSenderRPS: 1.0,
PerSenderBurst: 2,
PerWorkflowRPS: 1.0,
PerWorkflowBurst: 2,
}
rl, err := common.NewRateLimiter(config)
require.NoError(t, err)
require.True(t, rl.AllowWorkflow("user1"), "workflowID1")
require.True(t, rl.AllowWorkflow("user2"), "workflowID2")
require.True(t, rl.AllowWorkflow("user3"), "workflowID1")
require.False(t, rl.AllowWorkflow("user4"), "workflowID1")
require.False(t, rl.AllowWorkflow("user5"), "workflowID3")
}
Loading
Loading