Skip to content

Commit 218f319

Browse files
hchingsSuperjomn
authored andcommitted
rpc cleanup
Signed-off-by: Erin Ho <[email protected]>
1 parent 63c0df8 commit 218f319

File tree

8 files changed

+68
-272
lines changed

8 files changed

+68
-272
lines changed

tensorrt_llm/_utils.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -524,13 +524,6 @@ def mpi_disabled() -> bool:
524524
return os.environ.get("TLLM_DISABLE_MPI") == "1"
525525

526526

527-
def ray_use_rpc() -> bool:
528-
"""True if TLLM_RAY_USE_RPC is set to "1", False otherwise.
529-
# TODO: deprecate this once Ray is fully moved to use RPC client/server.
530-
"""
531-
return os.environ.get("TLLM_RAY_USE_RPC") == "1"
532-
533-
534527
def mpi_rank():
535528
if mpi_disabled():
536529
try:

tensorrt_llm/executor/executor.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,6 @@ def __init__(self,
103103
self._iter_kv_events_result: IterationResult | None = None
104104
self._iter_stats_result: IterationResult | None = None
105105

106-
def use_ray_queue(self) -> bool:
107-
return False
108-
109106
@abstractmethod
110107
def submit(self, request: GenerationRequest) -> GenerationResult:
111108
pass

tensorrt_llm/executor/ray_executor.py

Lines changed: 51 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@
1313
placement_group)
1414

1515
from tensorrt_llm._ray_utils import unwrap_ray_errors
16-
from tensorrt_llm._utils import get_free_port, nvtx_range_debug, ray_use_rpc
16+
from tensorrt_llm._utils import get_free_port, nvtx_range_debug
1717
from tensorrt_llm.logger import logger
1818

1919
from ..llmapi.utils import logger_debug
2020
from .executor import GenerationExecutor
2121
from .postproc_worker import PostprocWorkerConfig
2222
from .ray_gpu_worker import RayGPUWorker, RayWorkerWrapper
2323
from .request import GenerationRequest
24-
from .result import GenerationResult, RayAsyncQueue, RaySyncQueue
24+
from .result import GenerationResult
2525
from .rpc_proxy_mixin import RpcExecutorMixin
2626

