Skip to content

Commit 6717b2b

Browse files
committed
Merge branch 'feat/postprocess/add_mebin_post_processor' of https://github.com/StarPlatinum7/anomalib into feat/postprocess/add_mebin_post_processor
2 parents a7863dc + 5cb8f53 commit 6717b2b

File tree

5 files changed

+421
-9
lines changed

5 files changed

+421
-9
lines changed

src/anomalib/metrics/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
- ``BinaryPrecisionRecallCurve``: Computes precision-recall curves
2626
- ``Evaluator``: Combines multiple metrics for evaluation
2727
- ``MinMax``: Normalizes scores to [0,1] range
28+
- ``PBn``: Presorted bad with n% good samples misclassified
29+
- ``PGn``: Presorted good with n% bad samples missed
2830
- ``PRO``: Per-Region Overlap score
2931
- ``PIMO``: Per-Image Missed Overlap score
3032
@@ -56,6 +58,7 @@
5658
from .evaluator import Evaluator
5759
from .f1_score import F1Max, F1Score
5860
from .min_max import MinMax
61+
from .pg_pb import PBn, PGn
5962
from .pimo import AUPIMO, PIMO
6063
from .precision_recall_curve import BinaryPrecisionRecallCurve
6164
from .pro import PRO
@@ -75,6 +78,8 @@
7578
"F1Score",
7679
"ManualThreshold",
7780
"MinMax",
81+
"PGn",
82+
"PBn",
7883
"PRO",
7984
"PIMO",
8085
"AUPIMO",

src/anomalib/metrics/anomaly_score_distribution.py

Lines changed: 70 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33

44
"""Compute statistics of anomaly score distributions.
55
6-
This module provides the ``AnomalyScoreDistribution`` class which computes mean
7-
and standard deviation statistics of anomaly scores from normal training data.
6+
This module provides the ``AnomalyScoreDistribution`` class, which computes the mean
7+
and standard deviation statistics of anomaly scores.
88
Statistics are computed for both image-level and pixel-level scores.
9+
The ``plot`` method generates a histogram of anomaly scores,
10+
separated by label, to visualize score distributions for normal and abnormal samples.
911
1012
The class tracks:
1113
- Image-level statistics: Mean and std of image anomaly scores
@@ -17,29 +19,34 @@
1719
>>> # Create sample data
1820
>>> scores = torch.tensor([0.1, 0.2, 0.15]) # Image anomaly scores
1921
>>> maps = torch.tensor([[0.1, 0.2], [0.15, 0.25]]) # Pixel anomaly maps
22+
>>> labels = torch.tensor([0, 1, 0]) # Binary labels
2023
>>> # Initialize and compute stats
2124
>>> dist = AnomalyScoreDistribution()
22-
>>> dist.update(anomaly_scores=scores, anomaly_maps=maps)
25+
>>> dist.update(anomaly_scores=scores, anomaly_maps=maps, labels=labels)
2326
>>> image_mean, image_std, pixel_mean, pixel_std = dist.compute()
27+
>>> fig, title = dist.plot()
2428
2529
Note:
2630
The input scores and maps are log-transformed before computing statistics.
27-
Both image-level scores and pixel-level maps are optional inputs.
31+
Image-level scores, pixel-level maps, and labels are optional inputs.
2832
"""
2933

3034
import torch
35+
from matplotlib.figure import Figure
3136
from torchmetrics import Metric
3237

38+
from .utils import plot_score_histogram
39+
3340

