Skip to content

Commit ddb9823

Browse files
committed
add correctness tests
Signed-off-by: Superjomn <[email protected]>
1 parent 5edba20 commit ddb9823

File tree

4 files changed

+244
-207
lines changed

4 files changed

+244
-207
lines changed

tensorrt_llm/executor/rpc/rpc_client.py

Lines changed: 19 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import concurrent.futures
33
import threading
44
import uuid
5-
from typing import Any, AsyncIterator, Optional
5+
from typing import Any, AsyncIterator, Callable, Optional
66

77
from ...llmapi.utils import logger_debug
88
from ...logger import logger
@@ -14,7 +14,8 @@
1414
class RemoteCall:
1515
"""Helper class to enable chained remote call syntax like client.method().remote()"""
1616

17-
def __init__(self, client: 'RPCClient', method_name: str, *args, **kwargs):
17+
def __init__(self, client: 'RPCClient', method_name: str, *args,
18+
**kwargs) -> None:
1819
self.client = client
1920
self.method_name = method_name
2021
self.args = args
@@ -57,7 +58,7 @@ def remote(self,
5758

5859
def remote_async(self,
5960
timeout: Optional[float] = None,
60-
need_response: bool = True):
61+
need_response: bool = True) -> Any:
6162
"""Asynchronous remote call that returns a coroutine.
6263
6364
Args:
@@ -106,9 +107,9 @@ class RPCClient:
106107

107108
def __init__(self,
108109
address: str,
109-
hmac_key=None,
110+
hmac_key: Optional[bytes] = None,
110111
timeout: Optional[float] = None,
111-
num_workers: int = 4):
112+
num_workers: int = 4) -> None:
112113
'''
113114
Args:
114115
address: The ZMQ address to connect to.
@@ -137,7 +138,7 @@ def __init__(self,
137138

138139
logger_debug(f"RPC Client initialized. Connected to {self._address}")
139140

140-
def shutdown_server(self):
141+
def shutdown_server(self) -> None:
141142
"""Shutdown the server."""
142143
if self._server_stopped:
143144
return
@@ -146,7 +147,7 @@ def shutdown_server(self):
146147

147148
self._server_stopped = True
148149

149-
def close(self):
150+
def close(self) -> None:
150151
"""Gracefully close the client, cleaning up background tasks."""
151152
if self._closed:
152153
return
@@ -182,112 +183,7 @@ def close(self):
182183

183184
logger_debug("RPC Client closed")
184185

185-
def _handle_streaming_response(self, response: RPCResponse):
186-
"""Handle a streaming response by putting it in the appropriate queue.
187-
188-
Args:
189-
response: The streaming response to handle
190-
"""
191-
assert response.stream_status in [
192-
'start', 'data', 'end', 'error'
193-
], f"Invalid stream status: {response.stream_status}"
194-
195-
queue = self._streaming_queues.get(response.request_id)
196-
if queue:
197-
# put to the sync queue, as the current event loop is
198-
# different from the one in call_async or call_streaming
199-
assert isinstance(queue, AsyncQueue)
200-
if enable_llmapi_debug() or logger.level == 'debug':
201-
logger_debug(
202-
f"RPC Client putting response to AsyncQueue: status={response.stream_status}, request_id={response.request_id}"
203-
)
204-
queue.sync_q.put(response)
205-
# Clean up if stream ended
206-
if response.stream_status in ['end', 'error']:
207-
self._streaming_queues.pop(response.request_id, None)
208-
209-
def _handle_regular_response(self, response: RPCResponse):
210-
"""Handle a regular (non-streaming) response by setting the future result.
211-
212-
Args:
213-
response: The response to handle
214-
"""
215-
if future_info := self._pending_futures.get(response.request_id):
216-
future, target_loop = future_info
217-
218-
if not future.done():
219-
220-
def safe_set_result():
221-
"""Safely set result on future, handling race conditions."""
222-
try:
223-
if not future.done():
224-
if response.error is None:
225-
future.set_result(response.result)
226-
else:
227-
future.set_exception(response.error)
228-
except asyncio.InvalidStateError:
229-
# Future was cancelled or completed between the check and set
230-
# This is expected in high-load scenarios, just log and continue
231-
if enable_llmapi_debug() or logger.level == 'debug':
232-
logger_debug(
233-
f"Future already done for request_id: {response.request_id}, skipping"
234-
)
235-
236-
if enable_llmapi_debug() or logger.level == 'debug':
237-
if response.error is None:
238-
logger_debug(
239-
f"Setting result for request_id: {response.request_id}"
240-
)
241-
else:
242-
logger_debug(
243-
f"Setting exception for request_id: {response.request_id}, error: {response.error}"
244-
)
245-
246-
target_loop.call_soon_threadsafe(safe_set_result)
247-
else:
248-
if enable_llmapi_debug() or logger.level == 'debug':
249-
logger_debug(
250-
f"No future found for request_id: {response.request_id}")
251-
252-
self._pending_futures.pop(response.request_id, None)
253-
254-
async def _handle_reader_exception(self, exception: Exception):
255-
"""Propagate an exception to all pending futures and streaming queues.
256-
257-
Args:
258-
exception: The exception to propagate
259-
"""
260-
logger.error(f"Exception in RPC response reader: {exception}")
261-
262-
# Propagate exception to all pending futures
263-
for (future, target_loop) in self._pending_futures.values():
264-
if not future.done():
265-
266-
def safe_set_exception(f=future, exc=exception):
267-
"""Safely set exception on future, handling race conditions."""
268-
try:
269-
if not f.done():
270-
f.set_exception(exc)
271-
except asyncio.InvalidStateError:
272-
# Future was cancelled or completed, this is fine
273-
pass
274-
275-
target_loop.call_soon_threadsafe(safe_set_exception)
276-
277-
# Also signal error to streaming queues
278-
for queue in self._streaming_queues.values():
279-
await queue.put(RPCResponse("", None, exception, False, 0, 'error'))
280-
281-
async def _wait_for_response(self) -> RPCResponse:
282-
"""Wait for a response from the socket.
283-
284-
Returns:
285-
RPCResponse from the server
286-
"""
287-
# Directly await the socket - cancellation will be handled by task cancellation
288-
return await self._client_socket.get_async()
289-
290-
async def _response_reader(self):
186+
async def _response_reader(self) -> None:
291187
"""Task to read responses from the socket and set results on futures."""
292188
logger_debug("Response reader started")
293189

@@ -359,7 +255,7 @@ async def _response_reader(self):
359255
finally:
360256
logger_debug("Response reader exiting gracefully")
361257

362-
def _ensure_reader_task(self):
258+
def _ensure_reader_task(self) -> None:
363259
"""Ensure the response reader task is running."""
364260
with self._reader_lock:
365261
if self._reader_task is None or self._reader_task.done():
@@ -371,7 +267,7 @@ def _ensure_reader_task(self):
371267
# No running event loop, will be started when needed
372268
pass
373269

374-
async def _call_async(self, method_name, *args, **kwargs):
270+
async def _call_async(self, method_name: str, *args, **kwargs) -> Any:
375271
"""Async version of RPC call.
376272
Args:
377273
method_name: Method name to call
@@ -435,7 +331,7 @@ async def _call_async(self, method_name, *args, **kwargs):
435331
except Exception:
436332
raise
437333

