Skip to content

Commit 60b34cc

Browse files
committed
refactor density estimation
1 parent 453c197 commit 60b34cc

File tree

4 files changed

+130
-52
lines changed

4 files changed

+130
-52
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
venv*

README.md

+42
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,48 @@ This implementation only applies color jitter before the CutPaste augmentation.
6161
### Tensorflow vs PyTorch
6262
Li et al. use tensorflow for their implementation. This implementation is using PyTorch.
6363

64+
### Kernel Density Estimation
65+
I implemented two Kernel Density Estimation and mahalanobis distance pipelines.
66+
Li et al. use sklearn for the density estimation but [Ripple et al.](https://github.com/ORippler/gaussian-ad-mvtec) have their own.
67+
The `eval.py` has a `--density` flag that can be toggled between `torch` for the Ripple et al. implementation and `sklearn` for my sklearn implementation.
68+
In my limited testing both implementations have small differences between the resulting ROC AUCs:
69+
```
70+
> python eval.py --density torch --cuda 1 --head_layer 2 --save_plots 0| grep AUC
71+
bottle AUC: 0.9944444444444445
72+
cable AUC: 0.8549475262368815
73+
capsule AUC: 0.8232947746310331
74+
carpet AUC: 0.9329855537720706
75+
grid AUC: 0.982456140350877
76+
hazelnut AUC: 0.9160714285714285
77+
leather AUC: 1.0
78+
metal_nut AUC: 0.9403714565004888
79+
pill AUC: 0.8046917621385706
80+
screw AUC: 0.701988112318098
81+
tile AUC: 0.9430014430014431
82+
toothbrush AUC: 0.8972222222222221
83+
transistor AUC: 0.9008333333333334
84+
wood AUC: 0.9815789473684211
85+
zipper AUC: 0.9997373949579832
86+
87+
> python eval.py --density sklearn --cuda 1 --head_layer 2 --save_plots 0| grep AUC
88+
bottle AUC: 0.9944444444444445
89+
cable AUC: 0.8549475262368815
90+
capsule AUC: 0.8232947746310331
91+
carpet AUC: 0.9329855537720706
92+
grid AUC: 0.982456140350877
93+
hazelnut AUC: 0.9160714285714285
94+
leather AUC: 1.0
95+
metal_nut AUC: 0.9403714565004888
96+
pill AUC: 0.8046917621385706
97+
screw AUC: 0.701988112318098
98+
tile AUC: 0.9430014430014431
99+
toothbrush AUC: 0.8972222222222221
100+
transistor AUC: 0.9008333333333334
101+
wood AUC: 0.9815789473684211
102+
zipper AUC: 0.9997373949579832
103+
```
104+
105+
64106
# Results
65107
This implementation only tries to recreate the main results from section 4.1 and shown in table 1.
66108
## CutPaste

density.py

+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
2+
from sklearn.covariance import LedoitWolf
3+
from sklearn.neighbors import KernelDensity
4+
import torch
5+
6+
7+
class Density(object):
8+
def fit(self, embeddings):
9+
raise NotImplementedError
10+
11+
def predict(self, embeddings):
12+
raise NotImplementedError
13+
14+
15+
class GaussianDensityTorch(object):
16+
"""Gaussian Density estimation similar to the implementation used by Ripple et al.
17+
The code of Ripple et al. can be found here: https://github.com/ORippler/gaussian-ad-mvtec.
18+
"""
19+
def fit(self, embeddings):
20+
self.mean = torch.mean(embeddings, axis=0)
21+
self.inv_cov = torch.Tensor(LedoitWolf().fit(embeddings.cpu()).precision_,device="cpu")
22+
23+
def predict(self, embeddings):
24+
distances = self.mahalanobis_distance(embeddings, self.mean, self.inv_cov)
25+
return distances
26+
27+
@staticmethod
28+
def mahalanobis_distance(
29+
values: torch.Tensor, mean: torch.Tensor, inv_covariance: torch.Tensor
30+
) -> torch.Tensor:
31+
"""Compute the batched mahalanobis distance.
32+
values is a batch of feature vectors.
33+
mean is either the mean of the distribution to compare, or a second
34+
batch of feature vectors.
35+
inv_covariance is the inverse covariance of the target distribution.
36+
37+
from https://github.com/ORippler/gaussian-ad-mvtec/blob/4e85fb5224eee13e8643b684c8ef15ab7d5d016e/src/gaussian/model.py#L308
38+
"""
39+
assert values.dim() == 2
40+
assert 1 <= mean.dim() <= 2
41+
assert len(inv_covariance.shape) == 2
42+
assert values.shape[1] == mean.shape[-1]
43+
assert mean.shape[-1] == inv_covariance.shape[0]
44+
assert inv_covariance.shape[0] == inv_covariance.shape[1]
45+
46+
if mean.dim() == 1: # Distribution mean.
47+
mean = mean.unsqueeze(0)
48+
x_mu = values - mean # batch x features
49+
# Same as dist = x_mu.t() * inv_covariance * x_mu batch wise
50+
dist = torch.einsum("im,mn,in->i", x_mu, inv_covariance, x_mu)
51+
return dist.sqrt()
52+
53+
class GaussianDensitySklearn():
54+
"""Li et al. use sklearn for density estimation.
55+
This implementation uses sklearn KernelDensity module for fitting and predicting.
56+
"""
57+
def fit(self, embeddings):
58+
# estimate KDE parameters
59+
# use grid search cross-validation to optimize the bandwidth
60+
self.kde = KernelDensity(kernel='gaussian', bandwidth=1).fit(embeddings)
61+
62+
def predict(self, embeddings):
63+
scores = self.kde.score_samples(embeddings)
64+
65+
# invert scores, so they fit to the class labels for the auc calculation
66+
scores = -scores
67+
68+
return scores

eval.py

+19-52
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from sklearn.metrics import roc_curve, auc
22
from sklearn.manifold import TSNE
3-
from sklearn.neighbors import KernelDensity
43
from torchvision import transforms
54
from torch.utils.data import DataLoader
65
import torch
@@ -14,15 +13,15 @@
1413
from sklearn.utils import shuffle
1514
from sklearn.model_selection import GridSearchCV
1615
import numpy as np
17-
from sklearn.covariance import LedoitWolf
1816
from collections import defaultdict
17+
from density import GaussianDensitySklearn, GaussianDensityTorch
1918
import pandas as pd
2019
from utils import str2bool
2120

2221
test_data_eval = None
2322
test_transform = None
2423
cached_type = None
25-
24+
2625
def get_train_embeds(model, size, defect_type, transform, device):
2726
# train data / train kde
2827
test_data = MVTecAT("Data", defect_type, size, transform=transform, mode="train")
@@ -38,7 +37,7 @@ def get_train_embeds(model, size, defect_type, transform, device):
3837
train_embed = torch.cat(train_embed)
3938
return train_embed
4039

41-
def eval_model(modelname, defect_type, device="cpu", save_plots=False, size=256, show_training_data=True, model=None, train_embed=None, head_layer=8):
40+
def eval_model(modelname, defect_type, device="cpu", save_plots=False, size=256, show_training_data=True, model=None, train_embed=None, head_layer=8, density=GaussianDensityTorch()):
4241
# create test dataset
4342
global test_data_eval,test_transform, cached_type
4443

@@ -146,53 +145,10 @@ def eval_model(modelname, defect_type, device="cpu", save_plots=False, size=256,
146145
plot_tsne(tsne_labels, tsne_embeds, eval_dir / "tsne.png")
147146
else:
148147
eval_dir = Path("unused")
149-
# TODO: put the GDE stuff into the Model class and do this at the end of the training
150-
# # estemate KDE parameters
151-
# # use grid search cross-validation to optimize the bandwidth
152-
# params = {'bandwidth': np.logspace(-10, 10, 50)}
153-
# grid = GridSearchCV(KernelDensity(), params)
154-
# grid.fit(embeds)
155-
156-
# print("best bandwidth: {0}".format(grid.best_estimator_.bandwidth))
157-
158-
# # use the best estimator to compute the kernel density estimate
159-
# kde = grid.best_estimator_
160-
# kde = KernelDensity(kernel='gaussian', bandwidth=1).fit(train_embed)
161-
# scores = kde.score_samples(embeds)
162-
# print(scores)
163-
# we get the probability to be in the correct distribution
164-
# but our labels are inverted (1 for out of distribution)
165-
# so we have to relabel
166-
167-
# use own formulation with malanobis distance
168-
# from https://github.com/ORippler/gaussian-ad-mvtec/blob/4e85fb5224eee13e8643b684c8ef15ab7d5d016e/src/gaussian/model.py#L308
169-
def mahalanobis_distance(
170-
values: torch.Tensor, mean: torch.Tensor, inv_covariance: torch.Tensor
171-
) -> torch.Tensor:
172-
"""Compute the batched mahalanobis distance.
173-
values is a batch of feature vectors.
174-
mean is either the mean of the distribution to compare, or a second
175-
batch of feature vectors.
176-
inv_covariance is the inverse covariance of the target distribution.
177-
"""
178-
assert values.dim() == 2
179-
assert 1 <= mean.dim() <= 2
180-
assert len(inv_covariance.shape) == 2
181-
assert values.shape[1] == mean.shape[-1]
182-
assert mean.shape[-1] == inv_covariance.shape[0]
183-
assert inv_covariance.shape[0] == inv_covariance.shape[1]
184-
185-
if mean.dim() == 1: # Distribution mean.
186-
mean = mean.unsqueeze(0)
187-
x_mu = values - mean # batch x features
188-
# Same as dist = x_mu.t() * inv_covariance * x_mu batch wise
189-
dist = torch.einsum("im,mn,in->i", x_mu, inv_covariance, x_mu)
190-
return dist.sqrt()
191-
# claculate mean
192-
mean = torch.mean(train_embed, axis=0)
193-
inv_cov = torch.Tensor(LedoitWolf().fit(train_embed.cpu()).precision_,device="cpu")
194-
195-
distances = mahalanobis_distance(embeds, mean, inv_cov)
148+
149+
print(f"using density estimation {density.__class__.__name__}")
150+
density.fit(train_embed)
151+
distances = density.predict(embeds)
196152
#TODO: set threshold on mahalanobis distances and use "real" probabilities
197153

198154
roc_auc = plot_roc(labels, distances, eval_dir / "roc_plot.png", modelname=modelname, save_plots=save_plots)
@@ -250,6 +206,11 @@ def plot_tsne(labels, embeds, filename):
250206
parser.add_argument('--head_layer', default=8, type=int,
251207
help='number of layers in the projection head (default: 8)')
252208

209+
parser.add_argument('--density', default="torch", choices=["torch", "sklearn"],
210+
help='density implementation to use. See `density.py` for both implementations. (default: torch)')
211+
212+
parser.add_argument('--save_plots', default=True, type=str2bool,
213+
help='save TSNE and roc plots')
253214

254215

255216
args = parser.parse_args()
@@ -279,6 +240,12 @@ def plot_tsne(labels, embeds, filename):
279240

280241
device = "cuda" if args.cuda else "cpu"
281242

243+
density_mapping = {
244+
"torch": GaussianDensityTorch,
245+
"sklearn": GaussianDensitySklearn
246+
}
247+
density = density_mapping[args.density]
248+
282249
# find models
283250
model_names = [list(Path(args.model_dir).glob(f"model-{data_type}*"))[0] for data_type in types if len(list(Path(args.model_dir).glob(f"model-{data_type}*"))) > 0]
284251
if len(model_names) < len(all_types):
@@ -288,7 +255,7 @@ def plot_tsne(labels, embeds, filename):
288255
for model_name, data_type in zip(model_names, types):
289256
print(f"evaluating {data_type}")
290257

291-
roc_auc = eval_model(model_name, data_type, save_plots=True, device=device, head_layer=args.head_layer)
258+
roc_auc = eval_model(model_name, data_type, save_plots=args.save_plots, device=device, head_layer=args.head_layer, density=density())
292259
print(f"{data_type} AUC: {roc_auc}")
293260
obj["defect_type"].append(data_type)
294261
obj["roc_auc"].append(roc_auc)

0 commit comments

Comments
 (0)