From e76a774793dfa4e164fea29365907608bd3a54cc Mon Sep 17 00:00:00 2001 From: "inaku@wsl" Date: Tue, 4 Nov 2025 22:13:27 +0800 Subject: [PATCH 1/4] feat: enhance `ClientSessionGroup.call_tool` method to support full `ClientSession.callback` --- src/mcp/client/session_group.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index ec2eb18fe..36e4efa12 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -25,6 +25,7 @@ from mcp.client.stdio import StdioServerParameters from mcp.client.streamable_http import streamablehttp_client from mcp.shared.exceptions import McpError +from mcp.shared.session import ProgressFnT class SseServerParameters(BaseModel): @@ -172,11 +173,25 @@ def tools(self) -> dict[str, types.Tool]: """Returns the tools as a dictionary of names to tools.""" return self._tools - async def call_tool(self, name: str, args: dict[str, Any]) -> types.CallToolResult: + async def call_tool( + self, + name: str, + args: dict[str, Any], + read_timeout_seconds: timedelta | None = None, + progress_callback: ProgressFnT | None = None, + *, + meta: dict[str, Any] | None = None, + ) -> types.CallToolResult: """Executes a tool given its name and arguments.""" session = self._tool_to_session[name] session_tool_name = self.tools[name].name - return await session.call_tool(session_tool_name, args) + return await session.call_tool( + session_tool_name, + args, + read_timeout_seconds=read_timeout_seconds, + progress_callback=progress_callback, + meta=meta, + ) async def disconnect_from_server(self, session: mcp.ClientSession) -> None: """Disconnects from a single MCP server.""" From 47bc4ca18ceeef76b448b09dd2a2b74b622946b3 Mon Sep 17 00:00:00 2001 From: "inaku@wsl" Date: Tue, 4 Nov 2025 22:28:41 +0800 Subject: [PATCH 2/4] update `call_tool` method to replace `args` with `arguments` parameter --- src/mcp/client/session_group.py | 30 +++++++++++++++++++++++++++--- tests/client/test_session_group.py | 2 +- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index 36e4efa12..532c65ee6 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -13,11 +13,11 @@ from collections.abc import Callable from datetime import timedelta from types import TracebackType -from typing import Any, TypeAlias +from typing import Any, TypeAlias, overload import anyio from pydantic import BaseModel -from typing_extensions import Self +from typing_extensions import Self, deprecated import mcp from mcp import types @@ -173,21 +173,45 @@ def tools(self) -> dict[str, types.Tool]: """Returns the tools as a dictionary of names to tools.""" return self._tools + @overload async def call_tool( self, name: str, + arguments: dict[str, Any], + read_timeout_seconds: timedelta | None = None, + progress_callback: ProgressFnT | None = None, + *, + meta: dict[str, Any] | None = None, + ) -> types.CallToolResult: ... + + @overload + @deprecated("The 'args' parameter is deprecated. Use 'arguments' instead.") + async def call_tool( + self, + name: str, + *, args: dict[str, Any], read_timeout_seconds: timedelta | None = None, progress_callback: ProgressFnT | None = None, + meta: dict[str, Any] | None = None, + ) -> types.CallToolResult: ... + + async def call_tool( + self, + name: str, + arguments: dict[str, Any] | None = None, + read_timeout_seconds: timedelta | None = None, + progress_callback: ProgressFnT | None = None, *, meta: dict[str, Any] | None = None, + args: dict[str, Any] | None = None, ) -> types.CallToolResult: """Executes a tool given its name and arguments.""" session = self._tool_to_session[name] session_tool_name = self.tools[name].name return await session.call_tool( session_tool_name, - args, + arguments if args is None else args, read_timeout_seconds=read_timeout_seconds, progress_callback=progress_callback, meta=meta, diff --git a/tests/client/test_session_group.py b/tests/client/test_session_group.py index c38cfeabc..f8a577974 100644 --- a/tests/client/test_session_group.py +++ b/tests/client/test_session_group.py @@ -62,7 +62,7 @@ def hook(name: str, server_info: types.Implementation) -> str: # --- Test Execution --- result = await mcp_session_group.call_tool( name="server1-my_tool", - args={ + arguments={ "name": "value1", "args": {}, }, From 2ed88b8e232baf03edfeaa81f79bc8ba3efe9329 Mon Sep 17 00:00:00 2001 From: "inaku@wsl" Date: Tue, 4 Nov 2025 23:03:11 +0800 Subject: [PATCH 3/4] feat: add `ClientSessionParameters` to enhance `ClientSessionGroup.connect_to_server` method --- src/mcp/client/session_group.py | 43 ++++++++++++++++++++++++++++++--- 1 file changed, 40 insertions(+), 3 deletions(-) diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index 532c65ee6..26a291779 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -11,6 +11,7 @@ import contextlib import logging from collections.abc import Callable +from dataclasses import dataclass from datetime import timedelta from types import TracebackType from typing import Any, TypeAlias, overload @@ -27,6 +28,8 @@ from mcp.shared.exceptions import McpError from mcp.shared.session import ProgressFnT +from .session import ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT + class SseServerParameters(BaseModel): """Parameters for intializing a sse_client.""" @@ -66,6 +69,21 @@ class StreamableHttpParameters(BaseModel): ServerParameters: TypeAlias = StdioServerParameters | SseServerParameters | StreamableHttpParameters +# Use dataclass instead of pydantic BaseModel +# because pydantic BaseModel cannot handle Protocol fields. +@dataclass +class ClientSessionParameters: + """Parameters for establishing a client session to an MCP server.""" + + read_timeout_seconds: timedelta | None = None + sampling_callback: SamplingFnT | None = None + elicitation_callback: ElicitationFnT | None = None + list_roots_callback: ListRootsFnT | None = None + logging_callback: LoggingFnT | None = None + message_handler: MessageHandlerFnT | None = None + client_info: types.Implementation | None = None + + class ClientSessionGroup: """Client for managing connections to multiple MCP servers. @@ -264,13 +282,16 @@ async def connect_with_session( async def connect_to_server( self, server_params: ServerParameters, + session_params: ClientSessionParameters | None = None, ) -> mcp.ClientSession: """Connects to a single MCP server.""" - server_info, session = await self._establish_session(server_params) + server_info, session = await self._establish_session(server_params, session_params) return await self.connect_with_session(server_info, session) async def _establish_session( - self, server_params: ServerParameters + self, + server_params: ServerParameters, + session_params: ClientSessionParameters | None = None, ) -> tuple[types.Implementation, mcp.ClientSession]: """Establish a client session to an MCP server.""" @@ -298,7 +319,23 @@ async def _establish_session( ) read, write, _ = await session_stack.enter_async_context(client) - session = await session_stack.enter_async_context(mcp.ClientSession(read, write)) + if session_params is None: + session = await session_stack.enter_async_context(mcp.ClientSession(read, write)) + else: + session = await session_stack.enter_async_context( + mcp.ClientSession( + read, + write, + read_timeout_seconds=session_params.read_timeout_seconds, + sampling_callback=session_params.sampling_callback, + elicitation_callback=session_params.elicitation_callback, + list_roots_callback=session_params.list_roots_callback, + logging_callback=session_params.logging_callback, + message_handler=session_params.message_handler, + client_info=session_params.client_info, + ) + ) + result = await session.initialize() # Session successfully initialized. From d137b5eb228537ac3bf03abb7da676b3c0494a8c Mon Sep 17 00:00:00 2001 From: "inaku@wsl" Date: Tue, 4 Nov 2025 23:44:09 +0800 Subject: [PATCH 4/4] fix `test_call_tool` assertion --- tests/client/test_session_group.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/client/test_session_group.py b/tests/client/test_session_group.py index f8a577974..584c9bddf 100644 --- a/tests/client/test_session_group.py +++ b/tests/client/test_session_group.py @@ -73,6 +73,9 @@ def hook(name: str, server_info: types.Implementation) -> str: mock_session.call_tool.assert_called_once_with( "my_tool", {"name": "value1", "args": {}}, + read_timeout_seconds=None, + progress_callback=None, + meta=None, ) async def test_connect_to_server(self, mock_exit_stack: contextlib.AsyncExitStack):