Skip to content

Commit df9d081

Browse files
committed
Add idle timeout termination for StreamableHTTPServerTransport
1 parent 6566c08 commit df9d081

File tree

3 files changed

+150
-25
lines changed

3 files changed

+150
-25
lines changed

src/mcp/server/streamable_http.py

Lines changed: 61 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from contextlib import asynccontextmanager
1616
from dataclasses import dataclass
1717
from http import HTTPStatus
18+
from types import TracebackType
1819

1920
import anyio
2021
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
@@ -23,6 +24,7 @@
2324
from starlette.requests import Request
2425
from starlette.responses import Response
2526
from starlette.types import Receive, Scope, Send
27+
from typing_extensions import Self
2628

2729
from mcp.server.transport_security import (
2830
TransportSecurityMiddleware,
@@ -140,6 +142,7 @@ def __init__(
140142
is_json_response_enabled: bool = False,
141143
event_store: EventStore | None = None,
142144
security_settings: TransportSecuritySettings | None = None,
145+
timeout: float | None = None,
143146
) -> None:
144147
"""
145148
Initialize a new StreamableHTTP server transport.
@@ -153,6 +156,9 @@ def __init__(
153156
resumability will be enabled, allowing clients to
154157
reconnect and resume messages.
155158
security_settings: Optional security settings for DNS rebinding protection.
159+
timeout: Optional idle timeout for transport. If provided, the transport will
160+
terminate if it remains idle for longer than the defined timeout
161+
duration in seconds.
156162
157163
Raises:
158164
ValueError: If the session ID contains invalid characters.
@@ -172,6 +178,12 @@ def __init__(
172178
],
173179
] = {}
174180
self._terminated = False
181+
self._timeout = timeout
182+
183+
# for idle detection
184+
self._processing_request_count = 0
185+
self._idle_condition = anyio.Condition()
186+
self._has_request = False
175187

176188
@property
177189
def is_terminated(self) -> bool:
@@ -626,6 +638,9 @@ async def terminate(self) -> None:
626638
Once terminated, all requests with this session ID will receive 404 Not Found.
627639
"""
628640

641+
if self._terminated:
642+
return
643+
629644
self._terminated = True
630645
logger.info(f"Terminating session: {self.mcp_session_id}")
631646

@@ -796,6 +811,42 @@ async def send_event(event_message: EventMessage) -> None:
796811
)
797812
await response(request.scope, request.receive, send)
798813

