-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtest.py
30 lines (26 loc) · 1.24 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
test_pred = torch.LongTensor()
target_pred = torch.LongTensor()
target_data = torch.LongTensor()
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
pred_cpu = output.cpu().data.max(dim=1, keepdim=True)[1]
correct += pred.eq(target.view_as(pred)).sum().item()
test_pred = torch.cat((test_pred, pred_cpu), dim=0)
target_pred = torch.cat((target_pred, target.cpu()), dim=0)
target_data = torch.cat((target_data, data.cpu()), dim=0)
test_loss /= len(test_loader.dataset)
test_acc = 100.*correct/len(test_loader.dataset)
print(f'\nTest set: Average loss: {test_loss:.3f}, Accuracy: {100. * correct/len(test_loader.dataset):.2f}')
return test_loss, test_acc, test_pred, target_pred, target_data