15
15
from contextlib import asynccontextmanager
16
16
from dataclasses import dataclass
17
17
from http import HTTPStatus
18
+ from types import TracebackType
18
19
19
20
import anyio
20
21
from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
23
24
from starlette .requests import Request
24
25
from starlette .responses import Response
25
26
from starlette .types import Receive , Scope , Send
27
+ from typing_extensions import Self
26
28
27
29
from mcp .server .transport_security import (
28
30
TransportSecurityMiddleware ,
@@ -140,6 +142,7 @@ def __init__(
140
142
is_json_response_enabled : bool = False ,
141
143
event_store : EventStore | None = None ,
142
144
security_settings : TransportSecuritySettings | None = None ,
145
+ timeout : float | None = None ,
143
146
) -> None :
144
147
"""
145
148
Initialize a new StreamableHTTP server transport.
@@ -153,6 +156,9 @@ def __init__(
153
156
resumability will be enabled, allowing clients to
154
157
reconnect and resume messages.
155
158
security_settings: Optional security settings for DNS rebinding protection.
159
+ timeout: Optional idle timeout for transport. If provided, the transport will
160
+ terminate if it remains idle for longer than the defined timeout
161
+ duration in seconds.
156
162
157
163
Raises:
158
164
ValueError: If the session ID contains invalid characters.
@@ -172,6 +178,12 @@ def __init__(
172
178
],
173
179
] = {}
174
180
self ._terminated = False
181
+ self ._timeout = timeout
182
+
183
+ # for idle detection
184
+ self ._processing_request_count = 0
185
+ self ._idle_condition = anyio .Condition ()
186
+ self ._has_request = False
175
187
176
188
@property
177
189
def is_terminated (self ) -> bool :
@@ -626,6 +638,9 @@ async def terminate(self) -> None:
626
638
Once terminated, all requests with this session ID will receive 404 Not Found.
627
639
"""
628
640
641
+ if self ._terminated :
642
+ return
643
+
629
644
self ._terminated = True
630
645
logger .info (f"Terminating session: { self .mcp_session_id } " )
631
646
@@ -796,6 +811,42 @@ async def send_event(event_message: EventMessage) -> None:
796
811
)
797
812
await response (request .scope , request .receive , send )
798
813
814
+ async def __aenter__ (self ) -> Self :
815
+ async with self ._idle_condition :
816
+ self ._processing_request_count += 1
817
+ self ._has_request = True
818
+ return self
819
+
820
+ async def __aexit__ (
821
+ self ,
822
+ exc_type : type [BaseException ] | None ,
823
+ exc_value : BaseException | None ,
824
+ traceback : TracebackType | None ,
825
+ ) -> None :
826
+ async with self ._idle_condition :
827
+ self ._processing_request_count -= 1
828
+ if self ._processing_request_count == 0 :
829
+ self ._idle_condition .notify_all ()
830
+
831
+ async def _idle_timeout_terminate (self , timeout : float ) -> None :
832
+ """
833
+ Terminate the transport if it remains idle for longer than the defined timeout duration.
834
+ """
835
+ while not self ._terminated :
836
+ # wait for transport to be idle
837
+ async with self ._idle_condition :
838
+ if self ._processing_request_count > 0 :
839
+ await self ._idle_condition .wait ()
840
+ self ._has_request = False
841
+
842
+ # wait for idle timeout
843
+ await anyio .sleep (timeout )
844
+
845
+ # If there are no requests during the wait period, terminate the transport
846
+ if not self ._has_request :
847
+ logger .debug (f"Terminating transport due to idle timeout: { self .mcp_session_id } " )
848
+ await self .terminate ()
849
+
799
850
@asynccontextmanager
800
851
async def connect (
801
852
self ,
@@ -812,6 +863,10 @@ async def connect(
812
863
Tuple of (read_stream, write_stream) for bidirectional communication
813
864
"""
814
865
866
+ # Terminated transports should not be connected again
867
+ if self ._terminated :
868
+ raise RuntimeError ("Transport is terminated" )
869
+
815
870
# Create the memory streams for this connection
816
871
817
872
read_stream_writer , read_stream = anyio .create_memory_object_stream [SessionMessage | Exception ](0 )
@@ -884,20 +939,13 @@ async def message_router():
884
939
# Start the message router
885
940
tg .start_soon (message_router )
886
941
942
+ # Start idle timeout task if timeout is set
943
+ if self ._timeout is not None :
944
+ tg .start_soon (self ._idle_timeout_terminate , self ._timeout )
945
+
887
946
try :
888
947
# Yield the streams for the caller to use
889
948
yield read_stream , write_stream
890
949
finally :
891
- for stream_id in list (self ._request_streams .keys ()):
892
- await self ._clean_up_memory_streams (stream_id )
893
- self ._request_streams .clear ()
894
-
895
- # Clean up the read and write streams
896
- try :
897
- await read_stream_writer .aclose ()
898
- await read_stream .aclose ()
899
- await write_stream_reader .aclose ()
900
- await write_stream .aclose ()
901
- except Exception as e :
902
- # During cleanup, we catch all exceptions since streams might be in various states
903
- logger .debug (f"Error closing streams: { e } " )
950
+ # Terminate the transport when the context manager exits
951
+ await self .terminate ()
0 commit comments