Skip to content

Commit 54cace8

Browse files
authored
Merge pull request #2006 from josephschorr/version-middleware-serve-test
Add server version middleware to serve-testing
2 parents 4429644 + 4ef7f38 commit 54cace8

File tree

5 files changed

+237
-60
lines changed

5 files changed

+237
-60
lines changed

pkg/cmd/server/defaults.go

+84-32
Original file line numberDiff line numberDiff line change
@@ -168,15 +168,75 @@ const (
168168
DefaultInternalMiddlewareServerSpecific = "servicespecific"
169169
)
170170

171+
//go:generate go run github.com/ecordell/optgen -output zz_generated.middlewareoption.go . MiddlewareOption
171172
type MiddlewareOption struct {
172-
logger zerolog.Logger
173-
authFunc grpcauth.AuthFunc
174-
enableVersionResponse bool
175-
dispatcher dispatch.Dispatcher
176-
ds datastore.Datastore
177-
enableRequestLog bool
178-
enableResponseLog bool
179-
disableGRPCHistogram bool
173+
Logger zerolog.Logger `debugmap:"hidden"`
174+
AuthFunc grpcauth.AuthFunc `debugmap:"hidden"`
175+
EnableVersionResponse bool `debugmap:"visible"`
176+
DispatcherForMiddleware dispatch.Dispatcher `debugmap:"hidden"`
177+
EnableRequestLog bool `debugmap:"visible"`
178+
EnableResponseLog bool `debugmap:"visible"`
179+
DisableGRPCHistogram bool `debugmap:"visible"`
180+
181+
unaryDatastoreMiddleware *ReferenceableMiddleware[grpc.UnaryServerInterceptor] `debugmap:"hidden"`
182+
streamDatastoreMiddleware *ReferenceableMiddleware[grpc.StreamServerInterceptor] `debugmap:"hidden"`
183+
}
184+
185+
type Middleware interface {
186+
UnaryServerInterceptor() grpc.UnaryServerInterceptor
187+
StreamServerInterceptor() grpc.StreamServerInterceptor
188+
}
189+
190+
func (m MiddlewareOption) WithDatastoreMiddleware(middleware Middleware) MiddlewareOption {
191+
unary := NewUnaryMiddleware().
192+
WithName(DefaultInternalMiddlewareDatastore).
193+
WithInternal(true).
194+
WithInterceptor(middleware.UnaryServerInterceptor()).
195+
Done()
196+
197+
stream := NewStreamMiddleware().
198+
WithName(DefaultInternalMiddlewareDatastore).
199+
WithInternal(true).
200+
WithInterceptor(middleware.StreamServerInterceptor()).
201+
Done()
202+
203+
return MiddlewareOption{
204+
Logger: m.Logger,
205+
AuthFunc: m.AuthFunc,
206+
EnableVersionResponse: m.EnableVersionResponse,
207+
DispatcherForMiddleware: m.DispatcherForMiddleware,
208+
EnableRequestLog: m.EnableRequestLog,
209+
EnableResponseLog: m.EnableResponseLog,
210+
DisableGRPCHistogram: m.DisableGRPCHistogram,
211+
unaryDatastoreMiddleware: &unary,
212+
streamDatastoreMiddleware: &stream,
213+
}
214+
}
215+
216+
func (m MiddlewareOption) WithDatastore(ds datastore.Datastore) MiddlewareOption {
217+
unary := NewUnaryMiddleware().
218+
WithName(DefaultInternalMiddlewareDatastore).
219+
WithInternal(true).
220+
WithInterceptor(datastoremw.UnaryServerInterceptor(ds)).
221+
Done()
222+
223+
stream := NewStreamMiddleware().
224+
WithName(DefaultInternalMiddlewareDatastore).
225+
WithInternal(true).
226+
WithInterceptor(datastoremw.StreamServerInterceptor(ds)).
227+
Done()
228+
229+
return MiddlewareOption{
230+
Logger: m.Logger,
231+
AuthFunc: m.AuthFunc,
232+
EnableVersionResponse: m.EnableVersionResponse,
233+
DispatcherForMiddleware: m.DispatcherForMiddleware,
234+
EnableRequestLog: m.EnableRequestLog,
235+
EnableResponseLog: m.EnableResponseLog,
236+
DisableGRPCHistogram: m.DisableGRPCHistogram,
237+
unaryDatastoreMiddleware: &unary,
238+
streamDatastoreMiddleware: &stream,
239+
}
180240
}
181241

