33import hmac
44import os
55import pickle # nosec B403
6+ import threading
67import time
78import traceback
89from queue import Queue
@@ -65,6 +66,13 @@ def __init__(self,
6566 self .hmac_key = address [1 ] if address is not None else None
6667 self .use_hmac_encryption = use_hmac_encryption
6768
69+ self ._setup_lock = threading .Lock ()
70+
71+ # Thread safety debugging
72+ self ._zmq_thread_id = None
73+ self ._zmq_debug_enabled = os .environ .get ('TLLM_LLMAPI_ZMQ_DEBUG' ,
74+ '0' ) != '0'
75+
6876 # Check HMAC key condition
6977 if self .use_hmac_encryption and not self .is_server and self .hmac_key is None :
7078 raise ValueError (
@@ -93,25 +101,52 @@ def __init__(self,
93101 self .address = (self .address_endpoint , self .hmac_key )
94102
95103 def setup_lazily (self ):
104+ # Early return if setup is already done
96105 if self ._setup_done :
97106 return
98- self ._setup_done = True
99107
100- if not self .is_server :
101- logger_debug (
102- f"Client [{ self .name } ] connecting to { self .address_endpoint } in { self .socket_type_str [self .socket_type ]} \n " ,
103- "green" )
104- self .socket .connect (self .address_endpoint )
108+ with self ._setup_lock :
109+ if self ._setup_done :
110+ return
111+ self ._setup_done = True
112+
113+ if not self .is_server :
114+ logger_debug (
115+ f"Client [{ self .name } ] connecting to { self .address_endpoint } in { self .socket_type_str [self .socket_type ]} \n " ,
116+ "green" )
117+ self .socket .connect (self .address_endpoint )
105118
106- self .poller = zmq .Poller ()
107- self .poller .register (self .socket , zmq .POLLIN )
119+ self .poller = zmq .Poller ()
120+ self .poller .register (self .socket , zmq .POLLIN )
121+
122+ def _check_thread_safety (self ):
123+ """Check if the current thread is the same as the thread that first used the socket."""
124+ if not self ._zmq_debug_enabled :
125+ return
126+
127+ current_thread_id = threading .get_ident ()
128+
129+ if self ._zmq_thread_id is None :
130+ # First call - capture the thread ID
131+ self ._zmq_thread_id = current_thread_id
132+ logger_debug (
133+ f"ZMQ socket [{ self .name } ] initialized on thread { current_thread_id } " ,
134+ "cyan" )
135+ elif self ._zmq_thread_id != current_thread_id :
136+ # Thread mismatch - raise error
137+ raise RuntimeError (
138+ f"ZMQ thread safety violation detected in [{ self .name } ]: "
139+ f"Socket created on thread { self ._zmq_thread_id } , "
140+ f"but accessed from thread { current_thread_id } . "
141+ f"ZMQ sockets are not thread-safe!" )
108142
109143 def poll (self , timeout : int ) -> bool :
110144 """
111145 Parameters:
112146 timeout (int): Timeout in seconds
113147 """
114148 self .setup_lazily ()
149+ self ._check_thread_safety ()
115150
116151 events = dict (self .poller .poll (timeout = timeout * 1000 ))
117152 if self .socket in events and events [self .socket ] == zmq .POLLIN :
@@ -121,6 +156,7 @@ def poll(self, timeout: int) -> bool:
121156
122157 def put (self , obj : Any ):
123158 self .setup_lazily ()
159+ self ._check_thread_safety ()
124160 with nvtx_range_debug ("send" , color = "blue" , category = "IPC" ):
125161 if self .use_hmac_encryption or self .socket_type == zmq .ROUTER :
126162 # Need manual serialization for encryption or ROUTER multipart
@@ -148,6 +184,7 @@ def put_noblock(self,
148184 assert retry >= 0 and retry <= 10 , "Retry must be between 0 and 10, adjust the wait_time if needed"
149185
150186 self .setup_lazily ()
187+ self ._check_thread_safety ()
151188 with nvtx_range_debug ("send" , color = "blue" , category = "IPC" ):
152189
153190 data = self ._prepare_data (obj )
@@ -162,6 +199,7 @@ def put_noblock(self,
162199
163200 async def put_async (self , obj : Any ):
164201 self .setup_lazily ()
202+ self ._check_thread_safety ()
165203 try :
166204 if self .use_hmac_encryption or self .socket_type == zmq .ROUTER :
167205 # Need manual serialization for encryption or ROUTER multipart
@@ -182,6 +220,7 @@ async def put_async(self, obj: Any):
182220
183221 async def put_async_noblock (self , obj : Any ):
184222 self .setup_lazily ()
223+ self ._check_thread_safety ()
185224 try :
186225 if self .use_hmac_encryption :
187226 data = pickle .dumps (obj ) # nosec B301
@@ -196,14 +235,55 @@ async def put_async_noblock(self, obj: Any):
196235
197236 def get (self ) -> Any :
198237 self .setup_lazily ()
238+ self ._check_thread_safety ()
199239 return self ._recv_data ()
200240
201241 async def get_async (self ) -> Any :
202242 self .setup_lazily ()
243+ self ._check_thread_safety ()
203244 return await self ._recv_data_async ()
204245
205246 async def get_async_noblock (self , timeout : float = 0.5 ) -> Any :
206- return await asyncio .wait_for (self .get_async (), timeout )
247+ """Get data with timeout using polling to avoid message drops.
248+
249+ This method uses ZMQ's NOBLOCK flag with polling instead of asyncio.wait_for
250+ to prevent cancelling recv operations which can cause message drops.
251+
252+ Args:
253+ timeout: Timeout in seconds
254+
255+ Returns:
256+ The received object
257+
258+ Raises:
259+ asyncio.TimeoutError: If timeout is reached without receiving data
260+ """
261+ self .setup_lazily ()
262+ self ._check_thread_safety ()
263+
264+ # Use polling loop instead of asyncio.wait_for to avoid cancelling recv
265+ # which can cause message drops
266+ deadline = asyncio .get_event_loop ().time () + timeout
267+ while True :
268+ try :
269+ # Try non-blocking receive
270+ if self .socket_type == zmq .ROUTER :
271+ identity , data = await self .socket .recv_multipart (
272+ flags = zmq .NOBLOCK )
273+ self ._last_identity = identity
274+ return self ._parse_data (data )
275+ else :
276+ if self .use_hmac_encryption :
277+ data = await self .socket .recv (flags = zmq .NOBLOCK )
278+ return self ._parse_data (data )
279+ else :
280+ return await self .socket .recv_pyobj (flags = zmq .NOBLOCK )
281+ except zmq .Again :
282+ # No message available yet
283+ if asyncio .get_event_loop ().time () >= deadline :
284+ raise asyncio .TimeoutError ()
285+ # Short sleep to avoid busy-waiting
286+ await asyncio .sleep (0.01 )
207287
208288 def close (self ):
209289 if self .socket :
@@ -311,6 +391,7 @@ def notify_with_retry(self, message, max_retries=5, timeout=1):
311391 raise ValueError (
312392 "notify_with_retry is only supported for DEALER socket for now" )
313393
394+ self ._check_thread_safety ()
314395 retry_count = 0
315396
316397 while retry_count < max_retries :
0 commit comments