Skip to content

Commit ae40798

Browse files
committed
brave running
1 parent e7b85cd commit ae40798

9 files changed

+1067
-339
lines changed

mcp_testing/adapters/base.py

+199
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
#!/usr/bin/env python3
2+
# SPDX-License-Identifier: AGPL-3.0-or-later
3+
4+
"""
5+
Base server adapter for MCP testing.
6+
7+
This module defines the base server adapter class that will be extended
8+
by specific server implementations.
9+
"""
10+
11+
import asyncio
12+
import json
13+
import logging
14+
from abc import ABC, abstractmethod
15+
from typing import Dict, Any, List, Optional, Tuple, Union
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
class MCPServerAdapter(ABC):
21+
"""
22+
Base class for MCP server adapters.
23+
24+
Server adapters are responsible for starting, communicating with, and stopping
25+
MCP servers during testing. This abstract base class defines the interface that
26+
all server adapters must implement.
27+
"""
28+
29+
def __init__(self, protocol_version: str, debug: bool = False):
30+
"""
31+
Initialize a server adapter.
32+
33+
Args:
34+
protocol_version: The MCP protocol version to use
35+
debug: Whether to enable debug logging
36+
"""
37+
self.protocol_version = protocol_version
38+
self.debug = debug
39+
self.server_info = None
40+
self._request_id = 0
41+
42+
@abstractmethod
43+
async def start(self) -> bool:
44+
"""
45+
Start the server.
46+
47+
Returns:
48+
True if started successfully, False otherwise
49+
"""
50+
pass
51+
52+
@abstractmethod
53+
async def stop(self) -> bool:
54+
"""
55+
Stop the server.
56+
57+
Returns:
58+
True if stopped successfully, False otherwise
59+
"""
60+
pass
61+
62+
@abstractmethod
63+
async def send_request(self, method: str, params: Dict[str, Any] = None) -> Dict[str, Any]:
64+
"""
65+
Send a request to the server and wait for a response.
66+
67+
Args:
68+
method: The JSON-RPC method name
69+
params: The method parameters
70+
71+
Returns:
72+
The server's response
73+
74+
Raises:
75+
RuntimeError: If the server is not started or the request fails
76+
"""
77+
pass
78+
79+
@abstractmethod
80+
async def send_notification(self, method: str, params: Dict[str, Any] = None) -> None:
81+
"""
82+
Send a notification to the server (no response expected).
83+
84+
Args:
85+
method: The JSON-RPC method name
86+
params: The method parameters
87+
88+
Raises:
89+
RuntimeError: If the server is not started or the notification fails
90+
"""
91+
pass
92+
93+
async def initialize(self) -> Dict[str, Any]:
94+
"""
95+
Initialize the server.
96+
97+
This sends the standard initialize request to the server.
98+
99+
Returns:
100+
The server's initialization response
101+
102+
Raises:
103+
RuntimeError: If initialization fails
104+
"""
105+
params = {
106+
"protocolVersion": self.protocol_version,
107+
"options": {}
108+
}
109+
110+
response = await self.send_request("initialize", params)
111+
112+
if "error" in response:
113+
error_msg = response.get("error", {}).get("message", "Unknown error")
114+
logger.error(f"Server initialization failed: {error_msg}")
115+
raise RuntimeError(f"Failed to initialize server: {error_msg}")
116+
117+
if "result" not in response:
118+
raise RuntimeError("Invalid initialize response, missing 'result' field")
119+
120+
self.server_info = response["result"]
121+
return response
122+
123+
async def shutdown(self) -> Optional[Dict[str, Any]]:
124+
"""
125+
Send a shutdown request to the server.
126+
127+
Returns:
128+
The server's shutdown response, or None if the server doesn't support shutdown
129+
"""
130+
try:
131+
response = await self.send_request("shutdown", {})
132+
await self.send_notification("exit")
133+
return response
134+
except Exception as e:
135+
logger.warning(f"Failed to shut down server: {str(e)}")
136+
return None
137+
138+
async def list_tools(self) -> List[Dict[str, Any]]:
139+
"""
140+
Get the list of available tools from the server.
141+
142+
Returns:
143+
A list of tool definitions
144+
145+
Raises:
146+
RuntimeError: If the request fails
147+
"""
148+
response = await self.send_request("listTools", {})
149+
150+
if "error" in response:
151+
error_msg = response.get("error", {}).get("message", "Unknown error")
152+
raise RuntimeError(f"Failed to list tools: {error_msg}")
153+
154+
if "result" not in response:
155+
raise RuntimeError("Invalid listTools response, missing 'result' field")
156+
157+
return response["result"]
158+
159+
async def call_tool(self, name: str, params: Dict[str, Any]) -> Dict[str, Any]:
160+
"""
161+
Call a tool on the server.
162+
163+
Args:
164+
name: The name of the tool to call
165+
params: The tool parameters
166+
167+
Returns:
168+
The tool's response
169+
170+
Raises:
171+
RuntimeError: If the tool call fails
172+
"""
173+
request_params = {
174+
"name": name,
175+
"params": params
176+
}
177+
178+
response = await self.send_request("callTool", request_params)
179+
180+
if "error" in response:
181+
error = response.get("error", {})
182+
error_msg = error.get("message", "Unknown error")
183+
logger.error(f"Tool call failed: {error_msg}")
184+
return response
185+
186+
if "result" not in response:
187+
logger.error("Invalid callTool response, missing 'result' field")
188+
189+
return response
190+
191+
def _get_next_request_id(self) -> int:
192+
"""
193+
Get the next request ID.
194+
195+
Returns:
196+
The next request ID
197+
"""
198+
self._request_id += 1
199+
return self._request_id

0 commit comments

Comments
 (0)