814+
async def __aenter__(self) -> Self:
815+
async with self._idle_condition:
816+
self._processing_request_count += 1
817+
self._has_request = True
818+
return self
819+
820+
async def __aexit__(
821+
self,
822+
exc_type: type[BaseException] | None,
823+
exc_value: BaseException | None,
824+
traceback: TracebackType | None,
825+
) -> None:
826+
async with self._idle_condition:
827+
self._processing_request_count -= 1
828+
if self._processing_request_count == 0:
829+
self._idle_condition.notify_all()
830+
831+
async def _idle_timeout_terminate(self, timeout: float) -> None:
832+
"""
833+
Terminate the transport if it remains idle for longer than the defined timeout duration.
834+
"""
835+
while not self._terminated:
836+
# wait for transport to be idle
837+
async with self._idle_condition:
838+
if self._processing_request_count > 0:
839+
await self._idle_condition.wait()
840+
self._has_request = False
841+
842+
# wait for idle timeout
843+
await anyio.sleep(timeout)
844+
845+
# If there are no requests during the wait period, terminate the transport
846+
if not self._has_request:
847+
logger.debug(f"Terminating transport due to idle timeout: {self.mcp_session_id}")
848+
await self.terminate()
849+
799850
@asynccontextmanager
800851
async def connect(
801852
self,
@@ -812,6 +863,10 @@ async def connect(
812863
Tuple of (read_stream, write_stream) for bidirectional communication
813864
"""
814865

866+
# Terminated transports should not be connected again
867+
if self._terminated:
868+
raise RuntimeError("Transport is terminated")
869+
815870
# Create the memory streams for this connection
816871

817872
read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0)
@@ -884,20 +939,13 @@ async def message_router():
884939
# Start the message router
885940
tg.start_soon(message_router)
886941

942+
# Start idle timeout task if timeout is set
943+
if self._timeout is not None:
944+
tg.start_soon(self._idle_timeout_terminate, self._timeout)
945+
887946
try:
888947
# Yield the streams for the caller to use
889948
yield read_stream, write_stream
890949
finally:
891-
for stream_id in list(self._request_streams.keys()):
892-
await self._clean_up_memory_streams(stream_id)
893-
self._request_streams.clear()
894-
895-
# Clean up the read and write streams
896-
try:
897-
await read_stream_writer.aclose()
898-
await read_stream.aclose()
899-
await write_stream_reader.aclose()
900-
await write_stream.aclose()
901-
except Exception as e:
902-
# During cleanup, we catch all exceptions since streams might be in various states
903-
logger.debug(f"Error closing streams: {e}")
950+
# Terminate the transport when the context manager exits
951+
await self.terminate()

src/mcp/server/streamable_http_manager.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ class StreamableHTTPSessionManager:
5252
json_response: Whether to use JSON responses instead of SSE streams
5353
stateless: If True, creates a completely fresh transport for each request
5454
with no session tracking or state persistence between requests.
55+
timeout: Optional idle timeout for the stateful transport. If specified,
56+
the stateful transport will terminate if it remains idle for longer
57+
than the defined timeout duration in seconds.
5558
"""
5659

5760
def __init__(
@@ -60,12 +63,14 @@ def __init__(
6063
event_store: EventStore | None = None,
6164
json_response: bool = False,
6265
stateless: bool = False,
66+
timeout: float | None = None,
6367
security_settings: TransportSecuritySettings | None = None,
6468
):
6569
self.app = app
6670
self.event_store = event_store
6771
self.json_response = json_response
6872
self.stateless = stateless
73+
self.timeout = timeout
6974
self.security_settings = security_settings
7075

7176
# Session tracking (only used if not stateless)
@@ -187,11 +192,12 @@ async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STA
187192
# Start the server task
188193
await self._task_group.start(run_stateless_server)
189194

190-
# Handle the HTTP request and return the response
191-
await http_transport.handle_request(scope, receive, send)
192-
193-
# Terminate the transport after the request is handled
194-
await http_transport.terminate()
195+
try:
196+
# Handle the HTTP request and return the response
197+
await http_transport.handle_request(scope, receive, send)
198+
finally:
199+
# Terminate the transport after the request is handled
200+
await http_transport.terminate()
195201

196202
async def _handle_stateful_request(
197203
self,
@@ -214,7 +220,8 @@ async def _handle_stateful_request(
214220
if request_mcp_session_id is not None and request_mcp_session_id in self._server_instances:
215221
transport = self._server_instances[request_mcp_session_id]
216222
logger.debug("Session already exists, handling request directly")
217-
await transport.handle_request(scope, receive, send)
223+
async with transport:
224+
await transport.handle_request(scope, receive, send)
218225
return
219226

220227
if request_mcp_session_id is None:
@@ -227,6 +234,7 @@ async def _handle_stateful_request(
227234
is_json_response_enabled=self.json_response,
228235
event_store=self.event_store, # May be None (no resumability)
229236
security_settings=self.security_settings,
237+
timeout=self.timeout,
230238
)
231239

232240
assert http_transport.mcp_session_id is not None
@@ -251,11 +259,11 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE
251259
exc_info=True,
252260
)
253261
finally:
254-
# Only remove from instances if not terminated
262+
# remove from instances, we do not need to terminate the transport
263+
# as it will be terminated when the context manager exits
255264
if (
256265
http_transport.mcp_session_id
257266
and http_transport.mcp_session_id in self._server_instances
258-
and not http_transport.is_terminated
259267
):
260268
logger.info(
261269
"Cleaning up crashed session "
@@ -270,11 +278,13 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE
270278
await self._task_group.start(run_server)
271279

272280
# Handle the HTTP request and return the response
273-
await http_transport.handle_request(scope, receive, send)
281+
async with http_transport:
282+
await http_transport.handle_request(scope, receive, send)
274283
else:
275-
# Invalid session ID
284+
# Client may send a outdated session ID
285+
# We should return 404 to notify the client to start a new session
276286
response = Response(
277-
"Bad Request: No valid session ID provided",
278-
status_code=HTTPStatus.BAD_REQUEST,
287+
"Not Found: Session has been terminated",
288+
status_code=HTTPStatus.NOT_FOUND,
279289
)
280290
await response(scope, receive, send)

tests/server/test_streamable_http_manager.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,3 +260,70 @@ async def mock_receive():
260260

261261
# Verify internal state is cleaned up
262262
assert len(transport._request_streams) == 0, "Transport should have no active request streams"
263+
264+
265+
@pytest.mark.anyio
266+
async def test_stateful_session_cleanup_on_idle_timeout():
267+
"""Test that stateful sessions are cleaned up when idle timeout with real transports and sessions."""
268+
app = Server("test-stateful-idle-timeout")
269+
manager = StreamableHTTPSessionManager(app=app, timeout=0.01)
270+
271+
created_transports: list[streamable_http_manager.StreamableHTTPServerTransport] = []
272+
273+
original_transport_constructor = streamable_http_manager.StreamableHTTPServerTransport
274+
275+
def track_transport(*args, **kwargs):
276+
transport = original_transport_constructor(*args, **kwargs)
277+
created_transports.append(transport)
278+
return transport
279+
280+
with patch.object(streamable_http_manager, "StreamableHTTPServerTransport", side_effect=track_transport):
281+
async with manager.run():
282+
sent_messages = []
283+
284+
async def mock_send(message):
285+
sent_messages.append(message)
286+
287+
scope = {
288+
"type": "http",
289+
"method": "POST",
290+
"path": "/mcp",
291+
"headers": [(b"content-type", b"application/json")],
292+
}
293+
294+
async def mock_receive():
295+
return {"type": "http.request", "body": b"", "more_body": False}
296+
297+
# Trigger session creation
298+
await manager.handle_request(scope, mock_receive, mock_send)
299+
300+
session_id = None
301+
for msg in sent_messages:
302+
if msg["type"] == "http.response.start":
303+
for header_name, header_value in msg.get("headers", []):
304+
if header_name.decode().lower() == MCP_SESSION_ID_HEADER.lower():
305+
session_id = header_value.decode()
306+
break
307+
if session_id: # Break outer loop if session_id is found
308+
break
309+
310+
assert session_id is not None, "Session ID not found in response headers"
311+
312+
assert len(created_transports) == 1, "Should have created one transport"
313+
314+
transport = created_transports[0]
315+
316+
# the transport should not be terminated before idle timeout
317+
assert not transport.is_terminated, "Transport should not be terminated before idle timeout"
318+
assert session_id in manager._server_instances, (
319+
"Session ID should be tracked in _server_instances before idle timeout"
320+
)
321+
322+
# wait for idle timeout
323+
await anyio.sleep(0.1)
324+
325+
assert transport.is_terminated, "Transport should be terminated after idle timeout"
326+
assert session_id not in manager._server_instances, (
327+
"Session ID should be removed from _server_instances after idle timeout"
328+
)
329+
assert not manager._server_instances, "No sessions should be tracked after the only session idle timeout"

0 commit comments

Comments
 (0)