1
1
from sklearn .metrics import roc_curve , auc
2
2
from sklearn .manifold import TSNE
3
- from sklearn .neighbors import KernelDensity
4
3
from torchvision import transforms
5
4
from torch .utils .data import DataLoader
6
5
import torch
14
13
from sklearn .utils import shuffle
15
14
from sklearn .model_selection import GridSearchCV
16
15
import numpy as np
17
- from sklearn .covariance import LedoitWolf
18
16
from collections import defaultdict
17
+ from density import GaussianDensitySklearn , GaussianDensityTorch
19
18
import pandas as pd
20
19
from utils import str2bool
21
20
22
21
test_data_eval = None
23
22
test_transform = None
24
23
cached_type = None
25
-
24
+
26
25
def get_train_embeds (model , size , defect_type , transform , device ):
27
26
# train data / train kde
28
27
test_data = MVTecAT ("Data" , defect_type , size , transform = transform , mode = "train" )
@@ -38,7 +37,7 @@ def get_train_embeds(model, size, defect_type, transform, device):
38
37
train_embed = torch .cat (train_embed )
39
38
return train_embed
40
39
41
- def eval_model (modelname , defect_type , device = "cpu" , save_plots = False , size = 256 , show_training_data = True , model = None , train_embed = None , head_layer = 8 ):
40
+ def eval_model (modelname , defect_type , device = "cpu" , save_plots = False , size = 256 , show_training_data = True , model = None , train_embed = None , head_layer = 8 , density = GaussianDensityTorch () ):
42
41
# create test dataset
43
42
global test_data_eval ,test_transform , cached_type
44
43
@@ -146,53 +145,10 @@ def eval_model(modelname, defect_type, device="cpu", save_plots=False, size=256,
146
145
plot_tsne (tsne_labels , tsne_embeds , eval_dir / "tsne.png" )
147
146
else :
148
147
eval_dir = Path ("unused" )
149
- # TODO: put the GDE stuff into the Model class and do this at the end of the training
150
- # # estemate KDE parameters
151
- # # use grid search cross-validation to optimize the bandwidth
152
- # params = {'bandwidth': np.logspace(-10, 10, 50)}
153
- # grid = GridSearchCV(KernelDensity(), params)
154
- # grid.fit(embeds)
155
-
156
- # print("best bandwidth: {0}".format(grid.best_estimator_.bandwidth))
157
-
158
- # # use the best estimator to compute the kernel density estimate
159
- # kde = grid.best_estimator_
160
- # kde = KernelDensity(kernel='gaussian', bandwidth=1).fit(train_embed)
161
- # scores = kde.score_samples(embeds)
162
- # print(scores)
163
- # we get the probability to be in the correct distribution
164
- # but our labels are inverted (1 for out of distribution)
165
- # so we have to relabel
166
-
167
- # use own formulation with malanobis distance
168
- # from https://github.com/ORippler/gaussian-ad-mvtec/blob/4e85fb5224eee13e8643b684c8ef15ab7d5d016e/src/gaussian/model.py#L308
169
- def mahalanobis_distance (
170
- values : torch .Tensor , mean : torch .Tensor , inv_covariance : torch .Tensor
171
- ) -> torch .Tensor :
172
- """Compute the batched mahalanobis distance.
173
- values is a batch of feature vectors.
174
- mean is either the mean of the distribution to compare, or a second
175
- batch of feature vectors.
176
- inv_covariance is the inverse covariance of the target distribution.
177
- """
178
- assert values .dim () == 2
179
- assert 1 <= mean .dim () <= 2
180
- assert len (inv_covariance .shape ) == 2
181
- assert values .shape [1 ] == mean .shape [- 1 ]
182
- assert mean .shape [- 1 ] == inv_covariance .shape [0 ]
183
- assert inv_covariance .shape [0 ] == inv_covariance .shape [1 ]
184
-
185
- if mean .dim () == 1 : # Distribution mean.
186
- mean = mean .unsqueeze (0 )
187
- x_mu = values - mean # batch x features
188
- # Same as dist = x_mu.t() * inv_covariance * x_mu batch wise
189
- dist = torch .einsum ("im,mn,in->i" , x_mu , inv_covariance , x_mu )
190
- return dist .sqrt ()
191
- # claculate mean
192
- mean = torch .mean (train_embed , axis = 0 )
193
- inv_cov = torch .Tensor (LedoitWolf ().fit (train_embed .cpu ()).precision_ ,device = "cpu" )
194
-
195
- distances = mahalanobis_distance (embeds , mean , inv_cov )
148
+
149
+ print (f"using density estimation { density .__class__ .__name__ } " )
150
+ density .fit (train_embed )
151
+ distances = density .predict (embeds )
196
152
#TODO: set threshold on mahalanobis distances and use "real" probabilities
197
153
198
154
roc_auc = plot_roc (labels , distances , eval_dir / "roc_plot.png" , modelname = modelname , save_plots = save_plots )
@@ -250,6 +206,11 @@ def plot_tsne(labels, embeds, filename):
250
206
parser .add_argument ('--head_layer' , default = 8 , type = int ,
251
207
help = 'number of layers in the projection head (default: 8)' )
252
208
209
+ parser .add_argument ('--density' , default = "torch" , choices = ["torch" , "sklearn" ],
210
+ help = 'density implementation to use. See `density.py` for both implementations. (default: torch)' )
211
+
212
+ parser .add_argument ('--save_plots' , default = True , type = str2bool ,
213
+ help = 'save TSNE and roc plots' )
253
214
254
215
255
216
args = parser .parse_args ()
@@ -279,6 +240,12 @@ def plot_tsne(labels, embeds, filename):
279
240
280
241
device = "cuda" if args .cuda else "cpu"
281
242
243
+ density_mapping = {
244
+ "torch" : GaussianDensityTorch ,
245
+ "sklearn" : GaussianDensitySklearn
246
+ }
247
+ density = density_mapping [args .density ]
248
+
282
249
# find models
283
250
model_names = [list (Path (args .model_dir ).glob (f"model-{ data_type } *" ))[0 ] for data_type in types if len (list (Path (args .model_dir ).glob (f"model-{ data_type } *" ))) > 0 ]
284
251
if len (model_names ) < len (all_types ):
@@ -288,7 +255,7 @@ def plot_tsne(labels, embeds, filename):
288
255
for model_name , data_type in zip (model_names , types ):
289
256
print (f"evaluating { data_type } " )
290
257
291
- roc_auc = eval_model (model_name , data_type , save_plots = True , device = device , head_layer = args .head_layer )
258
+ roc_auc = eval_model (model_name , data_type , save_plots = args . save_plots , device = device , head_layer = args .head_layer , density = density () )
292
259
print (f"{ data_type } AUC: { roc_auc } " )
293
260
obj ["defect_type" ].append (data_type )
294
261
obj ["roc_auc" ].append (roc_auc )
0 commit comments