@@ -68,6 +68,11 @@ def __init__(self,
6868
6969 self ._setup_lock = threading .Lock ()
7070
71+ # Thread safety debugging
72+ self ._zmq_thread_id = None
73+ self ._zmq_debug_enabled = os .environ .get ('TLLM_LLMAPI_ZMQ_DEBUG' ,
74+ '0' ) != '0'
75+
7176 # Check HMAC key condition
7277 if self .use_hmac_encryption and not self .is_server and self .hmac_key is None :
7378 raise ValueError (
@@ -114,12 +119,34 @@ def setup_lazily(self):
114119 self .poller = zmq .Poller ()
115120 self .poller .register (self .socket , zmq .POLLIN )
116121
122+ def _check_thread_safety (self ):
123+ """Check if the current thread is the same as the thread that first used the socket."""
124+ if not self ._zmq_debug_enabled :
125+ return
126+
127+ current_thread_id = threading .get_ident ()
128+
129+ if self ._zmq_thread_id is None :
130+ # First call - capture the thread ID
131+ self ._zmq_thread_id = current_thread_id
132+ logger_debug (
133+ f"ZMQ socket [{ self .name } ] initialized on thread { current_thread_id } " ,
134+ "cyan" )
135+ elif self ._zmq_thread_id != current_thread_id :
136+ # Thread mismatch - raise error
137+ raise RuntimeError (
138+ f"ZMQ thread safety violation detected in [{ self .name } ]: "
139+ f"Socket created on thread { self ._zmq_thread_id } , "
140+ f"but accessed from thread { current_thread_id } . "
141+ f"ZMQ sockets are not thread-safe!" )
142+
117143 def poll (self , timeout : int ) -> bool :
118144 """
119145 Parameters:
120146 timeout (int): Timeout in seconds
121147 """
122148 self .setup_lazily ()
149+ self ._check_thread_safety ()
123150
124151 events = dict (self .poller .poll (timeout = timeout * 1000 ))
125152 if self .socket in events and events [self .socket ] == zmq .POLLIN :
@@ -129,6 +156,7 @@ def poll(self, timeout: int) -> bool:
129156
130157 def put (self , obj : Any ):
131158 self .setup_lazily ()
159+ self ._check_thread_safety ()
132160 with nvtx_range_debug ("send" , color = "blue" , category = "IPC" ):
133161 if self .use_hmac_encryption or self .socket_type == zmq .ROUTER :
134162 # Need manual serialization for encryption or ROUTER multipart
@@ -156,6 +184,7 @@ def put_noblock(self,
156184 assert retry >= 0 and retry <= 10 , "Retry must be between 0 and 10, adjust the wait_time if needed"
157185
158186 self .setup_lazily ()
187+ self ._check_thread_safety ()
159188 with nvtx_range_debug ("send" , color = "blue" , category = "IPC" ):
160189
161190 data = self ._prepare_data (obj )
@@ -170,6 +199,7 @@ def put_noblock(self,
170199
171200 async def put_async (self , obj : Any ):
172201 self .setup_lazily ()
202+ self ._check_thread_safety ()
173203 try :
174204 if self .use_hmac_encryption or self .socket_type == zmq .ROUTER :
175205 # Need manual serialization for encryption or ROUTER multipart
@@ -190,6 +220,7 @@ async def put_async(self, obj: Any):
190220
191221 async def put_async_noblock (self , obj : Any ):
192222 self .setup_lazily ()
223+ self ._check_thread_safety ()
193224 try :
194225 if self .use_hmac_encryption :
195226 data = pickle .dumps (obj ) # nosec B301
@@ -204,13 +235,16 @@ async def put_async_noblock(self, obj: Any):
204235
205236 def get (self ) -> Any :
206237 self .setup_lazily ()
238+ self ._check_thread_safety ()
207239 return self ._recv_data ()
208240
209241 async def get_async (self ) -> Any :
210242 self .setup_lazily ()
243+ self ._check_thread_safety ()
211244 return await self ._recv_data_async ()
212245
213246 async def get_async_noblock (self , timeout : float = 0.5 ) -> Any :
247+ self ._check_thread_safety ()
214248 return await asyncio .wait_for (self .get_async (), timeout )
215249
216250 def close (self ):
@@ -319,6 +353,7 @@ def notify_with_retry(self, message, max_retries=5, timeout=1):
319353 raise ValueError (
320354 "notify_with_retry is only supported for DEALER socket for now" )
321355
356+ self ._check_thread_safety ()
322357 retry_count = 0
323358
324359 while retry_count < max_retries :
0 commit comments