Skip to content

Commit 57d8627

Browse files
committed
add possibility to pass labels to doc_classifier
1 parent 0880427 commit 57d8627

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

edsnlp/pipes/trainable/doc_classifier/doc_classifier.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,12 @@ def __init__(
4747
label2id: Optional[Dict[str, int]] = None,
4848
id2label: Optional[Dict[int, str]] = None,
4949
loss_fn=None,
50+
labels: Optional[Sequence[str]] = None,
5051
):
5152
self.label_attr: Attributes = label_attr
5253
self.label2id = label2id or {}
5354
self.id2label = id2label or {}
55+
self.labels = labels
5456
super().__init__(nlp, name)
5557
self.embedding = embedding
5658
self.loss_fn = loss_fn or torch.nn.CrossEntropyLoss()
@@ -70,11 +72,14 @@ def set_extensions(self) -> None:
7072

7173
def post_init(self, gold_data: Iterable[Doc], exclude: Set[str]):
7274
if not self.label2id:
73-
labels = set()
74-
for doc in gold_data:
75-
label = getattr(doc._, self.label_attr, None)
76-
if isinstance(label, str):
77-
labels.add(label)
75+
if self.labels is not None:
76+
labels = set(self.labels)
77+
else:
78+
labels = set()
79+
for doc in gold_data:
80+
label = getattr(doc._, self.label_attr, None)
81+
if isinstance(label, str):
82+
labels.add(label)
7883
if labels:
7984
self.label2id = {}
8085
self.id2label = {}

0 commit comments

Comments
 (0)