77from spacy .tokens import Doc
88from typing_extensions import Literal , NotRequired , TypedDict
99
10+ import edsnlp
1011from edsnlp .core .pipeline import PipelineProtocol
1112from edsnlp .core .torch_component import BatchInput , TorchComponent
1213from edsnlp .pipes .base import BaseComponent
3334)
3435
3536
37+ @edsnlp .registry .misc .register ("focal_loss" )
38+ class FocalLoss (nn .Module ):
39+ """
40+ Focal Loss implementation for multi-class classification.
41+
42+ Parameters
43+ ----------
44+ alpha : torch.Tensor or float, optional
45+ Class weights. If None, no weighting is applied
46+ gamma : float, default=2.0
47+ Focusing parameter. Higher values give more weight to hard examples
48+ reduction : str, default='mean'
49+ Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'
50+ """
51+
52+ def __init__ (
53+ self ,
54+ alpha : Optional [Union [torch .Tensor , float ]] = None ,
55+ gamma : float = 2.0 ,
56+ reduction : str = "mean" ,
57+ ):
58+ super ().__init__ ()
59+ self .alpha = alpha
60+ self .gamma = gamma
61+ self .reduction = reduction
62+
63+ def forward (self , inputs : torch .Tensor , targets : torch .Tensor ) -> torch .Tensor :
64+ """
65+ Forward pass
66+ """
67+ ce_loss = torch .nn .functional .cross_entropy (
68+ inputs , targets , weight = self .alpha , reduction = "none"
69+ )
70+
71+ pt = torch .exp (- ce_loss )
72+
73+ focal_loss = (1 - pt ) ** self .gamma * ce_loss
74+
75+ if self .reduction == "mean" :
76+ return focal_loss .mean ()
77+ elif self .reduction == "sum" :
78+ return focal_loss .sum ()
79+ else :
80+ return focal_loss
81+
82+
3683class TrainableDocClassifier (
3784 TorchComponent [DocClassifierBatchOutput , DocClassifierBatchInput ],
3885 BaseComponent ,
@@ -49,9 +96,9 @@ def __init__(
4996 label_attr : str = "label" ,
5097 label2id : Optional [Dict [str , int ]] = None ,
5198 id2label : Optional [Dict [int , str ]] = None ,
52- loss_fn = None ,
99+ loss : Literal [ "ce" , "focal" ] = "ce" ,
53100 labels : Optional [Sequence [str ]] = None ,
54- class_weights : Optional [Union [ Dict [str , float ], str ]] = None ,
101+ class_weights : Optional [Dict [str , float ]] = None ,
55102 hidden_size : Optional [int ] = None ,
56103 activation_mode : Literal ["relu" , "gelu" , "silu" ] = "relu" ,
57104 dropout_rate : Optional [float ] = 0.0 ,
@@ -71,8 +118,7 @@ def __init__(
71118 super ().__init__ (nlp , name )
72119 self .embedding = embedding
73120
74- self ._loss_fn = loss_fn
75- self .loss_fn = None
121+ self .loss = loss
76122
77123 if not hasattr (self .embedding , "output_size" ):
78124 raise ValueError (
@@ -112,17 +158,13 @@ def _compute_class_weights(self, freq_dict: Dict[str, int]) -> torch.Tensor:
112158
113159 return weights
114160
115- def _load_class_weights_from_file (self , filepath : str ) -> Dict [str , int ]:
116- """Load class weights from pickle file."""
117- with open (filepath , "rb" ) as f :
118- return pickle .load (f )
119-
120161 def set_extensions (self ) -> None :
121162 super ().set_extensions ()
122163 if not Doc .has_extension (self .label_attr ):
123164 Doc .set_extension (self .label_attr , default = {})
124165
125166 def post_init (self , gold_data : Iterable [Doc ], exclude : Set [str ]):
167+ print ("post_init" )
126168 if not self .label2id :
127169 if self .labels is not None :
128170 labels = set (self .labels )
@@ -141,22 +183,19 @@ def post_init(self, gold_data: Iterable[Doc], exclude: Set[str]):
141183 self .num_classes = len (self .label2id )
142184 print ("num classes:" , self .num_classes )
143185 self .build_classifier ()
144-
186+ print ( "label2id fini" )
145187 weight_tensor = None
146188 if self .class_weights is not None :
147- if isinstance (self .class_weights , str ):
148- freq_dict = self ._load_class_weights_from_file (self .class_weights )
149- weight_tensor = self ._compute_class_weights (freq_dict )
150- elif isinstance (self .class_weights , dict ):
151- weight_tensor = self ._compute_class_weights (self .class_weights )
152-
189+ weight_tensor = self ._compute_class_weights (self .class_weights )
153190 print (f"Using class weights: { weight_tensor } " )
154-
155- if self ._loss_fn is not None :
156- self .loss_fn = self ._loss_fn
157- else :
191+ print ("weight tensor fini" )
192+ if self .loss == "ce" :
158193 self .loss_fn = torch .nn .CrossEntropyLoss (weight = weight_tensor )
159-
194+ elif self .loss == "focal" :
195+ self .loss_fn = FocalLoss (alpha = weight_tensor , gamma = 2.0 , reduction = "mean" )
196+ else :
197+ raise ValueError (f"Unknown loss: { self .loss } " )
198+ print ("loss finie" )
160199 super ().post_init (gold_data , exclude = exclude )
161200
162201 def preprocess (self , doc : Doc ) -> Dict [str , Any ]:
0 commit comments