@@ -36,10 +36,17 @@ class RpcWorker(BaseWorker):
3636 - `shutdown`: Shutdown the worker.
3737 """
3838
39- # Number of RPC server workers
39+ # Default number of RPC server workers
4040 # Increased to handle concurrent requests and prevent thread pool exhaustion
4141 # Need enough workers for: submit requests + fetch_responses + other operations
42- NUM_WORKERS = 32
42+ # Can be overridden via constructor parameter
43+ DEFAULT_NUM_WORKERS = 32
44+
45+ # Default timeout for fetch_responses in seconds
46+ # This is a short timeout to prevent blocking the event loop while still allowing
47+ # responses to be fetched efficiently. The value is tuned to balance responsiveness
48+ # and CPU usage. Can be overridden via constructor parameter.
49+ DEFAULT_FETCH_TIMEOUT = 0.1
4350
4451 def __init__ (
4552 self ,
@@ -51,6 +58,8 @@ def __init__(
5158 hf_model_dir : Optional [Path ] = None ,
5259 tokenizer : Optional [TokenizerBase ] = None ,
5360 llm_args : Optional [BaseLlmArgs ] = None ,
61+ num_workers : Optional [int ] = None ,
62+ fetch_timeout : Optional [float ] = None ,
5463 ) -> None :
5564 super ().__init__ (
5665 engine = engine ,
@@ -63,6 +72,12 @@ def __init__(
6372 llm_args = llm_args ,
6473 )
6574
75+ # Configure number of RPC workers
76+ self .num_workers = num_workers if num_workers is not None else self .DEFAULT_NUM_WORKERS
77+
78+ # Configure fetch timeout
79+ self ._fetch_timeout = fetch_timeout if fetch_timeout is not None else self .DEFAULT_FETCH_TIMEOUT
80+
6681 # Extract garbage_collection_gen0_threshold from llm_args if available
6782 self .garbage_collection_gen0_threshold = (
6883 llm_args .garbage_collection_gen0_threshold if llm_args is not None
@@ -95,7 +110,9 @@ def fetch_responses(self, timeout: Optional[float] = None) -> list:
95110 color = "orange" ,
96111 category = "Worker" ):
97112 # NOTE: This is a blocking call, it will wait for the responses to be available.
98- responses = super ().await_responses (timeout = 0.1 )
113+ # Use the configured fetch timeout if no timeout is provided
114+ actual_timeout = timeout if timeout is not None else self ._fetch_timeout
115+ responses = super ().await_responses (timeout = actual_timeout )
99116 self ._await_response_helper .responses_handler (responses )
100117 logger_debug (f"[worker] Fetched { len (responses )} responses" ,
101118 color = "green" )
@@ -248,11 +265,11 @@ def main_task(
248265
249266 else :
250267 logger_debug (
251- f"[worker] Worker { mpi_rank ()} is creating the RPC service" ,
268+ f"[worker] Worker { mpi_rank ()} is creating the RPC service with { worker . num_workers } workers " ,
252269 color = "yellow" )
253270 # Step 2: Create the RPC service, it will expose all the APIs of the worker as remote call to the client
254271 # Set num_workers to larger than 1 since there are some streaming tasks runs infinitely, such as await_responses_async.
255- rpc_server = RPCServer (worker , num_workers = RpcWorker . NUM_WORKERS )
272+ rpc_server = RPCServer (worker , num_workers = worker . num_workers )
256273 rpc_server .bind (rpc_addr )
257274 rpc_server .start ()
258275 logger_debug (f"[worker] RPC server { mpi_rank ()} is started" ,
0 commit comments