@@ -82,8 +82,22 @@ def __init__(
8282 self .client = AsyncOpenAI (
8383 api_key = api_key ,
8484 base_url = api_url ,
85+ default_headers = {"Connection" : "close" },
8586 )
8687
88+ async def __aenter__ (self ):
89+ """Async context manager entry."""
90+ return self
91+
92+ async def __aexit__ (self , exc_type , exc_val , exc_tb ):
93+ """Async context manager exit - properly close the client."""
94+ await self .aclose ()
95+
96+ async def aclose (self ):
97+ """Properly close the AsyncOpenAI client to prevent resource leaks."""
98+ if hasattr (self , "client" ) and self .client is not None :
99+ await self .client .close ()
100+
87101 @property
88102 def lock (self ):
89103 if self ._lock is None :
@@ -170,25 +184,36 @@ async def async_worker(
170184 name : str ,
171185 id_messages_tuples : AsyncIterator [Tuple [int , List [List [Dict [str , str ]]]]],
172186 ):
173- """ """
174-
175- async for (
176- idx ,
177- message ,
178- ) in id_messages_tuples :
179- logger .info (idx )
180-
187+ while True :
181188 try :
189+ (
190+ idx ,
191+ message ,
192+ ) = await anext (id_messages_tuples ) # noqa: F821
182193 idx , response = await self .call_llm (idx , message )
183194
184195 logger .info (f"Worker { name } has finished process { idx } " )
185- except Exception as e :
186- logger .error (f"[{ name } ] Exception raised on chunk { idx } \n { e } " )
196+ except StopAsyncIteration :
197+ # Everything has been parsed!
198+ logger .info (
199+ f"[{ name } ] Received StopAsyncIteration, worker will shutdown"
200+ )
201+ break
202+ except TimeoutError as e :
203+ logger .error (f"[{ name } ] TimeoutError on chunk { idx } \n { e } " )
204+ logger .error (f"Timeout was set to { self .timeout } seconds" )
205+ if self .n_completions == 1 :
206+ response = ""
207+ else :
208+ response = ["" ] * self .n_completions
209+ except BaseException as e :
210+ logger .error (
211+ f"[{ name } ] Exception raised on chunk { idx } \n { e } "
212+ ) # type(e)
187213 if self .n_completions == 1 :
188214 response = ""
189215 else :
190216 response = ["" ] * self .n_completions
191-
192217 async with self .lock :
193218 self .store_responses (
194219 idx ,
@@ -225,21 +250,28 @@ async def __call__(self, batch_messages: List[List[Dict[str, str]]]):
225250 List of message batches to send to the LLM, where each batch is a list
226251 of dictionaries with keys 'role' and 'content'.
227252 """
228- # Shared prompt generator
229- id_messages_tuples = self .async_id_message_generator (batch_messages )
253+ try :
254+ # Shared prompt generator
255+ id_messages_tuples = self .async_id_message_generator (batch_messages )
230256
231- # n concurrent tasks
232- tasks = {
233- asyncio .create_task (self .async_worker (f"Worker-{ i } " , id_messages_tuples ))
234- for i in range (self .n_concurrent_tasks )
235- }
257+ # n concurrent tasks
258+ tasks = {
259+ asyncio .create_task (
260+ self .async_worker (f"Worker-{ i } " , id_messages_tuples )
261+ )
262+ for i in range (self .n_concurrent_tasks )
263+ }
236264
237- await asyncio .gather (* tasks )
238- tasks .clear ()
239- predictions = self .sort_responses ()
240- self .clean_storage ()
265+ await asyncio .gather (* tasks )
266+ tasks .clear ()
267+ predictions = self .sort_responses ()
268+ self .clean_storage ()
241269
242- return predictions
270+ return predictions
271+ except Exception :
272+ # Ensure cleanup even if an exception occurs
273+ await self .aclose ()
274+ raise
243275
244276
245277def create_prompt_messages (
0 commit comments