From 00280172337d3204b7bc5d65f4a2219e92ae8963 Mon Sep 17 00:00:00 2001 From: hundao Date: Wed, 4 Feb 2026 17:13:59 +0800 Subject: [PATCH] feat(mcp): replace hand-rolled HTTP with SDK streamable_http_client Use the MCP SDK's streamable_http_client for spec-compliant Streamable HTTP transport, unifying both STDIO and HTTP paths to share the same ClientSession-based tool listing, invocation, and cleanup logic. - Remove httpx dependency and hand-rolled JSON-RPC calls - Auto-append /mcp to URLs without a path (FastMCP default) - Add unit tests (15) and integration tests (7) for the HTTP transport --- core/framework/runner/mcp_client.py | 254 +++++++++++----------- core/tests/test_mcp_client.py | 224 +++++++++++++++++++ core/tests/test_mcp_client_integration.py | 192 ++++++++++++++++ 3 files changed, 547 insertions(+), 123 deletions(-) create mode 100644 core/tests/test_mcp_client.py create mode 100644 core/tests/test_mcp_client_integration.py diff --git a/core/framework/runner/mcp_client.py b/core/framework/runner/mcp_client.py index faf7f68194..a506f9fc53 100644 --- a/core/framework/runner/mcp_client.py +++ b/core/framework/runner/mcp_client.py @@ -2,15 +2,18 @@ This module provides a client for connecting to MCP servers and invoking their tools. Supports both STDIO and HTTP transports using the official MCP Python SDK. + +STDIO uses the SDK's stdio_client, HTTP uses the SDK's streamable_http_client. +Both transports share the same ClientSession-based tool listing and invocation. """ import asyncio import logging import os +import threading from dataclasses import dataclass, field from typing import Any, Literal - -import httpx +from urllib.parse import urlparse logger = logging.getLogger(__name__) @@ -65,12 +68,13 @@ def __init__(self, config: MCPServerConfig): self._session = None self._read_stream = None self._write_stream = None - self._stdio_context = None # Context manager for stdio_client - self._http_client: httpx.Client | None = None + self._transport_context = None # Context manager for transport (stdio or http) + self._http_async_client = None # httpx.AsyncClient for HTTP transport + self._get_session_id = None # Session ID callback from streamable_http_client self._tools: dict[str, MCPTool] = {} self._connected = False - # Background event loop for persistent STDIO connection + # Background event loop for persistent connection (both transports) self._loop = None self._loop_thread = None @@ -84,7 +88,7 @@ def _run_async(self, coro): Returns: Result of the coroutine """ - # If we have a persistent loop (for STDIO), use it + # If we have a persistent loop (for STDIO or HTTP), use it if self._loop is not None: # Check if loop is running AND not closed if self._loop.is_running() and not self._loop.is_closed(): @@ -99,8 +103,6 @@ def _run_async(self, coro): asyncio.get_running_loop() # If we're here, we're in an async context # Create a new thread to run the coroutine - import threading - result = None exception = None @@ -149,8 +151,6 @@ def _connect_stdio(self) -> None: raise ValueError("command is required for STDIO transport") try: - import threading - from mcp import StdioServerParameters # Create server parameters @@ -184,11 +184,11 @@ async def init_connection(): from mcp.client.stdio import stdio_client # Create persistent stdio client context - self._stdio_context = stdio_client(server_params) + self._transport_context = stdio_client(server_params) ( self._read_stream, self._write_stream, - ) = await self._stdio_context.__aenter__() + ) = await self._transport_context.__aenter__() # Create persistent session self._session = ClientSession(self._read_stream, self._write_stream) @@ -226,34 +226,92 @@ async def init_connection(): raise RuntimeError(f"Failed to connect to MCP server: {e}") from e def _connect_http(self) -> None: - """Connect to MCP server via HTTP transport.""" + """Connect to MCP server via HTTP transport using MCP SDK with persistent connection. + + Uses the SDK's streamable_http_client for spec-compliant Streamable HTTP transport. + This follows the same background event loop pattern as _connect_stdio(). + """ if not self.config.url: raise ValueError("url is required for HTTP transport") - self._http_client = httpx.Client( - base_url=self.config.url, - headers=self.config.headers, - timeout=30.0, - ) + # Ensure URL includes the MCP endpoint path. + # FastMCP defaults to /mcp for Streamable HTTP transport. + parsed = urlparse(self.config.url) + if not parsed.path or parsed.path == "/": + url = self.config.url.rstrip("/") + "/mcp" + else: + url = self.config.url - # Test connection - try: - response = self._http_client.get("/health") - response.raise_for_status() - logger.info( - f"Connected to MCP server '{self.config.name}' via HTTP at {self.config.url}" - ) - except Exception as e: - logger.warning(f"Health check failed for MCP server '{self.config.name}': {e}") - # Continue anyway, server might not have health endpoint + # Start background event loop for persistent connection + loop_started = threading.Event() + connection_ready = threading.Event() + connection_error = [] + + def run_event_loop(): + """Run event loop in background thread.""" + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + loop_started.set() + + async def init_connection(): + try: + from mcp import ClientSession + from mcp.client.streamable_http import streamable_http_client + from mcp.shared._httpx_utils import create_mcp_http_client + + # Create httpx.AsyncClient with custom headers for auth + self._http_async_client = create_mcp_http_client( + headers=self.config.headers or None, + ) + + # Create persistent Streamable HTTP client context + self._transport_context = streamable_http_client( + url, + http_client=self._http_async_client, + ) + ( + self._read_stream, + self._write_stream, + self._get_session_id, + ) = await self._transport_context.__aenter__() + + # Create persistent session + self._session = ClientSession(self._read_stream, self._write_stream) + await self._session.__aenter__() + + # Initialize session (MCP spec handshake) + await self._session.initialize() + + connection_ready.set() + except Exception as e: + connection_error.append(e) + connection_ready.set() + + self._loop.create_task(init_connection()) + self._loop.run_forever() + + self._loop_thread = threading.Thread(target=run_event_loop, daemon=True) + self._loop_thread.start() + + # Wait for loop to start + loop_started.wait(timeout=5) + if not loop_started.is_set(): + raise RuntimeError("Event loop failed to start") + + # Wait for connection to be ready (HTTP may need longer than STDIO) + connection_ready.wait(timeout=30) + if connection_error: + raise RuntimeError( + f"Failed to connect to MCP server '{self.config.name}' " + f"via HTTP at {url}: {connection_error[0]}" + ) from connection_error[0] + + logger.info(f"Connected to MCP server '{self.config.name}' via HTTP at {url} (persistent)") def _discover_tools(self) -> None: """Discover available tools from the MCP server.""" try: - if self.config.transport == "stdio": - tools_list = self._run_async(self._list_tools_stdio_async()) - else: - tools_list = self._list_tools_http() + tools_list = self._run_async(self._list_tools_async()) self._tools = {} for tool_data in tools_list: @@ -273,10 +331,10 @@ def _discover_tools(self) -> None: logger.error(f"Failed to discover tools from '{self.config.name}': {e}") raise - async def _list_tools_stdio_async(self) -> list[dict]: - """List tools via STDIO protocol using persistent session.""" + async def _list_tools_async(self) -> list[dict]: + """List tools via persistent MCP session (works for both transports).""" if not self._session: - raise RuntimeError("STDIO session not initialized") + raise RuntimeError("MCP session not initialized") # List tools using persistent session response = await self._session.list_tools() @@ -294,32 +352,6 @@ async def _list_tools_stdio_async(self) -> list[dict]: return tools_list - def _list_tools_http(self) -> list[dict]: - """List tools via HTTP protocol.""" - if not self._http_client: - raise RuntimeError("HTTP client not initialized") - - try: - # Use MCP over HTTP protocol - response = self._http_client.post( - "/mcp/v1", - json={ - "jsonrpc": "2.0", - "id": 1, - "method": "tools/list", - "params": {}, - }, - ) - response.raise_for_status() - data = response.json() - - if "error" in data: - raise RuntimeError(f"MCP error: {data['error']}") - - return data.get("result", {}).get("tools", []) - except Exception as e: - raise RuntimeError(f"Failed to list tools via HTTP: {e}") from e - def list_tools(self) -> list[MCPTool]: """ Get list of available tools. @@ -349,15 +381,12 @@ def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any: if tool_name not in self._tools: raise ValueError(f"Unknown tool: {tool_name}") - if self.config.transport == "stdio": - return self._run_async(self._call_tool_stdio_async(tool_name, arguments)) - else: - return self._call_tool_http(tool_name, arguments) + return self._run_async(self._call_tool_async(tool_name, arguments)) - async def _call_tool_stdio_async(self, tool_name: str, arguments: dict[str, Any]) -> Any: - """Call tool via STDIO protocol using persistent session.""" + async def _call_tool_async(self, tool_name: str, arguments: dict[str, Any]) -> Any: + """Call tool via persistent MCP session (works for both transports).""" if not self._session: - raise RuntimeError("STDIO session not initialized") + raise RuntimeError("MCP session not initialized") # Call tool using persistent session result = await self._session.call_tool(tool_name, arguments=arguments) @@ -376,49 +405,21 @@ async def _call_tool_stdio_async(self, tool_name: str, arguments: dict[str, Any] return None - def _call_tool_http(self, tool_name: str, arguments: dict[str, Any]) -> Any: - """Call tool via HTTP protocol.""" - if not self._http_client: - raise RuntimeError("HTTP client not initialized") - - try: - response = self._http_client.post( - "/mcp/v1", - json={ - "jsonrpc": "2.0", - "id": 2, - "method": "tools/call", - "params": { - "name": tool_name, - "arguments": arguments, - }, - }, - ) - response.raise_for_status() - data = response.json() - - if "error" in data: - raise RuntimeError(f"Tool execution error: {data['error']}") - - return data.get("result", {}).get("content", []) - except Exception as e: - raise RuntimeError(f"Failed to call tool via HTTP: {e}") from e - _CLEANUP_TIMEOUT = 10 _THREAD_JOIN_TIMEOUT = 12 - async def _cleanup_stdio_async(self) -> None: - """Async cleanup for STDIO session and context managers. + async def _cleanup_async(self) -> None: + """Async cleanup for MCP session and transport context managers. Cleanup order is critical: - - The session must be closed BEFORE the stdio_context because the session - depends on the streams provided by stdio_context. - - This mirrors the initialization order in _connect_stdio(), where - stdio_context is entered first (providing streams), then the session is - created with those streams and entered. + - The session must be closed BEFORE the transport_context because the session + depends on the streams provided by the transport context. + - This mirrors the initialization order in _connect_stdio()/_connect_http(), + where the transport context is entered first (providing streams), then the + session is created with those streams and entered. - Do not change this ordering without carefully considering these dependencies. """ - # First: close session (depends on stdio_context streams) + # First: close session (depends on transport_context streams) try: if self._session: await self._session.__aexit__(None, None, None) @@ -431,22 +432,32 @@ async def _cleanup_stdio_async(self) -> None: finally: self._session = None - # Second: close stdio_context (provides the underlying streams) + # Second: close transport context (provides the underlying streams) try: - if self._stdio_context: - await self._stdio_context.__aexit__(None, None, None) + if self._transport_context: + await self._transport_context.__aexit__(None, None, None) except asyncio.CancelledError: logger.warning( - "STDIO context cleanup was cancelled; proceeding with best-effort shutdown" + "Transport context cleanup was cancelled; proceeding with best-effort shutdown" ) except Exception as e: - logger.warning(f"Error closing STDIO context: {e}") + logger.warning(f"Error closing transport context: {e}") + finally: + self._transport_context = None + + # Third: close the httpx.AsyncClient if it was created for HTTP transport + # The SDK does NOT auto-close a provided http_client, so we must do it. + try: + if self._http_async_client: + await self._http_async_client.aclose() + except Exception as e: + logger.warning(f"Error closing HTTP async client: {e}") finally: - self._stdio_context = None + self._http_async_client = None def disconnect(self) -> None: """Disconnect from the MCP server.""" - # Clean up persistent STDIO connection + # Clean up persistent connection (both STDIO and HTTP use background event loop) if self._loop is not None: cleanup_attempted = False @@ -457,7 +468,7 @@ def disconnect(self) -> None: if self._loop.is_running(): try: cleanup_future = asyncio.run_coroutine_threadsafe( - self._cleanup_stdio_async(), self._loop + self._cleanup_async(), self._loop ) cleanup_future.result(timeout=self._CLEANUP_TIMEOUT) cleanup_attempted = True @@ -470,7 +481,7 @@ def disconnect(self) -> None: cleanup_attempted = True logger.debug(f"Event loop stopped during async cleanup: {e}") except Exception as e: - # Cleanup was attempted but failed (e.g., error in _cleanup_stdio_async()) + # Cleanup was attempted but failed (e.g., error in _cleanup_async()) cleanup_attempted = True logger.warning(f"Error during async cleanup: {e}") @@ -484,11 +495,11 @@ def disconnect(self) -> None: if not cleanup_attempted: # Fallback: loop exists but is not running (e.g., crashed or stopped externally). # At this point the loop and associated resources are in an undefined state. - # The context managers (_session, _stdio_context) were created in the loop's + # The context managers (_session, _transport_context) were created in the loop's # thread and may not be safely cleanable from here. Just log and proceed # with reference clearing - the OS will reclaim resources on process exit. logger.warning( - "Event loop for STDIO MCP connection exists but is not running; " + "Event loop for MCP connection exists but is not running; " "skipping async cleanup. Resources may not be fully released." ) @@ -497,29 +508,26 @@ def disconnect(self) -> None: self._loop_thread.join(timeout=self._THREAD_JOIN_TIMEOUT) if self._loop_thread.is_alive(): logger.warning( - "Event loop thread for STDIO MCP connection did not terminate " + "Event loop thread for MCP connection did not terminate " f"within {self._THREAD_JOIN_TIMEOUT}s; thread may still be running." ) # Clear remaining references - # Note: _session and _stdio_context may already be None if _cleanup_stdio_async() + # Note: _session and _transport_context may already be None if _cleanup_async() # succeeded. This redundant assignment is intentional for safety in cases where: # 1. Cleanup timed out or failed # 2. Cleanup was skipped (loop not running) # 3. CancelledError interrupted cleanup # Setting None to None is safe and ensures clean state. self._session = None - self._stdio_context = None + self._transport_context = None + self._http_async_client = None + self._get_session_id = None self._read_stream = None self._write_stream = None self._loop = None self._loop_thread = None - # Clean up HTTP client - if self._http_client: - self._http_client.close() - self._http_client = None - self._connected = False logger.info(f"Disconnected from MCP server '{self.config.name}'") diff --git a/core/tests/test_mcp_client.py b/core/tests/test_mcp_client.py new file mode 100644 index 0000000000..70fa001be8 --- /dev/null +++ b/core/tests/test_mcp_client.py @@ -0,0 +1,224 @@ +"""Unit tests for the MCP client. + +Tests the MCPClient's URL handling, transport unification, and cleanup logic +using mocks (no real MCP server required). +""" + +import asyncio +from unittest.mock import MagicMock, patch + +import pytest + +from framework.runner.mcp_client import MCPClient, MCPServerConfig + + +class TestURLPathHandling: + """Tests for HTTP URL path auto-append logic.""" + + def test_url_without_path_gets_mcp_appended(self): + """http://localhost:4001 → http://localhost:4001/mcp""" + config = MCPServerConfig(name="test", transport="http", url="http://localhost:4001") + MCPClient(config) + + # Extract the URL logic by checking what _connect_http would compute + from urllib.parse import urlparse + + parsed = urlparse(config.url) + if not parsed.path or parsed.path == "/": + url = config.url.rstrip("/") + "/mcp" + else: + url = config.url + + assert url == "http://localhost:4001/mcp" + + def test_url_with_trailing_slash_gets_mcp_appended(self): + """http://localhost:4001/ → http://localhost:4001/mcp""" + config = MCPServerConfig(name="test", transport="http", url="http://localhost:4001/") + + from urllib.parse import urlparse + + parsed = urlparse(config.url) + if not parsed.path or parsed.path == "/": + url = config.url.rstrip("/") + "/mcp" + else: + url = config.url + + assert url == "http://localhost:4001/mcp" + + def test_url_with_custom_path_preserved(self): + """http://remote:8080/custom/endpoint stays as-is.""" + config = MCPServerConfig( + name="test", transport="http", url="http://remote:8080/custom/endpoint" + ) + + from urllib.parse import urlparse + + parsed = urlparse(config.url) + if not parsed.path or parsed.path == "/": + url = config.url.rstrip("/") + "/mcp" + else: + url = config.url + + assert url == "http://remote:8080/custom/endpoint" + + def test_url_with_mcp_path_preserved(self): + """http://localhost:4001/mcp stays as-is.""" + config = MCPServerConfig(name="test", transport="http", url="http://localhost:4001/mcp") + + from urllib.parse import urlparse + + parsed = urlparse(config.url) + if not parsed.path or parsed.path == "/": + url = config.url.rstrip("/") + "/mcp" + else: + url = config.url + + assert url == "http://localhost:4001/mcp" + + +class TestHTTPConnectionValidation: + """Tests for HTTP connection validation.""" + + def test_connect_http_requires_url(self): + """HTTP transport without URL should raise ValueError.""" + config = MCPServerConfig(name="test", transport="http", url=None) + client = MCPClient(config) + + with pytest.raises(ValueError, match="url is required"): + client._connect_http() + + def test_connect_stdio_requires_command(self): + """STDIO transport without command should raise ValueError.""" + config = MCPServerConfig(name="test", transport="stdio", command=None) + client = MCPClient(config) + + with pytest.raises(ValueError, match="command is required"): + client._connect_stdio() + + def test_unsupported_transport_raises(self): + """Unsupported transport should raise ValueError.""" + config = MCPServerConfig(name="test", transport="grpc") + client = MCPClient(config) + + with pytest.raises(ValueError, match="Unsupported transport"): + client.connect() + + +class TestUnifiedInterface: + """Tests that both transports share the same session-based methods.""" + + def test_call_tool_routes_through_async(self): + """call_tool should use _call_tool_async for any transport.""" + config = MCPServerConfig(name="test", transport="http", url="http://localhost:4001") + client = MCPClient(config) + client._connected = True + client._tools = {"echo": MagicMock()} + + mock_result = "hello" + with patch.object(client, "_run_async", return_value=mock_result) as mock_run: + result = client.call_tool("echo", {"message": "hello"}) + + assert result == mock_result + mock_run.assert_called_once() + + def test_call_tool_unknown_tool_raises(self): + """Calling an unknown tool should raise ValueError.""" + config = MCPServerConfig(name="test", transport="http", url="http://localhost:4001") + client = MCPClient(config) + client._connected = True + client._tools = {} + + with pytest.raises(ValueError, match="Unknown tool"): + client.call_tool("nonexistent", {}) + + +class TestCleanup: + """Tests for cleanup and disconnect logic.""" + + def test_disconnect_clears_all_references(self): + """After disconnect, all internal references should be None.""" + config = MCPServerConfig(name="test", transport="http", url="http://localhost:4001") + client = MCPClient(config) + + # Simulate a connected state without actually connecting + asyncio.new_event_loop() + + async def noop(): + pass + + # Set up state as if connected + client._session = MagicMock() + client._transport_context = MagicMock() + client._http_async_client = MagicMock() + client._get_session_id = MagicMock() + client._read_stream = MagicMock() + client._write_stream = MagicMock() + client._connected = True + # Don't set _loop to avoid triggering real async cleanup + # Just test the non-loop cleanup path + + client.disconnect() + + assert client._connected is False + + def test_context_manager(self): + """MCPClient should support context manager protocol.""" + config = MCPServerConfig(name="test", transport="http", url="http://localhost:4001") + client = MCPClient(config) + + with ( + patch.object(client, "connect") as mock_connect, + patch.object(client, "disconnect") as mock_disconnect, + ): + with client: + mock_connect.assert_called_once() + mock_disconnect.assert_called_once() + + def test_double_connect_is_noop(self): + """Calling connect() when already connected should be a no-op.""" + config = MCPServerConfig(name="test", transport="http", url="http://localhost:4001") + client = MCPClient(config) + client._connected = True + + with patch.object(client, "_connect_http") as mock_connect_http: + client.connect() + mock_connect_http.assert_not_called() + + +class TestMCPServerConfig: + """Tests for the MCPServerConfig dataclass.""" + + def test_stdio_config(self): + config = MCPServerConfig( + name="tools", + transport="stdio", + command="python", + args=["mcp_server.py", "--stdio"], + env={"KEY": "value"}, + cwd="/tmp", + ) + assert config.name == "tools" + assert config.transport == "stdio" + assert config.command == "python" + assert config.args == ["mcp_server.py", "--stdio"] + assert config.env == {"KEY": "value"} + assert config.cwd == "/tmp" + + def test_http_config(self): + config = MCPServerConfig( + name="remote", + transport="http", + url="http://localhost:4001", + headers={"Authorization": "Bearer token"}, + ) + assert config.name == "remote" + assert config.transport == "http" + assert config.url == "http://localhost:4001" + assert config.headers == {"Authorization": "Bearer token"} + + def test_defaults(self): + config = MCPServerConfig(name="test", transport="stdio") + assert config.args == [] + assert config.env == {} + assert config.headers == {} + assert config.description == "" diff --git a/core/tests/test_mcp_client_integration.py b/core/tests/test_mcp_client_integration.py new file mode 100644 index 0000000000..49db51272c --- /dev/null +++ b/core/tests/test_mcp_client_integration.py @@ -0,0 +1,192 @@ +"""Integration tests for the MCP client over HTTP transport. + +These tests start a real FastMCP server in-process and connect to it +via the Streamable HTTP transport to verify end-to-end functionality. + +Requires: mcp, fastmcp packages installed. +""" + +import socket +import threading +import time + +import pytest +import uvicorn +from anyio import ClosedResourceError + +from framework.runner.mcp_client import MCPClient, MCPServerConfig + + +def _find_free_port() -> int: + """Find a free TCP port on localhost.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +@pytest.fixture +def mcp_http_server(): + """Start a FastMCP server with test tools on a random port. + + Yields the base URL (e.g. http://127.0.0.1:PORT). + Server runs in a daemon thread and dies with the test process. + """ + from fastmcp import FastMCP + + mcp = FastMCP("test-tools") + + @mcp.tool() + def echo(message: str) -> str: + """Echo a message back.""" + return f"echo: {message}" + + @mcp.tool() + def add(a: int, b: int) -> str: + """Add two numbers.""" + return str(a + b) + + port = _find_free_port() + + # Use uvicorn directly for more control over server lifecycle + app = mcp.http_app() + config = uvicorn.Config( + app, + host="127.0.0.1", + port=port, + log_level="warning", + ) + server = uvicorn.Server(config) + + thread = threading.Thread(target=server.run, daemon=True) + thread.start() + + # Wait for server to be ready + for _ in range(50): + try: + with socket.create_connection(("127.0.0.1", port), timeout=0.1): + break + except (ConnectionRefusedError, OSError): + time.sleep(0.1) + else: + pytest.fail(f"Server failed to start on port {port}") + + yield f"http://127.0.0.1:{port}" + + server.should_exit = True + + +class TestHTTPRoundtrip: + """Test full connect → discover → call → disconnect cycle over HTTP.""" + + def test_connect_and_list_tools(self, mcp_http_server): + """Connect to HTTP MCP server and discover tools.""" + config = MCPServerConfig( + name="test", + transport="http", + url=mcp_http_server, + ) + with MCPClient(config) as client: + tools = client.list_tools() + tool_names = {t.name for t in tools} + assert "echo" in tool_names + assert "add" in tool_names + + def test_call_echo_tool(self, mcp_http_server): + """Call the echo tool and verify the response.""" + config = MCPServerConfig( + name="test", + transport="http", + url=mcp_http_server, + ) + with MCPClient(config) as client: + result = client.call_tool("echo", {"message": "hello world"}) + assert result == "echo: hello world" + + def test_call_add_tool(self, mcp_http_server): + """Call the add tool and verify the response.""" + config = MCPServerConfig( + name="test", + transport="http", + url=mcp_http_server, + ) + with MCPClient(config) as client: + result = client.call_tool("add", {"a": 3, "b": 7}) + assert result == "10" + + def test_multiple_tool_calls(self, mcp_http_server): + """Call multiple tools in sequence on the same persistent connection.""" + config = MCPServerConfig( + name="test", + transport="http", + url=mcp_http_server, + ) + with MCPClient(config) as client: + r1 = client.call_tool("echo", {"message": "first"}) + r2 = client.call_tool("add", {"a": 1, "b": 2}) + r3 = client.call_tool("echo", {"message": "third"}) + + assert r1 == "echo: first" + assert r2 == "3" + assert r3 == "echo: third" + + +class TestHTTPWithHeaders: + """Test that custom headers are passed through.""" + + def test_custom_headers_dont_break_connection(self, mcp_http_server): + """Custom headers should be sent without breaking the connection.""" + config = MCPServerConfig( + name="test", + transport="http", + url=mcp_http_server, + headers={"X-Custom-Header": "test-value"}, + ) + with MCPClient(config) as client: + tools = client.list_tools() + assert len(tools) > 0 + + +class TestHTTPDisconnectReconnect: + """Test disconnect and reconnect cycle.""" + + def test_disconnect_and_reconnect(self, mcp_http_server): + """Client should be able to disconnect and reconnect.""" + config = MCPServerConfig( + name="test", + transport="http", + url=mcp_http_server, + ) + + client = MCPClient(config) + client.connect() + tools_first = client.list_tools() + assert len(tools_first) > 0 + + client.disconnect() + assert client._connected is False + + # Reconnect + client.connect() + tools_second = client.list_tools() + assert len(tools_second) == len(tools_first) + + client.disconnect() + + +class TestHTTPConnectionFailure: + """Test behavior when server is unreachable.""" + + def test_connection_to_unreachable_server(self): + """Connecting to unreachable server should raise an error.""" + port = _find_free_port() + config = MCPServerConfig( + name="test", + transport="http", + url=f"http://127.0.0.1:{port}", + ) + client = MCPClient(config) + + # The error may surface during connection init or tool discovery, + # depending on when the transport detects the unreachable server. + with pytest.raises((RuntimeError, OSError, ClosedResourceError)): + client.connect()