Skip to content

Commit b1aa8ff

Browse files
committed
fix race condition
Signed-off-by: Superjomn <[email protected]> unwaive rpc tests simplify RPCServer shutdown Remove pending requests processing, shutdown immediately fix streaming cancelled share event_loop between proxy and client refactor RpcClient by unifying event_loop Simplify. refactor RPCServer by simpify add correctness tests fix worker refactor test_rpc_worker Focus on testing the RpcWorker APIs fix test_rpc_proxy.py restore RPCClient with a dedicated background thread The test_rpc_proxy.py tp1[1] passed fix test_rpc_proxy.py restore RPCClient with a dedicated background thread The test_rpc_proxy.py tp1[1] passed add threaded remote_call test add more debugging print dedicated thread for fetch_responses random hang with submit failed cleanup test_rpc.py fix race condition in zmq socket socket is used in both event_loop in two threads, unify the sending in the rpc_client's main loop thread add ipc TLLM_LLMAPI_ZMQ_DEBUG fix wait_for lost message test passed the race condition is resolved completely refine the pr add test_ipc.py fix tests
1 parent 6e5384d commit b1aa8ff

File tree

17 files changed

+2497
-1050
lines changed

17 files changed

+2497
-1050
lines changed

tensorrt_llm/executor/ipc.py

