Skip to content

Commit 03ebd18

Browse files
committed
Remove torchmetrics
1 parent c433ee7 commit 03ebd18

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

project/lit_image_classifier.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
import pytorch_lightning as pl
66
import torch
77
import torch.nn as nn
8-
import torchmetrics as tm
98
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
9+
from pytorch_lightning.metrics import Accuracy
1010
from pytorch_lightning.plugins import DDPPlugin
1111
from torch.optim.lr_scheduler import CosineAnnealingLR
1212
from torch.utils.data import DataLoader, random_split
@@ -42,9 +42,9 @@ def __init__(self, backbone, num_epochs: int = 5, lr=1e-4):
4242
self.loss_fn = nn.CrossEntropyLoss()
4343

4444
# Define cross-validation metrics
45-
self.train_acc = tm.Accuracy()
46-
self.val_acc = tm.Accuracy()
47-
self.test_acc = tm.Accuracy()
45+
self.train_acc = Accuracy()
46+
self.val_acc = Accuracy()
47+
self.test_acc = Accuracy()
4848

4949
def forward(self, x):
5050
# use forward for inference/predictions

0 commit comments

Comments
 (0)