Skip to content

Include chainId in JWT audience field #215

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

11 changes: 11 additions & 0 deletions aggregator/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ func (r *RpcServer) GetKey(ctx context.Context, payload *avsproto.GetKeyReq) (*a
"chainId", payload.ChainId,
)

if r.chainID != nil && payload.ChainId != r.chainID.Int64() {
return nil, status.Errorf(codes.InvalidArgument, "Invalid chainId: requested chainId %d does not match SmartWallet chainId %d", payload.ChainId, r.chainID.Int64())
}

if strings.Contains(payload.Signature, ".") {
// API key directly
authenticated, err := auth.VerifyJwtKeyForUser(r.config.JwtSecret, payload.Signature, submitAddress)
Expand Down Expand Up @@ -85,6 +89,7 @@ func (r *RpcServer) GetKey(ctx context.Context, payload *avsproto.GetKeyReq) (*a
ExpiresAt: jwt.NewNumericDate(payload.ExpiredAt.AsTime()),
Issuer: auth.Issuer,
Subject: payload.Owner,
Audience: jwt.ClaimStrings{fmt.Sprintf("%d", payload.ChainId)},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do a comparision here as well. if payload.ChainId is different from our chainid in SmartWallet. reject the auth request

}

token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
Expand Down Expand Up @@ -143,6 +148,12 @@ func (r *RpcServer) verifyAuth(ctx context.Context) (*model.User, error) {
return nil, fmt.Errorf("%s", auth.InvalidAuthenticationKey)
}

chainIdStr := fmt.Sprintf("%d", r.chainID)
aud, err := token.Claims.GetAudience()
if err != nil || len(aud) == 0 || aud[0] != chainIdStr {
return nil, fmt.Errorf("%s: invalid chainId in audience", auth.InvalidAuthenticationKey)
}

user := model.User{
Address: common.HexToAddress(claims["sub"].(string)),
}
Expand Down
55 changes: 55 additions & 0 deletions aggregator/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package aggregator
import (
"context"
"fmt"
"math/big"
"testing"
"time"

Expand All @@ -11,6 +12,8 @@ import (
"github.com/ethereum/go-ethereum/common/hexutil"
"github.com/ethereum/go-ethereum/crypto"
"github.com/golang-jwt/jwt/v5"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
timestamppb "google.golang.org/protobuf/types/known/timestamppb"

"github.com/AvaProtocol/ap-avs/core/chainio/signer"
Expand Down Expand Up @@ -75,4 +78,56 @@ func TestGetKeyWithSignature(t *testing.T) {
if sub != "0x578B110b0a7c06e66b7B1a33C39635304aaF733c" {
t.Errorf("invalid subject. expected 0x578B110b0a7c06e66b7B1a33C39635304aaF733c but got %s", sub)
}

aud, _ := token.Claims.GetAudience()
if len(aud) != 1 || aud[0] != "11155111" {
t.Errorf("invalid audience. expected [11155111] but got %v", aud)
}
}

func TestCrossChainJWTValidation(t *testing.T) {
logger, _ := sdklogging.NewZapLogger("development")

// Create RpcServer with chainID set to Sepolia (11155111)
r := RpcServer{
config: &config.Config{
JwtSecret: []byte("test123"),
Logger: logger,
},
chainID: big.NewInt(11155111), // Sepolia chainID
}

owner := "0x578B110b0a7c06e66b7B1a33C39635304aaF733c"
differentChainID := int64(5) // Goerli chainID
issuedTs, _ := time.Parse(time.RFC3339, "2025-01-01T00:00:00Z")
expiredTs, _ := time.Parse(time.RFC3339, "2025-01-02T00:00:00Z")
issuedAt := timestamppb.New(issuedTs)
expiredAt := timestamppb.New(expiredTs)

text := fmt.Sprintf(authTemplate, differentChainID, issuedTs.UTC().Format("2006-01-02T15:04:05.000Z"), expiredTs.UTC().Format("2006-01-02T15:04:05.000Z"), owner)
privateKey, _ := crypto.HexToECDSA("e0502ddd5a0d05ec7b5c22614a01c8ce783810edaa98e44cc82f5fa5a819aaa9")
signature, _ := signer.SignMessage(privateKey, []byte(text))

payload := &avsproto.GetKeyReq{
ChainId: differentChainID,
IssuedAt: issuedAt,
ExpiredAt: expiredAt,
Owner: owner,
Signature: hexutil.Encode(signature),
}

_, err := r.GetKey(context.Background(), payload)

if err == nil {
t.Errorf("expected GetKey to fail for mismatched chainId, but it succeeded")
}

statusErr, ok := status.FromError(err)
if !ok {
t.Errorf("expected a gRPC status error, got: %v", err)
} else if statusErr.Code() != codes.InvalidArgument {
t.Errorf("expected InvalidArgument error code, got: %v", statusErr.Code())
} else if expected := fmt.Sprintf("Invalid chainId: requested chainId %d does not match SmartWallet chainId %d", differentChainID, r.chainID.Int64()); statusErr.Message() != expected {
t.Errorf("expected error message '%s', got: '%s'", expected, statusErr.Message())
}
}
7 changes: 7 additions & 0 deletions aggregator/rpc_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ type RpcServer struct {
ethrpc *ethclient.Client

smartWalletRpc *ethclient.Client
chainID *big.Int
}

// Get nonce of an existing smart wallet of a given owner
Expand Down Expand Up @@ -398,6 +399,11 @@ func (agg *Aggregator) startRpcServer(ctx context.Context) error {
panic(err)
}

smartWalletChainID, err := smartwalletClient.ChainID(context.Background())
if err != nil {
panic(err)
}

rpcServer := &RpcServer{
cache: agg.cache,
db: agg.db,
Expand All @@ -408,6 +414,7 @@ func (agg *Aggregator) startRpcServer(ctx context.Context) error {

config: agg.config,
operatorPool: agg.operatorPool,
chainID: smartWalletChainID,
}

// TODO: split node and aggregator
Expand Down
Loading