3441
class AnomalyScoreDistribution(Metric):
3542
"""Compute distribution statistics of anomaly scores.
3643
3744
This class tracks and computes the mean and standard deviation of anomaly
38-
scores from the normal samples in the training set. Statistics are computed
39-
for both image-level scores and pixel-level anomaly maps.
45+
scores. Statistics are computed for both image-level scores and pixel-level
46+
anomaly maps.
4047
41-
The metric maintains internal state to accumulate scores and maps across
42-
batches before computing final statistics.
48+
The metric maintains internal state to accumulate scores, anomaly maps,
49+
and labels across batches before computing final statistics.
4350
4451
Example:
4552
>>> dist = AnomalyScoreDistribution()
@@ -59,6 +66,7 @@ def __init__(self, **kwargs) -> None:
5966
super().__init__(**kwargs)
6067
self.anomaly_maps: list[torch.Tensor] = []
6168
self.anomaly_scores: list[torch.Tensor] = []
69+
self.labels: list[torch.Tensor] = []
6270

6371
self.add_state("image_mean", torch.empty(0), persistent=True)
6472
self.add_state("image_std", torch.empty(0), persistent=True)
@@ -75,6 +83,7 @@ def update(
7583
*args,
7684
anomaly_scores: torch.Tensor | None = None,
7785
anomaly_maps: torch.Tensor | None = None,
86+
labels: torch.Tensor | None = None,
7887
**kwargs,
7988
) -> None:
8089
"""Update the internal state with new scores and maps.
@@ -83,6 +92,7 @@ def update(
8392
*args: Unused positional arguments.
8493
anomaly_scores: Batch of image-level anomaly scores.
8594
anomaly_maps: Batch of pixel-level anomaly maps.
95+
labels: Batch of binary labels.
8696
**kwargs: Unused keyword arguments.
8797
"""
8898
del args, kwargs # These variables are not used.
@@ -91,6 +101,8 @@ def update(
91101
self.anomaly_maps.append(anomaly_maps)
92102
if anomaly_scores is not None:
93103
self.anomaly_scores.append(anomaly_scores)
104+
if labels is not None:
105+
self.labels.append(labels)
94106

95107
def compute(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
96108
"""Compute distribution statistics from accumulated scores and maps.
@@ -116,3 +128,53 @@ def compute(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tenso
116128
self.pixel_std = anomaly_maps.std(dim=0).squeeze()
117129

118130
return self.image_mean, self.image_std, self.pixel_mean, self.pixel_std
131+
132+
def plot(
133+
self,
134+
bins: int = 30,
135+
good_color: str = "skyblue",
136+
bad_color: str = "salmon",
137+
xlabel: str = "Score",
138+
ylabel: str = "Relative Count",
139+
title: str = "Score Histogram",
140+
legend_labels: tuple[str, str] = ("Good", "Bad"),
141+
) -> tuple[Figure, str]:
142+
"""Generate a histogram of scores.
143+
144+
Args:
145+
bins (int, optional): Number of histogram bins. Defaults to 30.
146+
good_color (str, optional): Color for good samples. Defaults to "skyblue".
147+
bad_color (str, optional): Color for bad samples. Defaults to "salmon".
148+
xlabel (str, optional): Label for the x-axis. Defaults to "Score".
149+
ylabel (str, optional): Label for the y-axis. Defaults to "Relative Count".
150+
title (str, optional): Title of the plot. Defaults to "Score Histogram".
151+
legend_labels (tuple[str, str], optional): Legend labels for good and bad samples.
152+
Defaults to ("Good", "Bad").
153+
154+
Returns:
155+
tuple[Figure, str]: Tuple containing both the figure and the figure
156+
title to be used for logging
157+
158+
Raises:
159+
ValueError: If no anomaly scores or labels are available.
160+
"""
161+
if len(self.anomaly_scores) == 0:
162+
msg = "No anomaly scores available."
163+
raise ValueError(msg)
164+
if len(self.labels) == 0:
165+
msg = "No labels available."
166+
raise ValueError(msg)
167+
168+
fig, _ = plot_score_histogram(
169+
scores=torch.hstack(self.anomaly_scores),
170+
labels=torch.hstack(self.labels),
171+
bins=bins,
172+
good_color=good_color,
173+
bad_color=bad_color,
174+
xlabel=xlabel,
175+
ylabel=ylabel,
176+
title=title,
177+
legend_labels=legend_labels,
178+
)
179+
180+
return fig, title

src/anomalib/metrics/pg_pb.py

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
# Copyright (C) 2025 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""PGn and PBn metrics for binary image-level classification tasks.
5+
6+
This module provides two metrics for evaluating binary image-level classification performance
7+
on the assumption that bad (anomalous) samples are considered to be the positive class:
8+
9+
- ``PGn``: Presorted good with n% bad samples missed, can be interpreted as true negative rate
10+
at a fixed false negative rate (TNR@nFNR).
11+
- ``PBn``: Presorted bad with n% good samples misclassified, can be interpreted as true positive rate
12+
at a fixed false positive rate (TPR@nFPR).
13+
14+
These metrics emphasize the practical applications of anomaly detection models by showing their potential
15+
to reduce human operator workload while maintaining an acceptable level of misclassification.
16+
17+
Example:
18+
>>> from anomalib.metrics import PGn, PBn
19+
>>> from anomalib.data import ImageBatch
20+
>>> import torch
21+
>>> # Create sample batch
22+
>>> batch = ImageBatch(
23+
... image=torch.rand(4, 3, 32, 32),
24+
... pred_score=torch.tensor([0.1, 0.4, 0.35, 0.8]),
25+
... gt_label=torch.tensor([0, 0, 1, 1])
26+
... )
27+
>>> pg = PGn(fnr=0.2)
28+
>>> # Print name of the metric
29+
>>> print(pg.name)
30+
PG20
31+
>>> # Compute PGn score
32+
>>> pg.update(batch)
33+
>>> pg.compute()
34+
tensor(1.0)
35+
>>> pb = PBn(fpr=0.2)
36+
>>> # Print name of the metric
37+
>>> print(pb.name)
38+
PB20
39+
>>> # Compute PBn score
40+
>>> pb.update(batch)
41+
>>> pb.compute()
42+
tensor(1.0)
43+
44+
Note:
45+
Scores for both metrics range from 0 to 1, with 1 indicating perfect separation
46+
of the respective class with ``n``% or less of the other class misclassified.
47+
48+
Reference:
49+
Aimira Baitieva, Yacine Bouaouni, Alexandre Briot, Dick Ameln, Souhaiel Khalfaoui,
50+
Samet Akcay; Beyond Academic Benchmarks: Critical Analysis and Best Practices
51+
for Visual Industrial Anomaly Detection; in: Proceedings of the IEEE/CVF Conference
52+
on Computer Vision and Pattern Recognition (CVPR) Workshops, 2025, pp. 4024-4034,
53+
https://arxiv.org/abs/2503.23451
54+
"""
55+
56+
import torch
57+
from torchmetrics import Metric
58+
from torchmetrics.utilities import dim_zero_cat
59+
60+
from anomalib.metrics.base import AnomalibMetric
61+
62+
63+
class _PGn(Metric):
64+
"""Presorted good metric.
65+
66+
This class calculates the Presorted good (PGn) metric, which is the true negative rate
67+
at a fixed false negative rate.
68+
69+
Args:
70+
**kwargs: Additional arguments passed to the parent ``Metric`` class.
71+
72+
Attributes:
73+
fnr (torch.Tensor): Fixed false negative rate (bad parts misclassified).
74+
Defaults to ``0.05``.
75+
76+
Example:
77+
>>> from anomalib.metrics.pg_pb import _PGn
78+
>>> import torch
79+
>>> # Create sample data
80+
>>> preds = torch.tensor([0.1, 0.4, 0.35, 0.8])
81+
>>> target = torch.tensor([0, 0, 1, 1])
82+
>>> # Compute PGn score
83+
>>> pg = _PGn(fnr=0.2)
84+
>>> pg.update(preds, target)
85+
>>> pg.compute()
86+
tensor(1.0)
87+
"""
88+
89+
def __init__(self, fnr: float = 0.05, **kwargs) -> None:
90+
super().__init__(**kwargs)
91+
if fnr < 0 or fnr > 1:
92+
msg = f"False negative rate must be in the range between 0 and 1, got {fnr}."
93+
raise ValueError(msg)
94+
95+
self.fnr = torch.tensor(fnr, dtype=torch.float32)
96+
self.name = "PG" + str(int(fnr * 100))
97+
98+
self.add_state("preds", default=[], dist_reduce_fx="cat")
99+
self.add_state("target", default=[], dist_reduce_fx="cat")
100+
101+
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
102+
"""Update state with new values.
103+
104+
Args:
105+
preds (torch.Tensor): predictions of the model
106+
target (torch.Tensor): ground truth targets
107+
"""
108+
self.target.append(target)
109+
self.preds.append(preds)
110+
111+
def compute(self) -> torch.Tensor:
112+
"""Compute the PGn score at a given false negative rate.
113+
114+
Returns:
115+
torch.Tensor: PGn score value.
116+
117+
Raises:
118+
ValueError: If no negative samples are found.
119+
"""
120+
preds = dim_zero_cat(self.preds)
121+
target = dim_zero_cat(self.target)
122+
123+
pos_scores = preds[target == 1]
124+
thr_accept = torch.quantile(pos_scores, self.fnr)
125+
126+
neg_scores = preds[target == 0]
127+
if neg_scores.numel() == 0:
128+
msg = "No negative samples found. Cannot compute PGn score."
129+
raise ValueError(msg)
130+
pg = neg_scores[neg_scores < thr_accept].numel() / neg_scores.numel()
131+
132+
return torch.tensor(pg, dtype=preds.dtype)
133+
134+
135+
class PGn(AnomalibMetric, _PGn): # type: ignore[misc]
136+
"""Wrapper to add AnomalibMetric functionality to PGn metric.
137+
138+
This class wraps the internal ``_PGn`` metric to make it compatible with
139+
Anomalib's batch processing capabilities.
140+
"""
141+
142+
default_fields = ("pred_score", "gt_label")
143+
144+
145+
class _PBn(Metric):
146+
"""Presorted bad metric.
147+
148+
This class calculates the Presorted bad (PBn) metric, which is the true positive rate
149+
at a fixed false positive rate.
150+
151+
Args:
152+
fpr (float): Fixed false positive rate (good parts misclassified). Defaults to ``0.05``.
153+
**kwargs: Additional arguments passed to the parent ``Metric`` class.
154+
155+
Example:
156+
>>> from anomalib.metrics import _PBn
157+
>>> import torch
158+
>>> preds = torch.tensor([0.1, 0.4, 0.35, 0.8])
159+
>>> target = torch.tensor([0, 0, 1, 1])
160+
>>> pb = _PBn(fpr=0.2)
161+
>>> pb.update(preds, target)
162+
>>> pb.compute()
163+
tensor(1.0)
164+
"""
165+
166+
def __init__(self, fpr: float = 0.05, **kwargs) -> None:
167+
super().__init__(**kwargs)
168+
if fpr < 0 or fpr > 1:
169+
msg = f"False positive rate must be in the range between 0 and 1, got {fpr}."
170+
raise ValueError(msg)
171+
172+
self.fpr = torch.tensor(fpr, dtype=torch.float32)
173+
self.name = "PB" + str(int(fpr * 100))
174+
175+
self.add_state("preds", default=[], dist_reduce_fx="cat")
176+
self.add_state("target", default=[], dist_reduce_fx="cat")
177+
178+
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
179+
"""Update state with new values.
180+
181+
Args:
182+
preds (torch.Tensor): predictions of the model
183+
target (torch.Tensor): ground truth targets
184+
"""
185+
self.target.append(target)
186+
self.preds.append(preds)
187+
188+
def compute(self) -> torch.Tensor:
189+
"""Compute the PBn score at a given false positive rate.
190+
191+
Returns:
192+
torch.Tensor: PBn score value.
193+
194+
Raises:
195+
ValueError: If no positive samples are found.
196+
"""
197+
preds = dim_zero_cat(self.preds)
198+
target = dim_zero_cat(self.target)
199+
200+
neg_scores = preds[target == 0]
201+
thr_accept = torch.quantile(neg_scores, 1 - self.fpr)
202+
203+
pos_scores = preds[target == 1]
204+
if pos_scores.numel() == 0:
205+
msg = "No positive samples found. Cannot compute PBn score."
206+
raise ValueError(msg)
207+
pb = pos_scores[pos_scores > thr_accept].numel() / pos_scores.numel()
208+
209+
return torch.tensor(pb, dtype=preds.dtype)
210+
211+
212+
class PBn(AnomalibMetric, _PBn): # type: ignore[misc]
213+
"""Wrapper to add AnomalibMetric functionality to PBn metric.
214+
215+
This class wraps the internal ``_PBn`` metric to make it compatible with
216+
Anomalib's batch processing capabilities.
217+
"""
218+
219+
default_fields = ("pred_score", "gt_label")

0 commit comments

Comments
 (0)