Skip to content

Commit 254cb2d

Browse files
yuehaiisd2k
authored andcommitted
fix: notification break the client tool call (#642)
* feat: client roots feature * feat: finish client roots, pass unit and integration test * client roots http sample code * client roots for stdio and pass integration test * update roots stio client example * add godoc and const of rootlist * update godoc and data format * update examples for client roots * add fallback for demonstration * adjust roots path and signals of examples * update roots http client example * samples: fix unit test and refactor with lint * examples: refactor to adapt windows os and nitpick comments * update for nitpick comments * refactor for nitpick comments * fix: notifications breaking the tool call Signed-off-by: hai.yue <[email protected]> * add a regression test #642 (comment) Signed-off-by: hai.yue <[email protected]>
1 parent 37e9cb5 commit 254cb2d

File tree

2 files changed

+91
-0
lines changed

2 files changed

+91
-0
lines changed

server/session.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,9 @@ func (s *MCPServer) SendLogMessageToClient(ctx context.Context, notification mcp
171171
func (s *MCPServer) sendNotificationToAllClients(notification mcp.JSONRPCNotification) {
172172
s.sessions.Range(func(k, v any) bool {
173173
if session, ok := v.(ClientSession); ok && session.Initialized() {
174+
if sessionWithStreamableHTTPConfig, ok := session.(SessionWithStreamableHTTPConfig); ok {
175+
sessionWithStreamableHTTPConfig.UpgradeToSSEWhenReceiveNotification()
176+
}
174177
select {
175178
case session.NotificationChannel() <- notification:
176179
// Successfully sent notification

server/streamable_http_test.go

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2251,3 +2251,91 @@ func TestStreamableHTTP_SendNotificationToSpecificClient(t *testing.T) {
22512251
}
22522252
})
22532253
}
2254+
2255+
// TestStreamableHTTP_AddToolDuringToolCall tests that adding a tool while a tool call
2256+
// is in progress doesn't break the client's response.
2257+
// This is a regression test for issue #638 where notifications sent via
2258+
// sendNotificationToAllClients during an in-progress request would cause
2259+
// the response to fail with "unexpected nil response".
2260+
func TestStreamableHTTP_AddToolDuringToolCall(t *testing.T) {
2261+
mcpServer := NewMCPServer("test-mcp-server", "1.0",
2262+
WithToolCapabilities(true), // Enable tool list change notifications
2263+
)
2264+
// Add a tool that takes some time to complete
2265+
mcpServer.AddTool(mcp.NewTool("slow_tool",
2266+
mcp.WithDescription("A tool that takes time to complete"),
2267+
), func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
2268+
// Simulate work that takes some time
2269+
time.Sleep(100 * time.Millisecond)
2270+
return mcp.NewToolResultText("done"), nil
2271+
})
2272+
server := NewTestStreamableHTTPServer(mcpServer, WithStateful(true))
2273+
defer server.Close()
2274+
// Initialize to get session
2275+
resp, err := postJSON(server.URL, initRequest)
2276+
if err != nil {
2277+
t.Fatalf("Failed to initialize: %v", err)
2278+
}
2279+
sessionID := resp.Header.Get(HeaderKeySessionID)
2280+
resp.Body.Close()
2281+
if sessionID == "" {
2282+
t.Fatal("Expected session ID in response header")
2283+
}
2284+
// Start the tool call in a goroutine
2285+
resultChan := make(chan struct {
2286+
statusCode int
2287+
body string
2288+
err error
2289+
})
2290+
go func() {
2291+
toolRequest := map[string]any{
2292+
"jsonrpc": "2.0",
2293+
"id": 1,
2294+
"method": "tools/call",
2295+
"params": map[string]any{
2296+
"name": "slow_tool",
2297+
},
2298+
}
2299+
toolBody, _ := json.Marshal(toolRequest)
2300+
req, _ := http.NewRequest("POST", server.URL, bytes.NewReader(toolBody))
2301+
req.Header.Set("Content-Type", "application/json")
2302+
req.Header.Set(HeaderKeySessionID, sessionID)
2303+
resp, err := server.Client().Do(req)
2304+
if err != nil {
2305+
resultChan <- struct {
2306+
statusCode int
2307+
body string
2308+
err error
2309+
}{0, "", err}
2310+
return
2311+
}
2312+
defer resp.Body.Close()
2313+
body, _ := io.ReadAll(resp.Body)
2314+
resultChan <- struct {
2315+
statusCode int
2316+
body string
2317+
err error
2318+
}{resp.StatusCode, string(body), nil}
2319+
}()
2320+
// Wait a bit then add a new tool while the slow_tool is executing
2321+
// This triggers sendNotificationToAllClients
2322+
time.Sleep(50 * time.Millisecond)
2323+
mcpServer.AddTool(mcp.NewTool("new_tool",
2324+
mcp.WithDescription("A new tool added during execution"),
2325+
), func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
2326+
return mcp.NewToolResultText("new tool result"), nil
2327+
})
2328+
// Wait for the tool call to complete
2329+
result := <-resultChan
2330+
if result.err != nil {
2331+
t.Fatalf("Tool call failed with error: %v", result.err)
2332+
}
2333+
if result.statusCode != http.StatusOK {
2334+
t.Errorf("Expected status 200, got %d. Body: %s", result.statusCode, result.body)
2335+
}
2336+
// The response should contain the tool result
2337+
// It may be SSE format (text/event-stream) due to the notification upgrade
2338+
if !strings.Contains(result.body, "done") {
2339+
t.Errorf("Expected response to contain 'done', got: %s", result.body)
2340+
}
2341+
}

0 commit comments

Comments
 (0)