99from torch .optim .lr_scheduler import CosineAnnealingLR
1010from torch .utils .data import DataLoader , random_split
1111from torchvision import transforms
12- import torchmetrics as tm
1312from torchvision .datasets .mnist import MNIST
1413
1514
@@ -39,21 +38,9 @@ def __init__(self, backbone, num_epochs: int = 5, lr=1e-4):
3938 self .loss_fn = nn .CrossEntropyLoss ()
4039
4140 # Define cross-validation metrics
42- self .train_acc = tm .Accuracy (average = 'weighted' , num_classes = 10 )
43- self .val_acc = tm .Accuracy (average = 'weighted' , num_classes = 10 )
44- self .test_acc = tm .Accuracy (average = 'weighted' , num_classes = 10 )
45-
46- self .train_auroc = tm .AUROC (average = 'weighted' )
47- self .val_auroc = tm .AUROC (average = 'weighted' )
48- self .test_auroc = tm .AUROC (average = 'weighted' )
49-
50- self .train_auprc = tm .AveragePrecision ()
51- self .val_auprc = tm .AveragePrecision ()
52- self .test_auprc = tm .AveragePrecision ()
53-
54- self .train_f1 = tm .F1 (average = 'weighted' , num_classes = 10 )
55- self .val_f1 = tm .F1 (average = 'weighted' , num_classes = 10 )
56- self .test_f1 = tm .F1 (average = 'weighted' , num_classes = 10 )
41+ self .train_acc = pl .metrics .Accuracy (average = 'weighted' , num_classes = 10 )
42+ self .val_acc = pl .metrics .Accuracy (average = 'weighted' , num_classes = 10 )
43+ self .test_acc = pl .metrics .Accuracy (average = 'weighted' , num_classes = 10 )
5744
5845 def forward (self , x ):
5946 # use forward for inference/predictions
@@ -65,48 +52,30 @@ def training_step(self, batch, batch_idx):
6552 y_hat = self (x )
6653 loss = self .loss_fn (y_hat , y )
6754 self .log ('train_acc' , self .train_acc (y_hat , y ), sync_dist = True )
68- self .log ('train_auroc' , self .train_auroc (y_hat , y ), sync_dist = True )
69- self .log ('train_auprc' , self .train_auprc (y_hat , y ), sync_dist = True )
70- self .log ('train_f1' , self .train_f1 (y_hat , y ), sync_dist = True )
7155 return loss
7256
7357 def training_epoch_end (self , outputs ):
7458 self .train_acc .reset ()
75- self .train_auroc .reset ()
76- self .train_auprc .reset ()
77- self .train_f1 .reset ()
7859
7960 def validation_step (self , batch , batch_idx ):
8061 x , y = batch
8162 y_hat = self (x )
8263 loss = self .loss_fn (y_hat , y )
8364 self .log ('val_acc' , self .train_acc (y_hat , y ), sync_dist = True )
84- self .log ('val_auroc' , self .train_auroc (y_hat , y ), sync_dist = True )
85- self .log ('val_auprc' , self .train_auprc (y_hat , y ), sync_dist = True )
86- self .log ('val_f1' , self .train_f1 (y_hat , y ), sync_dist = True )
8765 return loss
8866
8967 def validation_epoch_end (self , outputs ):
9068 self .val_acc .reset ()
91- self .val_auroc .reset ()
92- self .val_auprc .reset ()
93- self .val_f1 .reset ()
9469
9570 def test_step (self , batch , batch_idx ):
9671 x , y = batch
9772 y_hat = self (x )
9873 loss = self .loss_fn (y_hat , y )
9974 self .log ('test_acc' , self .train_acc (y_hat , y ), sync_dist = True )
100- self .log ('test_auroc' , self .train_auroc (y_hat , y ), sync_dist = True )
101- self .log ('test_auprc' , self .train_auprc (y_hat , y ), sync_dist = True )
102- self .log ('test_f1' , self .train_f1 (y_hat , y ), sync_dist = True )
10375 return loss
10476
10577 def test_epoch_end (self , outputs ):
10678 self .test_acc .reset ()
107- self .test_auroc .reset ()
108- self .test_auprc .reset ()
109- self .test_f1 .reset ()
11079
11180 # ---------------------
11281 # training setup
@@ -115,7 +84,7 @@ def configure_optimizers(self):
11584 # self.hparams available because we called self.save_hyperparameters()
11685 optimizer = torch .optim .Adam (self .parameters (), lr = self .hparams .lr )
11786 scheduler = CosineAnnealingLR (optimizer , self .hparams .num_epochs )
118- metric_to_track = 'val_auroc '
87+ metric_to_track = 'val_acc '
11988 return {
12089 'optimizer' : optimizer ,
12190 'lr_scheduler' : scheduler ,
@@ -202,17 +171,17 @@ def cli_main():
202171 # Resume from checkpoint if path to a valid one is provided
203172 args .ckpt_name = args .ckpt_name \
204173 if args .ckpt_name is not None \
205- else 'LitClassifier-{epoch:02d}-{val_auroc :.2f}.ckpt'
174+ else 'LitClassifier-{epoch:02d}-{val_acc :.2f}.ckpt'
206175 checkpoint_path = os .path .join (args .ckpt_dir , args .ckpt_name )
207176 trainer .resume_from_checkpoint = checkpoint_path if os .path .exists (checkpoint_path ) else None
208177
209178 # ------------
210179 # training
211180 # ------------
212181 # Create and use callbacks
213- early_stop_callback = EarlyStopping (monitor = 'val_auroc ' , mode = 'min ' , min_delta = 0.00 , patience = 3 )
214- checkpoint_callback = ModelCheckpoint (monitor = 'val_auroc ' , save_top_k = 3 , dirpath = args .ckpt_dir ,
215- filename = 'LitClassifier-{epoch:02d}-{val_auroc :.2f}' )
182+ early_stop_callback = EarlyStopping (monitor = 'val_acc ' , mode = 'max ' , min_delta = 0.01 , patience = 3 )
183+ checkpoint_callback = ModelCheckpoint (monitor = 'val_acc ' , save_top_k = 3 , dirpath = args .ckpt_dir ,
184+ filename = 'LitClassifier-{epoch:02d}-{val_acc :.2f}' )
216185 lr_callback = LearningRateMonitor (logging_interval = 'epoch' ) # Use with a learning rate scheduler
217186 trainer .callbacks = [early_stop_callback , checkpoint_callback , lr_callback ]
218187
0 commit comments