9
9
from torch .optim .lr_scheduler import CosineAnnealingLR
10
10
from torch .utils .data import DataLoader , random_split
11
11
from torchvision import transforms
12
- import torchmetrics as tm
13
12
from torchvision .datasets .mnist import MNIST
14
13
15
14
@@ -39,21 +38,9 @@ def __init__(self, backbone, num_epochs: int = 5, lr=1e-4):
39
38
self .loss_fn = nn .CrossEntropyLoss ()
40
39
41
40
# Define cross-validation metrics
42
- self .train_acc = tm .Accuracy (average = 'weighted' , num_classes = 10 )
43
- self .val_acc = tm .Accuracy (average = 'weighted' , num_classes = 10 )
44
- self .test_acc = tm .Accuracy (average = 'weighted' , num_classes = 10 )
45
-
46
- self .train_auroc = tm .AUROC (average = 'weighted' )
47
- self .val_auroc = tm .AUROC (average = 'weighted' )
48
- self .test_auroc = tm .AUROC (average = 'weighted' )
49
-
50
- self .train_auprc = tm .AveragePrecision ()
51
- self .val_auprc = tm .AveragePrecision ()
52
- self .test_auprc = tm .AveragePrecision ()
53
-
54
- self .train_f1 = tm .F1 (average = 'weighted' , num_classes = 10 )
55
- self .val_f1 = tm .F1 (average = 'weighted' , num_classes = 10 )
56
- self .test_f1 = tm .F1 (average = 'weighted' , num_classes = 10 )
41
+ self .train_acc = pl .metrics .Accuracy (average = 'weighted' , num_classes = 10 )
42
+ self .val_acc = pl .metrics .Accuracy (average = 'weighted' , num_classes = 10 )
43
+ self .test_acc = pl .metrics .Accuracy (average = 'weighted' , num_classes = 10 )
57
44
58
45
def forward (self , x ):
59
46
# use forward for inference/predictions
@@ -65,48 +52,30 @@ def training_step(self, batch, batch_idx):
65
52
y_hat = self (x )
66
53
loss = self .loss_fn (y_hat , y )
67
54
self .log ('train_acc' , self .train_acc (y_hat , y ), sync_dist = True )
68
- self .log ('train_auroc' , self .train_auroc (y_hat , y ), sync_dist = True )
69
- self .log ('train_auprc' , self .train_auprc (y_hat , y ), sync_dist = True )
70
- self .log ('train_f1' , self .train_f1 (y_hat , y ), sync_dist = True )
71
55
return loss
72
56
73
57
def training_epoch_end (self , outputs ):
74
58
self .train_acc .reset ()
75
- self .train_auroc .reset ()
76
- self .train_auprc .reset ()
77
- self .train_f1 .reset ()
78
59
79
60
def validation_step (self , batch , batch_idx ):
80
61
x , y = batch
81
62
y_hat = self (x )
82
63
loss = self .loss_fn (y_hat , y )
83
64
self .log ('val_acc' , self .train_acc (y_hat , y ), sync_dist = True )
84
- self .log ('val_auroc' , self .train_auroc (y_hat , y ), sync_dist = True )
85
- self .log ('val_auprc' , self .train_auprc (y_hat , y ), sync_dist = True )
86
- self .log ('val_f1' , self .train_f1 (y_hat , y ), sync_dist = True )
87
65
return loss
88
66
89
67
def validation_epoch_end (self , outputs ):
90
68
self .val_acc .reset ()
91
- self .val_auroc .reset ()
92
- self .val_auprc .reset ()
93
- self .val_f1 .reset ()
94
69
95
70
def test_step (self , batch , batch_idx ):
96
71
x , y = batch
97
72
y_hat = self (x )
98
73
loss = self .loss_fn (y_hat , y )
99
74
self .log ('test_acc' , self .train_acc (y_hat , y ), sync_dist = True )
100
- self .log ('test_auroc' , self .train_auroc (y_hat , y ), sync_dist = True )
101
- self .log ('test_auprc' , self .train_auprc (y_hat , y ), sync_dist = True )
102
- self .log ('test_f1' , self .train_f1 (y_hat , y ), sync_dist = True )
103
75
return loss
104
76
105
77
def test_epoch_end (self , outputs ):
106
78
self .test_acc .reset ()
107
- self .test_auroc .reset ()
108
- self .test_auprc .reset ()
109
- self .test_f1 .reset ()
110
79
111
80
# ---------------------
112
81
# training setup
@@ -115,7 +84,7 @@ def configure_optimizers(self):
115
84
# self.hparams available because we called self.save_hyperparameters()
116
85
optimizer = torch .optim .Adam (self .parameters (), lr = self .hparams .lr )
117
86
scheduler = CosineAnnealingLR (optimizer , self .hparams .num_epochs )
118
- metric_to_track = 'val_auroc '
87
+ metric_to_track = 'val_acc '
119
88
return {
120
89
'optimizer' : optimizer ,
121
90
'lr_scheduler' : scheduler ,
@@ -202,17 +171,17 @@ def cli_main():
202
171
# Resume from checkpoint if path to a valid one is provided
203
172
args .ckpt_name = args .ckpt_name \
204
173
if args .ckpt_name is not None \
205
- else 'LitClassifier-{epoch:02d}-{val_auroc :.2f}.ckpt'
174
+ else 'LitClassifier-{epoch:02d}-{val_acc :.2f}.ckpt'
206
175
checkpoint_path = os .path .join (args .ckpt_dir , args .ckpt_name )
207
176
trainer .resume_from_checkpoint = checkpoint_path if os .path .exists (checkpoint_path ) else None
208
177
209
178
# ------------
210
179
# training
211
180
# ------------
212
181
# Create and use callbacks
213
- early_stop_callback = EarlyStopping (monitor = 'val_auroc ' , mode = 'min ' , min_delta = 0.00 , patience = 3 )
214
- checkpoint_callback = ModelCheckpoint (monitor = 'val_auroc ' , save_top_k = 3 , dirpath = args .ckpt_dir ,
215
- filename = 'LitClassifier-{epoch:02d}-{val_auroc :.2f}' )
182
+ early_stop_callback = EarlyStopping (monitor = 'val_acc ' , mode = 'max ' , min_delta = 0.01 , patience = 3 )
183
+ checkpoint_callback = ModelCheckpoint (monitor = 'val_acc ' , save_top_k = 3 , dirpath = args .ckpt_dir ,
184
+ filename = 'LitClassifier-{epoch:02d}-{val_acc :.2f}' )
216
185
lr_callback = LearningRateMonitor (logging_interval = 'epoch' ) # Use with a learning rate scheduler
217
186
trainer .callbacks = [early_stop_callback , checkpoint_callback , lr_callback ]
218
187
0 commit comments