Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Changelog

## Unreleased

### Added

- New `attention` pooling mode in `eds.span_pooler`
- New `word_pooling_mode=False` in `eds.transformer` to allow returning the worpiece embeddings directly, instead of the mean-pooled word embeddings. At the moment, this only works with `eds.span_pooler` which can pool over wordpieces or words seamlessly.

## v0.18.0 (2025-09-02)

📢 EDS-NLP will drop support for Python 3.7, 3.8 and 3.9 support in the next major release (v0.19.0), in October 2025. Please upgrade to Python 3.10 or later.
Expand All @@ -13,6 +20,7 @@
- New `eds.explode` pipe that splits one document into multiple documents, one per span yielded by its `span_getter` parameter, each new document containing exactly that single span.
- New `Training a span classifier` tutorial, and reorganized deep-learning docs
- `ScheduledOptimizer` now warns when a parameter selector does not match any parameter.
- New `attention` pooling mode in `eds.span_pooler`

### Fixed

Expand Down
6 changes: 3 additions & 3 deletions docs/tutorials/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ We provide step-by-step guides to get you started. We cover the following use-ca

### Base tutorials

<!-- --8<-- [start:tutorials] -->
<!-- --8<-- [start:classic-tutorials] -->

=== card {: href=/tutorials/spacy101 }
Expand Down Expand Up @@ -85,6 +86,8 @@ We provide step-by-step guides to get you started. We cover the following use-ca
---
Quickly visualize the results of your pipeline as annotations or tables.

<!-- --8<-- [end:classic-tutorials] -->

### Deep learning tutorials

We also provide tutorials on how to train deep-learning models with EDS-NLP. These tutorials cover the training API, hyperparameter tuning, and more.
Expand Down Expand Up @@ -123,8 +126,5 @@ We also provide tutorials on how to train deep-learning models with EDS-NLP. The
---
Learn how to tune hyperparameters of a model with `edsnlp.tune`.


<!-- --8<-- [end:deep-learning-tutorials] -->


