Skip to content

Commit 000bdc3

Browse files
committed
add mixup result for vgg19_bn
1 parent a9a6a16 commit 000bdc3

File tree

4 files changed

+46
-7
lines changed

4 files changed

+46
-7
lines changed

README.md

+14-4
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ For the training parameters, see [TRAINING.md](TRAINING.md). Earlier stopping th
1818

1919
<table><tbody>
2020
<th valign="bottom"><sup><sub>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;Model&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;</sub></sup></th>
21-
<th valign="bottom"><sup><sub>CIFAR-10<br/>test set<br/>accuracy</sub></sup></th>
21+
<th valign="bottom"><sup><sub>CIFAR10<br/>test set<br/>accuracy</sub></sup></th>
2222
<th valign="bottom"><sup><sub>Speech Commands<br/>test set<br/>accuracy</sub></sup></th>
2323
<th valign="bottom"><sup><sub>Speech Commands<br/>test set<br/>accuracy with crop</sub></sup></th>
2424
<th valign="bottom"><sup><sub>Speech Commands<br/>Kaggle private LB<br/>score</sub></sup></th>
@@ -40,8 +40,8 @@ For the training parameters, see [TRAINING.md](TRAINING.md). Earlier stopping th
4040
<td align="center"><sup><sub>-</sub></sup></td>
4141
<td align="center"><sup><sub>97.937089%</sub></sup></td>
4242
<td align="center"><sup><sub>97.922458%</sub></sup></td>
43-
<td align="center"><sup><sub></sub></sup></td>
44-
<td align="center"><sup><sub></sub></sup></td>
43+
<td align="center"><sup><sub>0.88546</sub></sup></td>
44+
<td align="center"><sup><sub>0.88699</sub></sup></td>
4545
<td align="left"><sup><sub></sub></sup></td>
4646
</tr>
4747

