Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(router): refactor complexity limits #1364

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
581 changes: 581 additions & 0 deletions router-tests/complexity_limits_test.go

Large diffs are not rendered by default.

172 changes: 0 additions & 172 deletions router-tests/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,6 @@ import (
"testing"
"time"

"go.opentelemetry.io/otel/sdk/metric"
"go.opentelemetry.io/otel/sdk/trace"
tracetest2 "go.opentelemetry.io/otel/sdk/trace/tracetest"

"github.com/wundergraph/cosmo/router/pkg/otel"
"github.com/wundergraph/cosmo/router/pkg/trace/tracetest"

"github.com/buger/jsonparser"
"github.com/sebdah/goldie/v2"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -1031,168 +1024,3 @@ func TestDataNotSetOnPreExecutionErrors(t *testing.T) {
require.Equal(t, `{"errors":[{"message":"unexpected token - got: RBRACE want one of: [COLON]","locations":[{"line":1,"column":46}]}]}`, res.Body)
})
}

func TestQueryDepthLimit(t *testing.T) {
t.Parallel()
t.Run("max query depth of 0 doesn't block", func(t *testing.T) {
t.Parallel()
testenv.Run(t, &testenv.Config{
ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) {
securityConfiguration.DepthLimit.Enabled = true
securityConfiguration.DepthLimit.Limit = 0
securityConfiguration.DepthLimit.CacheSize = 1024
},
}, func(t *testing.T, xEnv *testenv.Environment) {
res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
Query: `{ employee(id:1) { id details { forename surname } } }`,
})
require.JSONEq(t, `{"data":{"employee":{"id":1,"details":{"forename":"Jens","surname":"Neuse"}}}}`, res.Body)
})
})

t.Run("allows queries up to the max depth", func(t *testing.T) {
t.Parallel()
testenv.Run(t, &testenv.Config{
ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) {
securityConfiguration.DepthLimit.Enabled = true
securityConfiguration.DepthLimit.Limit = 3
securityConfiguration.DepthLimit.CacheSize = 1024
},
}, func(t *testing.T, xEnv *testenv.Environment) {
res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
Query: `{ employee(id:1) { id details { forename surname } } }`,
})
require.JSONEq(t, `{"data":{"employee":{"id":1,"details":{"forename":"Jens","surname":"Neuse"}}}}`, res.Body)
})
})

t.Run("max query depth blocks queries over the limit", func(t *testing.T) {
t.Parallel()
testenv.Run(t, &testenv.Config{
ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) {
securityConfiguration.DepthLimit.Enabled = true
securityConfiguration.DepthLimit.Limit = 2
securityConfiguration.DepthLimit.CacheSize = 1024
},
}, func(t *testing.T, xEnv *testenv.Environment) {
res, _ := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{
Query: `{ employee(id:1) { id details { forename surname } } }`,
})
require.Equal(t, 400, res.Response.StatusCode)
require.Equal(t, `{"errors":[{"message":"The query depth 3 exceeds the max query depth allowed (2)"}]}`, res.Body)
})
})

t.Run("max query depth blocks persisted queries over the limit", func(t *testing.T) {
t.Parallel()
testenv.Run(t, &testenv.Config{
ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) {
securityConfiguration.DepthLimit.Enabled = true
securityConfiguration.DepthLimit.Limit = 2
securityConfiguration.DepthLimit.CacheSize = 1024
},
}, func(t *testing.T, xEnv *testenv.Environment) {
header := make(http.Header)
header.Add("graphql-client-name", "my-client")
res, _ := xEnv.MakeGraphQLRequestOverGET(testenv.GraphQLRequest{
OperationName: []byte(`Find`),
Variables: []byte(`{"criteria": {"nationality": "GERMAN" }}`),
Extensions: []byte(`{"persistedQuery": {"version": 1, "sha256Hash": "e33580cf6276de9a75fb3b1c4b7580fec2a1c8facd13f3487bf6c7c3f854f7e3"}}`),
Header: header,
})
require.Equal(t, 400, res.Response.StatusCode)
require.Equal(t, `{"errors":[{"message":"The query depth 3 exceeds the max query depth allowed (2)"}]}`, res.Body)
})
})

