Skip to content

Commit 0f940f9

Browse files
committed
generate a kaggle submission file
1 parent 3fda2c0 commit 0f940f9

File tree

3 files changed

+10
-6
lines changed

3 files changed

+10
-6
lines changed

.gitignore

+1-5
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
11
__pycache__
22
*.pth
3+
*.csv
34
pretrained/
45
runs/
5-
6-
checkpoint/
7-
train.py
8-
main.py
9-
utils.py

test_speech_commands.py

+8
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
import argparse
77
import time
8+
import csv
9+
import os
810

911
from tqdm import *
1012

@@ -125,3 +127,9 @@ def test():
125127
probabilities, predictions = test()
126128
if args.generate_kaggle_submission:
127129
print("generating kaggle submission file '%s'..." % args.output)
130+
with open(args.output, 'w') as outfile:
131+
fieldnames = ['fname', 'label']
132+
writer = csv.DictWriter(outfile, fieldnames=fieldnames)
133+
writer.writeheader()
134+
for fname, pred in predictions.items():
135+
writer.writerow({'fname': os.path.basename(fname), 'label': test_dataset.classes[pred]})

train_speech_commands.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
parser.add_argument("--batch-size", type=int, default=128, help='batch size')
3232
parser.add_argument("--dataload-workers-nums", type=int, default=6, help='number of workers for dataloader')
3333
parser.add_argument("--weight-decay", type=float, default=1e-2, help='weight decay')
34-
parser.add_argument("--optim", choices=['sgd', 'adam'], default='adam', help='choices of optimization algorithms')
34+
parser.add_argument("--optim", choices=['sgd', 'adam'], default='sgd', help='choices of optimization algorithms')
3535
parser.add_argument("--learning-rate", type=float, default=1e-4, help='learning rate for optimization')
3636
parser.add_argument("--lr-scheduler", choices=['plateau', 'step'], default='plateau', help='method to adjust learning rate')
3737
parser.add_argument("--lr-scheduler-patience", type=int, default=5, help='lr scheduler plateau: Number of epochs with no improvement after which learning rate will be reduced')

0 commit comments

Comments
 (0)