Skip to content

Commit 7fce3b9

Browse files
committed
add ipc TLLM_LLMAPI_ZMQ_DEBUG
Signed-off-by: Superjomn <[email protected]>
1 parent 01ef3f1 commit 7fce3b9

File tree

6 files changed

+149
-95
lines changed

6 files changed

+149
-95
lines changed

tensorrt_llm/executor/ipc.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,11 @@ def __init__(self,
6868

6969
self._setup_lock = threading.Lock()
7070

71+
# Thread safety debugging
72+
self._zmq_thread_id = None
73+
self._zmq_debug_enabled = os.environ.get('TLLM_LLMAPI_ZMQ_DEBUG',
74+
'0') != '0'
75+
7176
# Check HMAC key condition
7277
if self.use_hmac_encryption and not self.is_server and self.hmac_key is None:
7378
raise ValueError(
@@ -114,12 +119,34 @@ def setup_lazily(self):
114119
self.poller = zmq.Poller()
115120
self.poller.register(self.socket, zmq.POLLIN)
116121

122+
def _check_thread_safety(self):
123+
"""Check if the current thread is the same as the thread that first used the socket."""
124+
if not self._zmq_debug_enabled:
125+
return
126+
127+
current_thread_id = threading.get_ident()
128+
129+
if self._zmq_thread_id is None:
130+
# First call - capture the thread ID
131+
self._zmq_thread_id = current_thread_id
132+
logger_debug(
133+
f"ZMQ socket [{self.name}] initialized on thread {current_thread_id}",
134+
"cyan")
135+
elif self._zmq_thread_id != current_thread_id:
136+
# Thread mismatch - raise error
137+
raise RuntimeError(
138+
f"ZMQ thread safety violation detected in [{self.name}]: "
139+
f"Socket created on thread {self._zmq_thread_id}, "
140+
f"but accessed from thread {current_thread_id}. "
141+
f"ZMQ sockets are not thread-safe!")
142+
117143
def poll(self, timeout: int) -> bool:
118144
"""
119145
Parameters:
120146
timeout (int): Timeout in seconds
121147
"""
122148
self.setup_lazily()
149+
self._check_thread_safety()
123150

124151
events = dict(self.poller.poll(timeout=timeout * 1000))
125152
if self.socket in events and events[self.socket] == zmq.POLLIN:
@@ -129,6 +156,7 @@ def poll(self, timeout: int) -> bool:
129156

130157
def put(self, obj: Any):
131158
self.setup_lazily()
159+
self._check_thread_safety()
132160
with nvtx_range_debug("send", color="blue", category="IPC"):
133161
if self.use_hmac_encryption or self.socket_type == zmq.ROUTER:
134162
# Need manual serialization for encryption or ROUTER multipart
@@ -156,6 +184,7 @@ def put_noblock(self,
156184
assert retry >= 0 and retry <= 10, "Retry must be between 0 and 10, adjust the wait_time if needed"
157185

158186
self.setup_lazily()
187+
self._check_thread_safety()
159188
with nvtx_range_debug("send", color="blue", category="IPC"):
160189

161190
data = self._prepare_data(obj)
@@ -170,6 +199,7 @@ def put_noblock(self,
170199

171200
async def put_async(self, obj: Any):
172201
self.setup_lazily()
202+
self._check_thread_safety()
173203
try:
174204
if self.use_hmac_encryption or self.socket_type == zmq.ROUTER:
175205
# Need manual serialization for encryption or ROUTER multipart
@@ -190,6 +220,7 @@ async def put_async(self, obj: Any):
190220

191221
async def put_async_noblock(self, obj: Any):
192222
self.setup_lazily()
223+
self._check_thread_safety()
193224
try:
194225
if self.use_hmac_encryption:
195226
data = pickle.dumps(obj) # nosec B301
@@ -204,13 +235,16 @@ async def put_async_noblock(self, obj: Any):
204235

205236
def get(self) -> Any:
206237
self.setup_lazily()
238+
self._check_thread_safety()
207239
return self._recv_data()
208240

209241
async def get_async(self) -> Any:
210242
self.setup_lazily()
243+
self._check_thread_safety()
211244
return await self._recv_data_async()
212245

213246
async def get_async_noblock(self, timeout: float = 0.5) -> Any:
247+
self._check_thread_safety()
214248
return await asyncio.wait_for(self.get_async(), timeout)
215249

216250
def close(self):
@@ -319,6 +353,7 @@ def notify_with_retry(self, message, max_retries=5, timeout=1):
319353
raise ValueError(
320354
"notify_with_retry is only supported for DEALER socket for now")
321355

356+
self._check_thread_safety()
322357
retry_count = 0
323358

324359
while retry_count < max_retries:

0 commit comments

Comments
 (0)