t.Run("max query depth doesn't block persisted queries if DisableDepthLimitPersistedOperations set", func(t *testing.T) {
t.Parallel()
testenv.Run(t, &testenv.Config{
ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) {
securityConfiguration.DepthLimit.Enabled = true
securityConfiguration.DepthLimit.Limit = 2
securityConfiguration.DepthLimit.CacheSize = 1024
securityConfiguration.DepthLimit.IgnorePersistedOperations = true
},
}, func(t *testing.T, xEnv *testenv.Environment) {
header := make(http.Header)
header.Add("graphql-client-name", "my-client")
res, _ := xEnv.MakeGraphQLRequestOverGET(testenv.GraphQLRequest{
OperationName: []byte(`Find`),
Variables: []byte(`{"criteria": {"nationality": "GERMAN" }}`),
Extensions: []byte(`{"persistedQuery": {"version": 1, "sha256Hash": "e33580cf6276de9a75fb3b1c4b7580fec2a1c8facd13f3487bf6c7c3f854f7e3"}}`),
Header: header,
})
require.Equal(t, 200, res.Response.StatusCode)
require.Equal(t, `{"data":{"findEmployees":[{"id":1,"details":{"forename":"Jens","surname":"Neuse"}},{"id":2,"details":{"forename":"Dustin","surname":"Deus"}},{"id":4,"details":{"forename":"Björn","surname":"Schwenzer"}},{"id":11,"details":{"forename":"Alexandra","surname":"Neuse"}}]}}`, res.Body)
})
})

t.Run("query depth validation caches success and failure runs", func(t *testing.T) {
t.Parallel()

metricReader := metric.NewManualReader()
exporter := tracetest.NewInMemoryExporter(t)
testenv.Run(t, &testenv.Config{
TraceExporter: exporter,
MetricReader: metricReader,
ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) {
securityConfiguration.DepthLimit.Enabled = true
securityConfiguration.DepthLimit.Limit = 2
securityConfiguration.DepthLimit.CacheSize = 1024
},
}, func(t *testing.T, xEnv *testenv.Environment) {
failedRes, _ := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{
Query: `{ employee(id:1) { id details { forename surname } } }`,
})
require.Equal(t, 400, failedRes.Response.StatusCode)
require.Equal(t, `{"errors":[{"message":"The query depth 3 exceeds the max query depth allowed (2)"}]}`, failedRes.Body)

testSpan := requireSpanWithName(t, exporter, "Operation - Validate")
require.Contains(t, testSpan.Attributes(), otel.WgQueryDepth.Int(3))
require.Contains(t, testSpan.Attributes(), otel.WgQueryDepthCacheHit.Bool(false))
exporter.Reset()

failedRes2, _ := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{
Query: `{ employee(id:1) { id details { forename surname } } }`,
})
require.Equal(t, 400, failedRes2.Response.StatusCode)
require.Equal(t, `{"errors":[{"message":"The query depth 3 exceeds the max query depth allowed (2)"}]}`, failedRes2.Body)

testSpan2 := requireSpanWithName(t, exporter, "Operation - Validate")
require.Contains(t, testSpan2.Attributes(), otel.WgQueryDepth.Int(3))
require.Contains(t, testSpan2.Attributes(), otel.WgQueryDepthCacheHit.Bool(true))
exporter.Reset()

successRes := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
Query: `query { employees { id } }`,
})
require.JSONEq(t, employeesIDData, successRes.Body)
testSpan3 := requireSpanWithName(t, exporter, "Operation - Validate")
require.Contains(t, testSpan3.Attributes(), otel.WgQueryDepth.Int(2))
require.Contains(t, testSpan3.Attributes(), otel.WgQueryDepthCacheHit.Bool(false))
exporter.Reset()

successRes2 := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
Query: `query { employees { id } }`,
})
require.JSONEq(t, employeesIDData, successRes2.Body)
testSpan4 := requireSpanWithName(t, exporter, "Operation - Validate")
require.Contains(t, testSpan4.Attributes(), otel.WgQueryDepth.Int(2))
require.Contains(t, testSpan4.Attributes(), otel.WgQueryDepthCacheHit.Bool(true))
})
})
}

func requireSpanWithName(t *testing.T, exporter *tracetest2.InMemoryExporter, name string) trace.ReadOnlySpan {
sn := exporter.GetSpans().Snapshots()
var testSpan trace.ReadOnlySpan
for _, span := range sn {
if span.Name() == name {
testSpan = span
break
}
}
require.NotNil(t, testSpan)
return testSpan
}
36 changes: 17 additions & 19 deletions router/core/graph_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,14 +267,14 @@ func (s *graphServer) buildMultiGraphHandler(ctx context.Context, baseMux *chi.M
}