2727
__all__ = [
@@ -76,38 +76,18 @@ def __init__(self,
7676
self.tp_size = tp_size
7777
self.master_address = ray.util.get_node_ip_address()
7878
self.master_port = get_free_port()
79-
self.use_rpc = ray_use_rpc()
8079

8180
worker_kwargs = dict(**worker_kwargs,
8281
postproc_worker_config=postproc_worker_config,
8382
is_llm_executor=is_llm_executor)
8483

85-
if self.use_rpc:
86-
self.init_rpc_executor()
87-
worker_kwargs['rpc_addr'] = self.rpc_addr
88-
self.create_workers(RayGPUWorker, worker_kwargs)
89-
self.setup_engine_remote()
90-
self.setup_mainloop(tasks=[self._fetch_responses_loop_async],
91-
thread_name="ray_executor_main_loop")
92-
logger.info(f"Connecting to RPC server at {self.rpc_addr}")
93-
else:
94-
self.response_queue = RayAsyncQueue.options(runtime_env={
95-
"env_vars": {
96-
"TLLM_DISABLE_MPI": "1"
97-
}
98-
}).remote()
99-
self.response_sync_queue = RaySyncQueue.options(runtime_env={
100-
"env_vars": {
101-
"TLLM_DISABLE_MPI": "1"
102-
}
103-
}).remote()
104-
self.async_response_queue_weakref = self.create_actor_weak_ref(
105-
self.response_queue)
106-
self.sync_response_queue_weakref = self.create_actor_weak_ref(
107-
self.response_sync_queue)
108-
self.response_queue.warmup.remote()
109-
self.response_sync_queue.warmup.remote()
110-
self.create_workers(RayGPUWorker, worker_kwargs)
84+
self.init_rpc_executor()
85+
worker_kwargs['rpc_addr'] = self.rpc_addr
86+
self.create_workers(RayGPUWorker, worker_kwargs)
87+
self.setup_engine_remote()
88+
self.setup_mainloop(tasks=[self._fetch_responses_loop_async],
89+
thread_name="ray_executor_main_loop")
90+
logger.info(f"Connecting to RPC server at {self.rpc_addr}")
11191

11292
except Exception as e:
11393
self.shutdown()
@@ -192,37 +172,21 @@ def collective_rpc(self,
192172
def submit(self, request: "GenerationRequest") -> "GenerationResult":
193173
"""
194174
Low-level API to the executor. Return a "future" GenerationResult
195-
which can be waited.
196-
Forwards the request to the workers through RPC or Ray queues depending on mode.
175+
which can be waited. Forwards the request to the workers through RPC.
197176
"""
198177
request.set_id(self._get_next_client_id())
199178
logprob_params = self._get_logprob_params(request)
200179

201-
if self.use_rpc:
202-
with nvtx_range_debug("rpc_submit"):
203-
self.rpc_client.submit(request).remote(need_response=False)
204-
205-
result = GenerationResult(
206-
request,
207-
background_error_handler=self._handle_background_error,
208-
executor=self,
209-
disaggregated_params=request.disaggregated_params,
210-
logprob_params=logprob_params)
211-
self._results[request.id] = result
212-
else:
213-
result = GenerationResult(
214-
request,
215-
background_error_handler=self._handle_background_error,
216-
executor=self,
217-
disaggregated_params=request.disaggregated_params,
218-
logprob_params=logprob_params)
219-
220-
with nvtx_range_debug("request_queue.put"):
221-
self.call_all_ray_workers("enqueue_request",
222-
leader_only=True,
223-
request=request,
224-
async_call=True,
225-
result_wait_queue=result.queue)
180+
with nvtx_range_debug("rpc_submit"):
181+
self.rpc_client.submit(request).remote(need_response=False)
182+
183+
result = GenerationResult(
184+
request,
185+
background_error_handler=self._handle_background_error,
186+
executor=self,
187+
disaggregated_params=request.disaggregated_params,
188+
logprob_params=logprob_params)
189+
self._results[request.id] = result
226190

227191
return result
228192

@@ -238,9 +202,6 @@ def report_device_ids(self) -> list[str]:
238202
async_call=False)
239203
return sorted(gpu_ids)
240204

241-
def use_ray_queue(self) -> bool:
242-
return not self.use_rpc
243-
244205
def abort_request(self, request_id: int) -> None:
245206
self.call_all_ray_workers("abort_request",
246207
leader_only=True,
@@ -253,54 +214,40 @@ def shutdown(self):
253214
if hasattr(self, '_shutdown_event'):
254215
self._shutdown_event.set()
255216

256-
mode_str = "RPC mode" if self.use_rpc else "Ray queue mode"
257-
logger_debug(f"Shutting down RayExecutor ({mode_str})", color="yellow")
217+
logger_debug(f"Shutting down RayExecutor", color="yellow")
258218

259-
if self.use_rpc:
260-
if hasattr(self, 'main_loop') and self.main_loop and hasattr(
261-
self, 'main_loop_task_obj') and self.main_loop_task_obj:
262-
logger_debug("Cancelling main loop task.", color="yellow")
263-
try:
264-
self.main_loop.call_soon_threadsafe(
265-
self.main_loop_task_obj.cancel)
266-
except Exception as e:
267-
logger_debug(f"Error cancelling main loop task: {e}",
268-
color="yellow")
219+
if hasattr(self, 'main_loop') and self.main_loop and hasattr(
220+
self, 'main_loop_task_obj') and self.main_loop_task_obj:
221+
logger_debug("Cancelling main loop task.", color="yellow")
222+
try:
223+
self.main_loop.call_soon_threadsafe(
224+
self.main_loop_task_obj.cancel)
225+
except Exception as e:
226+
logger_debug(f"Error cancelling main loop task: {e}",
227+
color="yellow")
269228

270-
if hasattr(self, 'main_loop_thread'):
271-
self.main_loop_thread.join()
229+
if hasattr(self, 'main_loop_thread'):
230+
self.main_loop_thread.join()
272231

273-
# Then, shutdown the workers
274-
if hasattr(self, 'workers') and self.workers is not None:
275-
try:
276-
logger_debug("Shutting down RPC remote", color="yellow")
277-
shutdown_refs = [
278-
worker.shutdown.remote() for worker in self.workers
279-
]
280-
# Add timeout to prevent indefinite hanging
281-
ray.get(shutdown_refs, timeout=30.0)
282-
except ray.exceptions.GetTimeoutError:
283-
logger.warning(
284-
"Timeout waiting for workers to shutdown after 30 seconds"
285-
)
286-
except Exception as e:
287-
logger.warning(f"Error shutting down RPC remote: {e}")
288-
289-
if hasattr(self, 'rpc_client') and self.rpc_client is not None:
290-
try:
291-
self.rpc_client.close()
292-
except Exception as e:
293-
# Suppress errors during RPC client shutdown
294-
# These can occur if the client is already closed or if there are
295-
# pending operations that get cancelled during cleanup
296-
logger_debug(
297-
f"Suppressed error during RPC client close: {e}")
298-
else:
299-
# Release actors
300-
self.response_queue = None
301-
self.response_sync_queue = None
302-
self.async_response_queue_weakref = None
303-
self.sync_response_queue_weakref = None
232+
# Then, shutdown the workers
233+
if hasattr(self, 'workers') and self.workers is not None:
234+
try:
235+
shutdown_refs = [
236+
worker.shutdown.remote() for worker in self.workers
237+
]
238+
# Add timeout to prevent indefinite hanging
239+
ray.get(shutdown_refs, timeout=30.0)
240+
except ray.exceptions.GetTimeoutError:
241+
logger.warning(
242+
"Timeout waiting for workers to shutdown after 30 seconds")
243+
except Exception as e:
244+
logger.warning(f"Error shutting down: {e}")
245+
246+
if hasattr(self, 'rpc_client') and self.rpc_client is not None:
247+
try:
248+
self.rpc_client.close()
249+
except Exception as e:
250+
logger_debug(f"Suppressed error during RPC client close: {e}")
304251

305252
self.workers = None
306253
if hasattr(self,
@@ -387,9 +334,3 @@ def enable_postprocess_parallel(self) -> bool:
387334
ret = super().enable_postprocess_parallel
388335
assert ret == False, "Postprocess parallel is not supported in RayExecutor"
389336
return ret
390-
391-
@staticmethod
392-
def create_actor_weak_ref(actor_handle: ray.actor.ActorHandle):
393-
state, _, _ = actor_handle._serialization_helper()
394-
return ray.actor.ActorHandle._deserialization_helper(state,
395-
weak_ref=True)

tensorrt_llm/executor/ray_gpu_worker.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from tensorrt_llm._torch.virtual_memory import (materialize_with_tag,
1313
release_with_tag,
1414
verify_sleep_wakeup_tags)
15-
from tensorrt_llm._utils import ray_use_rpc
1615

1716
from ..bindings import executor as tllm
1817
from ..builder import Engine
@@ -189,14 +188,11 @@ def __init__(
189188
if self.global_rank > 1:
190189
logger.set_rank(self.global_rank)
191190

192-
if ray_use_rpc():
193-
if rpc_addr is None:
194-
raise RuntimeError(
195-
"RPC mode enabled but no rpc_addr provided to RayGPUWorker")
196-
self.init_rpc_worker(self.global_rank, rpc_addr)
197-
self.start_rpc_server()
198-
else:
199-
self.setup_engine()
191+
if rpc_addr is None:
192+
raise RuntimeError(
193+
"RPC mode enabled but no rpc_addr provided to RayGPUWorker")
194+
self.init_rpc_worker(self.global_rank, rpc_addr)
195+
self.start_rpc_server()
200196

201197
def setup_engine(self):
202198
if torch.distributed.is_initialized(

0 commit comments

Comments
 (0)