Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
# Changelog

## Unreleased

### Fixed

- Don't pass seed to openai API calls (only as extra body)
- Default to alignment threshold = 0 (better recall) for LLM annotated markup alignment with the original text
- Fix `eds.llm_markup_extractor` context splitting to yield full docs and not parts of docs

## v0.19.0 (2025-10-04)

📢 EDS-NLP will drop support for Python 3.7, 3.8 and 3.9 support in the next major release (v0.20.0), in October 2025. Please upgrade to Python 3.10 or later.
Expand Down
106 changes: 63 additions & 43 deletions edsnlp/pipes/llm/llm_markup_extractor/llm_markup_extractor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import warnings
from collections import deque
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -196,6 +197,8 @@
The markup format to use when formatting the few-shot examples and
parsing the model's output. Either "xml" (default) or "md" (Markdown).
Make sure the prompt template matches the chosen format.
alignment_threshold : float
The threshold used to align the model's output with the original text.
prompt : Union[str, Callable[[str, List[Tuple[str, str]]], List[Dict[str, str]]]]
The prompt is the main way to control the model's behavior.
It can be either:
Expand Down Expand Up @@ -262,6 +265,7 @@
str, Callable[[str, List[Tuple[str, str]]], List[Dict[str, str]]]
],
markup_mode: Literal["xml", "md"] = "xml",
alignment_threshold: float = 0.0,
examples: Iterable[Doc] = (),
max_few_shot_examples: int = -1,
use_retriever: Optional[bool] = None,
Expand Down Expand Up @@ -301,7 +305,9 @@
self.api_kwargs = api_kwargs or {}
self.max_concurrent_requests = max_concurrent_requests
self.on_error = on_error
self.seed = seed
self.alignment_threshold = alignment_threshold
if seed is not None:
api_kwargs["seed"] = seed
self.retriever = None
if self.max_few_shot_examples > 0 and use_retriever is not False:
self.build_few_shot_retriever_(self.examples)
Expand Down Expand Up @@ -335,6 +341,7 @@
aligned = align(
{"text": res_text, "entities": ents},
{"text": stripped_text, "entities": []},
threshold=self.alignment_threshold,
)
res_ents = [
(f["begin"], f["end"], e["label"], e["attributes"])
Expand Down Expand Up @@ -410,8 +417,8 @@
def pipe(self, docs: Iterable[Doc]) -> Iterable[Doc]:
"""
Extract entities concurrently, but yield results in the same order
as the input `docs`. Up to `max_concurrent_requests` documents are
processed in parallel.
as the input `docs`. Up to `max_concurrent_requests` span-level
requests are processed in parallel.

Parameters
----------
Expand All @@ -424,48 +431,70 @@
Processed documents in the original input order.
"""
if self.max_concurrent_requests <= 1: # pragma: no cover
for ctx in docs:
yield self.process(ctx)
for doc in docs:
yield self.process(doc)
return

worker = AsyncRequestWorker.instance()

# Documents sent to the worker, waiting for results
# Documents that are currently being processed, keyed by their
# index in the input stream.
pending_docs: Dict[int, Doc] = {}
# Documents already processed, waiting to be yielded in order
# Number of remaining contexts to process for each document.
remaining_ctx_counts: Dict[int, int] = {}
# Fully processed documents waiting to be yielded in order.
buffer: Dict[int, Doc] = {}
next_to_yield = 0
in_flight: Dict[int, int] = {}

ctx_iter = enumerate(
ctx for doc in docs for ctx in get_spans(doc, self.context_getter)
)
# In-flight LLM requests: task_id -> (doc_index, context)
in_flight: Dict[int, Tuple[int, Any]] = {}

for _ in range(self.max_concurrent_requests):
try:
i, ctx = next(ctx_iter)
except StopIteration:
break
messages = self.build_prompt(ctx)
task_id = worker.submit(self._llm_request_coro(messages))
in_flight[task_id] = i
pending_docs[i] = ctx
docs_iter = enumerate(docs)
ctx_queue: "deque[Tuple[int, Any]]" = deque()

def enqueue_new_docs() -> None:
# Fill the context queue up to `max_concurrent_requests`
nonlocal docs_iter
while len(ctx_queue) < self.max_concurrent_requests:
try:
doc_idx, doc = next(docs_iter)
except StopIteration:
break

pending_docs[doc_idx] = doc
contexts = list(get_spans(doc, self.context_getter))

if not contexts:
remaining_ctx_counts[doc_idx] = 0
buffer[doc_idx] = doc
else:
remaining_ctx_counts[doc_idx] = len(contexts)
for ctx in contexts:
ctx_queue.append((doc_idx, ctx))

def submit_until_full() -> None:
while len(in_flight) < self.max_concurrent_requests and ctx_queue:
doc_idx, ctx = ctx_queue.popleft()
messages = self.build_prompt(ctx)
task_id = worker.submit(self._llm_request_coro(messages))
in_flight[task_id] = (doc_idx, ctx)

enqueue_new_docs()
submit_until_full()

while in_flight:
done_task_id = worker.wait_for_any(in_flight.keys())
result = worker.pop_result(done_task_id)
i = in_flight.pop(done_task_id)
ctx = pending_docs.pop(i)
doc_idx, ctx = in_flight.pop(done_task_id)

if result is None:
buffer[i] = ctx
pass

Check warning on line 491 in edsnlp/pipes/llm/llm_markup_extractor/llm_markup_extractor.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Either remove or fill this block of code.

See more on https://sonarcloud.io/project/issues?id=aphp_edsnlp&issues=AZqmBwdDoMXXhAYP5efe&open=AZqmBwdDoMXXhAYP5efe&pullRequest=461
else:
res, err = result
if err is not None:
self._handle_err(
f"[llm_markup_extractor] failed for document #{i}: {err!r}"
f"[llm_markup_extractor] failed for doc #{doc_idx}: {err!r}"
)
buffer[i] = ctx
else:
try:
self.apply_markup_to_doc_(ctx, str(res))
Expand All @@ -474,23 +503,16 @@

traceback.print_exc()
self._handle_err(
f"[llm_markup_extractor] "
f"failed to parse result for document #{i}: {e!r} in "
f"{res!r}"
f"[llm_markup_extractor] failed to parse result for doc "
f"#{doc_idx}: {e!r} in {res!r}"
)
buffer[i] = ctx

while True:
try:
if len(in_flight) >= self.max_concurrent_requests:
break
i2, d2 = next(ctx_iter)
except StopIteration:
break
messages2 = self.build_prompt(d2)
task_id2 = worker.submit(self._llm_request_coro(messages2))
in_flight[task_id2] = i2
pending_docs[i2] = d2
remaining_ctx_counts[doc_idx] -= 1
if remaining_ctx_counts[doc_idx] == 0:
buffer[doc_idx] = pending_docs.pop(doc_idx)

enqueue_new_docs()
submit_until_full()

while next_to_yield in buffer:
yield buffer.pop(next_to_yield)
Expand All @@ -504,17 +526,15 @@
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
seed=self.seed,
**self.api_kwargs,
)
return response.choices[0].message.content
return str(response.choices[0].message.content)

