Skip to content

Commit a50ba3b

Browse files
committed
Try PL 1.2.4 instead
1 parent 3184729 commit a50ba3b

File tree

2 files changed

+9
-41
lines changed

2 files changed

+9
-41
lines changed

project/lit_image_classifier.py

+8-39
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from torch.optim.lr_scheduler import CosineAnnealingLR
1010
from torch.utils.data import DataLoader, random_split
1111
from torchvision import transforms
12-
import torchmetrics as tm
1312
from torchvision.datasets.mnist import MNIST
1413

1514

@@ -39,21 +38,9 @@ def __init__(self, backbone, num_epochs: int = 5, lr=1e-4):
3938
self.loss_fn = nn.CrossEntropyLoss()
4039

4140
# Define cross-validation metrics
42-
self.train_acc = tm.Accuracy(average='weighted', num_classes=10)
43-
self.val_acc = tm.Accuracy(average='weighted', num_classes=10)
44-
self.test_acc = tm.Accuracy(average='weighted', num_classes=10)
45-
46-
self.train_auroc = tm.AUROC(average='weighted')
47-
self.val_auroc = tm.AUROC(average='weighted')
48-
self.test_auroc = tm.AUROC(average='weighted')
49-
50-
self.train_auprc = tm.AveragePrecision()
51-
self.val_auprc = tm.AveragePrecision()
52-
self.test_auprc = tm.AveragePrecision()
53-
54-
self.train_f1 = tm.F1(average='weighted', num_classes=10)
55-
self.val_f1 = tm.F1(average='weighted', num_classes=10)
56-
self.test_f1 = tm.F1(average='weighted', num_classes=10)
41+
self.train_acc = pl.metrics.Accuracy(average='weighted', num_classes=10)
42+
self.val_acc = pl.metrics.Accuracy(average='weighted', num_classes=10)
43+
self.test_acc = pl.metrics.Accuracy(average='weighted', num_classes=10)
5744

5845
def forward(self, x):
5946
# use forward for inference/predictions
@@ -65,48 +52,30 @@ def training_step(self, batch, batch_idx):
6552
y_hat = self(x)
6653
loss = self.loss_fn(y_hat, y)
6754
self.log('train_acc', self.train_acc(y_hat, y), sync_dist=True)
68-
self.log('train_auroc', self.train_auroc(y_hat, y), sync_dist=True)
69-
self.log('train_auprc', self.train_auprc(y_hat, y), sync_dist=True)
70-
self.log('train_f1', self.train_f1(y_hat, y), sync_dist=True)
7155
return loss
7256

7357
def training_epoch_end(self, outputs):
7458
self.train_acc.reset()
75-
self.train_auroc.reset()
76-
self.train_auprc.reset()
77-
self.train_f1.reset()
7859

7960
def validation_step(self, batch, batch_idx):
8061
x, y = batch
8162
y_hat = self(x)
8263
loss = self.loss_fn(y_hat, y)
8364
self.log('val_acc', self.train_acc(y_hat, y), sync_dist=True)
84-
self.log('val_auroc', self.train_auroc(y_hat, y), sync_dist=True)
85-
self.log('val_auprc', self.train_auprc(y_hat, y), sync_dist=True)
86-
self.log('val_f1', self.train_f1(y_hat, y), sync_dist=True)
8765
return loss
8866

8967
def validation_epoch_end(self, outputs):
9068
self.val_acc.reset()
91-
self.val_auroc.reset()
92-
self.val_auprc.reset()
93-
self.val_f1.reset()
9469

9570
def test_step(self, batch, batch_idx):
9671
x, y = batch
9772
y_hat = self(x)
9873
loss = self.loss_fn(y_hat, y)
9974
self.log('test_acc', self.train_acc(y_hat, y), sync_dist=True)
100-
self.log('test_auroc', self.train_auroc(y_hat, y), sync_dist=True)
101-
self.log('test_auprc', self.train_auprc(y_hat, y), sync_dist=True)
102-
self.log('test_f1', self.train_f1(y_hat, y), sync_dist=True)
10375
return loss
10476

10577
def test_epoch_end(self, outputs):
10678
self.test_acc.reset()
107-
self.test_auroc.reset()
108-
self.test_auprc.reset()
109-
self.test_f1.reset()
11079

11180
# ---------------------
11281
# training setup
@@ -115,7 +84,7 @@ def configure_optimizers(self):
11584
# self.hparams available because we called self.save_hyperparameters()
11685
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
11786
scheduler = CosineAnnealingLR(optimizer, self.hparams.num_epochs)
118-
metric_to_track = 'val_auroc'
87+
metric_to_track = 'val_acc'
11988
return {
12089
'optimizer': optimizer,
12190
'lr_scheduler': scheduler,
@@ -202,17 +171,17 @@ def cli_main():
202171
# Resume from checkpoint if path to a valid one is provided
203172
args.ckpt_name = args.ckpt_name \
204173
if args.ckpt_name is not None \
205-
else 'LitClassifier-{epoch:02d}-{val_auroc:.2f}.ckpt'
174+
else 'LitClassifier-{epoch:02d}-{val_acc:.2f}.ckpt'
206175
checkpoint_path = os.path.join(args.ckpt_dir, args.ckpt_name)
207176
trainer.resume_from_checkpoint = checkpoint_path if os.path.exists(checkpoint_path) else None
208177

209178
# ------------
210179
# training
211180
# ------------
212181
# Create and use callbacks
213-
early_stop_callback = EarlyStopping(monitor='val_auroc', mode='min', min_delta=0.00, patience=3)
214-
checkpoint_callback = ModelCheckpoint(monitor='val_auroc', save_top_k=3, dirpath=args.ckpt_dir,
215-
filename='LitClassifier-{epoch:02d}-{val_auroc:.2f}')
182+
early_stop_callback = EarlyStopping(monitor='val_acc', mode='max', min_delta=0.01, patience=3)
183+
checkpoint_callback = ModelCheckpoint(monitor='val_acc', save_top_k=3, dirpath=args.ckpt_dir,
184+
filename='LitClassifier-{epoch:02d}-{val_acc:.2f}')
216185
lr_callback = LearningRateMonitor(logging_interval='epoch') # Use with a learning rate scheduler
217186
trainer.callbacks = [early_stop_callback, checkpoint_callback, lr_callback]
218187

requirements.txt

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
1-
pytorch-lightning==1.3.8
2-
torchmetrics==0.4.1
1+
pytorch-lightning==1.2.4

0 commit comments

Comments
 (0)