type graphMux struct {
mux *chi.Mux
planCache ExecutionPlanCache[uint64, *planWithMetaData]
persistedOperationCache *ristretto.Cache[uint64, NormalizationCacheEntry]
normalizationCache *ristretto.Cache[uint64, NormalizationCacheEntry]
validationCache *ristretto.Cache[uint64, bool]
queryDepthCache *ristretto.Cache[uint64, int]
operationHashCache *ristretto.Cache[uint64, string]
accessLogsFileLogger *logging.BufferedLogger
mux *chi.Mux
planCache ExecutionPlanCache[uint64, *planWithMetaData]
persistedOperationCache *ristretto.Cache[uint64, NormalizationCacheEntry]
normalizationCache *ristretto.Cache[uint64, NormalizationCacheEntry]
complexityCalculationCache *ristretto.Cache[uint64, ComplexityCacheEntry]
validationCache *ristretto.Cache[uint64, bool]
operationHashCache *ristretto.Cache[uint64, string]
accessLogsFileLogger *logging.BufferedLogger
}

func (s *graphMux) Shutdown(_ context.Context) error {
Expand All @@ -295,8 +295,8 @@ func (s *graphMux) Shutdown(_ context.Context) error {
s.validationCache.Close()
}

if s.queryDepthCache != nil {
s.queryDepthCache.Close()
if s.complexityCalculationCache != nil {
s.complexityCalculationCache.Close()
}

if s.accessLogsFileLogger != nil {
Expand Down Expand Up @@ -436,13 +436,13 @@ func (s *graphServer) buildGraphMux(ctx context.Context,
}
}

if s.securityConfiguration.DepthLimit.Enabled && s.securityConfiguration.DepthLimit.CacheSize > 0 {
queryDepthCacheConfig := &ristretto.Config[uint64, int]{
MaxCost: s.securityConfiguration.DepthLimit.CacheSize,
NumCounters: s.securityConfiguration.DepthLimit.CacheSize * 10,
if s.securityConfiguration.ComplexityCalculationCache != nil && s.securityConfiguration.ComplexityCalculationCache.Enabled && s.securityConfiguration.ComplexityCalculationCache.CacheSize > 0 {
complexityCalculationCacheConfig := &ristretto.Config[uint64, ComplexityCacheEntry]{
MaxCost: s.securityConfiguration.ComplexityCalculationCache.CacheSize,
NumCounters: s.securityConfiguration.ComplexityCalculationCache.CacheSize * 10,
BufferItems: 64,
}
gm.queryDepthCache, err = ristretto.NewCache[uint64, int](queryDepthCacheConfig)
gm.complexityCalculationCache, err = ristretto.NewCache[uint64, ComplexityCacheEntry](complexityCalculationCacheConfig)
if err != nil {
return nil, fmt.Errorf("failed to create query depth cache: %w", err)
}
Expand Down Expand Up @@ -710,7 +710,7 @@ func (s *graphServer) buildGraphMux(ctx context.Context,
PersistedOpsNormalizationCache: gm.persistedOperationCache,
NormalizationCache: gm.normalizationCache,
ValidationCache: gm.validationCache,
QueryDepthCache: gm.queryDepthCache,
QueryDepthCache: gm.complexityCalculationCache,
OperationHashCache: gm.operationHashCache,
ParseKitPoolSize: s.engineExecutionConfiguration.ParseKitPoolSize,
IntrospectionEnabled: s.Config.introspection,
Expand Down Expand Up @@ -775,9 +775,7 @@ func (s *graphServer) buildGraphMux(ctx context.Context,
FileUploadEnabled: s.fileUploadConfig.Enabled,
MaxUploadFiles: s.fileUploadConfig.MaxFiles,
MaxUploadFileSize: int(s.fileUploadConfig.MaxFileSizeBytes),
QueryDepthEnabled: s.securityConfiguration.DepthLimit.Enabled,
QueryDepthLimit: s.securityConfiguration.DepthLimit.Limit,
QueryIgnorePersistent: s.securityConfiguration.DepthLimit.IgnorePersistedOperations,
ComplexityLimits: s.securityConfiguration.ComplexityLimits,
AlwaysIncludeQueryPlan: s.engineExecutionConfiguration.Debug.AlwaysIncludeQueryPlan,
AlwaysSkipLoader: s.engineExecutionConfiguration.Debug.AlwaysSkipLoader,
QueryPlansEnabled: s.Config.queryPlansEnabled,
Expand Down
43 changes: 20 additions & 23 deletions router/core/graphql_prehandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,18 @@ import (
)

type PreHandlerOptions struct {
Logger *zap.Logger
Executor *Executor
Metrics RouterMetrics
OperationProcessor *OperationProcessor
Planner *OperationPlanner
AccessController *AccessController
OperationBlocker *OperationBlocker
RouterPublicKey *ecdsa.PublicKey
TracerProvider *sdktrace.TracerProvider
MaxUploadFiles int
MaxUploadFileSize int
QueryDepthEnabled bool
QueryDepthLimit int
QueryIgnorePersistent bool
Logger *zap.Logger
Executor *Executor
Metrics RouterMetrics
OperationProcessor *OperationProcessor
Planner *OperationPlanner
AccessController *AccessController
OperationBlocker *OperationBlocker
RouterPublicKey *ecdsa.PublicKey
TracerProvider *sdktrace.TracerProvider
ComplexityLimits *config.ComplexityLimits
MaxUploadFiles int
MaxUploadFileSize int

FlushTelemetryAfterResponse bool
FileUploadEnabled bool
Expand Down Expand Up @@ -86,9 +84,7 @@ type PreHandler struct {
fileUploadEnabled bool
maxUploadFiles int
maxUploadFileSize int
queryDepthEnabled bool
queryDepthLimit int
queryIgnorePersistent bool
complexityLimits *config.ComplexityLimits
bodyReadBuffers *sync.Pool
trackSchemaUsageInfo bool
clientHeader config.ClientHeader
Expand Down Expand Up @@ -129,9 +125,7 @@ func NewPreHandler(opts *PreHandlerOptions) *PreHandler {
fileUploadEnabled: opts.FileUploadEnabled,
maxUploadFiles: opts.MaxUploadFiles,
maxUploadFileSize: opts.MaxUploadFileSize,
queryDepthEnabled: opts.QueryDepthEnabled,
queryDepthLimit: opts.QueryDepthLimit,
queryIgnorePersistent: opts.QueryIgnorePersistent,
complexityLimits: opts.ComplexityLimits,
bodyReadBuffers: &sync.Pool{},
alwaysIncludeQueryPlan: opts.AlwaysIncludeQueryPlan,
alwaysSkipLoader: opts.AlwaysSkipLoader,
Expand Down Expand Up @@ -692,9 +686,12 @@ func (h *PreHandler) handleOperation(req *http.Request, buf *bytes.Buffer, varia

// Validate that the planned query doesn't exceed the maximum query depth configured
// This check runs if they've configured a max query depth, and it can optionally be turned off for persisted operations
if h.queryDepthEnabled && h.queryDepthLimit > 0 && (!operationKit.parsedOperation.IsPersistedOperation || operationKit.parsedOperation.IsPersistedOperation && !h.queryIgnorePersistent) {
cacheHit, depth, queryDepthErr := operationKit.ValidateQueryDepth(h.queryDepthLimit, operationKit.kit.doc, h.executor.RouterSchema)
engineValidateSpan.SetAttributes(otel.WgQueryDepth.Int(depth))
if h.complexityLimits != nil {
cacheHit, complexityCalcs, queryDepthErr := operationKit.ValidateQueryComplexity(h.complexityLimits, operationKit.kit.doc, h.executor.RouterSchema, operationKit.parsedOperation.IsPersistedOperation)
engineValidateSpan.SetAttributes(otel.WgQueryDepth.Int(complexityCalcs.Depth))
engineValidateSpan.SetAttributes(otel.WgQueryTotalFields.Int(complexityCalcs.TotalFields))
engineValidateSpan.SetAttributes(otel.WgQueryRootFields.Int(complexityCalcs.RootFields))
engineValidateSpan.SetAttributes(otel.WgQueryRootFieldAliases.Int(complexityCalcs.RootFieldAliases))
engineValidateSpan.SetAttributes(otel.WgQueryDepthCacheHit.Bool(cacheHit))
if queryDepthErr != nil {
rtrace.AttachErrToSpan(engineValidateSpan, err)
Expand Down
Loading
Loading