File tree Expand file tree Collapse file tree 1 file changed +10
-5
lines changed
edsnlp/pipes/trainable/doc_classifier Expand file tree Collapse file tree 1 file changed +10
-5
lines changed Original file line number Diff line number Diff line change @@ -47,10 +47,12 @@ def __init__(
4747 label2id : Optional [Dict [str , int ]] = None ,
4848 id2label : Optional [Dict [int , str ]] = None ,
4949 loss_fn = None ,
50+ labels : Optional [Sequence [str ]] = None ,
5051 ):
5152 self .label_attr : Attributes = label_attr
5253 self .label2id = label2id or {}
5354 self .id2label = id2label or {}
55+ self .labels = labels
5456 super ().__init__ (nlp , name )
5557 self .embedding = embedding
5658 self .loss_fn = loss_fn or torch .nn .CrossEntropyLoss ()
@@ -70,11 +72,14 @@ def set_extensions(self) -> None:
7072
7173 def post_init (self , gold_data : Iterable [Doc ], exclude : Set [str ]):
7274 if not self .label2id :
73- labels = set ()
74- for doc in gold_data :
75- label = getattr (doc ._ , self .label_attr , None )
76- if isinstance (label , str ):
77- labels .add (label )
75+ if self .labels is not None :
76+ labels = set (self .labels )
77+ else :
78+ labels = set ()
79+ for doc in gold_data :
80+ label = getattr (doc ._ , self .label_attr , None )
81+ if isinstance (label , str ):
82+ labels .add (label )
7883 if labels :
7984 self .label2id = {}
8085 self .id2label = {}
You can’t perform that action at this time.
0 commit comments