182242
// gRPCMetricsUnaryInterceptor creates the default prometheus metrics interceptor for unary gRPCs
@@ -212,7 +272,7 @@ func doesNotMatchRoute(route string) func(_ context.Context, c interceptors.Call
212272

213273
// DefaultUnaryMiddleware generates the default middleware chain used for the public SpiceDB Unary gRPC methods
214274
func DefaultUnaryMiddleware(opts MiddlewareOption) (*MiddlewareChain[grpc.UnaryServerInterceptor], error) {
215-
grpcMetricsUnaryInterceptor, _ := GRPCMetrics(opts.disableGRPCHistogram)
275+
grpcMetricsUnaryInterceptor, _ := GRPCMetrics(opts.DisableGRPCHistogram)
216276
chain, err := NewMiddlewareChain([]ReferenceableMiddleware[grpc.UnaryServerInterceptor]{
217277
NewUnaryMiddleware().
218278
WithName(DefaultMiddlewareRequestID).
@@ -232,15 +292,15 @@ func DefaultUnaryMiddleware(opts MiddlewareOption) (*MiddlewareChain[grpc.UnaryS
232292
NewUnaryMiddleware().
233293
WithName(DefaultMiddlewareGRPCLog + "-debug").
234294
WithInterceptor(selector.UnaryServerInterceptor(
235-
grpclog.UnaryServerInterceptor(InterceptorLogger(opts.logger), determineEventsToLog(opts), alwaysDebugOption, durationFieldOption, traceIDFieldOption),
295+
grpclog.UnaryServerInterceptor(InterceptorLogger(opts.Logger), determineEventsToLog(opts), alwaysDebugOption, durationFieldOption, traceIDFieldOption),
236296
selector.MatchFunc(matchesRoute(healthCheckRoute)))).
237297
EnsureAlreadyExecuted(DefaultMiddlewareOTelGRPC). // dependency so that OTel traceID is injected in logs),
238298
Done(),
239299

240300
NewUnaryMiddleware().
241301
WithName(DefaultMiddlewareGRPCLog).
242302
WithInterceptor(selector.UnaryServerInterceptor(
243-
grpclog.UnaryServerInterceptor(InterceptorLogger(opts.logger), determineEventsToLog(opts), defaultCodeToLevel, durationFieldOption, traceIDFieldOption),
303+
grpclog.UnaryServerInterceptor(InterceptorLogger(opts.Logger), determineEventsToLog(opts), defaultCodeToLevel, durationFieldOption, traceIDFieldOption),
244304
selector.MatchFunc(doesNotMatchRoute(healthCheckRoute)))).
245305
EnsureAlreadyExecuted(DefaultMiddlewareOTelGRPC). // dependency so that OTel traceID is injected in logs),
246306
Done(),
@@ -252,26 +312,22 @@ func DefaultUnaryMiddleware(opts MiddlewareOption) (*MiddlewareChain[grpc.UnaryS
252312

253313
NewUnaryMiddleware().
254314
WithName(DefaultMiddlewareGRPCAuth).
255-
WithInterceptor(grpcauth.UnaryServerInterceptor(opts.authFunc)).
315+
WithInterceptor(grpcauth.UnaryServerInterceptor(opts.AuthFunc)).
256316
EnsureAlreadyExecuted(DefaultMiddlewareGRPCProm). // so that prom middleware reports auth failures
257317
Done(),
258318

259319
NewUnaryMiddleware().
260320
WithName(DefaultMiddlewareServerVersion).
261-
WithInterceptor(serverversion.UnaryServerInterceptor(opts.enableVersionResponse)).
321+
WithInterceptor(serverversion.UnaryServerInterceptor(opts.EnableVersionResponse)).
262322
Done(),
263323

264324
NewUnaryMiddleware().
265325
WithName(DefaultInternalMiddlewareDispatch).
266326
WithInternal(true).
267-
WithInterceptor(dispatchmw.UnaryServerInterceptor(opts.dispatcher)).
327+
WithInterceptor(dispatchmw.UnaryServerInterceptor(opts.DispatcherForMiddleware)).
268328
Done(),
269329

270-
NewUnaryMiddleware().
271-
WithName(DefaultInternalMiddlewareDatastore).
272-
WithInternal(true).
273-
WithInterceptor(datastoremw.UnaryServerInterceptor(opts.ds)).
274-
Done(),
330+
*opts.unaryDatastoreMiddleware,
275331

276332
NewUnaryMiddleware().
277333
WithName(DefaultInternalMiddlewareConsistency).
@@ -290,7 +346,7 @@ func DefaultUnaryMiddleware(opts MiddlewareOption) (*MiddlewareChain[grpc.UnaryS
290346

291347
// DefaultStreamingMiddleware generates the default middleware chain used for the public SpiceDB Streaming gRPC methods
292348
func DefaultStreamingMiddleware(opts MiddlewareOption) (*MiddlewareChain[grpc.StreamServerInterceptor], error) {
293-
_, grpcMetricsStreamingInterceptor := GRPCMetrics(opts.disableGRPCHistogram)
349+
_, grpcMetricsStreamingInterceptor := GRPCMetrics(opts.DisableGRPCHistogram)
294350
chain, err := NewMiddlewareChain([]ReferenceableMiddleware[grpc.StreamServerInterceptor]{
295351
NewStreamMiddleware().
296352
WithName(DefaultMiddlewareRequestID).
@@ -310,15 +366,15 @@ func DefaultStreamingMiddleware(opts MiddlewareOption) (*MiddlewareChain[grpc.St
310366
NewStreamMiddleware().
311367
WithName(DefaultMiddlewareGRPCLog + "-debug").
312368
WithInterceptor(selector.StreamServerInterceptor(
313-
grpclog.StreamServerInterceptor(InterceptorLogger(opts.logger), determineEventsToLog(opts), alwaysDebugOption, durationFieldOption, traceIDFieldOption),
369+
grpclog.StreamServerInterceptor(InterceptorLogger(opts.Logger), determineEventsToLog(opts), alwaysDebugOption, durationFieldOption, traceIDFieldOption),
314370
selector.MatchFunc(matchesRoute(healthCheckRoute)))).
315371
EnsureInterceptorAlreadyExecuted(DefaultMiddlewareOTelGRPC). // dependency so that OTel traceID is injected in logs),
316372
Done(),
317373

318374
NewStreamMiddleware().
319375
WithName(DefaultMiddlewareGRPCLog).
320376
WithInterceptor(selector.StreamServerInterceptor(
321-
grpclog.StreamServerInterceptor(InterceptorLogger(opts.logger), determineEventsToLog(opts), defaultCodeToLevel, durationFieldOption, traceIDFieldOption),
377+
grpclog.StreamServerInterceptor(InterceptorLogger(opts.Logger), determineEventsToLog(opts), defaultCodeToLevel, durationFieldOption, traceIDFieldOption),
322378
selector.MatchFunc(doesNotMatchRoute(healthCheckRoute)))).
323379
EnsureInterceptorAlreadyExecuted(DefaultMiddlewareOTelGRPC). // dependency so that OTel traceID is injected in logs),
324380
Done(),
@@ -330,26 +386,22 @@ func DefaultStreamingMiddleware(opts MiddlewareOption) (*MiddlewareChain[grpc.St
330386

331387
NewStreamMiddleware().
332388
WithName(DefaultMiddlewareGRPCAuth).
333-
WithInterceptor(grpcauth.StreamServerInterceptor(opts.authFunc)).
389+
WithInterceptor(grpcauth.StreamServerInterceptor(opts.AuthFunc)).
334390
EnsureInterceptorAlreadyExecuted(DefaultMiddlewareGRPCProm). // so that prom middleware reports auth failures
335391
Done(),
336392

337393
NewStreamMiddleware().
338394
WithName(DefaultMiddlewareServerVersion).
339-
WithInterceptor(serverversion.StreamServerInterceptor(opts.enableVersionResponse)).
395+
WithInterceptor(serverversion.StreamServerInterceptor(opts.EnableVersionResponse)).
340396
Done(),
341397

342398
NewStreamMiddleware().
343399
WithName(DefaultInternalMiddlewareDispatch).
344400
WithInternal(true).
345-
WithInterceptor(dispatchmw.StreamServerInterceptor(opts.dispatcher)).
401+
WithInterceptor(dispatchmw.StreamServerInterceptor(opts.DispatcherForMiddleware)).
346402
Done(),
347403

348-
NewStreamMiddleware().
349-
WithName(DefaultInternalMiddlewareDatastore).
350-
WithInternal(true).
351-
WithInterceptor(datastoremw.StreamServerInterceptor(opts.ds)).
352-
Done(),
404+
*opts.streamDatastoreMiddleware,
353405

354406
NewStreamMiddleware().
355407
WithName(DefaultInternalMiddlewareConsistency).
@@ -368,11 +420,11 @@ func DefaultStreamingMiddleware(opts MiddlewareOption) (*MiddlewareChain[grpc.St
368420

369421
func determineEventsToLog(opts MiddlewareOption) grpclog.Option {
370422
eventsToLog := []grpclog.LoggableEvent{grpclog.FinishCall}
371-
if opts.enableRequestLog {
423+
if opts.EnableRequestLog {
372424
eventsToLog = append(eventsToLog, grpclog.PayloadReceived)
373425
}
374426

375-
if opts.enableResponseLog {
427+
if opts.EnableResponseLog {
376428
eventsToLog = append(eventsToLog, grpclog.PayloadSent)
377429
}
378430

pkg/cmd/server/server.go

+4-1
Original file line numberDiff line numberDiff line change
@@ -381,11 +381,14 @@ func (c *Config) Complete(ctx context.Context) (RunnableServer, error) {
381381
c.GRPCAuthFunc,
382382
!c.DisableVersionResponse,
383383
dispatcher,
384-
ds,
385384
c.EnableRequestLogs,
386385
c.EnableResponseLogs,
387386
c.DisableGRPCLatencyHistogram,
387+
nil,
388+
nil,
388389
}
390+
opts = opts.WithDatastore(ds)
391+
389392
defaultUnaryMiddlewareChain, err := DefaultUnaryMiddleware(opts)
390393
if err != nil {
391394
return nil, fmt.Errorf("error building default middlewares: %w", err)

pkg/cmd/server/server_test.go

+6-2
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,9 @@ func TestModifyUnaryMiddleware(t *testing.T) {
231231
},
232232
}}
233233

234-
opt := MiddlewareOption{logging.Logger, nil, false, nil, nil, false, false, false}
234+
opt := MiddlewareOption{logging.Logger, nil, false, nil, false, false, false, nil, nil}
235+
opt = opt.WithDatastore(nil)
236+
235237
defaultMw, err := DefaultUnaryMiddleware(opt)
236238
require.NoError(t, err)
237239

@@ -257,7 +259,9 @@ func TestModifyStreamingMiddleware(t *testing.T) {
257259
},
258260
}}
259261

260-
opt := MiddlewareOption{logging.Logger, nil, false, nil, nil, false, false, false}
262+
opt := MiddlewareOption{logging.Logger, nil, false, nil, false, false, false, nil, nil}
263+
opt = opt.WithDatastore(nil)
264+
261265
defaultMw, err := DefaultStreamingMiddleware(opt)
262266
require.NoError(t, err)
263267

pkg/cmd/server/zz_generated.middlewareoption.go

+121
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)