<!-- --8<-- [end:tutorials] -->
40 changes: 34 additions & 6 deletions edsnlp/core/torch_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,14 @@ def compute_training_metrics(
This is useful to compute averages when doing multi-gpu training or mini-batch
accumulation since full denominators are not known during the forward pass.
"""
return batch_output
return (
{
**batch_output,
"loss": batch_output["loss"] / count,
}
if "loss" in batch_output
else batch_output
)

def module_forward(self, *args, **kwargs): # pragma: no cover
"""
Expand All @@ -348,6 +355,31 @@ def module_forward(self, *args, **kwargs): # pragma: no cover
"""
return torch.nn.Module.__call__(self, *args, **kwargs)

def preprocess_batch(self, docs: Sequence[Doc], supervision=False, **kwargs):
"""
Convenience method to preprocess a batch of documents.
Features corresponding to the same path are grouped together in a list,
under the same key.

Parameters
----------
docs: Sequence[Doc]
Batch of documents
supervision: bool
Whether to extract supervision features or not

Returns
-------
Dict[str, Sequence[Any]]
The batch of features
"""
batch = [
(self.preprocess_supervised(d) if supervision else self.preprocess(d))
for d in docs
]
batch = decompress_dict(list(batch_compress_dict(batch)))
return batch

def prepare_batch(
self,
docs: Sequence[Doc],
Expand All @@ -372,11 +404,7 @@ def prepare_batch(
-------
Dict[str, Sequence[Any]]
"""
batch = [
(self.preprocess_supervised(doc) if supervision else self.preprocess(doc))
for doc in docs
]
batch = decompress_dict(list(batch_compress_dict(batch)))
batch = self.preprocess_batch(docs, supervision=supervision)
batch = self.collate(batch)
batch = self.batch_to_device(batch, device=device)
return batch
Expand Down
190 changes: 190 additions & 0 deletions edsnlp/metrics/doc_classif.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

from spacy.tokens import Doc
from spacy.training import Example

from edsnlp import registry
from edsnlp.metrics import make_examples


def doc_classification_metric(
examples: Union[Tuple[Iterable[Doc], Iterable[Doc]], Iterable[Example]],
label_attr: List[str],
micro_key: str = "micro",
macro_key: str = "macro",
filter_expr: Optional[str] = None,
) -> Dict[str, Dict[str, Any]]:
"""
Scores multi-head document-level classification (accuracy, precision, recall, F1)
for each head.

Parameters
----------
examples: Examples
The examples to score, either a tuple of (golds, preds) or a list of
spacy.training.Example objects
label_attr: List[str]
The list of Doc._ attributes containing the labels for each head
micro_key: str
The key to use to store the micro-averaged results
macro_key: str
The key to use to store the macro-averaged results
filter_expr: str
The filter expression to use to filter the documents

Returns
-------
Dict[str, Dict[str, Any]]
Dictionary mapping head names to their respective metrics
"""
examples = make_examples(examples)
if filter_expr is not None:
filter_fn = eval(f"lambda doc: {filter_expr}")
examples = [eg for eg in examples if filter_fn(eg.reference)]

all_head_results = {}

for head_name in label_attr:
pred_labels = []
gold_labels = []

for eg in examples:
pred = getattr(eg.predicted._, head_name, None)
gold = getattr(eg.reference._, head_name, None)
pred_labels.append(pred)
gold_labels.append(gold)

labels = set(gold_labels) | set(pred_labels)
labels = {label for label in labels if label is not None}
head_results = {}

for label in labels:
tp = sum(
1 for p, g in zip(pred_labels, gold_labels) if p == label and g == label
)
fp = sum(
1 for p, g in zip(pred_labels, gold_labels) if p == label and g != label
)
fn = sum(
1 for p, g in zip(pred_labels, gold_labels) if g == label and p != label
)

precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
f1 = (
(2 * precision * recall) / (precision + recall)
if (precision + recall) > 0
else 0.0
)

head_results[label] = {
"f": f1,
"p": precision,
"r": recall,
"tp": tp,
"fp": fp,
"fn": fn,
"support": tp + fn,
"positives": tp + fp,
}

total_tp = sum(1 for p, g in zip(pred_labels, gold_labels) if p == g)
total_fp = sum(1 for p, g in zip(pred_labels, gold_labels) if p != g)
total_fn = total_fp

micro_precision = (
total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0.0
)
micro_recall = (
total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0.0
)
micro_f1 = (
(2 * micro_precision * micro_recall) / (micro_precision + micro_recall)
if (micro_precision + micro_recall) > 0
else 0.0
)
accuracy = total_tp / len(pred_labels) if len(pred_labels) > 0 else 0.0

head_results[micro_key] = {
"accuracy": accuracy,
"f": micro_f1,
"p": micro_precision,
"r": micro_recall,
"tp": total_tp,
"fp": total_fp,
"fn": total_fn,
"support": len(gold_labels),
"positives": len(pred_labels),
}

per_class_precisions = [head_results[label]["p"] for label in labels]
per_class_recalls = [head_results[label]["r"] for label in labels]
per_class_f1s = [head_results[label]["f"] for label in labels]

macro_precision = (
sum(per_class_precisions) / len(per_class_precisions)
if per_class_precisions
else 0.0
)
macro_recall = (
sum(per_class_recalls) / len(per_class_recalls)
if per_class_recalls
else 0.0
)
macro_f1 = sum(per_class_f1s) / len(per_class_f1s) if per_class_f1s else 0.0

head_results[macro_key] = {
"f": macro_f1,
"p": macro_precision,
"r": macro_recall,
"support": len(labels),
"classes": len(labels),
}

all_head_results[head_name] = head_results

return all_head_results


@registry.metrics.register("eds.doc_classification")
class DocClassificationMetric:
def __init__(
self,
label_attr: List[str],
micro_key: str = "micro",
macro_key: str = "macro",
filter_expr: Optional[str] = None,
):
"""
Multi-head document classification metric.

Parameters
----------
label_attr: List[str]
List of Doc._ attributes containing the labels for each head
micro_key: str
The key to use to store the micro-averaged results
macro_key: str
The key to use to store the macro-averaged results
filter_expr: str
The filter expression to use to filter the documents
"""
self.label_attr = label_attr
self.micro_key = micro_key
self.macro_key = macro_key
self.filter_expr = filter_expr

def __call__(self, *examples):
return doc_classification_metric(
examples,
label_attr=self.label_attr,
micro_key=self.micro_key,
macro_key=self.macro_key,
filter_expr=self.filter_expr,
)


__all__ = [
"doc_classification_metric",
"DocClassificationMetric",
]
2 changes: 2 additions & 0 deletions edsnlp/pipes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,5 +82,7 @@
from .trainable.embeddings.span_pooler.factory import create_component as span_pooler
from .trainable.embeddings.transformer.factory import create_component as transformer
from .trainable.embeddings.text_cnn.factory import create_component as text_cnn
from .trainable.embeddings.doc_pooler.factory import create_component as doc_pooler
from .trainable.doc_classifier.factory import create_component as doc_classifier
from .misc.split import Split as split
from .misc.explode import Explode as explode
1 change: 1 addition & 0 deletions edsnlp/pipes/trainable/doc_classifier/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .factory import create_component
Loading