1313 placement_group )
1414
1515from tensorrt_llm ._ray_utils import unwrap_ray_errors
16- from tensorrt_llm ._utils import get_free_port , nvtx_range_debug , ray_use_rpc
16+ from tensorrt_llm ._utils import get_free_port , nvtx_range_debug
1717from tensorrt_llm .logger import logger
1818
1919from ..llmapi .utils import logger_debug
2020from .executor import GenerationExecutor
2121from .postproc_worker import PostprocWorkerConfig
2222from .ray_gpu_worker import RayGPUWorker , RayWorkerWrapper
2323from .request import GenerationRequest
24- from .result import GenerationResult , RayAsyncQueue , RaySyncQueue
24+ from .result import GenerationResult
2525from .rpc_proxy_mixin import RpcExecutorMixin
2626
2727__all__ = [
@@ -76,38 +76,18 @@ def __init__(self,
7676 self .tp_size = tp_size
7777 self .master_address = ray .util .get_node_ip_address ()
7878 self .master_port = get_free_port ()
79- self .use_rpc = ray_use_rpc ()
8079
8180 worker_kwargs = dict (** worker_kwargs ,
8281 postproc_worker_config = postproc_worker_config ,
8382 is_llm_executor = is_llm_executor )
8483
85- if self .use_rpc :
86- self .init_rpc_executor ()
87- worker_kwargs ['rpc_addr' ] = self .rpc_addr
88- self .create_workers (RayGPUWorker , worker_kwargs )
89- self .setup_engine_remote ()
90- self .setup_mainloop (tasks = [self ._fetch_responses_loop_async ],
91- thread_name = "ray_executor_main_loop" )
92- logger .info (f"Connecting to RPC server at { self .rpc_addr } " )
93- else :
94- self .response_queue = RayAsyncQueue .options (runtime_env = {
95- "env_vars" : {
96- "TLLM_DISABLE_MPI" : "1"
97- }
98- }).remote ()
99- self .response_sync_queue = RaySyncQueue .options (runtime_env = {
100- "env_vars" : {
101- "TLLM_DISABLE_MPI" : "1"
102- }
103- }).remote ()
104- self .async_response_queue_weakref = self .create_actor_weak_ref (
105- self .response_queue )
106- self .sync_response_queue_weakref = self .create_actor_weak_ref (
107- self .response_sync_queue )
108- self .response_queue .warmup .remote ()
109- self .response_sync_queue .warmup .remote ()
110- self .create_workers (RayGPUWorker , worker_kwargs )
84+ self .init_rpc_executor ()
85+ worker_kwargs ['rpc_addr' ] = self .rpc_addr
86+ self .create_workers (RayGPUWorker , worker_kwargs )
87+ self .setup_engine_remote ()
88+ self .setup_mainloop (tasks = [self ._fetch_responses_loop_async ],
89+ thread_name = "ray_executor_main_loop" )
90+ logger .info (f"Connecting to RPC server at { self .rpc_addr } " )
11191
11292 except Exception as e :
11393 self .shutdown ()
@@ -192,37 +172,21 @@ def collective_rpc(self,
192172 def submit (self , request : "GenerationRequest" ) -> "GenerationResult" :
193173 """
194174 Low-level API to the executor. Return a "future" GenerationResult
195- which can be waited.
196- Forwards the request to the workers through RPC or Ray queues depending on mode.
175+ which can be waited. Forwards the request to the workers through RPC.
197176 """
198177 request .set_id (self ._get_next_client_id ())
199178 logprob_params = self ._get_logprob_params (request )
200179
201- if self .use_rpc :
202- with nvtx_range_debug ("rpc_submit" ):
203- self .rpc_client .submit (request ).remote (need_response = False )
204-
205- result = GenerationResult (
206- request ,
207- background_error_handler = self ._handle_background_error ,
208- executor = self ,
209- disaggregated_params = request .disaggregated_params ,
210- logprob_params = logprob_params )
211- self ._results [request .id ] = result
212- else :
213- result = GenerationResult (
214- request ,
215- background_error_handler = self ._handle_background_error ,
216- executor = self ,
217- disaggregated_params = request .disaggregated_params ,
218- logprob_params = logprob_params )
219-
220- with nvtx_range_debug ("request_queue.put" ):
221- self .call_all_ray_workers ("enqueue_request" ,
222- leader_only = True ,
223- request = request ,
224- async_call = True ,
225- result_wait_queue = result .queue )
180+ with nvtx_range_debug ("rpc_submit" ):
181+ self .rpc_client .submit (request ).remote (need_response = False )
182+
183+ result = GenerationResult (
184+ request ,
185+ background_error_handler = self ._handle_background_error ,
186+ executor = self ,
187+ disaggregated_params = request .disaggregated_params ,
188+ logprob_params = logprob_params )
189+ self ._results [request .id ] = result
226190
227191 return result
228192
@@ -238,9 +202,6 @@ def report_device_ids(self) -> list[str]:
238202 async_call = False )
239203 return sorted (gpu_ids )
240204
241- def use_ray_queue (self ) -> bool :
242- return not self .use_rpc
243-
244205 def abort_request (self , request_id : int ) -> None :
245206 self .call_all_ray_workers ("abort_request" ,
246207 leader_only = True ,
@@ -253,54 +214,40 @@ def shutdown(self):
253214 if hasattr (self , '_shutdown_event' ):
254215 self ._shutdown_event .set ()
255216
256- mode_str = "RPC mode" if self .use_rpc else "Ray queue mode"
257- logger_debug (f"Shutting down RayExecutor ({ mode_str } )" , color = "yellow" )
217+ logger_debug (f"Shutting down RayExecutor" , color = "yellow" )
258218
259- if self .use_rpc :
260- if hasattr (self , 'main_loop' ) and self .main_loop and hasattr (
261- self , 'main_loop_task_obj' ) and self .main_loop_task_obj :
262- logger_debug ("Cancelling main loop task." , color = "yellow" )
263- try :
264- self .main_loop .call_soon_threadsafe (
265- self .main_loop_task_obj .cancel )
266- except Exception as e :
267- logger_debug (f"Error cancelling main loop task: { e } " ,
268- color = "yellow" )
219+ if hasattr (self , 'main_loop' ) and self .main_loop and hasattr (
220+ self , 'main_loop_task_obj' ) and self .main_loop_task_obj :
221+ logger_debug ("Cancelling main loop task." , color = "yellow" )
222+ try :
223+ self .main_loop .call_soon_threadsafe (
224+ self .main_loop_task_obj .cancel )
225+ except Exception as e :
226+ logger_debug (f"Error cancelling main loop task: { e } " ,
227+ color = "yellow" )
269228
270- if hasattr (self , 'main_loop_thread' ):
271- self .main_loop_thread .join ()
229+ if hasattr (self , 'main_loop_thread' ):
230+ self .main_loop_thread .join ()
272231
273- # Then, shutdown the workers
274- if hasattr (self , 'workers' ) and self .workers is not None :
275- try :
276- logger_debug ("Shutting down RPC remote" , color = "yellow" )
277- shutdown_refs = [
278- worker .shutdown .remote () for worker in self .workers
279- ]
280- # Add timeout to prevent indefinite hanging
281- ray .get (shutdown_refs , timeout = 30.0 )
282- except ray .exceptions .GetTimeoutError :
283- logger .warning (
284- "Timeout waiting for workers to shutdown after 30 seconds"
285- )
286- except Exception as e :
287- logger .warning (f"Error shutting down RPC remote: { e } " )
288-
289- if hasattr (self , 'rpc_client' ) and self .rpc_client is not None :
290- try :
291- self .rpc_client .close ()
292- except Exception as e :
293- # Suppress errors during RPC client shutdown
294- # These can occur if the client is already closed or if there are
295- # pending operations that get cancelled during cleanup
296- logger_debug (
297- f"Suppressed error during RPC client close: { e } " )
298- else :
299- # Release actors
300- self .response_queue = None
301- self .response_sync_queue = None
302- self .async_response_queue_weakref = None
303- self .sync_response_queue_weakref = None
232+ # Then, shutdown the workers
233+ if hasattr (self , 'workers' ) and self .workers is not None :
234+ try :
235+ shutdown_refs = [
236+ worker .shutdown .remote () for worker in self .workers
237+ ]
238+ # Add timeout to prevent indefinite hanging
239+ ray .get (shutdown_refs , timeout = 30.0 )
240+ except ray .exceptions .GetTimeoutError :
241+ logger .warning (
242+ "Timeout waiting for workers to shutdown after 30 seconds" )
243+ except Exception as e :
244+ logger .warning (f"Error shutting down: { e } " )
245+
246+ if hasattr (self , 'rpc_client' ) and self .rpc_client is not None :
247+ try :
248+ self .rpc_client .close ()
249+ except Exception as e :
250+ logger_debug (f"Suppressed error during RPC client close: { e } " )
304251
305252 self .workers = None
306253 if hasattr (self ,
@@ -387,9 +334,3 @@ def enable_postprocess_parallel(self) -> bool:
387334 ret = super ().enable_postprocess_parallel
388335 assert ret == False , "Postprocess parallel is not supported in RayExecutor"
389336 return ret
390-
391- @staticmethod
392- def create_actor_weak_ref (actor_handle : ray .actor .ActorHandle ):
393- state , _ , _ = actor_handle ._serialization_helper ()
394- return ray .actor .ActorHandle ._deserialization_helper (state ,
395- weak_ref = True )
0 commit comments