Skip to content

Commit

Permalink
add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
DavidOrchard committed Sep 19, 2024
1 parent 3573fd8 commit 7464429
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 25 deletions.
25 changes: 17 additions & 8 deletions core/capabilities/webapi/trigger.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (
"github.com/smartcontractkit/chainlink/v2/core/services/gateway/connector"
"github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/common"
"github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/workflow"
"github.com/smartcontractkit/chainlink/v2/core/services/job"
)

const defaultSendChannelBufferSize = 1000
Expand Down Expand Up @@ -50,6 +49,7 @@ type triggerConnectorHandler struct {
mu sync.Mutex
// Will this have to get pulled into a store to have the topic and workflow ID?
registeredWorkflows map[string]chan capabilities.TriggerResponse
allowedSendersMap map[string]bool
signerKey *ecdsa.PrivateKey
rateLimiter *common.RateLimiter
}
Expand All @@ -61,7 +61,7 @@ var _ services.Service = &triggerConnectorHandler{}
// Once connected to a Gateway, each connector handler periodically sends metadata messages containing aggregated
// config for all registered workflow specs using web-trigger.

func NewTrigger(config TriggerConfig, registry core.CapabilitiesRegistry, connector connector.GatewayConnector, signerKey *ecdsa.PrivateKey, lggr logger.Logger) (job.ServiceCtx, error) {
func NewTrigger(config TriggerConfig, registry core.CapabilitiesRegistry, connector connector.GatewayConnector, signerKey *ecdsa.PrivateKey, lggr logger.Logger) (*triggerConnectorHandler, error) {
// TODO (CAPPL-22, CAPPL-24):
// - decode config
// - create an implementation of the capability API and add it to the Registry
Expand All @@ -73,13 +73,18 @@ func NewTrigger(config TriggerConfig, registry core.CapabilitiesRegistry, connec
if err != nil {
return nil, err
}
allowedSendersMap := map[string]bool{}
for _, k := range config.AllowedSenders {
allowedSendersMap[k.String()] = true
}

handler := &triggerConnectorHandler{
config: config,
connector: connector,
signerKey: signerKey,
rateLimiter: rateLimiter,
lggr: lggr.Named("WorkflowConnectorHandler"),
allowedSendersMap: allowedSendersMap,
config: config,
connector: connector,
signerKey: signerKey,
rateLimiter: rateLimiter,
lggr: lggr.Named("WorkflowConnectorHandler"),
}

return handler, nil
Expand Down Expand Up @@ -131,7 +136,10 @@ func (h *triggerConnectorHandler) HandleGatewayMessage(ctx context.Context, gate
h.lggr.Errorw("request rate-limited")
return
}
// TODO: apply allowlist
if !h.allowedSendersMap[sender.String()] {
h.lggr.Errorw("Unauthorized Sender")
return
}
h.lggr.Debugw("handling gateway request", "id", gatewayID, "method", body.Method, "sender", sender)
var payload TriggerRequestPayload
err := json.Unmarshal(body.Payload, &payload)
Expand All @@ -142,6 +150,7 @@ func (h *triggerConnectorHandler) HandleGatewayMessage(ctx context.Context, gate
switch body.Method {
case workflow.MethodWebAPITrigger:
h.lggr.Debugw("added MethodWebAPITrigger message", "payload", string(body.Payload))
// TODO: Is the staleness check supposed to be in the gateway?
currentTime := time.Now()
// TODO: check against h.config.MaxAllowedMessageAgeSec
if currentTime.Unix()-3000 > payload.Timestamp {
Expand Down
150 changes: 150 additions & 0 deletions core/capabilities/webapi/trigger_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
package webapi

import (
"encoding/json"
"flag"
"testing"

ethCommon "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"

"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"

"github.com/smartcontractkit/chainlink-common/pkg/capabilities"
"github.com/smartcontractkit/chainlink-common/pkg/logger"
registrymock "github.com/smartcontractkit/chainlink-common/pkg/types/core/mocks"
"github.com/smartcontractkit/chainlink/v2/core/internal/testutils"
corelogger "github.com/smartcontractkit/chainlink/v2/core/logger"
"github.com/smartcontractkit/chainlink/v2/core/services/gateway/api"
gcmocks "github.com/smartcontractkit/chainlink/v2/core/services/gateway/connector/mocks"
"github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/common"
)

const (
workflowID1 = "15c631d295ef5e32deb99a10ee6804bc4af13855687559d7ff6552ac6dbb2ce0"
workflowExecutionID1 = "95ef5e32deb99a10ee6804bc4af13855687559d7ff6552ac6dbb2ce0abbadeed"
owner1 = "0x00000000000000000000000000000000000000aa"
)

type testHarness struct {
registry *registrymock.CapabilitiesRegistry
connector *gcmocks.GatewayConnector
lggr logger.Logger
config TriggerConfig
trigger *triggerConnectorHandler
}

func setup(t *testing.T) testHarness {
privateKey, _ := testutils.NewPrivateKeyAndAddress(t)
registry := registrymock.NewCapabilitiesRegistry(t)
connector := gcmocks.NewGatewayConnector(t)
lggr := corelogger.TestLogger(t)
config := TriggerConfig{
RateLimiter: common.RateLimiterConfig{
GlobalRPS: 100.0,
GlobalBurst: 100,
PerSenderRPS: 100.0,
PerSenderBurst: 100,
},
AllowedSenders: []ethCommon.Address{ethCommon.HexToAddress("a")},
}
trigger, err := NewTrigger(config, registry, connector, privateKey, lggr)
require.NoError(t, err)

return testHarness{
registry: registry,
connector: connector,
lggr: lggr,
config: config,
trigger: trigger,
}
}

func gatewayRequest(t *testing.T) *api.Message {
// TODO: are flags like this ok? this is how the upload_workflow test script does it
privateKey := flag.String("private_key", "65456ffb8af4a2b93959256a8e04f6f2fe0943579fb3c9c3350593aabb89023f", "Private key to sign the message with")
messageID := flag.String("id", "12345", "Request ID")
methodName := flag.String("method", "web_trigger", "Method name")
donID := flag.String("don_id", "workflow_don_1", "DON ID")

flag.Parse()
key, err := crypto.HexToECDSA(*privateKey)
require.NoError(t, err)

payload := `{
trigger_id: "[email protected]",
trigger_event_id: "action_1234567890",
timestamp: 1234567890,
topics: ["daily_price_update"],
params: {
bid: "101",
ask: "102"
}
}
`
payloadJSON := []byte(payload)
msg := &api.Message{
Body: api.MessageBody{
MessageId: *messageID,
Method: *methodName,
DonId: *donID,
Payload: json.RawMessage(payloadJSON),
},
}
err = msg.Sign(key)
require.NoError(t, err)

return msg
}

func TestCapability_Execute(t *testing.T) {
th := setup(t)
ctx := testutils.Context(t)

t.Run("happy case", func(t *testing.T) {
triggerReq := capabilities.TriggerRegistrationRequest{
Metadata: capabilities.RequestMetadata{
WorkflowID: workflowID1,
WorkflowOwner: owner1,
},
}
_, err := th.trigger.RegisterTrigger(ctx, triggerReq)
require.NoError(t, err)

gatewayRequest := gatewayRequest(t)

th.connector.On("SendToGateway", mock.Anything, mock.Anything).Return(nil).Once()

// TODO: verify SendToGateway called
th.trigger.HandleGatewayMessage(ctx, "gateway1", gatewayRequest)

// TODO: verify message sent to trigger channel
})

// TODO: allowedSenders fail
// TODO: rateLimit fail
// TODO: empty allowedSenders
// TODO: missing required parameters
// TODO: invalid message
// TODO: other edge cases? empty topics?
// TODO: Test duplicate messages, ie PENDING returned.
// TODO: Test message sent to multiple trigger channels
}

func TestRegisterUnregister(t *testing.T) {
th := setup(t)
ctx := testutils.Context(t)

triggerReq := capabilities.TriggerRegistrationRequest{
Metadata: capabilities.RequestMetadata{
WorkflowID: workflowID1,
WorkflowOwner: owner1,
},
}
_, err := th.trigger.RegisterTrigger(ctx, triggerReq)
require.NoError(t, err)

err = th.trigger.UnregisterTrigger(ctx, triggerReq)
require.NoError(t, err)
}
22 changes: 5 additions & 17 deletions core/scripts/gateway/web_api_trigger/invoke_trigger.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@ func main() {
messageID := flag.String("id", "12345", "Request ID")
methodName := flag.String("method", "web_trigger", "Method name")
donID := flag.String("don_id", "workflow_don_1", "DON ID")
// workflowSpec := flag.String("workflow_spec", "[my spec abcd]", "Workflow Spec")
// payloadJSON := []byte("{\"spec\": \"" + *workflowSpec + "\"}")

flag.Parse()

Expand All @@ -81,21 +79,11 @@ func main() {
trigger_id: "[email protected]",
trigger_event_id: "action_1234567890",
timestamp: 1234567890,
sub-events: [
{
topics: ["daily_price_update"],
params: {
bid: "101",
ask: "102"
}
},
{
topics: ["daily_message", "summary"],
params: {
message: "all good!",
}
},
]
topics: ["daily_price_update"],
params: {
bid: "101",
ask: "102"
}
}
`
payloadJSON := []byte(payload)
Expand Down

0 comments on commit 7464429

Please sign in to comment.