Skip to content

Commit 9a55bad

Browse files
committed
async worker and headers
1 parent 992642b commit 9a55bad

File tree

2 files changed

+60
-24
lines changed

2 files changed

+60
-24
lines changed

edsnlp/pipes/qualifiers/llm/llm_utils.py

Lines changed: 55 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,22 @@ def __init__(
8282
self.client = AsyncOpenAI(
8383
api_key=api_key,
8484
base_url=api_url,
85+
default_headers={"Connection": "close"},
8586
)
8687

88+
async def __aenter__(self):
89+
"""Async context manager entry."""
90+
return self
91+
92+
async def __aexit__(self, exc_type, exc_val, exc_tb):
93+
"""Async context manager exit - properly close the client."""
94+
await self.aclose()
95+
96+
async def aclose(self):
97+
"""Properly close the AsyncOpenAI client to prevent resource leaks."""
98+
if hasattr(self, "client") and self.client is not None:
99+
await self.client.close()
100+
87101
@property
88102
def lock(self):
89103
if self._lock is None:
@@ -170,25 +184,36 @@ async def async_worker(
170184
name: str,
171185
id_messages_tuples: AsyncIterator[Tuple[int, List[List[Dict[str, str]]]]],
172186
):
173-
""" """
174-
175-
async for (
176-
idx,
177-
message,
178-
) in id_messages_tuples:
179-
logger.info(idx)
180-
187+
while True:
181188
try:
189+
(
190+
idx,
191+
message,
192+
) = await anext(id_messages_tuples) # noqa: F821
182193
idx, response = await self.call_llm(idx, message)
183194

184195
logger.info(f"Worker {name} has finished process {idx}")
185-
except Exception as e:
186-
logger.error(f"[{name}] Exception raised on chunk {idx}\n{e}")
196+
except StopAsyncIteration:
197+
# Everything has been parsed!
198+
logger.info(
199+
f"[{name}] Received StopAsyncIteration, worker will shutdown"
200+
)
201+
break
202+
except TimeoutError as e:
203+
logger.error(f"[{name}] TimeoutError on chunk {idx}\n{e}")
204+
logger.error(f"Timeout was set to {self.timeout} seconds")
205+
if self.n_completions == 1:
206+
response = ""
207+
else:
208+
response = [""] * self.n_completions
209+
except BaseException as e:
210+
logger.error(
211+
f"[{name}] Exception raised on chunk {idx}\n{e}"
212+
) # type(e)
187213
if self.n_completions == 1:
188214
response = ""
189215
else:
190216
response = [""] * self.n_completions
191-
192217
async with self.lock:
193218
self.store_responses(
194219
idx,
@@ -225,21 +250,28 @@ async def __call__(self, batch_messages: List[List[Dict[str, str]]]):
225250
List of message batches to send to the LLM, where each batch is a list
226251
of dictionaries with keys 'role' and 'content'.
227252
"""
228-
# Shared prompt generator
229-
id_messages_tuples = self.async_id_message_generator(batch_messages)
253+
try:
254+
# Shared prompt generator
255+
id_messages_tuples = self.async_id_message_generator(batch_messages)
230256

231-
# n concurrent tasks
232-
tasks = {
233-
asyncio.create_task(self.async_worker(f"Worker-{i}", id_messages_tuples))
234-
for i in range(self.n_concurrent_tasks)
235-
}
257+
# n concurrent tasks
258+
tasks = {
259+
asyncio.create_task(
260+
self.async_worker(f"Worker-{i}", id_messages_tuples)
261+
)
262+
for i in range(self.n_concurrent_tasks)
263+
}
236264

237-
await asyncio.gather(*tasks)
238-
tasks.clear()
239-
predictions = self.sort_responses()
240-
self.clean_storage()
265+
await asyncio.gather(*tasks)
266+
tasks.clear()
267+
predictions = self.sort_responses()
268+
self.clean_storage()
241269

242-
return predictions
270+
return predictions
271+
except Exception:
272+
# Ensure cleanup even if an exception occurs
273+
await self.aclose()
274+
raise
243275

244276

245277
def create_prompt_messages(

tests/pipelines/qualifiers/test_llm_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,11 @@ def test_async_llm(n_concurrent_tasks):
3434
response = run_async(
3535
llm_api(
3636
batch_messages=[
37-
[{"role": "user", "content": "your prompt here"}],
37+
[
38+
{"role": "user", "content": "your prompt here"},
39+
{"role": "assistant", "content": "Hello!"},
40+
{"role": "user", "content": "your second prompt here"},
41+
],
3842
[{"role": "user", "content": "your second prompt here"}],
3943
]
4044
)

0 commit comments

Comments
 (0)