diff --git a/mcp_proxy_for_aws/__init__.py b/mcp_proxy_for_aws/__init__.py index 21ceeeb..0c72d68 100644 --- a/mcp_proxy_for_aws/__init__.py +++ b/mcp_proxy_for_aws/__init__.py @@ -16,6 +16,8 @@ from importlib.metadata import version as _metadata_version +import mcp_proxy_for_aws.fastmcp_patch as _fastmcp_patch + __all__ = ['__version__'] __version__ = _metadata_version('mcp-proxy-for-aws') diff --git a/mcp_proxy_for_aws/fastmcp_patch.py b/mcp_proxy_for_aws/fastmcp_patch.py new file mode 100644 index 0000000..0ea949f --- /dev/null +++ b/mcp_proxy_for_aws/fastmcp_patch.py @@ -0,0 +1,34 @@ +import fastmcp.server.low_level as low_level_module +import mcp.types +from functools import wraps +from mcp import McpError +from mcp.server.stdio import stdio_server as stdio_server +from mcp.shared.session import RequestResponder + + +original_receive_request = low_level_module.MiddlewareServerSession._received_request + + +@wraps(original_receive_request) +async def _received_request( + self, + responder: RequestResponder[mcp.types.ClientRequest, mcp.types.ServerResult], +): + """Monkey patch fastmcp so that the initialize error from the middleware can be send back to the client. + + https://github.com/jlowin/fastmcp/pull/2531 + """ + if isinstance(responder.request.root, mcp.types.InitializeRequest): + try: + return await original_receive_request(self, responder) + except McpError as e: + if not responder._completed: + with responder: + return await responder.respond(e.error) + + raise e + else: + return await original_receive_request(self, responder) + + +low_level_module.MiddlewareServerSession._received_request = _received_request diff --git a/mcp_proxy_for_aws/middleware/initialize_middleware.py b/mcp_proxy_for_aws/middleware/initialize_middleware.py index 06fa277..5d74def 100644 --- a/mcp_proxy_for_aws/middleware/initialize_middleware.py +++ b/mcp_proxy_for_aws/middleware/initialize_middleware.py @@ -25,6 +25,20 @@ async def on_initialize( try: logger.debug('Received initialize request %s.', context.message) self._client_factory.set_init_params(context.message) + client = await self._client_factory.get_client() + # connect the http client, fail and don't succeed the stdio connect + # if remote client cannot be connected + client_name = context.message.params.clientInfo.name.lower() + if 'kiro cli' not in client_name and 'q dev cli' not in client_name: + # q cli / kiro cli uses the rust SDK which does not handle json rpc error + # properly during initialization. + # https://github.com/modelcontextprotocol/rust-sdk/pull/569 + # if calling _connect below raise mcp error, the q cli will skip the message + # and continue wait for a json rpc response message which will never come. + # Luckily, q cli calls list tool immediately after being connected to a mcp server + # the list_tool call will require the client to be connected again, so the mcp error + # will be displayed in the q cli logs. + await client._connect() return await call_next(context) except Exception: logger.exception('Initialize failed in middleware.') diff --git a/tests/unit/test_fastmcp_patch.py b/tests/unit/test_fastmcp_patch.py new file mode 100644 index 0000000..acb50f0 --- /dev/null +++ b/tests/unit/test_fastmcp_patch.py @@ -0,0 +1,106 @@ +import mcp.types as mt +import pytest +from mcp import McpError +from mcp.shared.session import RequestResponder +from unittest.mock import AsyncMock, Mock, patch + + +@pytest.mark.asyncio +async def test_patched_received_request_initialize_success(): + """Test that patched _received_request calls original for successful initialize.""" + # Import after patching is applied + import fastmcp.server.low_level as low_level_module + from mcp_proxy_for_aws import fastmcp_patch + + mock_self = Mock() + mock_self.fastmcp = Mock() + + mock_request = Mock() + mock_request.root = Mock(spec=mt.InitializeRequest) + + mock_responder = Mock(spec=RequestResponder) + mock_responder.request = mock_request + + with patch.object( + fastmcp_patch, 'original_receive_request', new_callable=AsyncMock + ) as mock_original: + await low_level_module.MiddlewareServerSession._received_request(mock_self, mock_responder) + mock_original.assert_called_once_with(mock_self, mock_responder) + + +@pytest.mark.asyncio +async def test_patched_received_request_initialize_mcp_error_not_completed(): + """Test that patched _received_request handles McpError when responder not completed.""" + import fastmcp.server.low_level as low_level_module + from mcp_proxy_for_aws import fastmcp_patch + + mock_self = Mock() + mock_self.fastmcp = Mock() + + mock_request = Mock() + mock_request.root = Mock(spec=mt.InitializeRequest) + + mock_responder = Mock(spec=RequestResponder) + mock_responder.request = mock_request + mock_responder._completed = False + mock_responder.__enter__ = Mock(return_value=mock_responder) + mock_responder.__exit__ = Mock(return_value=False) + mock_responder.respond = AsyncMock() + + error = mt.ErrorData(code=1, message='test error') + mcp_error = McpError(error=error) + + with patch.object( + fastmcp_patch, 'original_receive_request', new_callable=AsyncMock, side_effect=mcp_error + ): + await low_level_module.MiddlewareServerSession._received_request(mock_self, mock_responder) + mock_responder.respond.assert_called_once_with(error) + + +@pytest.mark.asyncio +async def test_patched_received_request_initialize_mcp_error_completed(): + """Test that patched _received_request re-raises McpError when responder completed.""" + import fastmcp.server.low_level as low_level_module + from mcp_proxy_for_aws import fastmcp_patch + + mock_self = Mock() + mock_self.fastmcp = Mock() + + mock_request = Mock() + mock_request.root = Mock(spec=mt.InitializeRequest) + + mock_responder = Mock(spec=RequestResponder) + mock_responder.request = mock_request + mock_responder._completed = True + + error = mt.ErrorData(code=1, message='test error') + mcp_error = McpError(error=error) + + with patch.object( + fastmcp_patch, 'original_receive_request', new_callable=AsyncMock, side_effect=mcp_error + ): + with pytest.raises(McpError): + await low_level_module.MiddlewareServerSession._received_request( + mock_self, mock_responder + ) + + +@pytest.mark.asyncio +async def test_patched_received_request_non_initialize(): + """Test that patched _received_request calls original for non-initialize requests.""" + import fastmcp.server.low_level as low_level_module + from mcp_proxy_for_aws import fastmcp_patch + + mock_self = Mock() + + mock_request = Mock() + mock_request.root = Mock(spec=mt.CallToolRequest) + + mock_responder = Mock(spec=RequestResponder) + mock_responder.request = mock_request + + with patch.object( + fastmcp_patch, 'original_receive_request', new_callable=AsyncMock + ) as mock_original: + await low_level_module.MiddlewareServerSession._received_request(mock_self, mock_responder) + mock_original.assert_called_once_with(mock_self, mock_responder) diff --git a/tests/unit/test_initialize_middleware.py b/tests/unit/test_initialize_middleware.py new file mode 100644 index 0000000..d347b64 --- /dev/null +++ b/tests/unit/test_initialize_middleware.py @@ -0,0 +1,98 @@ +import mcp.types as mt +import pytest +from mcp_proxy_for_aws.middleware.initialize_middleware import InitializeMiddleware +from unittest.mock import AsyncMock, Mock + + +def create_initialize_request(client_name: str) -> mt.InitializeRequest: + """Create a real InitializeRequest object.""" + return mt.InitializeRequest( + method='initialize', + params=mt.InitializeRequestParams( + protocolVersion='2024-11-05', + capabilities=mt.ClientCapabilities(), + clientInfo=mt.Implementation(name=client_name, version='1.0'), + ), + ) + + +@pytest.mark.asyncio +async def test_on_initialize_connects_client(): + """Test that on_initialize calls client._connect().""" + mock_client = Mock() + mock_client._connect = AsyncMock() + + mock_factory = Mock() + mock_factory.set_init_params = Mock() + mock_factory.get_client = AsyncMock(return_value=mock_client) + + middleware = InitializeMiddleware(mock_factory) + + mock_context = Mock() + mock_context.message = create_initialize_request('test-client') + + mock_call_next = AsyncMock() + + await middleware.on_initialize(mock_context, mock_call_next) + + mock_factory.set_init_params.assert_called_once_with(mock_context.message) + mock_factory.get_client.assert_called_once() + mock_client._connect.assert_called_once() + mock_call_next.assert_called_once_with(mock_context) + + +@pytest.mark.asyncio +async def test_on_initialize_fails_if_connect_fails(): + """Test that on_initialize raises exception if _connect() fails.""" + mock_client = Mock() + mock_client._connect = AsyncMock(side_effect=Exception('Connection failed')) + + mock_factory = Mock() + mock_factory.set_init_params = Mock() + mock_factory.get_client = AsyncMock(return_value=mock_client) + + middleware = InitializeMiddleware(mock_factory) + + mock_context = Mock() + mock_context.message = create_initialize_request('test-client') + + mock_call_next = AsyncMock() + + with pytest.raises(Exception, match='Connection failed'): + await middleware.on_initialize(mock_context, mock_call_next) + + mock_call_next.assert_not_called() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'client_name', + [ + 'Kiro CLI', + 'kiro cli', + 'KIRO CLI', + 'Amazon Q Dev CLI', + 'amazon q dev cli', + 'Q DEV CLI', + ], +) +async def test_on_initialize_skips_connect_for_special_clients(client_name): + """Test that on_initialize skips _connect() for Kiro CLI and Q Dev CLI.""" + mock_client = Mock() + mock_client._connect = AsyncMock() + + mock_factory = Mock() + mock_factory.set_init_params = Mock() + mock_factory.get_client = AsyncMock(return_value=mock_client) + + middleware = InitializeMiddleware(mock_factory) + + mock_context = Mock() + mock_context.message = create_initialize_request(client_name) + + mock_call_next = AsyncMock() + + await middleware.on_initialize(mock_context, mock_call_next) + + mock_client._connect.assert_not_called() + mock_call_next.assert_called_once_with(mock_context)