22import concurrent .futures
33import threading
44import uuid
5- from typing import Any , AsyncIterator , Optional
5+ from typing import Any , AsyncIterator , Callable , Optional
66
77from ...llmapi .utils import logger_debug
88from ...logger import logger
1414class RemoteCall :
1515 """Helper class to enable chained remote call syntax like client.method().remote()"""
1616
17- def __init__ (self , client : 'RPCClient' , method_name : str , * args , ** kwargs ):
17+ def __init__ (self , client : 'RPCClient' , method_name : str , * args ,
18+ ** kwargs ) -> None :
1819 self .client = client
1920 self .method_name = method_name
2021 self .args = args
@@ -57,7 +58,7 @@ def remote(self,
5758
5859 def remote_async (self ,
5960 timeout : Optional [float ] = None ,
60- need_response : bool = True ):
61+ need_response : bool = True ) -> Any :
6162 """Asynchronous remote call that returns a coroutine.
6263
6364 Args:
@@ -106,9 +107,9 @@ class RPCClient:
106107
107108 def __init__ (self ,
108109 address : str ,
109- hmac_key = None ,
110+ hmac_key : Optional [ bytes ] = None ,
110111 timeout : Optional [float ] = None ,
111- num_workers : int = 4 ):
112+ num_workers : int = 4 ) -> None :
112113 '''
113114 Args:
114115 address: The ZMQ address to connect to.
@@ -137,7 +138,7 @@ def __init__(self,
137138
138139 logger_debug (f"RPC Client initialized. Connected to { self ._address } " )
139140
140- def shutdown_server (self ):
141+ def shutdown_server (self ) -> None :
141142 """Shutdown the server."""
142143 if self ._server_stopped :
143144 return
@@ -146,7 +147,7 @@ def shutdown_server(self):
146147
147148 self ._server_stopped = True
148149
149- def close (self ):
150+ def close (self ) -> None :
150151 """Gracefully close the client, cleaning up background tasks."""
151152 if self ._closed :
152153 return
@@ -172,7 +173,7 @@ def close(self):
172173
173174 logger_debug ("RPC Client closed" )
174175
175- async def _response_reader (self ):
176+ async def _response_reader (self ) -> None :
176177 """Task to read responses from the socket and set results on futures."""
177178 logger_debug ("Response reader started" )
178179
@@ -242,7 +243,7 @@ async def _response_reader(self):
242243 finally :
243244 logger_debug ("Response reader exiting gracefully" )
244245
245- def _ensure_reader_task (self ):
246+ def _ensure_reader_task (self ) -> None :
246247 """Ensure the response reader task is running."""
247248 with self ._reader_lock :
248249 if self ._reader_task is None or self ._reader_task .done ():
@@ -254,7 +255,7 @@ def _ensure_reader_task(self):
254255 # No running event loop, will be started when needed
255256 pass
256257
257- async def _call_async (self , method_name , * args , ** kwargs ):
258+ async def _call_async (self , method_name : str , * args , ** kwargs ) -> Any :
258259 """Async version of RPC call.
259260 Args:
260261 method_name: Method name to call
@@ -320,7 +321,7 @@ async def _call_async(self, method_name, *args, **kwargs):
320321 except Exception :
321322 raise
322323
323- def _call_sync (self , method_name , * args , ** kwargs ):
324+ def _call_sync (self , method_name : str , * args , ** kwargs ) -> Any :
324325 """Synchronous version of RPC call."""
325326 logger_debug (
326327 f"RPC Client calling method: { method_name } with args: { args } and kwargs: { kwargs } "
@@ -331,7 +332,7 @@ def _call_sync(self, method_name, *args, **kwargs):
331332 asyncio .get_running_loop ()
332333
333334 # We're inside an event loop, we need to run in a thread to avoid deadlock
334- def run_in_thread ():
335+ def run_in_thread () -> Any :
335336 return asyncio .run (
336337 self ._call_async (method_name , * args , ** kwargs ))
337338
@@ -454,12 +455,12 @@ async def _call_streaming(self, name: str, *args,
454455 # Clean up
455456 self ._streaming_queues .pop (request_id , None )
456457
457- def get_server_attr (self , name : str ):
458+ def get_server_attr (self , name : str ) -> Any :
458459 """ Get the attribute of the RPC server.
459460 This is mainly used for testing. """
460461 return self ._rpc_get_attr (name ).remote ()
461462
462- def __getattr__ (self , name ) :
463+ def __getattr__ (self , name : str ) -> Callable [..., RemoteCall ] :
463464 """
464465 Magically handles calls to non-existent methods.
465466 Returns a callable that when invoked returns a RemoteCall instance.
@@ -472,16 +473,16 @@ def __getattr__(self, name):
472473 """
473474 logger_debug (f"RPC Client getting attribute: { name } " )
474475
475- def method_caller (* args , ** kwargs ):
476+ def method_caller (* args , ** kwargs ) -> RemoteCall :
476477 return RemoteCall (self , name , * args , ** kwargs )
477478
478479 return method_caller
479480
480- def __enter__ (self ):
481+ def __enter__ (self ) -> 'RPCClient' :
481482 return self
482483
483- def __exit__ (self , exc_type , exc_value , traceback ) :
484+ def __exit__ (self , exc_type : Any , exc_value : Any , traceback : Any ) -> None :
484485 self .close ()
485486
486- def __del__ (self ):
487+ def __del__ (self ) -> None :
487488 self .close ()
0 commit comments