diff --git a/changelog.md b/changelog.md index 241515e97..169176e20 100644 --- a/changelog.md +++ b/changelog.md @@ -1,5 +1,13 @@ # Changelog +## Unreleased + +### Fixed + +- Don't pass seed to openai API calls (only as extra body) +- Default to alignment threshold = 0 (better recall) for LLM annotated markup alignment with the original text +- Fix `eds.llm_markup_extractor` context splitting to yield full docs and not parts of docs + ## v0.19.0 (2025-10-04) 📢 EDS-NLP will drop support for Python 3.7, 3.8 and 3.9 support in the next major release (v0.20.0), in October 2025. Please upgrade to Python 3.10 or later. diff --git a/edsnlp/pipes/llm/llm_markup_extractor/llm_markup_extractor.py b/edsnlp/pipes/llm/llm_markup_extractor/llm_markup_extractor.py index 27597e56b..77bec520f 100644 --- a/edsnlp/pipes/llm/llm_markup_extractor/llm_markup_extractor.py +++ b/edsnlp/pipes/llm/llm_markup_extractor/llm_markup_extractor.py @@ -1,5 +1,6 @@ import os import warnings +from collections import deque from typing import ( Any, Callable, @@ -196,6 +197,8 @@ def prompt(doc_text, examples): The markup format to use when formatting the few-shot examples and parsing the model's output. Either "xml" (default) or "md" (Markdown). Make sure the prompt template matches the chosen format. + alignment_threshold : float + The threshold used to align the model's output with the original text. prompt : Union[str, Callable[[str, List[Tuple[str, str]]], List[Dict[str, str]]]] The prompt is the main way to control the model's behavior. It can be either: @@ -262,6 +265,7 @@ def __init__( str, Callable[[str, List[Tuple[str, str]]], List[Dict[str, str]]] ], markup_mode: Literal["xml", "md"] = "xml", + alignment_threshold: float = 0.0, examples: Iterable[Doc] = (), max_few_shot_examples: int = -1, use_retriever: Optional[bool] = None, @@ -301,7 +305,9 @@ def __init__( self.api_kwargs = api_kwargs or {} self.max_concurrent_requests = max_concurrent_requests self.on_error = on_error - self.seed = seed + self.alignment_threshold = alignment_threshold + if seed is not None: + api_kwargs["seed"] = seed self.retriever = None if self.max_few_shot_examples > 0 and use_retriever is not False: self.build_few_shot_retriever_(self.examples) @@ -335,6 +341,7 @@ def apply_markup_to_doc_(self, doclike: Any, markup_answer: str): aligned = align( {"text": res_text, "entities": ents}, {"text": stripped_text, "entities": []}, + threshold=self.alignment_threshold, ) res_ents = [ (f["begin"], f["end"], e["label"], e["attributes"]) @@ -410,8 +417,8 @@ def process(self, doc): def pipe(self, docs: Iterable[Doc]) -> Iterable[Doc]: """ Extract entities concurrently, but yield results in the same order - as the input `docs`. Up to `max_concurrent_requests` documents are - processed in parallel. + as the input `docs`. Up to `max_concurrent_requests` span-level + requests are processed in parallel. Parameters ---------- @@ -424,48 +431,70 @@ def pipe(self, docs: Iterable[Doc]) -> Iterable[Doc]: Processed documents in the original input order. """ if self.max_concurrent_requests <= 1: # pragma: no cover - for ctx in docs: - yield self.process(ctx) + for doc in docs: + yield self.process(doc) return worker = AsyncRequestWorker.instance() - # Documents sent to the worker, waiting for results + # Documents that are currently being processed, keyed by their + # index in the input stream. pending_docs: Dict[int, Doc] = {} - # Documents already processed, waiting to be yielded in order + # Number of remaining contexts to process for each document. + remaining_ctx_counts: Dict[int, int] = {} + # Fully processed documents waiting to be yielded in order. buffer: Dict[int, Doc] = {} next_to_yield = 0 - in_flight: Dict[int, int] = {} - ctx_iter = enumerate( - ctx for doc in docs for ctx in get_spans(doc, self.context_getter) - ) + # In-flight LLM requests: task_id -> (doc_index, context) + in_flight: Dict[int, Tuple[int, Any]] = {} - for _ in range(self.max_concurrent_requests): - try: - i, ctx = next(ctx_iter) - except StopIteration: - break - messages = self.build_prompt(ctx) - task_id = worker.submit(self._llm_request_coro(messages)) - in_flight[task_id] = i - pending_docs[i] = ctx + docs_iter = enumerate(docs) + ctx_queue: "deque[Tuple[int, Any]]" = deque() + + def enqueue_new_docs() -> None: + # Fill the context queue up to `max_concurrent_requests` + nonlocal docs_iter + while len(ctx_queue) < self.max_concurrent_requests: + try: + doc_idx, doc = next(docs_iter) + except StopIteration: + break + + pending_docs[doc_idx] = doc + contexts = list(get_spans(doc, self.context_getter)) + + if not contexts: + remaining_ctx_counts[doc_idx] = 0 + buffer[doc_idx] = doc + else: + remaining_ctx_counts[doc_idx] = len(contexts) + for ctx in contexts: + ctx_queue.append((doc_idx, ctx)) + + def submit_until_full() -> None: + while len(in_flight) < self.max_concurrent_requests and ctx_queue: + doc_idx, ctx = ctx_queue.popleft() + messages = self.build_prompt(ctx) + task_id = worker.submit(self._llm_request_coro(messages)) + in_flight[task_id] = (doc_idx, ctx) + + enqueue_new_docs() + submit_until_full() while in_flight: done_task_id = worker.wait_for_any(in_flight.keys()) result = worker.pop_result(done_task_id) - i = in_flight.pop(done_task_id) - ctx = pending_docs.pop(i) + doc_idx, ctx = in_flight.pop(done_task_id) if result is None: - buffer[i] = ctx + pass else: res, err = result if err is not None: self._handle_err( - f"[llm_markup_extractor] failed for document #{i}: {err!r}" + f"[llm_markup_extractor] failed for doc #{doc_idx}: {err!r}" ) - buffer[i] = ctx else: try: self.apply_markup_to_doc_(ctx, str(res)) @@ -474,23 +503,16 @@ def pipe(self, docs: Iterable[Doc]) -> Iterable[Doc]: traceback.print_exc() self._handle_err( - f"[llm_markup_extractor] " - f"failed to parse result for document #{i}: {e!r} in " - f"{res!r}" + f"[llm_markup_extractor] failed to parse result for doc " + f"#{doc_idx}: {e!r} in {res!r}" ) - buffer[i] = ctx - while True: - try: - if len(in_flight) >= self.max_concurrent_requests: - break - i2, d2 = next(ctx_iter) - except StopIteration: - break - messages2 = self.build_prompt(d2) - task_id2 = worker.submit(self._llm_request_coro(messages2)) - in_flight[task_id2] = i2 - pending_docs[i2] = d2 + remaining_ctx_counts[doc_idx] -= 1 + if remaining_ctx_counts[doc_idx] == 0: + buffer[doc_idx] = pending_docs.pop(doc_idx) + + enqueue_new_docs() + submit_until_full() while next_to_yield in buffer: yield buffer.pop(next_to_yield) @@ -504,17 +526,15 @@ def _llm_request_sync(self, messages) -> str: response = self.client.chat.completions.create( model=self.model, messages=messages, - seed=self.seed, **self.api_kwargs, ) - return response.choices[0].message.content + return str(response.choices[0].message.content) def _llm_request_coro(self, messages) -> Coroutine[Any, Any, str]: async def _coro(): response = await self.async_client.chat.completions.create( model=self.model, messages=messages, - seed=self.seed, **self.api_kwargs, ) return response.choices[0].message.content diff --git a/tests/pipelines/llm/test_llm_markup_extractor.py b/tests/pipelines/llm/test_llm_markup_extractor.py index 0c430e218..8d6636854 100644 --- a/tests/pipelines/llm/test_llm_markup_extractor.py +++ b/tests/pipelines/llm/test_llm_markup_extractor.py @@ -145,3 +145,49 @@ def responder(**kw): docs = docs.to_iterable(converter="markup", preset="md") docs = list(docs) assert docs == md + + +def test_context_getter_async(): + nlp = edsnlp.blank("eds") + nlp.add_pipe("eds.normalizer") + nlp.add_pipe("eds.sentences") + nlp.add_pipe( + eds.llm_markup_extractor( + api_url="http://localhost:8080/v1", + model="my-custom-model", + prompt=PROMPT, + max_concurrent_requests=2, + context_getter="sents", + ) + ) + + md = [ + "La patient souffre de [tuberculose](diagnosis). On débute une " + "[antibiothérapie](treatment) dès ajd.", + "Il a une [pneumonie](diagnosis) du thorax. C'est très grave.", + ] + + counter = 0 + + def responder(messages, **kw): + nonlocal counter + counter += 1 + assert len(messages) == 2 # 1 system + 1 user + res = ( + messages[-1]["content"] + .replace("tuberculose", "tuberculose") + .replace("antibiothérapie", "antibiothérapie") + .replace("pneumonie", "pneumonie") + .replace("grave", "grave") + ) + return res + + with mock_llm_service(responder=responder): + docs = edsnlp.data.from_iterable(md, converter="markup", preset="md") + docs = docs.map(lambda x: x.text) + docs = docs.map_pipeline(nlp) + docs = docs.to_iterable(converter="markup", preset="md") + docs = list(docs) + assert docs == md + + assert counter == 4