Lines changed: 90 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import hmac
44
import os
55
import pickle # nosec B403
6+
import threading
67
import time
78
import traceback
89
from queue import Queue
@@ -65,6 +66,13 @@ def __init__(self,
6566
self.hmac_key = address[1] if address is not None else None
6667
self.use_hmac_encryption = use_hmac_encryption
6768

69+
self._setup_lock = threading.Lock()
70+
71+
# Thread safety debugging
72+
self._zmq_thread_id = None
73+
self._zmq_debug_enabled = os.environ.get('TLLM_LLMAPI_ZMQ_DEBUG',
74+
'0') != '0'
75+
6876
# Check HMAC key condition
6977
if self.use_hmac_encryption and not self.is_server and self.hmac_key is None:
7078
raise ValueError(
@@ -93,25 +101,52 @@ def __init__(self,
93101
self.address = (self.address_endpoint, self.hmac_key)
94102

95103
def setup_lazily(self):
104+
# Early return if setup is already done
96105
if self._setup_done:
97106
return
98-
self._setup_done = True
99107

100-
if not self.is_server:
101-
logger_debug(
102-
f"Client [{self.name}] connecting to {self.address_endpoint} in {self.socket_type_str[self.socket_type]}\n",
103-
"green")
104-
self.socket.connect(self.address_endpoint)
108+
with self._setup_lock:
109+
if self._setup_done:
110+
return
111+
self._setup_done = True
112+
113+
if not self.is_server:
114+
logger_debug(
115+
f"Client [{self.name}] connecting to {self.address_endpoint} in {self.socket_type_str[self.socket_type]}\n",
116+
"green")
117+
self.socket.connect(self.address_endpoint)
105118

106-
self.poller = zmq.Poller()
107-
self.poller.register(self.socket, zmq.POLLIN)
119+
self.poller = zmq.Poller()
120+
self.poller.register(self.socket, zmq.POLLIN)
121+
122+
def _check_thread_safety(self):
123+
"""Check if the current thread is the same as the thread that first used the socket."""
124+
if not self._zmq_debug_enabled:
125+
return
126+
127+
current_thread_id = threading.get_ident()
128+
129+
if self._zmq_thread_id is None:
130+
# First call - capture the thread ID
131+
self._zmq_thread_id = current_thread_id
132+
logger_debug(
133+
f"ZMQ socket [{self.name}] initialized on thread {current_thread_id}",
134+
"cyan")
135+
elif self._zmq_thread_id != current_thread_id:
136+
# Thread mismatch - raise error
137+
raise RuntimeError(
138+
f"ZMQ thread safety violation detected in [{self.name}]: "
139+
f"Socket created on thread {self._zmq_thread_id}, "
140+
f"but accessed from thread {current_thread_id}. "
141+
f"ZMQ sockets are not thread-safe!")
108142

109143
def poll(self, timeout: int) -> bool:
110144
"""
111145
Parameters:
112146
timeout (int): Timeout in seconds
113147
"""
114148
self.setup_lazily()
149+
self._check_thread_safety()
115150

116151
events = dict(self.poller.poll(timeout=timeout * 1000))
117152
if self.socket in events and events[self.socket] == zmq.POLLIN:
@@ -121,6 +156,7 @@ def poll(self, timeout: int) -> bool:
121156

122157
def put(self, obj: Any):
123158
self.setup_lazily()
159+
self._check_thread_safety()
124160
with nvtx_range_debug("send", color="blue", category="IPC"):
125161
if self.use_hmac_encryption or self.socket_type == zmq.ROUTER:
126162
# Need manual serialization for encryption or ROUTER multipart
@@ -148,6 +184,7 @@ def put_noblock(self,
148184
assert retry >= 0 and retry <= 10, "Retry must be between 0 and 10, adjust the wait_time if needed"
149185

150186
self.setup_lazily()
187+
self._check_thread_safety()
151188
with nvtx_range_debug("send", color="blue", category="IPC"):
152189

153190
data = self._prepare_data(obj)
@@ -162,6 +199,7 @@ def put_noblock(self,
162199

163200
async def put_async(self, obj: Any):
164201
self.setup_lazily()
202+
self._check_thread_safety()
165203
try:
166204
if self.use_hmac_encryption or self.socket_type == zmq.ROUTER:
167205
# Need manual serialization for encryption or ROUTER multipart
@@ -182,6 +220,7 @@ async def put_async(self, obj: Any):
182220

183221
async def put_async_noblock(self, obj: Any):
184222
self.setup_lazily()
223+
self._check_thread_safety()
185224
try:
186225
if self.use_hmac_encryption:
187226
data = pickle.dumps(obj) # nosec B301
@@ -196,14 +235,55 @@ async def put_async_noblock(self, obj: Any):
196235

197236
def get(self) -> Any:
198237
self.setup_lazily()
238+
self._check_thread_safety()
199239
return self._recv_data()
200240

201241
async def get_async(self) -> Any:
202242
self.setup_lazily()
243+
self._check_thread_safety()
203244
return await self._recv_data_async()
204245

205246
async def get_async_noblock(self, timeout: float = 0.5) -> Any:
206-
return await asyncio.wait_for(self.get_async(), timeout)
247+
"""Get data with timeout using polling to avoid message drops.
248+
249+
This method uses ZMQ's NOBLOCK flag with polling instead of asyncio.wait_for
250+
to prevent cancelling recv operations which can cause message drops.
251+
252+
Args:
253+
timeout: Timeout in seconds
254+
255+
Returns:
256+
The received object
257+
258+
Raises:
259+
asyncio.TimeoutError: If timeout is reached without receiving data
260+
"""
261+
self.setup_lazily()
262+
self._check_thread_safety()
263+
264+
# Use polling loop instead of asyncio.wait_for to avoid cancelling recv
265+
# which can cause message drops
266+
deadline = asyncio.get_event_loop().time() + timeout
267+
while True:
268+
try:
269+
# Try non-blocking receive
270+
if self.socket_type == zmq.ROUTER:
271+
identity, data = await self.socket.recv_multipart(
272+
flags=zmq.NOBLOCK)
273+
self._last_identity = identity
274+
return self._parse_data(data)
275+
else:
276+
if self.use_hmac_encryption:
277+
data = await self.socket.recv(flags=zmq.NOBLOCK)
278+
return self._parse_data(data)
279+
else:
280+
return await self.socket.recv_pyobj(flags=zmq.NOBLOCK)
281+
except zmq.Again:
282+
# No message available yet
283+
if asyncio.get_event_loop().time() >= deadline:
284+
raise asyncio.TimeoutError()
285+
# Short sleep to avoid busy-waiting
286+
await asyncio.sleep(0.01)
207287

208288
def close(self):
209289
if self.socket:
@@ -311,6 +391,7 @@ def notify_with_retry(self, message, max_retries=5, timeout=1):
311391
raise ValueError(
312392
"notify_with_retry is only supported for DEALER socket for now")
313393

394+
self._check_thread_safety()
314395
retry_count = 0
315396

316397
while retry_count < max_retries:

tensorrt_llm/executor/ray_executor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from .ray_gpu_worker import RayGPUWorker, RayWorkerWrapper
2323
from .request import GenerationRequest
2424
from .result import GenerationResult, RayAsyncQueue, RaySyncQueue
25-
from .rpc_proxy import RpcExecutorMixin
25+
from .rpc_proxy_mixin import RpcExecutorMixin
2626

2727
__all__ = [
2828
"RayExecutor",

tensorrt_llm/executor/ray_gpu_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from .postproc_worker import PostprocWorkerConfig
2424
from .request import GenerationRequest
2525
from .result import GenerationResult
26-
from .rpc_worker import RpcWorkerMixin
26+
from .rpc_worker_mixin import RpcWorkerMixin
2727

2828
__all__ = [
2929
"RayGPUWorker",

0 commit comments

Comments
 (0)