Skip to content

Commit 0d3fe14

Browse files
seungwonparksoumith
authored andcommitted
Use argmax() instead of max()[1] in mnist/main.py (pytorch#494)
Thanks for great examples! For beginners, I thought using `.argmax()` will be better than `.max()[1]`.
1 parent 29a38c6 commit 0d3fe14

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

mnist/main.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def test(args, model, device, test_loader):
4848
data, target = data.to(device), target.to(device)
4949
output = model(data)
5050
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
51-
pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
51+
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
5252
correct += pred.eq(target.view_as(pred)).sum().item()
5353

5454
test_loss /= len(test_loader.dataset)

0 commit comments

Comments
 (0)