22import concurrent .futures
33import threading
44import uuid
5- from typing import Any , AsyncIterator , Optional
5+ from typing import Any , AsyncIterator , Callable , Optional
66
77from ...llmapi .utils import logger_debug
88from ...logger import logger
1414class RemoteCall :
1515 """Helper class to enable chained remote call syntax like client.method().remote()"""
1616
17- def __init__ (self , client : 'RPCClient' , method_name : str , * args , ** kwargs ):
17+ def __init__ (self , client : 'RPCClient' , method_name : str , * args ,
18+ ** kwargs ) -> None :
1819 self .client = client
1920 self .method_name = method_name
2021 self .args = args
@@ -57,7 +58,7 @@ def remote(self,
5758
5859 def remote_async (self ,
5960 timeout : Optional [float ] = None ,
60- need_response : bool = True ):
61+ need_response : bool = True ) -> Any :
6162 """Asynchronous remote call that returns a coroutine.
6263
6364 Args:
@@ -106,9 +107,9 @@ class RPCClient:
106107
107108 def __init__ (self ,
108109 address : str ,
109- hmac_key = None ,
110+ hmac_key : Optional [ bytes ] = None ,
110111 timeout : Optional [float ] = None ,
111- num_workers : int = 4 ):
112+ num_workers : int = 4 ) -> None :
112113 '''
113114 Args:
114115 address: The ZMQ address to connect to.
@@ -137,7 +138,7 @@ def __init__(self,
137138
138139 logger_debug (f"RPC Client initialized. Connected to { self ._address } " )
139140
140- def shutdown_server (self ):
141+ def shutdown_server (self ) -> None :
141142 """Shutdown the server."""
142143 if self ._server_stopped :
143144 return
@@ -146,7 +147,7 @@ def shutdown_server(self):
146147
147148 self ._server_stopped = True
148149
149- def close (self ):
150+ def close (self ) -> None :
150151 """Gracefully close the client, cleaning up background tasks."""
151152 if self ._closed :
152153 return
@@ -182,112 +183,7 @@ def close(self):
182183
183184 logger_debug ("RPC Client closed" )
184185
185- def _handle_streaming_response (self , response : RPCResponse ):
186- """Handle a streaming response by putting it in the appropriate queue.
187-
188- Args:
189- response: The streaming response to handle
190- """
191- assert response .stream_status in [
192- 'start' , 'data' , 'end' , 'error'
193- ], f"Invalid stream status: { response .stream_status } "
194-
195- queue = self ._streaming_queues .get (response .request_id )
196- if queue :
197- # put to the sync queue, as the current event loop is
198- # different from the one in call_async or call_streaming
199- assert isinstance (queue , AsyncQueue )
200- if enable_llmapi_debug () or logger .level == 'debug' :
201- logger_debug (
202- f"RPC Client putting response to AsyncQueue: status={ response .stream_status } , request_id={ response .request_id } "
203- )
204- queue .sync_q .put (response )
205- # Clean up if stream ended
206- if response .stream_status in ['end' , 'error' ]:
207- self ._streaming_queues .pop (response .request_id , None )
208-
209- def _handle_regular_response (self , response : RPCResponse ):
210- """Handle a regular (non-streaming) response by setting the future result.
211-
212- Args:
213- response: The response to handle
214- """
215- if future_info := self ._pending_futures .get (response .request_id ):
216- future , target_loop = future_info
217-
218- if not future .done ():
219-
220- def safe_set_result ():
221- """Safely set result on future, handling race conditions."""
222- try :
223- if not future .done ():
224- if response .error is None :
225- future .set_result (response .result )
226- else :
227- future .set_exception (response .error )
228- except asyncio .InvalidStateError :
229- # Future was cancelled or completed between the check and set
230- # This is expected in high-load scenarios, just log and continue
231- if enable_llmapi_debug () or logger .level == 'debug' :
232- logger_debug (
233- f"Future already done for request_id: { response .request_id } , skipping"
234- )
235-
236- if enable_llmapi_debug () or logger .level == 'debug' :
237- if response .error is None :
238- logger_debug (
239- f"Setting result for request_id: { response .request_id } "
240- )
241- else :
242- logger_debug (
243- f"Setting exception for request_id: { response .request_id } , error: { response .error } "
244- )
245-
246- target_loop .call_soon_threadsafe (safe_set_result )
247- else :
248- if enable_llmapi_debug () or logger .level == 'debug' :
249- logger_debug (
250- f"No future found for request_id: { response .request_id } " )
251-
252- self ._pending_futures .pop (response .request_id , None )
253-
254- async def _handle_reader_exception (self , exception : Exception ):
255- """Propagate an exception to all pending futures and streaming queues.
256-
257- Args:
258- exception: The exception to propagate
259- """
260- logger .error (f"Exception in RPC response reader: { exception } " )
261-
262- # Propagate exception to all pending futures
263- for (future , target_loop ) in self ._pending_futures .values ():
264- if not future .done ():
265-
266- def safe_set_exception (f = future , exc = exception ):
267- """Safely set exception on future, handling race conditions."""
268- try :
269- if not f .done ():
270- f .set_exception (exc )
271- except asyncio .InvalidStateError :
272- # Future was cancelled or completed, this is fine
273- pass
274-
275- target_loop .call_soon_threadsafe (safe_set_exception )
276-
277- # Also signal error to streaming queues
278- for queue in self ._streaming_queues .values ():
279- await queue .put (RPCResponse ("" , None , exception , False , 0 , 'error' ))
280-
281- async def _wait_for_response (self ) -> RPCResponse :
282- """Wait for a response from the socket.
283-
284- Returns:
285- RPCResponse from the server
286- """
287- # Directly await the socket - cancellation will be handled by task cancellation
288- return await self ._client_socket .get_async ()
289-
290- async def _response_reader (self ):
186+ async def _response_reader (self ) -> None :
291187 """Task to read responses from the socket and set results on futures."""
292188 logger_debug ("Response reader started" )
293189
@@ -359,7 +255,7 @@ async def _response_reader(self):
359255 finally :
360256 logger_debug ("Response reader exiting gracefully" )
361257
362- def _ensure_reader_task (self ):
258+ def _ensure_reader_task (self ) -> None :
363259 """Ensure the response reader task is running."""
364260 with self ._reader_lock :
365261 if self ._reader_task is None or self ._reader_task .done ():
@@ -371,7 +267,7 @@ def _ensure_reader_task(self):
371267 # No running event loop, will be started when needed
372268 pass
373269
374- async def _call_async (self , method_name , * args , ** kwargs ):
270+ async def _call_async (self , method_name : str , * args , ** kwargs ) -> Any :
375271 """Async version of RPC call.
376272 Args:
377273 method_name: Method name to call
@@ -435,7 +331,7 @@ async def _call_async(self, method_name, *args, **kwargs):
435331 except Exception :
436332 raise
437333
438- def _call_sync (self , method_name , * args , ** kwargs ):
334+ def _call_sync (self , method_name : str , * args , ** kwargs ) -> Any :
439335 """Synchronous version of RPC call."""
440336 logger_debug (
441337 f"RPC Client calling method: { method_name } with args: { args } and kwargs: { kwargs } "
@@ -446,7 +342,7 @@ def _call_sync(self, method_name, *args, **kwargs):
446342 asyncio .get_running_loop ()
447343
448344 # We're inside an event loop, we need to run in a thread to avoid deadlock
449- def run_in_thread ():
345+ def run_in_thread () -> Any :
450346 return asyncio .run (
451347 self ._call_async (method_name , * args , ** kwargs ))
452348
@@ -566,12 +462,12 @@ async def _call_streaming(self, name: str, *args,
566462 # Clean up
567463 self ._streaming_queues .pop (request_id , None )
568464
569- def get_server_attr (self , name : str ):
465+ def get_server_attr (self , name : str ) -> Any :
570466 """ Get the attribute of the RPC server.
571467 This is mainly used for testing. """
572468 return self ._rpc_get_attr (name ).remote ()
573469
574- def __getattr__ (self , name ) :
470+ def __getattr__ (self , name : str ) -> Callable [..., RemoteCall ] :
575471 """
576472 Magically handles calls to non-existent methods.
577473 Returns a callable that when invoked returns a RemoteCall instance.
@@ -584,16 +480,16 @@ def __getattr__(self, name):
584480 """
585481 logger_debug (f"RPC Client getting attribute: { name } " )
586482
587- def method_caller (* args , ** kwargs ):
483+ def method_caller (* args , ** kwargs ) -> RemoteCall :
588484 return RemoteCall (self , name , * args , ** kwargs )
589485
590486 return method_caller
591487
592- def __enter__ (self ):
488+ def __enter__ (self ) -> 'RPCClient' :
593489 return self
594490
595- def __exit__ (self , exc_type , exc_value , traceback ) :
491+ def __exit__ (self , exc_type : Any , exc_value : Any , traceback : Any ) -> None :
596492 self .close ()
597493
598- def __del__ (self ):
494+ def __del__ (self ) -> None :
599495 self .close ()
0 commit comments