438-
def _call_sync(self, method_name, *args, **kwargs):
334+
def _call_sync(self, method_name: str, *args, **kwargs) -> Any:
439335
"""Synchronous version of RPC call."""
440336
logger_debug(
441337
f"RPC Client calling method: {method_name} with args: {args} and kwargs: {kwargs}"
@@ -446,7 +342,7 @@ def _call_sync(self, method_name, *args, **kwargs):
446342
asyncio.get_running_loop()
447343

448344
# We're inside an event loop, we need to run in a thread to avoid deadlock
449-
def run_in_thread():
345+
def run_in_thread() -> Any:
450346
return asyncio.run(
451347
self._call_async(method_name, *args, **kwargs))
452348

@@ -566,12 +462,12 @@ async def _call_streaming(self, name: str, *args,
566462
# Clean up
567463
self._streaming_queues.pop(request_id, None)
568464

569-
def get_server_attr(self, name: str):
465+
def get_server_attr(self, name: str) -> Any:
570466
""" Get the attribute of the RPC server.
571467
This is mainly used for testing. """
572468
return self._rpc_get_attr(name).remote()
573469

574-
def __getattr__(self, name):
470+
def __getattr__(self, name: str) -> Callable[..., RemoteCall]:
575471
"""
576472
Magically handles calls to non-existent methods.
577473
Returns a callable that when invoked returns a RemoteCall instance.
@@ -584,16 +480,16 @@ def __getattr__(self, name):
584480
"""
585481
logger_debug(f"RPC Client getting attribute: {name}")
586482

587-
def method_caller(*args, **kwargs):
483+
def method_caller(*args, **kwargs) -> RemoteCall:
588484
return RemoteCall(self, name, *args, **kwargs)
589485

590486
return method_caller
591487

592-
def __enter__(self):
488+
def __enter__(self) -> 'RPCClient':
593489
return self
594490

595-
def __exit__(self, exc_type, exc_value, traceback):
491+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
596492
self.close()
597493

598-
def __del__(self):
494+
def __del__(self) -> None:
599495
self.close()

tensorrt_llm/executor/rpc/rpc_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,5 +86,5 @@ class RPCResponse(NamedTuple):
8686
result: Any
8787
error: Optional[RPCError] = None
8888
is_streaming: bool = False # True if more responses coming
89-
sequence_number: int = 0 # For ordering streaming responses
89+
chunk_index: int = 0 # For ordering streaming responses
9090
stream_status: Literal['start', 'data', 'end', 'error'] = 'data'

0 commit comments

Comments
 (0)