Skip to content

Commit 8ef9b48

Browse files
committed
fix transport for auditor middleware
1 parent 8f1a7e9 commit 8ef9b48

File tree

7 files changed

+83
-69
lines changed

7 files changed

+83
-69
lines changed

pkg/audit/auditor.go

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"github.com/stacklok/toolhive/pkg/auth"
1616
"github.com/stacklok/toolhive/pkg/logger"
1717
"github.com/stacklok/toolhive/pkg/mcp"
18+
"github.com/stacklok/toolhive/pkg/transport/types"
1819
)
1920

2021
// LevelAudit is a custom audit log level - between Info and Warn
@@ -35,12 +36,13 @@ func NewAuditLogger(w io.Writer) *slog.Logger {
3536

3637
// Auditor handles audit logging for HTTP requests.
3738
type Auditor struct {
38-
config *Config
39-
auditLogger *slog.Logger
39+
config *Config
40+
auditLogger *slog.Logger
41+
transportType string // e.g., "sse", "streamable-http"
4042
}
4143

42-
// NewAuditor creates a new Auditor with the given configuration.
43-
func NewAuditor(config *Config) (*Auditor, error) {
44+
// NewAuditorWithTransport creates a new Auditor with the given configuration and transport information.
45+
func NewAuditorWithTransport(config *Config, transportType string) (*Auditor, error) {
4446
var logWriter io.Writer = os.Stdout // default to stdout
4547

4648
if config != nil {
@@ -54,11 +56,17 @@ func NewAuditor(config *Config) (*Auditor, error) {
5456
}
5557

5658
return &Auditor{
57-
config: config,
58-
auditLogger: NewAuditLogger(logWriter),
59+
config: config,
60+
auditLogger: NewAuditLogger(logWriter),
61+
transportType: transportType,
5962
}, nil
6063
}
6164

65+
// isSSETransport checks if the current transport is SSE
66+
func (a *Auditor) isSSETransport() bool {
67+
return a.transportType == types.TransportTypeSSE.String()
68+
}
69+
6270
// responseWriter wraps http.ResponseWriter to capture response data and status.
6371
type responseWriter struct {
6472
http.ResponseWriter
@@ -88,7 +96,7 @@ func (a *Auditor) Middleware(next http.Handler) http.Handler {
8896
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
8997
// Handle SSE endpoints specially - log the connection event immediately
9098
// since SSE connections are long-lived and don't follow normal request/response pattern
91-
if r.URL.Path == "/sse" {
99+
if a.isSSETransport() {
92100
// Log SSE connection event immediately
93101
a.logSSEConnectionEvent(r)
94102

@@ -164,7 +172,7 @@ func (a *Auditor) logAuditEvent(r *http.Request, rw *responseWriter, requestData
164172
}
165173

166174
// Add metadata
167-
a.addMetadata(event, r, duration, rw)
175+
a.addMetadata(event, duration, rw)
168176

169177
// Add request/response data if configured
170178
a.addEventData(event, r, rw, requestData)
@@ -184,7 +192,7 @@ func (a *Auditor) determineEventType(r *http.Request) string {
184192
path := r.URL.Path
185193

186194
// Handle SSE connection establishment
187-
if strings.Contains(path, "/sse") {
195+
if a.isSSETransport() {
188196
return EventTypeMCPInitialize
189197
}
190198

@@ -372,7 +380,7 @@ func (*Auditor) extractTarget(r *http.Request, eventType string) map[string]stri
372380
}
373381

374382
// addMetadata adds metadata to the audit event.
375-
func (*Auditor) addMetadata(event *AuditEvent, r *http.Request, duration time.Duration, rw *responseWriter) {
383+
func (a *Auditor) addMetadata(event *AuditEvent, duration time.Duration, rw *responseWriter) {
376384
if event.Metadata.Extra == nil {
377385
event.Metadata.Extra = make(map[string]any)
378386
}
@@ -381,11 +389,7 @@ func (*Auditor) addMetadata(event *AuditEvent, r *http.Request, duration time.Du
381389
event.Metadata.Extra[MetadataExtraKeyDuration] = duration.Milliseconds()
382390

383391
// Add transport information
384-
if strings.Contains(r.URL.Path, "/sse") {
385-
event.Metadata.Extra[MetadataExtraKeyTransport] = "sse"
386-
} else {
387-
event.Metadata.Extra[MetadataExtraKeyTransport] = "http"
388-
}
392+
event.Metadata.Extra[MetadataExtraKeyTransport] = a.transportType
389393

390394
// Add response size if available
391395
if rw.body != nil {
@@ -454,7 +458,7 @@ func (a *Auditor) logSSEConnectionEvent(r *http.Request) {
454458

455459
// Add metadata
456460
event.Metadata.Extra = map[string]any{
457-
"transport": "sse",
461+
"transport": a.transportType,
458462
"user_agent": r.Header.Get("User-Agent"),
459463
}
460464

pkg/audit/auditor_test.go

Lines changed: 41 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ func init() {
2626
func TestNewAuditor(t *testing.T) {
2727
t.Parallel()
2828
config := &Config{}
29-
auditor, err := NewAuditor(config)
29+
auditor, err := NewAuditorWithTransport(config, "sse")
3030

3131
assert.NoError(t, err)
3232
assert.NotNil(t, auditor)
@@ -36,7 +36,7 @@ func TestNewAuditor(t *testing.T) {
3636
func TestAuditorMiddlewareDisabled(t *testing.T) {
3737
t.Parallel()
3838
config := &Config{}
39-
auditor, err := NewAuditor(config)
39+
auditor, err := NewAuditorWithTransport(config, "sse")
4040
require.NoError(t, err)
4141

4242
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
@@ -61,7 +61,7 @@ func TestAuditorMiddlewareWithRequestData(t *testing.T) {
6161
IncludeRequestData: true,
6262
MaxDataSize: 1024,
6363
}
64-
auditor, err := NewAuditor(config)
64+
auditor, err := NewAuditorWithTransport(config, "sse")
6565
require.NoError(t, err)
6666

6767
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -91,7 +91,7 @@ func TestAuditorMiddlewareWithResponseData(t *testing.T) {
9191
IncludeResponseData: true,
9292
MaxDataSize: 1024,
9393
}
94-
auditor, err := NewAuditor(config)
94+
auditor, err := NewAuditorWithTransport(config, "sse")
9595
require.NoError(t, err)
9696

9797
responseData := `{"result": "success"}`
@@ -114,38 +114,43 @@ func TestAuditorMiddlewareWithResponseData(t *testing.T) {
114114

115115
func TestDetermineEventType(t *testing.T) {
116116
t.Parallel()
117-
auditor, err := NewAuditor(&Config{})
118-
require.NoError(t, err)
119117

120118
tests := []struct {
121-
name string
122-
path string
123-
method string
124-
expected string
119+
name string
120+
path string
121+
method string
122+
transport string
123+
expected string
125124
}{
126125
{
127-
name: "SSE endpoint",
128-
path: "/sse",
129-
method: "GET",
130-
expected: EventTypeMCPInitialize,
126+
name: "SSE endpoint",
127+
path: "/sse",
128+
method: "GET",
129+
transport: "sse",
130+
expected: EventTypeMCPInitialize,
131131
},
132132
{
133-
name: "MCP messages endpoint",
134-
path: "/messages",
135-
method: "POST",
136-
expected: "mcp_request", // Since extractMCPMethod returns empty
133+
name: "MCP messages endpoint",
134+
path: "/messages",
135+
method: "POST",
136+
transport: "streamable-http",
137+
expected: "mcp_request", // Since extractMCPMethod returns empty
137138
},
138139
{
139-
name: "Regular HTTP request",
140-
path: "/api/health",
141-
method: "GET",
142-
expected: "http_request",
140+
name: "Regular HTTP request",
141+
path: "/api/health",
142+
method: "GET",
143+
transport: "streamable-http",
144+
expected: "http_request",
143145
},
144146
}
145147

146148
for _, tt := range tests {
147149
t.Run(tt.name, func(t *testing.T) {
148150
t.Parallel()
151+
auditor, err := NewAuditorWithTransport(&Config{}, tt.transport)
152+
require.NoError(t, err)
153+
149154
req := httptest.NewRequest(tt.method, tt.path, nil)
150155
result := auditor.determineEventType(req)
151156
assert.Equal(t, tt.expected, result)
@@ -174,7 +179,7 @@ func TestMapMCPMethodToEventType(t *testing.T) {
174179
{"unknown_method", "mcp_request"},
175180
}
176181

177-
auditor, err := NewAuditor(&Config{})
182+
auditor, err := NewAuditorWithTransport(&Config{}, "sse")
178183
require.NoError(t, err)
179184
for _, tt := range tests {
180185
t.Run(tt.mcpMethod, func(t *testing.T) {
@@ -187,7 +192,7 @@ func TestMapMCPMethodToEventType(t *testing.T) {
187192

188193
func TestDetermineOutcome(t *testing.T) {
189194
t.Parallel()
190-
auditor, err := NewAuditor(&Config{})
195+
auditor, err := NewAuditorWithTransport(&Config{}, "sse")
191196
require.NoError(t, err)
192197

193198
tests := []struct {
@@ -218,7 +223,7 @@ func TestDetermineOutcome(t *testing.T) {
218223

219224
func TestGetClientIP(t *testing.T) {
220225
t.Parallel()
221-
auditor, err := NewAuditor(&Config{})
226+
auditor, err := NewAuditorWithTransport(&Config{}, "sse")
222227
require.NoError(t, err)
223228

224229
tests := []struct {
@@ -268,7 +273,7 @@ func TestGetClientIP(t *testing.T) {
268273

269274
func TestExtractSubjects(t *testing.T) {
270275
t.Parallel()
271-
auditor, err := NewAuditor(&Config{})
276+
auditor, err := NewAuditorWithTransport(&Config{}, "sse")
272277
require.NoError(t, err)
273278

274279
t.Run("with JWT claims", func(t *testing.T) {
@@ -342,7 +347,7 @@ func TestDetermineComponent(t *testing.T) {
342347
t.Run("with configured component", func(t *testing.T) {
343348
t.Parallel()
344349
config := &Config{Component: "custom-component"}
345-
auditor, err := NewAuditor(config)
350+
auditor, err := NewAuditorWithTransport(config, "sse")
346351
require.NoError(t, err)
347352

348353
req := httptest.NewRequest("GET", "/test", nil)
@@ -354,7 +359,7 @@ func TestDetermineComponent(t *testing.T) {
354359
t.Run("without configured component", func(t *testing.T) {
355360
t.Parallel()
356361
config := &Config{}
357-
auditor, err := NewAuditor(config)
362+
auditor, err := NewAuditorWithTransport(config, "sse")
358363
require.NoError(t, err)
359364

360365
req := httptest.NewRequest("GET", "/test", nil)
@@ -366,7 +371,7 @@ func TestDetermineComponent(t *testing.T) {
366371

367372
func TestExtractTarget(t *testing.T) {
368373
t.Parallel()
369-
auditor, err := NewAuditor(&Config{})
374+
auditor, err := NewAuditorWithTransport(&Config{}, "sse")
370375
require.NoError(t, err)
371376

372377
tests := []struct {
@@ -423,18 +428,17 @@ func TestExtractTarget(t *testing.T) {
423428

424429
func TestAddMetadata(t *testing.T) {
425430
t.Parallel()
426-
auditor, err := NewAuditor(&Config{})
431+
auditor, err := NewAuditorWithTransport(&Config{}, "sse")
427432
require.NoError(t, err)
428433

429434
event := NewAuditEvent("test", EventSource{}, OutcomeSuccess, map[string]string{}, "test")
430-
req := httptest.NewRequest("GET", "/sse/test", nil)
431435
duration := 150 * time.Millisecond
432436
rw := &responseWriter{
433437
ResponseWriter: httptest.NewRecorder(),
434438
body: bytes.NewBufferString("test response"),
435439
}
436440

437-
auditor.addMetadata(event, req, duration, rw)
441+
auditor.addMetadata(event, duration, rw)
438442

439443
require.NotNil(t, event.Metadata.Extra)
440444
assert.Equal(t, int64(150), event.Metadata.Extra[MetadataExtraKeyDuration])
@@ -450,7 +454,7 @@ func TestAddEventData(t *testing.T) {
450454
IncludeRequestData: true,
451455
IncludeResponseData: true,
452456
}
453-
auditor, err := NewAuditor(config)
457+
auditor, err := NewAuditorWithTransport(config, "sse")
454458
require.NoError(t, err)
455459

456460
event := NewAuditEvent("test", EventSource{}, OutcomeSuccess, map[string]string{}, "test")
@@ -483,7 +487,7 @@ func TestAddEventData(t *testing.T) {
483487
IncludeRequestData: true,
484488
IncludeResponseData: true,
485489
}
486-
auditor, err := NewAuditor(config)
490+
auditor, err := NewAuditorWithTransport(config, "sse")
487491
require.NoError(t, err)
488492

489493
event := NewAuditEvent("test", EventSource{}, OutcomeSuccess, map[string]string{}, "test")
@@ -511,7 +515,7 @@ func TestAddEventData(t *testing.T) {
511515
IncludeRequestData: false,
512516
IncludeResponseData: false,
513517
}
514-
auditor, err := NewAuditor(config)
518+
auditor, err := NewAuditorWithTransport(config, "sse")
515519
require.NoError(t, err)
516520

517521
event := NewAuditEvent("test", EventSource{}, OutcomeSuccess, map[string]string{}, "test")
@@ -531,7 +535,7 @@ func TestResponseWriterCapture(t *testing.T) {
531535
IncludeResponseData: true,
532536
MaxDataSize: 10, // Small limit for testing
533537
}
534-
auditor, err := NewAuditor(config)
538+
auditor, err := NewAuditorWithTransport(config, "sse")
535539
require.NoError(t, err)
536540

537541
rw := &responseWriter{
@@ -568,7 +572,7 @@ func TestResponseWriterStatusCode(t *testing.T) {
568572

569573
func TestExtractSourceWithHeaders(t *testing.T) {
570574
t.Parallel()
571-
auditor, err := NewAuditor(&Config{})
575+
auditor, err := NewAuditorWithTransport(&Config{}, "sse")
572576
require.NoError(t, err)
573577

574578
req := httptest.NewRequest("GET", "/test", nil)

pkg/audit/config.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,25 +104,26 @@ func (c *Config) ShouldAuditEvent(eventType string) bool {
104104
return true
105105
}
106106

107-
// CreateMiddleware creates an HTTP middleware from the audit configuration.
108-
func (c *Config) CreateMiddleware() (types.MiddlewareFunction, error) {
109-
auditor, err := NewAuditor(c)
107+
// CreateMiddlewareWithTransport creates an HTTP middleware from the audit configuration with transport information.
108+
func (c *Config) CreateMiddlewareWithTransport(transportType string) (types.MiddlewareFunction, error) {
109+
auditor, err := NewAuditorWithTransport(c, transportType)
110110
if err != nil {
111111
return nil, fmt.Errorf("failed to create auditor: %w", err)
112112
}
113113
return auditor.Middleware, nil
114114
}
115115

116116
// GetMiddlewareFromFile loads the audit configuration from a file and creates an HTTP middleware.
117-
func GetMiddlewareFromFile(path string) (func(http.Handler) http.Handler, error) {
117+
// Note: This function requires a transport type to be provided separately.
118+
func GetMiddlewareFromFile(path string, transportType string) (func(http.Handler) http.Handler, error) {
118119
// Load the configuration
119120
config, err := LoadFromFile(path)
120121
if err != nil {
121122
return nil, fmt.Errorf("failed to load audit config: %w", err)
122123
}
123124

124-
// Create the middleware
125-
return config.CreateMiddleware()
125+
// Create the middleware with transport information
126+
return config.CreateMiddlewareWithTransport(transportType)
126127
}
127128

128129
// Validate validates the audit configuration.

pkg/audit/config_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ func TestCreateMiddleware(t *testing.T) {
112112
t.Parallel()
113113
config := &Config{}
114114

115-
middleware, err := config.CreateMiddleware()
115+
middleware, err := config.CreateMiddlewareWithTransport("sse")
116116
assert.NoError(t, err)
117117
assert.NotNil(t, middleware)
118118
}
@@ -236,7 +236,7 @@ func TestConfigMinimalJSON(t *testing.T) {
236236
func TestGetMiddlewareFromFileError(t *testing.T) {
237237
t.Parallel()
238238
// Test with non-existent file
239-
_, err := GetMiddlewareFromFile("/non/existent/file.json")
239+
_, err := GetMiddlewareFromFile("/non/existent/file.json", "sse")
240240
assert.Error(t, err)
241241
assert.Contains(t, err.Error(), "failed to load audit config")
242242
}

0 commit comments

Comments
 (0)