Skip to content

Commit 8621718

Browse files
committed
change doc classif metric
1 parent 57d8627 commit 8621718

File tree

1 file changed

+78
-22
lines changed

1 file changed

+78
-22
lines changed

edsnlp/metrics/doc_classif.py

Lines changed: 78 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ def doc_classification_metric(
1111
examples: Union[Tuple[Iterable[Doc], Iterable[Doc]], Iterable[Example]],
1212
label_attr: str = "label",
1313
micro_key: str = "micro",
14+
macro_key: str = "macro",
1415
filter_expr: Optional[str] = None,
1516
) -> Dict[str, Any]:
1617
"""
1718
Scores document-level classification (accuracy, precision, recall, F1).
18-
1919
Parameters
2020
----------
2121
examples: Examples
@@ -25,9 +25,10 @@ def doc_classification_metric(
2525
The Doc._ attribute containing the label
2626
micro_key: str
2727
The key to use to store the micro-averaged results
28+
macro_key: str
29+
The key to use to store the macro-averaged results
2830
filter_expr: str
2931
The filter expression to use to filter the documents
30-
3132
Returns
3233
-------
3334
Dict[str, Any]
@@ -46,33 +47,88 @@ def doc_classification_metric(
4647
gold_labels.append(gold)
4748

4849
labels = set(gold_labels) | set(pred_labels)
50+
labels = {label for label in labels if label is not None}
4951
results = {}
52+
5053
for label in labels:
51-
pred_set = [i for i, p in enumerate(pred_labels) if p == label]
52-
gold_set = [i for i, g in enumerate(gold_labels) if g == label]
53-
tp = len(set(pred_set) & set(gold_set))
54-
num_pred = len(pred_set)
55-
num_gold = len(gold_set)
54+
tp = sum(
55+
1 for p, g in zip(pred_labels, gold_labels) if p == label and g == label
56+
)
57+
fp = sum(
58+
1 for p, g in zip(pred_labels, gold_labels) if p == label and g != label
59+
)
60+
fn = sum(
61+
1 for p, g in zip(pred_labels, gold_labels) if g == label and p != label
62+
)
63+
64+
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
65+
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
66+
f1 = (
67+
(2 * precision * recall) / (precision + recall)
68+
if (precision + recall) > 0
69+
else 0.0
70+
)
71+
5672
results[label] = {
57-
"f": 2 * tp / max(1, num_pred + num_gold),
58-
"p": 1 if tp == num_pred else (tp / num_pred) if num_pred else 0.0,
59-
"r": 1 if tp == num_gold else (tp / num_gold) if num_gold else 0.0,
73+
"f": f1,
74+
"p": precision,
75+
"r": recall,
6076
"tp": tp,
61-
"support": num_gold,
62-
"positives": num_pred,
77+
"fp": fp,
78+
"fn": fn,
79+
"support": tp + fn,
80+
"positives": tp + fp,
6381
}
6482

65-
tp = sum(1 for p, g in zip(pred_labels, gold_labels) if p == g)
66-
num_pred = len(pred_labels)
67-
num_gold = len(gold_labels)
83+
total_tp = sum(1 for p, g in zip(pred_labels, gold_labels) if p == g)
84+
total_fp = sum(1 for p, g in zip(pred_labels, gold_labels) if p != g)
85+
total_fn = total_fp
86+
87+
micro_precision = (
88+
total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0.0
89+
)
90+
micro_recall = (
91+
total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0.0
92+
)
93+
micro_f1 = (
94+
(2 * micro_precision * micro_recall) / (micro_precision + micro_recall)
95+
if (micro_precision + micro_recall) > 0
96+
else 0.0
97+
)
98+
accuracy = total_tp / len(pred_labels) if len(pred_labels) > 0 else 0.0
99+
68100
results[micro_key] = {
69-
"accuracy": tp / num_gold if num_gold else 0.0,
70-
"f": 2 * tp / max(1, num_pred + num_gold),
71-
"p": tp / num_pred if num_pred else 0.0,
72-
"r": tp / num_gold if num_gold else 0.0,
73-
"tp": tp,
74-
"support": num_gold,
75-
"positives": num_pred,
101+
"accuracy": accuracy,
102+
"f": micro_f1,
103+
"p": micro_precision,
104+
"r": micro_recall,
105+
"tp": total_tp,
106+
"fp": total_fp,
107+
"fn": total_fn,
108+
"support": len(gold_labels),
109+
"positives": len(pred_labels),
110+
}
111+
112+
per_class_precisions = [results[label]["p"] for label in labels]
113+
per_class_recalls = [results[label]["r"] for label in labels]
114+
per_class_f1s = [results[label]["f"] for label in labels]
115+
116+
macro_precision = (
117+
sum(per_class_precisions) / len(per_class_precisions)
118+
if per_class_precisions
119+
else 0.0
120+
)
121+
macro_recall = (
122+
sum(per_class_recalls) / len(per_class_recalls) if per_class_recalls else 0.0
123+
)
124+
macro_f1 = sum(per_class_f1s) / len(per_class_f1s) if per_class_f1s else 0.0
125+
126+
results[macro_key] = {
127+
"f": macro_f1,
128+
"p": macro_precision,
129+
"r": macro_recall,
130+
"support": len(labels),
131+
"classes": len(labels),
76132
}
77133
return results
78134

0 commit comments

Comments
 (0)