Skip to content

Commit

Permalink
fix(router): send graphql closing boundary to fit Apollo client (#1579)
Browse files Browse the repository at this point in the history
Co-authored-by: Ludwig <[email protected]>
Co-authored-by: Alessandro Pagnin <[email protected]>
  • Loading branch information
3 people authored Feb 18, 2025
1 parent 94addca commit c3d089a
Show file tree
Hide file tree
Showing 10 changed files with 220 additions and 28 deletions.
138 changes: 138 additions & 0 deletions router-tests/events/nats_events_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,144 @@ func TestNatsEvents(t *testing.T) {
}
})
})
t.Run("subscribe after message don't a boundary", func(t *testing.T) {
t.Parallel()

testenv.Run(t, &testenv.Config{
RouterConfigJSONTemplate: testenv.ConfigWithEdfsNatsJSONTemplate,
RouterOptions: []core.Option{
core.WithApolloCompatibilityFlagsConfig(config.ApolloCompatibilityFlags{
SubscriptionMultipartPrintBoundary: config.ApolloCompatibilitySubscriptionMultipartPrintBoundary{
Enabled: false,
},
}),
},
EnableNats: true,
}, func(t *testing.T, xEnv *testenv.Environment) {

subscribePayload := []byte(`{"query":"subscription { countFor(count: 0) }"}`)

var done atomic.Bool

go func() {
defer done.Store(true)

client := http.Client{}
req := xEnv.MakeGraphQLMultipartRequest(http.MethodPost, bytes.NewReader(subscribePayload))
resp, err := client.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
defer resp.Body.Close()

reader := bufio.NewReader(resp.Body)

// Read the first part

expected := "\r\n--graphql\nContent-Type: application/json\r\n\r\n{\"payload\":{\"data\":{\"countFor\":0}}}\n"
read := make([]byte, len(expected))
_, err = reader.Read(read)
assert.NoError(t, err)
assert.Equal(t, expected, string(read))
}()

xEnv.WaitForSubscriptionCount(1, NatsWaitTimeout)
require.Eventually(t, done.Load, NatsWaitTimeout, time.Millisecond*100)
})
})
})

t.Run("multipart with apollo compatibility", func(t *testing.T) {
t.Parallel()

t.Run("subscribe after message add a boundary", func(t *testing.T) {
t.Parallel()

testenv.Run(t, &testenv.Config{
RouterConfigJSONTemplate: testenv.ConfigWithEdfsNatsJSONTemplate,
RouterOptions: []core.Option{
core.WithApolloCompatibilityFlagsConfig(config.ApolloCompatibilityFlags{
SubscriptionMultipartPrintBoundary: config.ApolloCompatibilitySubscriptionMultipartPrintBoundary{
Enabled: true,
},
}),
},
EnableNats: true,
}, func(t *testing.T, xEnv *testenv.Environment) {

subscribePayload := []byte(`{"query":"subscription { countFor(count: 0) }"}`)

var done atomic.Bool

go func() {
defer done.Store(true)

client := http.Client{}
req := xEnv.MakeGraphQLMultipartRequest(http.MethodPost, bytes.NewReader(subscribePayload))
resp, err := client.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
defer resp.Body.Close()

reader := bufio.NewReader(resp.Body)

// Read the first part

expected := "\r\n--graphql\nContent-Type: application/json\r\n\r\n{\"payload\":{\"data\":{\"countFor\":0}}}\n\r\n--graphql"
read := make([]byte, len(expected))
_, err = reader.Read(read)
assert.NoError(t, err)
assert.Equal(t, expected, string(read))
}()

xEnv.WaitForSubscriptionCount(1, NatsWaitTimeout)
require.Eventually(t, done.Load, NatsWaitTimeout, time.Millisecond*100)
})
})

t.Run("subscribe with closing channel", func(t *testing.T) {
t.Parallel()

testenv.Run(t, &testenv.Config{
RouterConfigJSONTemplate: testenv.ConfigWithEdfsNatsJSONTemplate,
RouterOptions: []core.Option{
core.WithApolloCompatibilityFlagsConfig(config.ApolloCompatibilityFlags{
SubscriptionMultipartPrintBoundary: config.ApolloCompatibilitySubscriptionMultipartPrintBoundary{
Enabled: true,
},
}),
},
EnableNats: true,
}, func(t *testing.T, xEnv *testenv.Environment) {

subscribePayload := []byte(`{"query":"subscription { countFor(count: 3) }"}`)

var done atomic.Bool

go func() {
defer done.Store(true)

client := http.Client{}
req := xEnv.MakeGraphQLMultipartRequest(http.MethodPost, bytes.NewReader(subscribePayload))
resp, err := client.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
defer resp.Body.Close()

reader := bufio.NewReader(resp.Body)

// Read the first part
assertMultipartValueEventually(t, reader, "{\"payload\":{\"data\":{\"countFor\":0}}}")
assertMultipartValueEventually(t, reader, "{\"payload\":{\"data\":{\"countFor\":1}}}")
assertMultipartValueEventually(t, reader, "{\"payload\":{\"data\":{\"countFor\":2}}}")
assertMultipartValueEventually(t, reader, "{\"payload\":{\"data\":{\"countFor\":3}}}")
assertLineEquals(t, reader, "")
assertLineEquals(t, reader, "--graphql--")
}()

xEnv.WaitForSubscriptionCount(1, NatsWaitTimeout)
require.Eventually(t, done.Load, NatsWaitTimeout, time.Millisecond*100)
})
})
})

