Skip to content

Commit 3fda2c0

Browse files
committed
add test script for the Google speech commands
1 parent c5d3c56 commit 3fda2c0

File tree

2 files changed

+132
-2
lines changed

2 files changed

+132
-2
lines changed

test_cifar10.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import torch
1212
from torch.autograd import Variable
13+
from torch.utils.data import DataLoader
1314

1415
import torchvision
1516
from torchvision.transforms import *
@@ -40,7 +41,7 @@
4041
])
4142

4243
test_dataset = torchvision.datasets.CIFAR10(root=args.dataset_root, train=False, download=True, transform=to_tensor_and_normalize)
43-
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=args.test_batch_size, shuffle=False, num_workers=args.dataload_workers_nums)
44+
test_dataloader = DataLoader(test_dataset, batch_size=args.test_batch_size, shuffle=False, num_workers=args.dataload_workers_nums)
4445

4546
CLASSES = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
4647

@@ -83,7 +84,9 @@ def test():
8384
'acc': "%.02f%%" % (100*correct/total)
8485
})
8586

86-
print("accuracy: %f%%, loss: %f" % (100*correct/total, running_loss / it))
87+
accuracy = correct/total
88+
epoch_loss = running_loss / it
89+
print("accuracy: %f%%, loss: %f" % (100*accuracy, epoch_loss))
8790
print("confusion matrix:")
8891
print(confusion_matrix.value())
8992

test_speech_commands.py

+127
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
#!/usr/bin/env python
2+
"""Test a pretrained CNN for Google speech commands."""
3+
4+
__author__ = 'Yuan Xu, Erdene-Ochir Tuguldur'
5+
6+
import argparse
7+
import time
8+
9+
from tqdm import *
10+
11+
import torch
12+
from torch.autograd import Variable
13+
from torch.utils.data import DataLoader
14+
15+
from torchvision.transforms import *
16+
import torchnet
17+
18+
from speech_commands_dataset import *
19+
from transforms_wav import *
20+
from transforms_stft import *
21+
22+
parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
23+
parser.add_argument("--dataset-dir", type=str, default='datasets/speech_commands/test', help='path of test dataset')
24+
parser.add_argument("--batch-size", type=int, default=128, help='batch size')
25+
parser.add_argument("--dataload-workers-nums", type=int, default=3, help='number of workers for dataloader')
26+
parser.add_argument("--input", choices=['mel32'], default='mel32', help='input of NN')
27+
parser.add_argument('--multi-crop', action='store_true', help='apply crop and average the results')
28+
parser.add_argument('--generate-kaggle-submission', action='store_true', help='generate kaggle submission file')
29+
parser.add_argument('--output', type=str, help='save output to file for the kaggle competition', default='kaggle_submission.csv')
30+
#parser.add_argument('--prob-output', type=str, help='save probabilities to file', default='probabilities.json')
31+
parser.add_argument("model", help='a pretrained neural network model')
32+
args = parser.parse_args()
33+
34+
print("loading model...")
35+
model = torch.load(args.model)
36+
model.float()
37+
38+
use_gpu = torch.cuda.is_available()
39+
print('use_gpu', use_gpu)
40+
if use_gpu:
41+
torch.backends.cudnn.benchmark = True
42+
model.cuda()
43+
44+
n_mels = 32
45+
if args.input == 'mel40':
46+
n_mels = 40
47+
48+
feature_transform = Compose([ToMelSpectrogram(n_mels=n_mels), ToTensor('mel_spectrogram', 'input')])
49+
transform = Compose([LoadAudio(), FixAudioLength(), feature_transform])
50+
test_dataset = SpeechCommandsDataset(args.dataset_dir, transform, silence_percentage=0)
51+
test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, sampler=None,
52+
pin_memory=use_gpu, num_workers=args.dataload_workers_nums)
53+
54+
criterion = torch.nn.CrossEntropyLoss()
55+
56+
def multi_crop(inputs):
57+
b = 1
58+
size = inputs.size(3) - b * 2
59+
patches = [inputs[:, :, :, i*b:size+i*b] for i in range(3)]
60+
outputs = torch.stack(patches)
61+
outputs = outputs.view(-1, inputs.size(1), inputs.size(2), size)
62+
outputs = torch.nn.functional.pad(outputs, (b, b, 0, 0), mode='replicate')
63+
return torch.cat((inputs, outputs.data))
64+
65+
def test():
66+
model.eval() # Set model to evaluate mode
67+
68+
#running_loss = 0.0
69+
#it = 0
70+
correct = 0
71+
total = 0
72+
confusion_matrix = torchnet.meter.ConfusionMeter(len(CLASSES))
73+
predictions = {}
74+
probabilities = {}
75+
76+
pbar = tqdm(test_dataloader, unit="audios", unit_scale=test_dataloader.batch_size)
77+
for batch in pbar:
78+
inputs = batch['input']
79+
inputs = torch.unsqueeze(inputs, 1)
80+
targets = batch['target']
81+
82+
n = inputs.size(0)
83+
if args.multi_crop:
84+
inputs = multi_crop(inputs)
85+
86+
inputs = Variable(inputs, volatile = True)
87+
targets = Variable(targets, requires_grad=False)
88+
89+
if use_gpu:
90+
inputs = inputs.cuda()
91+
targets = targets.cuda(async=True)
92+
93+
# forward
94+
outputs = model(inputs)
95+
#loss = criterion(outputs, targets)
96+
outputs = torch.nn.functional.softmax(outputs, dim=1)
97+
if args.multi_crop:
98+
outputs = outputs.view(-1, n, outputs.size(1))
99+
outputs = torch.mean(outputs, dim=0)
100+
outputs = torch.nn.functional.softmax(outputs, dim=1)
101+
102+
# statistics
103+
#it += 1
104+
#running_loss += loss.data[0]
105+
pred = outputs.data.max(1, keepdim=True)[1]
106+
correct += pred.eq(targets.data.view_as(pred)).sum()
107+
total += targets.size(0)
108+
confusion_matrix.add(pred, targets.data)
109+
110+
filenames = batch['path']
111+
for j in range(len(pred)):
112+
fn = filenames[j]
113+
predictions[fn] = pred[j][0]
114+
probabilities[fn] = outputs.data[j].tolist()
115+
116+
accuracy = correct/total
117+
#epoch_loss = running_loss / it
118+
print("accuracy: %f%%" % (100*accuracy))
119+
print("confusion matrix:")
120+
print(confusion_matrix.value())
121+
122+
return probabilities, predictions
123+
124+
print("testing...")
125+
probabilities, predictions = test()
126+
if args.generate_kaggle_submission:
127+
print("generating kaggle submission file '%s'..." % args.output)

0 commit comments

Comments
 (0)