@@ -11,11 +11,11 @@ def doc_classification_metric(
1111 examples : Union [Tuple [Iterable [Doc ], Iterable [Doc ]], Iterable [Example ]],
1212 label_attr : str = "label" ,
1313 micro_key : str = "micro" ,
14+ macro_key : str = "macro" ,
1415 filter_expr : Optional [str ] = None ,
1516) -> Dict [str , Any ]:
1617 """
1718 Scores document-level classification (accuracy, precision, recall, F1).
18-
1919 Parameters
2020 ----------
2121 examples: Examples
@@ -25,9 +25,10 @@ def doc_classification_metric(
2525 The Doc._ attribute containing the label
2626 micro_key: str
2727 The key to use to store the micro-averaged results
28+ macro_key: str
29+ The key to use to store the macro-averaged results
2830 filter_expr: str
2931 The filter expression to use to filter the documents
30-
3132 Returns
3233 -------
3334 Dict[str, Any]
@@ -46,33 +47,88 @@ def doc_classification_metric(
4647 gold_labels .append (gold )
4748
4849 labels = set (gold_labels ) | set (pred_labels )
50+ labels = {label for label in labels if label is not None }
4951 results = {}
52+
5053 for label in labels :
51- pred_set = [i for i , p in enumerate (pred_labels ) if p == label ]
52- gold_set = [i for i , g in enumerate (gold_labels ) if g == label ]
53- tp = len (set (pred_set ) & set (gold_set ))
54- num_pred = len (pred_set )
55- num_gold = len (gold_set )
54+ tp = sum (
55+ 1 for p , g in zip (pred_labels , gold_labels ) if p == label and g == label
56+ )
57+ fp = sum (
58+ 1 for p , g in zip (pred_labels , gold_labels ) if p == label and g != label
59+ )
60+ fn = sum (
61+ 1 for p , g in zip (pred_labels , gold_labels ) if g == label and p != label
62+ )
63+
64+ precision = tp / (tp + fp ) if (tp + fp ) > 0 else 0.0
65+ recall = tp / (tp + fn ) if (tp + fn ) > 0 else 0.0
66+ f1 = (
67+ (2 * precision * recall ) / (precision + recall )
68+ if (precision + recall ) > 0
69+ else 0.0
70+ )
71+
5672 results [label ] = {
57- "f" : 2 * tp / max ( 1 , num_pred + num_gold ) ,
58- "p" : 1 if tp == num_pred else ( tp / num_pred ) if num_pred else 0.0 ,
59- "r" : 1 if tp == num_gold else ( tp / num_gold ) if num_gold else 0.0 ,
73+ "f" : f1 ,
74+ "p" : precision ,
75+ "r" : recall ,
6076 "tp" : tp ,
61- "support" : num_gold ,
62- "positives" : num_pred ,
77+ "fp" : fp ,
78+ "fn" : fn ,
79+ "support" : tp + fn ,
80+ "positives" : tp + fp ,
6381 }
6482
65- tp = sum (1 for p , g in zip (pred_labels , gold_labels ) if p == g )
66- num_pred = len (pred_labels )
67- num_gold = len (gold_labels )
83+ total_tp = sum (1 for p , g in zip (pred_labels , gold_labels ) if p == g )
84+ total_fp = sum (1 for p , g in zip (pred_labels , gold_labels ) if p != g )
85+ total_fn = total_fp
86+
87+ micro_precision = (
88+ total_tp / (total_tp + total_fp ) if (total_tp + total_fp ) > 0 else 0.0
89+ )
90+ micro_recall = (
91+ total_tp / (total_tp + total_fn ) if (total_tp + total_fn ) > 0 else 0.0
92+ )
93+ micro_f1 = (
94+ (2 * micro_precision * micro_recall ) / (micro_precision + micro_recall )
95+ if (micro_precision + micro_recall ) > 0
96+ else 0.0
97+ )
98+ accuracy = total_tp / len (pred_labels ) if len (pred_labels ) > 0 else 0.0
99+
68100 results [micro_key ] = {
69- "accuracy" : tp / num_gold if num_gold else 0.0 ,
70- "f" : 2 * tp / max (1 , num_pred + num_gold ),
71- "p" : tp / num_pred if num_pred else 0.0 ,
72- "r" : tp / num_gold if num_gold else 0.0 ,
73- "tp" : tp ,
74- "support" : num_gold ,
75- "positives" : num_pred ,
101+ "accuracy" : accuracy ,
102+ "f" : micro_f1 ,
103+ "p" : micro_precision ,
104+ "r" : micro_recall ,
105+ "tp" : total_tp ,
106+ "fp" : total_fp ,
107+ "fn" : total_fn ,
108+ "support" : len (gold_labels ),
109+ "positives" : len (pred_labels ),
110+ }
111+
112+ per_class_precisions = [results [label ]["p" ] for label in labels ]
113+ per_class_recalls = [results [label ]["r" ] for label in labels ]
114+ per_class_f1s = [results [label ]["f" ] for label in labels ]
115+
116+ macro_precision = (
117+ sum (per_class_precisions ) / len (per_class_precisions )
118+ if per_class_precisions
119+ else 0.0
120+ )
121+ macro_recall = (
122+ sum (per_class_recalls ) / len (per_class_recalls ) if per_class_recalls else 0.0
123+ )
124+ macro_f1 = sum (per_class_f1s ) / len (per_class_f1s ) if per_class_f1s else 0.0
125+
126+ results [macro_key ] = {
127+ "f" : macro_f1 ,
128+ "p" : macro_precision ,
129+ "r" : macro_recall ,
130+ "support" : len (labels ),
131+ "classes" : len (labels ),
76132 }
77133 return results
78134
0 commit comments