|
| 1 | +#!/usr/bin/env python |
| 2 | +"""Test a pretrained CNN for Google speech commands.""" |
| 3 | + |
| 4 | +__author__ = 'Yuan Xu, Erdene-Ochir Tuguldur' |
| 5 | + |
| 6 | +import argparse |
| 7 | +import time |
| 8 | + |
| 9 | +from tqdm import * |
| 10 | + |
| 11 | +import torch |
| 12 | +from torch.autograd import Variable |
| 13 | +from torch.utils.data import DataLoader |
| 14 | + |
| 15 | +from torchvision.transforms import * |
| 16 | +import torchnet |
| 17 | + |
| 18 | +from speech_commands_dataset import * |
| 19 | +from transforms_wav import * |
| 20 | +from transforms_stft import * |
| 21 | + |
| 22 | +parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter) |
| 23 | +parser.add_argument("--dataset-dir", type=str, default='datasets/speech_commands/test', help='path of test dataset') |
| 24 | +parser.add_argument("--batch-size", type=int, default=128, help='batch size') |
| 25 | +parser.add_argument("--dataload-workers-nums", type=int, default=3, help='number of workers for dataloader') |
| 26 | +parser.add_argument("--input", choices=['mel32'], default='mel32', help='input of NN') |
| 27 | +parser.add_argument('--multi-crop', action='store_true', help='apply crop and average the results') |
| 28 | +parser.add_argument('--generate-kaggle-submission', action='store_true', help='generate kaggle submission file') |
| 29 | +parser.add_argument('--output', type=str, help='save output to file for the kaggle competition', default='kaggle_submission.csv') |
| 30 | +#parser.add_argument('--prob-output', type=str, help='save probabilities to file', default='probabilities.json') |
| 31 | +parser.add_argument("model", help='a pretrained neural network model') |
| 32 | +args = parser.parse_args() |
| 33 | + |
| 34 | +print("loading model...") |
| 35 | +model = torch.load(args.model) |
| 36 | +model.float() |
| 37 | + |
| 38 | +use_gpu = torch.cuda.is_available() |
| 39 | +print('use_gpu', use_gpu) |
| 40 | +if use_gpu: |
| 41 | + torch.backends.cudnn.benchmark = True |
| 42 | + model.cuda() |
| 43 | + |
| 44 | +n_mels = 32 |
| 45 | +if args.input == 'mel40': |
| 46 | + n_mels = 40 |
| 47 | + |
| 48 | +feature_transform = Compose([ToMelSpectrogram(n_mels=n_mels), ToTensor('mel_spectrogram', 'input')]) |
| 49 | +transform = Compose([LoadAudio(), FixAudioLength(), feature_transform]) |
| 50 | +test_dataset = SpeechCommandsDataset(args.dataset_dir, transform, silence_percentage=0) |
| 51 | +test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, sampler=None, |
| 52 | + pin_memory=use_gpu, num_workers=args.dataload_workers_nums) |
| 53 | + |
| 54 | +criterion = torch.nn.CrossEntropyLoss() |
| 55 | + |
| 56 | +def multi_crop(inputs): |
| 57 | + b = 1 |
| 58 | + size = inputs.size(3) - b * 2 |
| 59 | + patches = [inputs[:, :, :, i*b:size+i*b] for i in range(3)] |
| 60 | + outputs = torch.stack(patches) |
| 61 | + outputs = outputs.view(-1, inputs.size(1), inputs.size(2), size) |
| 62 | + outputs = torch.nn.functional.pad(outputs, (b, b, 0, 0), mode='replicate') |
| 63 | + return torch.cat((inputs, outputs.data)) |
| 64 | + |
| 65 | +def test(): |
| 66 | + model.eval() # Set model to evaluate mode |
| 67 | + |
| 68 | + #running_loss = 0.0 |
| 69 | + #it = 0 |
| 70 | + correct = 0 |
| 71 | + total = 0 |
| 72 | + confusion_matrix = torchnet.meter.ConfusionMeter(len(CLASSES)) |
| 73 | + predictions = {} |
| 74 | + probabilities = {} |
| 75 | + |
| 76 | + pbar = tqdm(test_dataloader, unit="audios", unit_scale=test_dataloader.batch_size) |
| 77 | + for batch in pbar: |
| 78 | + inputs = batch['input'] |
| 79 | + inputs = torch.unsqueeze(inputs, 1) |
| 80 | + targets = batch['target'] |
| 81 | + |
| 82 | + n = inputs.size(0) |
| 83 | + if args.multi_crop: |
| 84 | + inputs = multi_crop(inputs) |
| 85 | + |
| 86 | + inputs = Variable(inputs, volatile = True) |
| 87 | + targets = Variable(targets, requires_grad=False) |
| 88 | + |
| 89 | + if use_gpu: |
| 90 | + inputs = inputs.cuda() |
| 91 | + targets = targets.cuda(async=True) |
| 92 | + |
| 93 | + # forward |
| 94 | + outputs = model(inputs) |
| 95 | + #loss = criterion(outputs, targets) |
| 96 | + outputs = torch.nn.functional.softmax(outputs, dim=1) |
| 97 | + if args.multi_crop: |
| 98 | + outputs = outputs.view(-1, n, outputs.size(1)) |
| 99 | + outputs = torch.mean(outputs, dim=0) |
| 100 | + outputs = torch.nn.functional.softmax(outputs, dim=1) |
| 101 | + |
| 102 | + # statistics |
| 103 | + #it += 1 |
| 104 | + #running_loss += loss.data[0] |
| 105 | + pred = outputs.data.max(1, keepdim=True)[1] |
| 106 | + correct += pred.eq(targets.data.view_as(pred)).sum() |
| 107 | + total += targets.size(0) |
| 108 | + confusion_matrix.add(pred, targets.data) |
| 109 | + |
| 110 | + filenames = batch['path'] |
| 111 | + for j in range(len(pred)): |
| 112 | + fn = filenames[j] |
| 113 | + predictions[fn] = pred[j][0] |
| 114 | + probabilities[fn] = outputs.data[j].tolist() |
| 115 | + |
| 116 | + accuracy = correct/total |
| 117 | + #epoch_loss = running_loss / it |
| 118 | + print("accuracy: %f%%" % (100*accuracy)) |
| 119 | + print("confusion matrix:") |
| 120 | + print(confusion_matrix.value()) |
| 121 | + |
| 122 | + return probabilities, predictions |
| 123 | + |
| 124 | +print("testing...") |
| 125 | +probabilities, predictions = test() |
| 126 | +if args.generate_kaggle_submission: |
| 127 | + print("generating kaggle submission file '%s'..." % args.output) |
0 commit comments