diff --git a/changelog.md b/changelog.md
index 241515e97..169176e20 100644
--- a/changelog.md
+++ b/changelog.md
@@ -1,5 +1,13 @@
# Changelog
+## Unreleased
+
+### Fixed
+
+- Don't pass seed to openai API calls (only as extra body)
+- Default to alignment threshold = 0 (better recall) for LLM annotated markup alignment with the original text
+- Fix `eds.llm_markup_extractor` context splitting to yield full docs and not parts of docs
+
## v0.19.0 (2025-10-04)
📢 EDS-NLP will drop support for Python 3.7, 3.8 and 3.9 support in the next major release (v0.20.0), in October 2025. Please upgrade to Python 3.10 or later.
diff --git a/edsnlp/pipes/llm/llm_markup_extractor/llm_markup_extractor.py b/edsnlp/pipes/llm/llm_markup_extractor/llm_markup_extractor.py
index 27597e56b..77bec520f 100644
--- a/edsnlp/pipes/llm/llm_markup_extractor/llm_markup_extractor.py
+++ b/edsnlp/pipes/llm/llm_markup_extractor/llm_markup_extractor.py
@@ -1,5 +1,6 @@
import os
import warnings
+from collections import deque
from typing import (
Any,
Callable,
@@ -196,6 +197,8 @@ def prompt(doc_text, examples):
The markup format to use when formatting the few-shot examples and
parsing the model's output. Either "xml" (default) or "md" (Markdown).
Make sure the prompt template matches the chosen format.
+ alignment_threshold : float
+ The threshold used to align the model's output with the original text.
prompt : Union[str, Callable[[str, List[Tuple[str, str]]], List[Dict[str, str]]]]
The prompt is the main way to control the model's behavior.
It can be either:
@@ -262,6 +265,7 @@ def __init__(
str, Callable[[str, List[Tuple[str, str]]], List[Dict[str, str]]]
],
markup_mode: Literal["xml", "md"] = "xml",
+ alignment_threshold: float = 0.0,
examples: Iterable[Doc] = (),
max_few_shot_examples: int = -1,
use_retriever: Optional[bool] = None,
@@ -301,7 +305,9 @@ def __init__(
self.api_kwargs = api_kwargs or {}
self.max_concurrent_requests = max_concurrent_requests
self.on_error = on_error
- self.seed = seed
+ self.alignment_threshold = alignment_threshold
+ if seed is not None:
+ api_kwargs["seed"] = seed
self.retriever = None
if self.max_few_shot_examples > 0 and use_retriever is not False:
self.build_few_shot_retriever_(self.examples)
@@ -335,6 +341,7 @@ def apply_markup_to_doc_(self, doclike: Any, markup_answer: str):
aligned = align(
{"text": res_text, "entities": ents},
{"text": stripped_text, "entities": []},
+ threshold=self.alignment_threshold,
)
res_ents = [
(f["begin"], f["end"], e["label"], e["attributes"])
@@ -410,8 +417,8 @@ def process(self, doc):
def pipe(self, docs: Iterable[Doc]) -> Iterable[Doc]:
"""
Extract entities concurrently, but yield results in the same order
- as the input `docs`. Up to `max_concurrent_requests` documents are
- processed in parallel.
+ as the input `docs`. Up to `max_concurrent_requests` span-level
+ requests are processed in parallel.
Parameters
----------
@@ -424,48 +431,70 @@ def pipe(self, docs: Iterable[Doc]) -> Iterable[Doc]:
Processed documents in the original input order.
"""
if self.max_concurrent_requests <= 1: # pragma: no cover
- for ctx in docs:
- yield self.process(ctx)
+ for doc in docs:
+ yield self.process(doc)
return
worker = AsyncRequestWorker.instance()
- # Documents sent to the worker, waiting for results
+ # Documents that are currently being processed, keyed by their
+ # index in the input stream.
pending_docs: Dict[int, Doc] = {}
- # Documents already processed, waiting to be yielded in order
+ # Number of remaining contexts to process for each document.
+ remaining_ctx_counts: Dict[int, int] = {}
+ # Fully processed documents waiting to be yielded in order.
buffer: Dict[int, Doc] = {}
next_to_yield = 0
- in_flight: Dict[int, int] = {}
- ctx_iter = enumerate(
- ctx for doc in docs for ctx in get_spans(doc, self.context_getter)
- )
+ # In-flight LLM requests: task_id -> (doc_index, context)
+ in_flight: Dict[int, Tuple[int, Any]] = {}
- for _ in range(self.max_concurrent_requests):
- try:
- i, ctx = next(ctx_iter)
- except StopIteration:
- break
- messages = self.build_prompt(ctx)
- task_id = worker.submit(self._llm_request_coro(messages))
- in_flight[task_id] = i
- pending_docs[i] = ctx
+ docs_iter = enumerate(docs)
+ ctx_queue: "deque[Tuple[int, Any]]" = deque()
+
+ def enqueue_new_docs() -> None:
+ # Fill the context queue up to `max_concurrent_requests`
+ nonlocal docs_iter
+ while len(ctx_queue) < self.max_concurrent_requests:
+ try:
+ doc_idx, doc = next(docs_iter)
+ except StopIteration:
+ break
+
+ pending_docs[doc_idx] = doc
+ contexts = list(get_spans(doc, self.context_getter))
+
+ if not contexts:
+ remaining_ctx_counts[doc_idx] = 0
+ buffer[doc_idx] = doc
+ else:
+ remaining_ctx_counts[doc_idx] = len(contexts)
+ for ctx in contexts:
+ ctx_queue.append((doc_idx, ctx))
+
+ def submit_until_full() -> None:
+ while len(in_flight) < self.max_concurrent_requests and ctx_queue:
+ doc_idx, ctx = ctx_queue.popleft()
+ messages = self.build_prompt(ctx)
+ task_id = worker.submit(self._llm_request_coro(messages))
+ in_flight[task_id] = (doc_idx, ctx)
+
+ enqueue_new_docs()
+ submit_until_full()
while in_flight:
done_task_id = worker.wait_for_any(in_flight.keys())
result = worker.pop_result(done_task_id)
- i = in_flight.pop(done_task_id)
- ctx = pending_docs.pop(i)
+ doc_idx, ctx = in_flight.pop(done_task_id)
if result is None:
- buffer[i] = ctx
+ pass
else:
res, err = result
if err is not None:
self._handle_err(
- f"[llm_markup_extractor] failed for document #{i}: {err!r}"
+ f"[llm_markup_extractor] failed for doc #{doc_idx}: {err!r}"
)
- buffer[i] = ctx
else:
try:
self.apply_markup_to_doc_(ctx, str(res))
@@ -474,23 +503,16 @@ def pipe(self, docs: Iterable[Doc]) -> Iterable[Doc]:
traceback.print_exc()
self._handle_err(
- f"[llm_markup_extractor] "
- f"failed to parse result for document #{i}: {e!r} in "
- f"{res!r}"
+ f"[llm_markup_extractor] failed to parse result for doc "
+ f"#{doc_idx}: {e!r} in {res!r}"
)
- buffer[i] = ctx
- while True:
- try:
- if len(in_flight) >= self.max_concurrent_requests:
- break
- i2, d2 = next(ctx_iter)
- except StopIteration:
- break
- messages2 = self.build_prompt(d2)
- task_id2 = worker.submit(self._llm_request_coro(messages2))
- in_flight[task_id2] = i2
- pending_docs[i2] = d2
+ remaining_ctx_counts[doc_idx] -= 1
+ if remaining_ctx_counts[doc_idx] == 0:
+ buffer[doc_idx] = pending_docs.pop(doc_idx)
+
+ enqueue_new_docs()
+ submit_until_full()
while next_to_yield in buffer:
yield buffer.pop(next_to_yield)
@@ -504,17 +526,15 @@ def _llm_request_sync(self, messages) -> str:
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
- seed=self.seed,
**self.api_kwargs,
)
- return response.choices[0].message.content
+ return str(response.choices[0].message.content)
def _llm_request_coro(self, messages) -> Coroutine[Any, Any, str]:
async def _coro():
response = await self.async_client.chat.completions.create(
model=self.model,
messages=messages,
- seed=self.seed,
**self.api_kwargs,
)
return response.choices[0].message.content
diff --git a/tests/pipelines/llm/test_llm_markup_extractor.py b/tests/pipelines/llm/test_llm_markup_extractor.py
index 0c430e218..8d6636854 100644
--- a/tests/pipelines/llm/test_llm_markup_extractor.py
+++ b/tests/pipelines/llm/test_llm_markup_extractor.py
@@ -145,3 +145,49 @@ def responder(**kw):
docs = docs.to_iterable(converter="markup", preset="md")
docs = list(docs)
assert docs == md
+
+
+def test_context_getter_async():
+ nlp = edsnlp.blank("eds")
+ nlp.add_pipe("eds.normalizer")
+ nlp.add_pipe("eds.sentences")
+ nlp.add_pipe(
+ eds.llm_markup_extractor(
+ api_url="http://localhost:8080/v1",
+ model="my-custom-model",
+ prompt=PROMPT,
+ max_concurrent_requests=2,
+ context_getter="sents",
+ )
+ )
+
+ md = [
+ "La patient souffre de [tuberculose](diagnosis). On débute une "
+ "[antibiothérapie](treatment) dès ajd.",
+ "Il a une [pneumonie](diagnosis) du thorax. C'est très grave.",
+ ]
+
+ counter = 0
+
+ def responder(messages, **kw):
+ nonlocal counter
+ counter += 1
+ assert len(messages) == 2 # 1 system + 1 user
+ res = (
+ messages[-1]["content"]
+ .replace("tuberculose", "tuberculose")
+ .replace("antibiothérapie", "antibiothérapie")
+ .replace("pneumonie", "pneumonie")
+ .replace("grave", "grave")
+ )
+ return res
+
+ with mock_llm_service(responder=responder):
+ docs = edsnlp.data.from_iterable(md, converter="markup", preset="md")
+ docs = docs.map(lambda x: x.text)
+ docs = docs.map_pipeline(nlp)
+ docs = docs.to_iterable(converter="markup", preset="md")
+ docs = list(docs)
+ assert docs == md
+
+ assert counter == 4