Skip to content

Commit 48bd61e

Browse files
committed
wip: feat: new llm extractor pipe
1 parent d49f929 commit 48bd61e

File tree

9 files changed

+406
-159
lines changed

9 files changed

+406
-159
lines changed

edsnlp/core/stream.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -701,37 +701,41 @@ def map_pipeline(
701701
"""
702702
new_ops = []
703703
tokenizer = model.tokenizer
704-
has_batches = batch_by is not None
705704
for op in self.ops:
706705
# check if the pipe has a "tokenizer" kwarg and update the kwargs if needed
707706
op = copy(op)
708707
if (
709-
isinstance(op, MapOp)
710-
and "tokenizer" in signature(op.pipe).parameters
711-
and "tokenizer" not in op.kwargs
712-
):
713-
op.kwargs["tokenizer"] = tokenizer
714-
elif (
715-
isinstance(op, MapBatchesOp)
716-
and hasattr(op.pipe, "batch_process")
717-
and "tokenizer" in signature(op.pipe.batch_process).parameters
718-
and "tokenizer" not in op.kwargs
719-
) or (
720-
isinstance(op, MapBatchesOp)
721-
and callable(op.pipe)
722-
and "tokenizer" in signature(op.pipe).parameters
723-
and "tokenizer" not in op.kwargs
708+
(
709+
isinstance(op, MapOp)
710+
and "tokenizer" in signature(op.pipe).parameters
711+
and "tokenizer" not in op.kwargs
712+
)
713+
or (
714+
isinstance(op, MapBatchesOp)
715+
and hasattr(op.pipe, "batch_process")
716+
and "tokenizer" in signature(op.pipe.batch_process).parameters
717+
and "tokenizer" not in op.kwargs
718+
)
719+
or (
720+
isinstance(op, MapBatchesOp)
721+
and callable(op.pipe)
722+
and "tokenizer" in signature(op.pipe).parameters
723+
and "tokenizer" not in op.kwargs
724+
)
724725
):
725726
op.kwargs["tokenizer"] = tokenizer
726-
has_batches = True
727727
if isinstance(op, (MapOp, MapBatchesOp)):
728728
op.context["tokenizer"] = tokenizer
729729
new_ops.append(op)
730+
has_batches = batch_by is not None or (
731+
hasattr(p, "batch_process") for n, p in model.pipeline
732+
)
730733
new_ops.append(MapOp(model._ensure_doc, {}))
731734
if has_batches:
732735
batch_size, batch_by = self.validate_batching(batch_size, batch_by)
733736
batch_by = batchify_fns.get(batch_by, batch_by)
734737
new_ops.append(BatchifyOp(batch_size, batch_by))
738+
735739
for name, pipe in model.pipeline:
736740
if name not in model._disabled:
737741
op = (

0 commit comments

Comments
 (0)