Skip to content

Commit 864b3e1

Browse files
committed
fix wait_for lost message
test passed the race condition is resolved completely Signed-off-by: Superjomn <[email protected]>
1 parent 7fce3b9 commit 864b3e1

File tree

3 files changed

+61
-3
lines changed

3 files changed

+61
-3
lines changed

tensorrt_llm/executor/ipc.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,8 +244,46 @@ async def get_async(self) -> Any:
244244
return await self._recv_data_async()
245245

246246
async def get_async_noblock(self, timeout: float = 0.5) -> Any:
247+
"""Get data with timeout using polling to avoid message drops.
248+
249+
This method uses ZMQ's NOBLOCK flag with polling instead of asyncio.wait_for
250+
to prevent cancelling recv operations which can cause message drops.
251+
252+
Args:
253+
timeout: Timeout in seconds
254+
255+
Returns:
256+
The received object
257+
258+
Raises:
259+
asyncio.TimeoutError: If timeout is reached without receiving data
260+
"""
261+
self.setup_lazily()
247262
self._check_thread_safety()
248-
return await asyncio.wait_for(self.get_async(), timeout)
263+
264+
# Use polling loop instead of asyncio.wait_for to avoid cancelling recv
265+
# which can cause message drops
266+
deadline = asyncio.get_event_loop().time() + timeout
267+
while True:
268+
try:
269+
# Try non-blocking receive
270+
if self.socket_type == zmq.ROUTER:
271+
identity, data = await self.socket.recv_multipart(
272+
flags=zmq.NOBLOCK)
273+
self._last_identity = identity
274+
return self._parse_data(data)
275+
else:
276+
if self.use_hmac_encryption:
277+
data = await self.socket.recv(flags=zmq.NOBLOCK)
278+
return self._parse_data(data)
279+
else:
280+
return await self.socket.recv_pyobj(flags=zmq.NOBLOCK)
281+
except zmq.Again:
282+
# No message available yet
283+
if asyncio.get_event_loop().time() >= deadline:
284+
raise asyncio.TimeoutError()
285+
# Short sleep to avoid busy-waiting
286+
await asyncio.sleep(0.01)
249287

250288
def close(self):
251289
if self.socket:

tensorrt_llm/executor/rpc/rpc_client.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import concurrent.futures
3+
import os
34
import threading
45
import time
56
import uuid
@@ -104,11 +105,20 @@ def __init__(self,
104105
'''
105106
self._address = address
106107
self._timeout = timeout
108+
109+
# Check if PAIR mode is enabled via environment variable
110+
use_pair_mode = os.environ.get('TLLM_LLMAPI_ZMQ_PAIR', '0') != '0'
111+
socket_type = zmq.PAIR if use_pair_mode else zmq.DEALER
112+
113+
if use_pair_mode:
114+
logger_debug(
115+
"[client] Using zmq.PAIR socket type for RPC communication")
116+
107117
self._client_socket = ZeroMqQueue(address=(address, hmac_key),
108118
is_server=False,
109119
is_async=True,
110120
use_hmac_encryption=False,
111-
socket_type=zmq.DEALER,
121+
socket_type=socket_type,
112122
name="rpc_client")
113123
self._pending_futures = {}
114124
# map request_id to the queue for streaming responses

tensorrt_llm/executor/rpc/rpc_server.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import inspect
3+
import os
34
import threading
45
import time
56
import traceback
@@ -95,11 +96,20 @@ def bind(self, address: str = "tcp://*:5555") -> None:
9596
address (str): The ZMQ address to bind the client-facing socket.
9697
"""
9798
self._address = address
99+
100+
# Check if PAIR mode is enabled via environment variable
101+
use_pair_mode = os.environ.get('TLLM_LLMAPI_ZMQ_PAIR', '0') != '0'
102+
socket_type = zmq.PAIR if use_pair_mode else zmq.ROUTER
103+
104+
if use_pair_mode:
105+
logger_debug(
106+
"[server] Using zmq.PAIR socket type for RPC communication")
107+
98108
self._client_socket = ZeroMqQueue(address=(address, self._hmac_key),
99109
is_server=True,
100110
is_async=True,
101111
use_hmac_encryption=False,
102-
socket_type=zmq.ROUTER,
112+
socket_type=socket_type,
103113
name="rpc_server")
104114
logger.info(f"RPCServer is bound to {self._address}")
105115

0 commit comments

Comments
 (0)