Skip to content

Commit 67cf5a9

Browse files
committed
add image_mobile.py
1 parent ab7e718 commit 67cf5a9

File tree

5 files changed

+936
-42
lines changed

5 files changed

+936
-42
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ Single crop validation error on ImageNet-1k (center 224x224 crop from resized im
4646
#### ShuffleNetV2_1x
4747
```
4848
python -m torch.distributed.launch --nproc_per_node=8 imagenet_mobile.py --cos -a shufflenetv2_1x --data /path/to/imagenet1k/ \
49-
--epochs 300 --wd 4e-5 --gamma 0.1 -c checkpoints/imagenet/shufflenetv2_1x --train-batch 128 --opt-level O0 # Triaing
49+
--epochs 300 --wd 4e-5 --gamma 0.1 -c checkpoints/imagenet/shufflenetv2_1x --train-batch 128 --opt-level O0 --nowd-bn # Triaing
5050
5151
python -m torch.distributed.launch --nproc_per_node=2 imagenet_mobile.py -a shufflenetv2_1x --data /path/to/imagenet1k/ \
5252
-e --resume ../pretrain/shufflenetv2_1x.pth.tar --test-batch 100 --opt-level O0 # Testing, ~69.6% top-1 Acc

classification/imagenet_fast.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,8 @@ def main():
401401
print('==> Resuming from checkpoint..', args.resume)
402402
assert os.path.isfile(args.resume), 'Error: no checkpoint directory found!'
403403
args.checkpoint = os.path.dirname(args.resume)
404-
checkpoint = torch.load(args.resume, map_location=torch.device('cpu'))
404+
#checkpoint = torch.load(args.resume, map_location=torch.device('cpu'))
405+
checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu))
405406
best_acc = checkpoint['best_acc']
406407
start_epoch = checkpoint['epoch']
407408
# model may have more keys
@@ -529,7 +530,11 @@ def train(train_loader, model, criterion, optimizer, epoch, use_cuda):
529530
# loss.backward()
530531
with amp.scale_loss(old_loss, optimizer) as loss:
531532
loss.backward()
532-
optimizer.step(print_flag=print_flag)
533+
534+
if args.el2:
535+
optimizer.step(print_flag=print_flag)
536+
else:
537+
optimizer.step()
533538

534539

535540
if batch_idx % args.print_freq == 0:

0 commit comments

Comments
 (0)