Skip to content

Commit d79ab63

Browse files
committed
added tests coverage and fixed e2e tests
1 parent 31946bf commit d79ab63

File tree

8 files changed

+151
-16
lines changed

8 files changed

+151
-16
lines changed

pkg/audit/auditor.go

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ func (a *Auditor) Middleware(next http.Handler) http.Handler {
9696
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
9797
// Handle SSE endpoints specially - log the connection event immediately
9898
// since SSE connections are long-lived and don't follow normal request/response pattern
99-
if a.isSSETransport() {
99+
if a.isSSETransport() && r.Method == http.MethodGet {
100100
// Log SSE connection event immediately
101101
a.logSSEConnectionEvent(r)
102102

@@ -188,16 +188,12 @@ func (a *Auditor) determineEventType(r *http.Request) string {
188188
return a.mapMCPMethodToEventType(mcpMethod)
189189
}
190190

191-
// Fall back to path-based detection for non-MCP requests
192-
path := r.URL.Path
193-
194191
// Handle SSE connection establishment
195-
if a.isSSETransport() {
196-
return EventTypeMCPInitialize
192+
if a.isSSETransport() && r.Method == http.MethodGet {
193+
return EventTypeSSEConnection
197194
}
198-
199195
// Handle MCP message endpoints that weren't parsed (malformed requests)
200-
if strings.Contains(path, "/messages") && r.Method == "POST" {
196+
if a.isSSETransport() && r.Method == http.MethodPost {
201197
return EventTypeMCPRequest
202198
}
203199

@@ -450,7 +446,7 @@ func (a *Auditor) logSSEConnectionEvent(r *http.Request) {
450446
component := a.determineComponent(r)
451447

452448
// Create the audit event for SSE connection
453-
event := NewAuditEvent("sse_connection", source, OutcomeSuccess, subjects, component)
449+
event := NewAuditEvent(EventTypeSSEConnection, source, OutcomeSuccess, subjects, component)
454450

455451
// Add target information
456452
target := map[string]string{
@@ -462,7 +458,7 @@ func (a *Auditor) logSSEConnectionEvent(r *http.Request) {
462458

463459
// Add metadata
464460
event.Metadata.Extra = map[string]any{
465-
"transport": a.transportType,
461+
"transport": "sse",
466462
"user_agent": r.Header.Get("User-Agent"),
467463
}
468464

pkg/audit/auditor_test.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"bytes"
55
"context"
66
"encoding/json"
7+
"fmt"
78
"net/http"
89
"net/http/httptest"
910
"strings"
@@ -112,6 +113,42 @@ func TestAuditorMiddlewareWithResponseData(t *testing.T) {
112113
assert.Equal(t, responseData, rr.Body.String())
113114
}
114115

116+
func TestAuditorMiddlewareWithDifferentSSEPaths(t *testing.T) {
117+
t.Parallel()
118+
config := &Config{}
119+
auditor, err := NewAuditorWithTransport(config, "sse")
120+
require.NoError(t, err)
121+
122+
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
123+
w.WriteHeader(http.StatusOK)
124+
w.Write([]byte("test response"))
125+
})
126+
127+
middleware := auditor.Middleware(handler)
128+
129+
// Test different SSE paths to ensure transport type detection works correctly
130+
testPaths := []string{
131+
"/sse",
132+
"/v1/sse",
133+
"/api/sse",
134+
"/mcp/v2/sse",
135+
"/events", // Non-SSE path but SSE transport
136+
}
137+
138+
for _, path := range testPaths {
139+
t.Run(fmt.Sprintf("path_%s", strings.ReplaceAll(path, "/", "_")), func(t *testing.T) {
140+
req := httptest.NewRequest("GET", path, nil)
141+
rr := httptest.NewRecorder()
142+
143+
middleware.ServeHTTP(rr, req)
144+
145+
// All requests should succeed regardless of path since transport type is SSE
146+
assert.Equal(t, http.StatusOK, rr.Code)
147+
assert.Equal(t, "test response", rr.Body.String())
148+
})
149+
}
150+
}
151+
115152
func TestDetermineEventType(t *testing.T) {
116153
t.Parallel()
117154

@@ -129,6 +166,34 @@ func TestDetermineEventType(t *testing.T) {
129166
transport: "sse",
130167
expected: EventTypeMCPInitialize,
131168
},
169+
{
170+
name: "SSE endpoint with version path",
171+
path: "/v1/sse",
172+
method: "GET",
173+
transport: "sse",
174+
expected: EventTypeMCPInitialize,
175+
},
176+
{
177+
name: "SSE endpoint with API prefix",
178+
path: "/api/sse",
179+
method: "GET",
180+
transport: "sse",
181+
expected: EventTypeMCPInitialize,
182+
},
183+
{
184+
name: "SSE endpoint with nested path",
185+
path: "/mcp/v2/sse",
186+
method: "GET",
187+
transport: "sse",
188+
expected: EventTypeMCPInitialize,
189+
},
190+
{
191+
name: "SSE transport with non-SSE path",
192+
path: "/events",
193+
method: "GET",
194+
transport: "sse",
195+
expected: EventTypeMCPInitialize,
196+
},
132197
{
133198
name: "MCP messages endpoint",
134199
path: "/messages",

pkg/audit/mcp_events.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ package audit
55
const (
66
// EventTypeMCPInitialize represents an MCP initialization event
77
EventTypeMCPInitialize = "mcp_initialize"
8+
// EventTypeSSEConnection represents an SSE connection event
9+
EventTypeSSEConnection = "sse_connection"
810
// EventTypeMCPToolCall represents an MCP tool call event
911
EventTypeMCPToolCall = "mcp_tool_call"
1012
// EventTypeMCPToolsList represents an MCP tools list event

test/e2e/audit_middleware_e2e_test.go

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,48 @@ var _ = Describe("Audit Middleware E2E", Label("middleware", "audit", "sse", "e2
324324
Expect(auditContent).ToNot(BeEmpty())
325325
})
326326
})
327+
328+
Context("when audit middleware is enabled with --enable-audit flag", func() {
329+
It("should capture audit events with default configuration", func() {
330+
By("Starting MCP server with --enable-audit flag")
331+
serverURL := startMCPServerWithEnableAuditFlag(config, workloadName, mcpServerName)
332+
333+
By("Making MCP HTTP requests to trigger audit events")
334+
// Make HTTP request to initialize endpoint
335+
initRequest := map[string]any{
336+
"jsonrpc": "2.0",
337+
"id": "enable-audit-init-1",
338+
"method": "initialize",
339+
"params": map[string]any{
340+
"protocolVersion": "2024-11-05",
341+
"clientInfo": map[string]any{
342+
"name": "enable-audit-test-client",
343+
"version": "1.0.0",
344+
},
345+
},
346+
}
347+
348+
makeHTTPMCPRequest(serverURL, initRequest)
349+
350+
// Make HTTP request to tools/list endpoint
351+
toolsRequest := map[string]any{
352+
"jsonrpc": "2.0",
353+
"id": "enable-audit-tools-1",
354+
"method": "tools/list",
355+
}
356+
357+
makeHTTPMCPRequest(serverURL, toolsRequest)
358+
359+
// Wait for audit events to be processed and written
360+
time.Sleep(3 * time.Second)
361+
362+
By("Verifying audit events were captured with --enable-audit flag")
363+
// With --enable-audit, audit events should be logged to stdout
364+
// We can verify this by checking that the server started successfully
365+
// and made the requests without errors
366+
Expect(serverURL).ToNot(BeEmpty(), "Server should be accessible")
367+
})
368+
})
327369
})
328370

329371
// Helper functions
@@ -379,6 +421,32 @@ func startMCPServerWithAuditConfig(config *e2e.TestConfig, workloadName, mcpServ
379421
return serverURL
380422
}
381423

424+
// startMCPServerWithEnableAuditFlag starts an MCP server with --enable-audit flag
425+
// Returns the server URL for making HTTP requests
426+
func startMCPServerWithEnableAuditFlag(config *e2e.TestConfig, workloadName, mcpServerName string) string {
427+
// Build args for running the MCP server with --enable-audit flag
428+
args := []string{
429+
"run",
430+
"--name", workloadName,
431+
"--transport", "sse", // Use SSE transport for HTTP-based testing
432+
"--enable-audit",
433+
mcpServerName,
434+
}
435+
436+
By(fmt.Sprintf("Starting MCP server with --enable-audit flag: %v", args))
437+
e2e.NewTHVCommand(config, args...).ExpectSuccess()
438+
439+
err := e2e.WaitForMCPServer(config, workloadName, 60*time.Second)
440+
Expect(err).ToNot(HaveOccurred())
441+
442+
// Get the server URL for making HTTP requests
443+
serverURL, err := e2e.GetMCPServerURL(config, workloadName)
444+
Expect(err).ToNot(HaveOccurred())
445+
446+
GinkgoWriter.Printf("MCP Server URL: %s\n", serverURL)
447+
return serverURL
448+
}
449+
382450
// makeHTTPMCPRequest makes an MCP request using the proper MCP client
383451
func makeHTTPMCPRequest(serverURL string, request map[string]any) {
384452
GinkgoWriter.Printf("Making MCP request to %s with payload: %s\n", serverURL, toJSONString(request))

test/e2e/osv_mcp_server_test.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,11 @@ var _ = Describe("OsvMcpServer", Label("mcp", "sse", "e2e"), Serial, func() {
4848
})
4949

5050
It("should successfully start and be accessible via SSE [Serial]", func() {
51-
By("Starting the OSV MCP server with SSE transport")
51+
By("Starting the OSV MCP server with SSE transport and audit enabled")
5252
stdout, stderr := e2e.NewTHVCommand(config, "run",
5353
"--name", serverName,
5454
"--transport", "sse",
55+
"--enable-audit",
5556
"osv").ExpectSuccess()
5657

5758
// The command should indicate success
@@ -69,10 +70,11 @@ var _ = Describe("OsvMcpServer", Label("mcp", "sse", "e2e"), Serial, func() {
6970
})
7071

7172
It("should be accessible via HTTP SSE endpoint [Serial]", func() {
72-
By("Starting the OSV MCP server")
73+
By("Starting the OSV MCP server with audit enabled")
7374
e2e.NewTHVCommand(config, "run",
7475
"--name", serverName,
7576
"--transport", "sse",
77+
"--enable-audit",
7678
"osv").ExpectSuccess()
7779

7880
By("Waiting for the server to be running")

test/e2e/osv_streamable_http_mcp_server_test.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,11 @@ var _ = Describe("OsvStreamableHttpMcpServer", Label("mcp", "streamable-http", "
3939
})
4040

4141
It("should successfully start and be accessible via Streamable HTTP [Serial]", func() {
42-
By("Starting the OSV MCP server with Streamable HTTP transport")
42+
By("Starting the OSV MCP server with Streamable HTTP transport and audit enabled")
4343
stdout, stderr := e2e.NewTHVCommand(config, "run",
4444
"--name", serverName,
4545
"--transport", "streamable-http",
46+
"--enable-audit",
4647
"osv").ExpectSuccess()
4748

4849
// The command should indicate success
@@ -60,10 +61,11 @@ var _ = Describe("OsvStreamableHttpMcpServer", Label("mcp", "streamable-http", "
6061
})
6162

6263
It("should be accessible via HTTP Streamable HTTP endpoint [Serial]", func() {
63-
By("Starting the OSV MCP server")
64+
By("Starting the OSV MCP server with audit enabled")
6465
e2e.NewTHVCommand(config, "run",
6566
"--name", serverName,
6667
"--transport", "streamable-http",
68+
"--enable-audit",
6769
"osv").ExpectSuccess()
6870

6971
By("Waiting for the server to be running")

test/e2e/proxy_stdio_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ var _ = Describe("Proxy Stdio E2E", Label("proxy", "stdio", "e2e"), Serial, func
4343

4444
JustBeforeEach(func() {
4545
// Build args after mcpServerName is set
46-
args := []string{"run", "--name", workloadName, "--transport", transportType.String()}
46+
args := []string{"run", "--name", workloadName, "--transport", transportType.String(), "--enable-audit"}
4747

4848
if transportType == types.TransportTypeStdio {
4949
Expect(proxyMode).ToNot(BeEmpty())

test/e2e/telemetry_middleware_e2e_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ var _ = Describe("Telemetry Middleware E2E", Label("middleware", "telemetry", "e
4343

4444
JustBeforeEach(func() {
4545
// Build args for running the MCP server
46-
args := []string{"run", "--name", workloadName, "--transport", transportType.String()}
46+
args := []string{"run", "--name", workloadName, "--transport", transportType.String(), "--enable-audit"}
4747

4848
if transportType == types.TransportTypeStdio {
4949
Expect(proxyMode).ToNot(BeEmpty())

0 commit comments

Comments
 (0)