Skip to content

Commit 453c197

Browse files
committed
fix boolean conversion and some spelling
1 parent c4bb088 commit 453c197

File tree

7 files changed

+30
-9
lines changed

7 files changed

+30
-9
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Implementation of CutPaste
22

3-
This is an **unofficial** work in progress PyTorch reimplementation of [CutPaste: Self-Supervised Learning for Anomaly Detection and Localization](https://arxiv.org/abs/2104.04015) and in no way affiliated with the original authors. Use at own risk. Pull requests and feedback is appreciated.
3+
This is an **unofficial** work in progress PyTorch reimplementation of [CutPaste: Self-Supervised Learning for Anomaly Detection and Localization](https://arxiv.org/abs/2104.04015) and in no way affiliated with the original authors. Use at own risk. Pull requests and feedback is appreciated.
44

55
## Setup
66
Download the MVTec Anomaly detection Dataset from [here](https://www.mvtec.com/company/research/datasets/mvtec-ad) and extract it into a new folder named `Data`.

dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __init__(self, root_dir, defect_name, size, transform=None, mode="train"):
2424
root_dir (string): Directory with the MVTec AD dataset.
2525
defect_name (string): defect to load.
2626
transform: Transform to apply to data
27-
mode: "train" loads training sammples "test" test samples default "train"
27+
mode: "train" loads training samples "test" test samples default "train"
2828
"""
2929
self.root_dir = Path(root_dir)
3030
self.defect_name = defect_name

eval.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from sklearn.covariance import LedoitWolf
1818
from collections import defaultdict
1919
import pandas as pd
20+
from utils import str2bool
2021

2122
test_data_eval = None
2223
test_transform = None
@@ -95,8 +96,8 @@ def eval_model(modelname, defect_type, device="cpu", save_plots=False, size=256,
9596
# also show some of the training data
9697
show_training_data = False
9798
if show_training_data:
98-
#augmentation settig
99-
# TODO: do all of this in a seperate function that we can call in training and evaluation.
99+
#augmentation setting
100+
# TODO: do all of this in a separate function that we can call in training and evaluation.
100101
# very ugly to just copy the code lol
101102
min_scale = 0.5
102103

@@ -243,7 +244,7 @@ def plot_tsne(labels, embeds, filename):
243244
parser.add_argument('--model_dir', default="models",
244245
help=' directory contating models to evaluate (default: models)')
245246

246-
parser.add_argument('--cuda', default=False,
247+
parser.add_argument('--cuda', default=False, type=str2bool,
247248
help='use cuda for model predictions (default: False)')
248249

249250
parser.add_argument('--head_layer', default=8, type=int,

model.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,6 @@ def __init__(self, pretrained=True, head_layers=[512,512,512,512,512,512,512,512
2121
last_layer = num_neurons
2222

2323
#the last layer without activation
24-
#TODO: is this correct? check one classe representation framework paper/code
25-
# sequential_layers.append(nn.Linear(last_layer, head_layers[-1]))
26-
# last_layer = head_layers[-1]
2724

2825
head = nn.Sequential(
2926
*sequential_layers

requirements.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
torch
2+
torchvision
3+
sklearn
4+
pandas
5+
seaborn
6+
tqdm
7+
tensorboard

run_training.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from cutpaste import CutPasteNormal,CutPasteScar, CutPaste3Way, CutPasteUnion, cut_paste_collate_fn
1818
from model import ProjectionNet
1919
from eval import eval_model
20+
from util import str2bool
2021

2122
def run_training(data_type="screw",
2223
model_dir="models",
@@ -204,7 +205,7 @@ def get_data_inf():
204205

205206
parser.add_argument('--variant', default="3way", choices=['normal', 'scar', '3way', 'union'], help='cutpaste variant to use (dafault: "3way")')
206207

207-
parser.add_argument('--cuda', default=False,
208+
parser.add_argument('--cuda', default=False, type=str2bool,
208209
help='use cuda for training (default: False)')
209210

210211
parser.add_argument('--workers', default=8, type=int, help="number of workers to use for data loading (default:8)")

utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import distutils
2+
3+
def str2bool(v):
4+
"""argparse handels type=bool in a weird way.
5+
See this stack overflow: https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
6+
we can use this function as type converter for boolean values
7+
"""
8+
if isinstance(v, bool):
9+
return v
10+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
11+
return True
12+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
13+
return False
14+
else:
15+
raise argparse.ArgumentTypeError('Boolean value expected.')

0 commit comments

Comments
 (0)