Skip to content

Commit ad92d4b

Browse files
committed
add correctness tests
Signed-off-by: Superjomn <[email protected]>
1 parent 403f13d commit ad92d4b

File tree

4 files changed

+244
-102
lines changed

4 files changed

+244
-102
lines changed

tensorrt_llm/executor/rpc/rpc_client.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import concurrent.futures
33
import threading
44
import uuid
5-
from typing import Any, AsyncIterator, Optional
5+
from typing import Any, AsyncIterator, Callable, Optional
66

77
from ...llmapi.utils import logger_debug
88
from ...logger import logger
@@ -14,7 +14,8 @@
1414
class RemoteCall:
1515
"""Helper class to enable chained remote call syntax like client.method().remote()"""
1616

17-
def __init__(self, client: 'RPCClient', method_name: str, *args, **kwargs):
17+
def __init__(self, client: 'RPCClient', method_name: str, *args,
18+
**kwargs) -> None:
1819
self.client = client
1920
self.method_name = method_name
2021
self.args = args
@@ -57,7 +58,7 @@ def remote(self,
5758

5859
def remote_async(self,
5960
timeout: Optional[float] = None,
60-
need_response: bool = True):
61+
need_response: bool = True) -> Any:
6162
"""Asynchronous remote call that returns a coroutine.
6263
6364
Args:
@@ -106,9 +107,9 @@ class RPCClient:
106107

107108
def __init__(self,
108109
address: str,
109-
hmac_key=None,
110+
hmac_key: Optional[bytes] = None,
110111
timeout: Optional[float] = None,
111-
num_workers: int = 4):
112+
num_workers: int = 4) -> None:
112113
'''
113114
Args:
114115
address: The ZMQ address to connect to.
@@ -137,7 +138,7 @@ def __init__(self,
137138

138139
logger_debug(f"RPC Client initialized. Connected to {self._address}")
139140

140-
def shutdown_server(self):
141+
def shutdown_server(self) -> None:
141142
"""Shutdown the server."""
142143
if self._server_stopped:
143144
return
@@ -146,7 +147,7 @@ def shutdown_server(self):
146147

147148
self._server_stopped = True
148149

149-
def close(self):
150+
def close(self) -> None:
150151
"""Gracefully close the client, cleaning up background tasks."""
151152
if self._closed:
152153
return
@@ -172,7 +173,7 @@ def close(self):
172173

173174
logger_debug("RPC Client closed")
174175

175-
async def _response_reader(self):
176+
async def _response_reader(self) -> None:
176177
"""Task to read responses from the socket and set results on futures."""
177178
logger_debug("Response reader started")
178179

@@ -242,7 +243,7 @@ async def _response_reader(self):
242243
finally:
243244
logger_debug("Response reader exiting gracefully")
244245

245-
def _ensure_reader_task(self):
246+
def _ensure_reader_task(self) -> None:
246247
"""Ensure the response reader task is running."""
247248
with self._reader_lock:
248249
if self._reader_task is None or self._reader_task.done():
@@ -254,7 +255,7 @@ def _ensure_reader_task(self):
254255
# No running event loop, will be started when needed
255256
pass
256257

257-
async def _call_async(self, method_name, *args, **kwargs):
258+
async def _call_async(self, method_name: str, *args, **kwargs) -> Any:
258259
"""Async version of RPC call.
259260
Args:
260261
method_name: Method name to call
@@ -320,7 +321,7 @@ async def _call_async(self, method_name, *args, **kwargs):
320321
except Exception:
321322
raise
322323

323-
def _call_sync(self, method_name, *args, **kwargs):
324+
def _call_sync(self, method_name: str, *args, **kwargs) -> Any:
324325
"""Synchronous version of RPC call."""
325326
logger_debug(
326327
f"RPC Client calling method: {method_name} with args: {args} and kwargs: {kwargs}"
@@ -331,7 +332,7 @@ def _call_sync(self, method_name, *args, **kwargs):
331332
asyncio.get_running_loop()
332333

333334
# We're inside an event loop, we need to run in a thread to avoid deadlock
334-
def run_in_thread():
335+
def run_in_thread() -> Any:
335336
return asyncio.run(
336337
self._call_async(method_name, *args, **kwargs))
337338

@@ -454,12 +455,12 @@ async def _call_streaming(self, name: str, *args,
454455
# Clean up
455456
self._streaming_queues.pop(request_id, None)
456457

457-
def get_server_attr(self, name: str):
458+
def get_server_attr(self, name: str) -> Any:
458459
""" Get the attribute of the RPC server.
459460
This is mainly used for testing. """
460461
return self._rpc_get_attr(name).remote()
461462

462-
def __getattr__(self, name):
463+
def __getattr__(self, name: str) -> Callable[..., RemoteCall]:
463464
"""
464465
Magically handles calls to non-existent methods.
465466
Returns a callable that when invoked returns a RemoteCall instance.
@@ -472,16 +473,16 @@ def __getattr__(self, name):
472473
"""
473474
logger_debug(f"RPC Client getting attribute: {name}")
474475

475-
def method_caller(*args, **kwargs):
476+
def method_caller(*args, **kwargs) -> RemoteCall:
476477
return RemoteCall(self, name, *args, **kwargs)
477478

478479
return method_caller
479480

480-
def __enter__(self):
481+
def __enter__(self) -> 'RPCClient':
481482
return self
482483

483-
def __exit__(self, exc_type, exc_value, traceback):
484+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
484485
self.close()
485486

486-
def __del__(self):
487+
def __del__(self) -> None:
487488
self.close()

tensorrt_llm/executor/rpc/rpc_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,5 +86,5 @@ class RPCResponse(NamedTuple):
8686
result: Any
8787
error: Optional[RPCError] = None
8888
is_streaming: bool = False # True if more responses coming
89-
sequence_number: int = 0 # For ordering streaming responses
89+
chunk_index: int = 0 # For ordering streaming responses
9090
stream_status: Literal['start', 'data', 'end', 'error'] = 'data'

0 commit comments

Comments
 (0)