t.Run("subscribe once", func(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion router/core/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ func writeRequestErrors(r *http.Request, w http.ResponseWriter, statusCode int,
// writeMultipartError writes the error response in a multipart format with proper boundaries and headers.
func writeMultipartError(w http.ResponseWriter, requestErrors graphqlerrors.RequestErrors, requestLogger *zap.Logger) error {
// Start with the multipart boundary
prefix := GetWriterPrefix(false, true)
prefix := GetWriterPrefix(false, true, true)
if _, err := w.Write([]byte(prefix)); err != nil {
return err
}
Expand Down
51 changes: 37 additions & 14 deletions router/core/flushwriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@ package core
import (
"bytes"
"context"
"github.com/wundergraph/astjson"
"github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve"
"io"
"mime"
"net/http"
"strconv"
"strings"

"github.com/wundergraph/astjson"
"github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve"
)

const (
Expand All @@ -22,6 +23,7 @@ const (
sseMimeType = "text/event-stream"
heartbeat = "{}"
multipartContent = multipartMime + "; boundary=" + multipartBoundary
multipartStart = "\r\n--" + multipartBoundary
)

type HttpFlushWriter struct {
Expand All @@ -33,6 +35,10 @@ type HttpFlushWriter struct {
sse bool
multipart bool
buf *bytes.Buffer
firstMessage bool
// apolloSubscriptionMultipartPrintBoundary if set to true will send the multipart boundary at the end of the message to allow
// misbehaving client (like apollo client) to read the message just sent before the next one or the heartbeat
apolloSubscriptionMultipartPrintBoundary bool
}

func (f *HttpFlushWriter) Complete() {
Expand All @@ -43,7 +49,11 @@ func (f *HttpFlushWriter) Complete() {
_, _ = f.writer.Write([]byte("event: complete"))
} else if f.multipart {
// Write the final boundary in the multipart response
_, _ = f.writer.Write([]byte("--" + multipartBoundary + "--\n"))
if f.apolloSubscriptionMultipartPrintBoundary {
_, _ = f.writer.Write([]byte("--\n"))
} else {
_, _ = f.writer.Write([]byte("--" + multipartBoundary + "--\n"))
}
}
f.Close()
}
Expand Down Expand Up @@ -72,7 +82,10 @@ func (f *HttpFlushWriter) Flush() (err error) {
resp := f.buf.Bytes()
f.buf.Reset()

flushBreak := GetWriterPrefix(f.sse, f.multipart)
flushBreak := GetWriterPrefix(f.sse, f.multipart, !f.apolloSubscriptionMultipartPrintBoundary || f.firstMessage)
if f.firstMessage {
f.firstMessage = false
}
if f.multipart && len(resp) > 0 {
var err error
resp, err = wrapMultipartMessage(resp)
Expand All @@ -83,7 +96,11 @@ func (f *HttpFlushWriter) Flush() (err error) {

separation := "\n\n"
if f.multipart {
separation = "\n"
if !f.apolloSubscriptionMultipartPrintBoundary {
separation = "\n"
} else {
separation = "\n" + multipartStart
}
} else if f.subscribeOnce {
separation = ""
}
Expand All @@ -100,7 +117,7 @@ func (f *HttpFlushWriter) Flush() (err error) {
return nil
}

func GetSubscriptionResponseWriter(ctx *resolve.Context, r *http.Request, w http.ResponseWriter) (*resolve.Context, resolve.SubscriptionResponseWriter, bool) {
func GetSubscriptionResponseWriter(ctx *resolve.Context, r *http.Request, w http.ResponseWriter, apolloSubscriptionMultipartPrintBoundary bool) (*resolve.Context, resolve.SubscriptionResponseWriter, bool) {
type withFlushWriter interface {
SubscriptionResponseWriter() resolve.SubscriptionResponseWriter
}
Expand All @@ -119,12 +136,14 @@ func GetSubscriptionResponseWriter(ctx *resolve.Context, r *http.Request, w http
flusher.Flush()

flushWriter := &HttpFlushWriter{
writer: w,
flusher: flusher,
sse: wgParams.UseSse,
multipart: wgParams.UseMultipart,
subscribeOnce: wgParams.SubscribeOnce,
buf: &bytes.Buffer{},
writer: w,
flusher: flusher,
sse: wgParams.UseSse,
multipart: wgParams.UseMultipart,
subscribeOnce: wgParams.SubscribeOnce,
buf: &bytes.Buffer{},
firstMessage: true,
apolloSubscriptionMultipartPrintBoundary: apolloSubscriptionMultipartPrintBoundary,
}

flushWriter.ctx, flushWriter.cancel = context.WithCancel(ctx.Context())
Expand Down Expand Up @@ -231,12 +250,16 @@ type SubscriptionParams struct {
UseMultipart bool
}

func GetWriterPrefix(sse bool, multipart bool) string {
func GetWriterPrefix(sse bool, multipart bool, firstMessage bool) string {
flushBreak := ""
if sse {
flushBreak = "event: next\ndata: "
} else if multipart {
flushBreak = "\r\n--" + multipartBoundary + "\nContent-Type: " + jsonContent + "\r\n\r\n"
messageStart := ""
if firstMessage {
messageStart = multipartStart
}
flushBreak = messageStart + "\nContent-Type: " + jsonContent + "\r\n\r\n"
}

return flushBreak
Expand Down
4 changes: 4 additions & 0 deletions router/core/graph_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1015,6 +1015,10 @@ func (s *graphServer) buildGraphMux(ctx context.Context,
}
}

if s.apolloCompatibilityFlags.SubscriptionMultipartPrintBoundary.Enabled {
handlerOpts.ApolloSubscriptionMultipartPrintBoundary = s.apolloCompatibilityFlags.SubscriptionMultipartPrintBoundary.Enabled
}

graphqlHandler := NewGraphQLHandler(handlerOpts)
executor.Resolver.SetAsyncErrorWriter(graphqlHandler)

Expand Down
16 changes: 10 additions & 6 deletions router/core/graphql_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ type HandlerOptions struct {
RateLimitConfig *config.RateLimitConfiguration
SubgraphErrorPropagation config.SubgraphErrorPropagationConfiguration
EngineLoaderHooks resolve.LoaderHooks
ApolloSubscriptionMultipartPrintBoundary bool
}

func NewGraphQLHandler(opts HandlerOptions) *GraphQLHandler {
Expand All @@ -92,11 +93,12 @@ func NewGraphQLHandler(opts HandlerOptions) *GraphQLHandler {
"wundergraph/cosmo/router/graphql_handler",
trace.WithInstrumentationVersion("0.0.1"),
),
authorizer: opts.Authorizer,
rateLimiter: opts.RateLimiter,
rateLimitConfig: opts.RateLimitConfig,
subgraphErrorPropagation: opts.SubgraphErrorPropagation,
engineLoaderHooks: opts.EngineLoaderHooks,
authorizer: opts.Authorizer,
rateLimiter: opts.RateLimiter,
rateLimitConfig: opts.RateLimitConfig,
subgraphErrorPropagation: opts.SubgraphErrorPropagation,
engineLoaderHooks: opts.EngineLoaderHooks,
apolloSubscriptionMultipartPrintBoundary: opts.ApolloSubscriptionMultipartPrintBoundary,
}
return graphQLHandler
}
Expand Down Expand Up @@ -127,6 +129,8 @@ type GraphQLHandler struct {
enablePersistedOperationCacheResponseHeader bool
enableNormalizationCacheResponseHeader bool
enableResponseHeaderPropagation bool

apolloSubscriptionMultipartPrintBoundary bool
}

func (h *GraphQLHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -191,7 +195,7 @@ func (h *GraphQLHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.setDebugCacheHeaders(w, requestContext.operation)

defer propagateSubgraphErrors(ctx)
ctx, writer, ok = GetSubscriptionResponseWriter(ctx, r, w)
ctx, writer, ok = GetSubscriptionResponseWriter(ctx, r, w, h.apolloSubscriptionMultipartPrintBoundary)
if !ok {
requestContext.logger.Error("unable to get subscription response writer", zap.Error(errCouldNotFlushResponse))
trackFinalResponseError(r.Context(), errCouldNotFlushResponse)
Expand Down
1 change: 1 addition & 0 deletions router/core/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -1824,6 +1824,7 @@ func WithApolloCompatibilityFlagsConfig(cfg config.ApolloCompatibilityFlags) Opt
cfg.ReplaceUndefinedOpFieldErrors.Enabled = true
cfg.ReplaceInvalidVarErrors.Enabled = true
cfg.ReplaceValidationErrorStatus.Enabled = true
cfg.SubscriptionMultipartPrintBoundary.Enabled = true
}
r.apolloCompatibilityFlags = cfg
}
Expand Down
19 changes: 12 additions & 7 deletions router/pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -756,13 +756,14 @@ type AccessLogsSubgraphsConfig struct {
}

type ApolloCompatibilityFlags struct {
EnableAll bool `yaml:"enable_all" envDefault:"false" env:"APOLLO_COMPATIBILITY_ENABLE_ALL"`
ValueCompletion ApolloCompatibilityValueCompletion `yaml:"value_completion"`
TruncateFloats ApolloCompatibilityTruncateFloats `yaml:"truncate_floats"`
SuppressFetchErrors ApolloCompatibilitySuppressFetchErrors `yaml:"suppress_fetch_errors"`
ReplaceUndefinedOpFieldErrors ApolloCompatibilityReplaceUndefinedOpFieldErrors `yaml:"replace_undefined_op_field_errors"`
ReplaceInvalidVarErrors ApolloCompatibilityReplaceInvalidVarErrors `yaml:"replace_invalid_var_errors"`
ReplaceValidationErrorStatus ApolloCompatibilityReplaceValidationErrorStatus `yaml:"replace_validation_error_status"`
EnableAll bool `yaml:"enable_all" envDefault:"false" env:"APOLLO_COMPATIBILITY_ENABLE_ALL"`
ValueCompletion ApolloCompatibilityValueCompletion `yaml:"value_completion"`
TruncateFloats ApolloCompatibilityTruncateFloats `yaml:"truncate_floats"`
SuppressFetchErrors ApolloCompatibilitySuppressFetchErrors `yaml:"suppress_fetch_errors"`
ReplaceUndefinedOpFieldErrors ApolloCompatibilityReplaceUndefinedOpFieldErrors `yaml:"replace_undefined_op_field_errors"`
ReplaceInvalidVarErrors ApolloCompatibilityReplaceInvalidVarErrors `yaml:"replace_invalid_var_errors"`
ReplaceValidationErrorStatus ApolloCompatibilityReplaceValidationErrorStatus `yaml:"replace_validation_error_status"`
SubscriptionMultipartPrintBoundary ApolloCompatibilitySubscriptionMultipartPrintBoundary `yaml:"subscription_multipart_print_boundary"`
}

type ApolloCompatibilityValueCompletion struct {
Expand Down Expand Up @@ -794,6 +795,10 @@ type ApolloCompatibilityReplaceValidationErrorStatus struct {
Enabled bool `yaml:"enabled" envDefault:"false" env:"APOLLO_COMPATIBILITY_REPLACE_VALIDATION_ERROR_STATUS_ENABLED"`
}

type ApolloCompatibilitySubscriptionMultipartPrintBoundary struct {
Enabled bool `yaml:"enabled" envDefault:"false" env:"APOLLO_COMPATIBILITY_SUBSCRIPTION_MULTIPART_PRINT_BOUNDARY_ENABLED"`
}

type ApolloRouterCompatibilityFlags struct {
ReplaceInvalidVarErrors ApolloRouterCompatibilityReplaceInvalidVarErrors `yaml:"replace_invalid_var_errors"`
SubrequestHTTPError ApolloRouterCompatibilitySubrequestHTTPError `yaml:"subrequest_http_error"`
Expand Down
11 changes: 11 additions & 0 deletions router/pkg/config/config.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -2393,6 +2393,17 @@
"default": false
}
}
},
"subscription_multipart_print_boundary": {
"type": "object",
"description": "Prints the multipart boundary right after the message in multipart subscriptions. Without this flag, the Apollo client wouldn’t parse a message until the next one is pushed.",
"additionalProperties": false,
"properties": {
"enabled": {
"type": "boolean",
"default": false
}
}
}
}
},
Expand Down
Loading

0 comments on commit c3d089a

Please sign in to comment.