@@ -103,13 +103,23 @@ After the competition, some of the networks were retrained using [mixup: Beyond
103103

104104
<table><tbody>
105105
<th valign="bottom"><sup><sub>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;Model&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;</sub></sup></th>
106-
<th valign="bottom"><sup><sub>CIFAR-10<br/>test set<br/>accuracy</sub></sup></th>
106+
<th valign="bottom"><sup><sub>CIFAR10<br/>test set<br/>accuracy</sub></sup></th>
107107
<th valign="bottom"><sup><sub>Speech Commands<br/>test set<br/>accuracy</sub></sup></th>
108108
<th valign="bottom"><sup><sub>Speech Commands<br/>test set<br/>accuracy with crop</sub></sup></th>
109109
<th valign="bottom"><sup><sub>Speech Commands<br/>Kaggle private LB<br/>score</sub></sup></th>
110110
<th valign="bottom"><sup><sub>Speech Commands<br/>Kaggle private LB<br/>score with crop</sub></sup></th>
111111
<th valign="bottom"><sup><sub>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;Remarks&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;</sub></sup></th>
112112

113+
<tr>
114+
<td align="left"><sup><sub>VGG19 BN</sub></sup></td>
115+
<td align="center"><sup><sub>-</sub></sup></td>
116+
<td align="center"><sup><sub>97.483541%</sub></sup></td>
117+
<td align="center"><sup><sub>97.542063%</sub></sup></td>
118+
<td align="center"><sup><sub>0.89521</sub></sup></td>
119+
<td align="center"><sup><sub>0.89839</sub></sup></td>
120+
<td align="left"><sup><sub></sub></sup></td>
121+
</tr>
122+
113123
<tr>
114124
<td align="left"><sup><sub>WRN-52-10</sub></sup></td>
115125
<td align="center"><sup><sub>-</sub></sup></td>

TRAINING.md

+7-1
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,14 @@
66
python train_speech_commands.py --model=vgg19_bn --optim=sgd --lr-scheduler=plateau --learning-rate=0.01 --lr-scheduler-patience=5 --max-epochs=70 --batch-size=96
77
```
88

9+
#### VGG19 BN with Mixup
10+
* accuracy: 97.483541%, 97.542063% with crop, Kaggle private LB score: 0.89521 and 0.89839 with crop, epoch time: 1m30s
11+
```sh
12+
python train_speech_commands.py --model=vgg19_bn --optim=sgd --lr-scheduler=plateau --learning-rate=0.01 --lr-scheduler-patience=5 --max-epochs=70 --batch-size=96 --mixup
13+
```
14+
915
#### WideResNet 28-10
10-
* accuracy: 97.937089%, 97.922458% with crop, Kaggle private LB score: and with crop, epoch time: ?
16+
* accuracy: 97.937089%, 97.922458% with crop, Kaggle private LB score: 0.88546 and 0.88699 with crop, epoch time: 2m5s
1117
```sh
1218
python train_speech_commands.py --model=wideresnet28_10 --optim=sgd --lr-scheduler=plateau --learning-rate=0.01 --lr-scheduler-patience=5 --max-epochs=70 --batch-size=96
1319
```

train_cifar10.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from tensorboardX import SummaryWriter
1919

2020
import models
21+
from mixup import *
2122

2223
parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
2324
parser.add_argument("--comment", type=str, default='', help='comment in tensorboard title')
@@ -35,6 +36,7 @@
3536
parser.add_argument("--max-epochs", type=int, default=150, help='max number of epochs')
3637
parser.add_argument("--resume", type=str, help='checkpoint file to resume')
3738
parser.add_argument("--model", choices=models.available_models, default=models.available_models[0], help='model of NN')
39+
parser.add_argument('--mixup', action='store_true', help='use mixup')
3840
args = parser.parse_args()
3941

4042
use_gpu = torch.cuda.is_available()
@@ -123,6 +125,10 @@ def train(epoch):
123125
pbar = tqdm(train_dataloader, unit="images", unit_scale=train_dataloader.batch_size)
124126
for batch in pbar:
125127
inputs, targets = batch
128+
129+
if args.mixup:
130+
inputs, targets = mixup(inputs, targets, num_classes=len(CLASSES))
131+
126132
inputs = Variable(inputs, requires_grad=True)
127133
targets = Variable(targets, requires_grad=False)
128134

@@ -132,7 +138,10 @@ def train(epoch):
132138

133139
# forward/backward
134140
outputs = model(inputs)
135-
loss = criterion(outputs, targets)
141+
if args.mixup:
142+
loss = mixup_cross_entropy_loss(outputs, targets)
143+
else:
144+
loss = criterion(outputs, targets)
136145
optimizer.zero_grad()
137146
loss.backward()
138147
optimizer.step()
@@ -142,6 +151,9 @@ def train(epoch):
142151
global_step += 1
143152
running_loss += loss.data[0]
144153
pred = outputs.data.max(1, keepdim=True)[1]
154+
if args.mixup:
155+
_, targets = batch
156+
targets = Variable(targets, requires_grad=False).cuda(async=True)
145157
correct += pred.eq(targets.data.view_as(pred)).sum()
146158
total += targets.size(0)
147159

train_speech_commands.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import models
2222
from datasets import *
2323
from transforms import *
24+
from mixup import *
2425

2526
parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
2627
parser.add_argument("--train-dataset", type=str, default='datasets/speech_commands/train', help='path of train dataset')
@@ -40,6 +41,7 @@
4041
parser.add_argument("--resume", type=str, help='checkpoint file to resume')
4142
parser.add_argument("--model", choices=models.available_models, default=models.available_models[0], help='model of NN')
4243
parser.add_argument("--input", choices=['mel32'], default='mel32', help='input of NN')
44+
parser.add_argument('--mixup', action='store_true', help='use mixup')
4345
args = parser.parse_args()
4446

4547
use_gpu = torch.cuda.is_available()
@@ -142,6 +144,9 @@ def train(epoch):
142144
inputs = torch.unsqueeze(inputs, 1)
143145
targets = batch['target']
144146

147+
if args.mixup:
148+
inputs, targets = mixup(inputs, targets, num_classes=len(CLASSES))
149+
145150
inputs = Variable(inputs, requires_grad=True)
146151
targets = Variable(targets, requires_grad=False)
147152

@@ -151,7 +156,10 @@ def train(epoch):
151156

152157
# forward/backward
153158
outputs = model(inputs)
154-
loss = criterion(outputs, targets)
159+
if args.mixup:
160+
loss = mixup_cross_entropy_loss(outputs, targets)
161+
else:
162+
loss = criterion(outputs, targets)
155163
optimizer.zero_grad()
156164
loss.backward()
157165
optimizer.step()
@@ -161,6 +169,9 @@ def train(epoch):
161169
global_step += 1
162170
running_loss += loss.data[0]
163171
pred = outputs.data.max(1, keepdim=True)[1]
172+
if args.mixup:
173+
targets = batch['target']
174+
targets = Variable(targets, requires_grad=False).cuda(async=True)
164175
correct += pred.eq(targets.data.view_as(pred)).sum()
165176
total += targets.size(0)
166177

0 commit comments

Comments
 (0)