def _llm_request_coro(self, messages) -> Coroutine[Any, Any, str]:
async def _coro():
response = await self.async_client.chat.completions.create(
model=self.model,
messages=messages,
seed=self.seed,
**self.api_kwargs,
)
return response.choices[0].message.content
Expand Down
46 changes: 46 additions & 0 deletions tests/pipelines/llm/test_llm_markup_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,49 @@ def responder(**kw):
docs = docs.to_iterable(converter="markup", preset="md")
docs = list(docs)
assert docs == md


def test_context_getter_async():
nlp = edsnlp.blank("eds")
nlp.add_pipe("eds.normalizer")
nlp.add_pipe("eds.sentences")
nlp.add_pipe(
eds.llm_markup_extractor(
api_url="http://localhost:8080/v1",
model="my-custom-model",
prompt=PROMPT,
max_concurrent_requests=2,
context_getter="sents",
)
)

md = [
"La patient souffre de [tuberculose](diagnosis). On débute une "
"[antibiothérapie](treatment) dès ajd.",
"Il a une [pneumonie](diagnosis) du thorax. C'est très grave.",
]

counter = 0

def responder(messages, **kw):
nonlocal counter
counter += 1
assert len(messages) == 2 # 1 system + 1 user
res = (
messages[-1]["content"]
.replace("tuberculose", "<diagnosis>tuberculose</diagnosis>")
.replace("antibiothérapie", "<treatment>antibiothérapie</treatment>")
.replace("pneumonie", "<diagnosis>pneumonie</diagnosis>")
.replace("grave", "grave</diagnosis>")
)
return res

with mock_llm_service(responder=responder):
docs = edsnlp.data.from_iterable(md, converter="markup", preset="md")
docs = docs.map(lambda x: x.text)
docs = docs.map_pipeline(nlp)
docs = docs.to_iterable(converter="markup", preset="md")
docs = list(docs)
assert docs == md

assert counter == 4
Loading