Skip to content

Commit 858d36c

Browse files
committed
Remove averaging in Accuracy
1 parent a50ba3b commit 858d36c

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

project/lit_image_classifier.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ def __init__(self, backbone, num_epochs: int = 5, lr=1e-4):
3838
self.loss_fn = nn.CrossEntropyLoss()
3939

4040
# Define cross-validation metrics
41-
self.train_acc = pl.metrics.Accuracy(average='weighted', num_classes=10)
42-
self.val_acc = pl.metrics.Accuracy(average='weighted', num_classes=10)
43-
self.test_acc = pl.metrics.Accuracy(average='weighted', num_classes=10)
41+
self.train_acc = pl.metrics.Accuracy()
42+
self.val_acc = pl.metrics.Accuracy()
43+
self.test_acc = pl.metrics.Accuracy()
4444

4545
def forward(self, x):
4646
# use forward for inference/predictions

0 commit comments

Comments
 (0)