Skip to content

Commit 49ec0bd

Browse files
authored
tweak running examples without cuda (#794)
* tweak running examples without cuda * rework dry_run handling in mnist, mnist_hogwild
1 parent 59caa16 commit 49ec0bd

File tree

10 files changed

+120
-57
lines changed

10 files changed

+120
-57
lines changed

dcgan/main.py

+6
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002')
2727
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
2828
parser.add_argument('--cuda', action='store_true', help='enables cuda')
29+
parser.add_argument('--dry-run', action='store_true', help='check a single training cycle works')
2930
parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
3031
parser.add_argument('--netG', default='', help="path to netG (to continue training)")
3132
parser.add_argument('--netD', default='', help="path to netD (to continue training)")
@@ -211,6 +212,9 @@ def forward(self, input):
211212
optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
212213
optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
213214

215+
if opt.dry_run:
216+
opt.niter = 1
217+
214218
for epoch in range(opt.niter):
215219
for i, data in enumerate(dataloader, 0):
216220
############################
@@ -261,6 +265,8 @@ def forward(self, input):
261265
'%s/fake_samples_epoch_%03d.png' % (opt.outf, epoch),
262266
normalize=True)
263267

268+
if opt.dry_run:
269+
break
264270
# do checkpointing
265271
torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (opt.outf, epoch))
266272
torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (opt.outf, epoch))

imagenet/main.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,9 @@ def main_worker(gpu, ngpus_per_node, args):
136136
print("=> creating model '{}'".format(args.arch))
137137
model = models.__dict__[args.arch]()
138138

139-
if args.distributed:
139+
if not torch.cuda.is_available():
140+
print('using CPU, this will be slow')
141+
elif args.distributed:
140142
# For multiprocessing distributed, DistributedDataParallel constructor
141143
# should always set the single device scope, otherwise,
142144
# DistributedDataParallel will use all available devices.
@@ -281,7 +283,8 @@ def train(train_loader, model, criterion, optimizer, epoch, args):
281283

282284
if args.gpu is not None:
283285
images = images.cuda(args.gpu, non_blocking=True)
284-
target = target.cuda(args.gpu, non_blocking=True)
286+
if torch.cuda.is_available():
287+
target = target.cuda(args.gpu, non_blocking=True)
285288

286289
# compute output
287290
output = model(images)
@@ -324,7 +327,8 @@ def validate(val_loader, model, criterion, args):
324327
for i, (images, target) in enumerate(val_loader):
325328
if args.gpu is not None:
326329
images = images.cuda(args.gpu, non_blocking=True)
327-
target = target.cuda(args.gpu, non_blocking=True)
330+
if torch.cuda.is_available():
331+
target = target.cuda(args.gpu, non_blocking=True)
328332

329333
# compute output
330334
output = model(images)

mnist/main.py

+21-14
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ def train(args, model, device, train_loader, optimizer, epoch):
4747
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
4848
epoch, batch_idx * len(data), len(train_loader.dataset),
4949
100. * batch_idx / len(train_loader), loss.item()))
50+
if args.dry_run:
51+
break
5052

5153

5254
def test(model, device, test_loader):
@@ -83,6 +85,8 @@ def main():
8385
help='Learning rate step gamma (default: 0.7)')
8486
parser.add_argument('--no-cuda', action='store_true', default=False,
8587
help='disables CUDA training')
88+
parser.add_argument('--dry-run', action='store_true', default=False,
89+
help='quickly check a single pass')
8690
parser.add_argument('--seed', type=int, default=1, metavar='S',
8791
help='random seed (default: 1)')
8892
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
@@ -96,20 +100,23 @@ def main():
96100

97101
device = torch.device("cuda" if use_cuda else "cpu")
98102

99-
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
100-
train_loader = torch.utils.data.DataLoader(
101-
datasets.MNIST('../data', train=True, download=True,
102-
transform=transforms.Compose([
103-
transforms.ToTensor(),
104-
transforms.Normalize((0.1307,), (0.3081,))
105-
])),
106-
batch_size=args.batch_size, shuffle=True, **kwargs)
107-
test_loader = torch.utils.data.DataLoader(
108-
datasets.MNIST('../data', train=False, transform=transforms.Compose([
109-
transforms.ToTensor(),
110-
transforms.Normalize((0.1307,), (0.3081,))
111-
])),
112-
batch_size=args.test_batch_size, shuffle=True, **kwargs)
103+
kwargs = {'batch_size': args.batch_size}
104+
if use_cuda:
105+
kwargs.update({'num_workers': 1,
106+
'pin_memory': True,
107+
'shuffle': True},
108+
)
109+
110+
transform=transforms.Compose([
111+
transforms.ToTensor(),
112+
transforms.Normalize((0.1307,), (0.3081,))
113+
])
114+
dataset1 = datasets.MNIST('../data', train=True, download=True,
115+
transform=transform)
116+
dataset2 = datasets.MNIST('../data', train=False,
117+
transform=transform)
118+
train_loader = torch.utils.data.DataLoader(dataset1,**kwargs)
119+
test_loader = torch.utils.data.DataLoader(dataset2, **kwargs)
113120

114121
model = Net().to(device)
115122
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)

mnist_hogwild/main.py

+22-3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import torch.nn as nn
55
import torch.nn.functional as F
66
import torch.multiprocessing as mp
7+
from torch.utils.data.sampler import Sampler
8+
from torchvision import datasets, transforms
79

810
from train import train, test
911

@@ -27,6 +29,8 @@
2729
help='how many training processes to use (default: 2)')
2830
parser.add_argument('--cuda', action='store_true', default=False,
2931
help='enables CUDA training')
32+
parser.add_argument('--dry-run', action='store_true', default=False,
33+
help='quickly check a single pass')
3034

3135
class Net(nn.Module):
3236
def __init__(self):
@@ -46,12 +50,26 @@ def forward(self, x):
4650
x = self.fc2(x)
4751
return F.log_softmax(x, dim=1)
4852

53+
4954
if __name__ == '__main__':
5055
args = parser.parse_args()
5156

5257
use_cuda = args.cuda and torch.cuda.is_available()
5358
device = torch.device("cuda" if use_cuda else "cpu")
54-
dataloader_kwargs = {'pin_memory': True} if use_cuda else {}
59+
transform=transforms.Compose([
60+
transforms.ToTensor(),
61+
transforms.Normalize((0.1307,), (0.3081,))
62+
])
63+
dataset1 = datasets.MNIST('../data', train=True, download=True,
64+
transform=transform)
65+
dataset2 = datasets.MNIST('../data', train=False,
66+
transform=transform)
67+
kwargs = {'batch_size': args.batch_size,
68+
'shuffle': True}
69+
if use_cuda:
70+
kwargs.update({'num_workers': 1,
71+
'pin_memory': True,
72+
})
5573

5674
torch.manual_seed(args.seed)
5775
mp.set_start_method('spawn')
@@ -61,12 +79,13 @@ def forward(self, x):
6179

6280
processes = []
6381
for rank in range(args.num_processes):
64-
p = mp.Process(target=train, args=(rank, args, model, device, dataloader_kwargs))
82+
p = mp.Process(target=train, args=(rank, args, model, device,
83+
dataset1, kwargs))
6584
# We first train the model across `num_processes` processes
6685
p.start()
6786
processes.append(p)
6887
for p in processes:
6988
p.join()
7089

7190
# Once training is complete, we can test the model
72-
test(args, model, device, dataloader_kwargs)
91+
test(args, model, device, dataset2, kwargs)

mnist_hogwild/train.py

+6-18
Original file line numberDiff line numberDiff line change
@@ -2,36 +2,22 @@
22
import torch
33
import torch.optim as optim
44
import torch.nn.functional as F
5-
from torchvision import datasets, transforms
65

76

8-
def train(rank, args, model, device, dataloader_kwargs):
7+
def train(rank, args, model, device, dataset, dataloader_kwargs):
98
torch.manual_seed(args.seed + rank)
109

11-
train_loader = torch.utils.data.DataLoader(
12-
datasets.MNIST('../data', train=True, download=True,
13-
transform=transforms.Compose([
14-
transforms.ToTensor(),
15-
transforms.Normalize((0.1307,), (0.3081,))
16-
])),
17-
batch_size=args.batch_size, shuffle=True, num_workers=1,
18-
**dataloader_kwargs)
10+
train_loader = torch.utils.data.DataLoader(dataset, **dataloader_kwargs)
1911

2012
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
2113
for epoch in range(1, args.epochs + 1):
2214
train_epoch(epoch, args, model, device, train_loader, optimizer)
2315

2416

25-
def test(args, model, device, dataloader_kwargs):
17+
def test(args, model, device, dataset, dataloader_kwargs):
2618
torch.manual_seed(args.seed)
2719

28-
test_loader = torch.utils.data.DataLoader(
29-
datasets.MNIST('../data', train=False, transform=transforms.Compose([
30-
transforms.ToTensor(),
31-
transforms.Normalize((0.1307,), (0.3081,))
32-
])),
33-
batch_size=args.batch_size, shuffle=True, num_workers=1,
34-
**dataloader_kwargs)
20+
test_loader = torch.utils.data.DataLoader(dataset, **dataloader_kwargs)
3521

3622
test_epoch(model, device, test_loader)
3723

@@ -49,6 +35,8 @@ def train_epoch(epoch, args, model, device, data_loader, optimizer):
4935
print('{}\tTrain Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
5036
pid, epoch, batch_idx * len(data), len(data_loader.dataset),
5137
100. * batch_idx / len(data_loader), loss.item()))
38+
if args.dry_run:
39+
break
5240

5341

5442
def test_epoch(model, device, data_loader):

run_python_examples.sh

+45-18
Original file line numberDiff line numberDiff line change
@@ -3,27 +3,34 @@
33
# This script runs through the code in each of the python examples.
44
# The purpose is just as an integrtion test, not to actually train
55
# models in any meaningful way. For that reason, most of these set
6-
# epochs = 1.
6+
# epochs = 1 and --dry-run.
77
#
88
# Optionally specify a comma separated list of examples to run.
99
# can be run as:
1010
# ./run_python_examples.sh "install_deps,run_all,clean"
1111
# to pip install dependencies (other than pytorch), run all examples,
1212
# and remove temporary/changed data files.
13-
# Expects pytorch to be installed.
13+
# Expects pytorch, torchvision to be installed.
1414

1515
BASE_DIR=`pwd`"/"`dirname $0`
1616
EXAMPLES=`echo $1 | sed -e 's/ //g'`
1717

18-
if which nvcc ; then
19-
echo "using cuda"
20-
CUDA=1
21-
CUDA_FLAG="--cuda"
22-
else
23-
echo "not using cuda"
24-
CUDA=0
25-
CUDA_FLAG=""
26-
fi
18+
USE_CUDA=$(python -c "import torchvision, torch; print(torch.cuda.is_available())")
19+
case $USE_CUDA in
20+
"True")
21+
echo "using cuda"
22+
CUDA=1
23+
CUDA_FLAG="--cuda"
24+
;;
25+
"False")
26+
echo "not using cuda"
27+
CUDA=0
28+
CUDA_FLAG=""
29+
;;
30+
"")
31+
exit 1;
32+
;;
33+
esac
2734

2835
ERRORS=""
2936

@@ -63,7 +70,7 @@ function dcgan() {
6370
unzip ${DATACLASS}_train_lmdb.zip || { error "couldn't unzip $DATACLASS"; return; }
6471
popd
6572
fi
66-
python main.py --dataset lsun --dataroot lsun --classes $DATACLASS --niter 1 $CUDA_FLAG || error "dcgan failed"
73+
python main.py --dataset lsun --dataroot lsun --classes $DATACLASS --niter 1 $CUDA_FLAG --dry-run || error "dcgan failed"
6774
}
6875

6976
function fast_neural_style() {
@@ -92,12 +99,12 @@ function imagenet() {
9299

93100
function mnist() {
94101
start
95-
python main.py --epochs 1 || error "mnist example failed"
102+
python main.py --epochs 1 --dry-run || error "mnist example failed"
96103
}
97104

98105
function mnist_hogwild() {
99106
start
100-
python main.py --epochs 1 $CUDA_FLAG || error "mnist hogwild failed"
107+
python main.py --epochs 1 --dry-run $CUDA_FLAG || error "mnist hogwild failed"
101108
}
102109

103110
function regression() {
@@ -115,7 +122,7 @@ function snli() {
115122
echo "installing 'en' model if not installed"
116123
python -m spacy download en || { error "couldn't download 'en' model needed for snli"; return; }
117124
echo "training..."
118-
python train.py --epochs 1 --no-bidirectional || error "couldn't train snli"
125+
python train.py --epochs 1 --dev_every 1 --no-bidirectional --dry-run || error "couldn't train snli"
119126
}
120127

121128
function super_resolution() {
@@ -126,7 +133,7 @@ function super_resolution() {
126133
function time_sequence_prediction() {
127134
start
128135
python generate_sine_wave.py || { error "generate sine wave failed"; return; }
129-
python train.py || error "time sequence prediction training failed"
136+
python train.py --steps 2 || error "time sequence prediction training failed"
130137
}
131138

132139
function vae() {
@@ -136,18 +143,38 @@ function vae() {
136143

137144
function word_language_model() {
138145
start
139-
python main.py --epochs 1 $CUDA_FLAG || error "word_language_model failed"
146+
python main.py --epochs 1 --dry-run $CUDA_FLAG || error "word_language_model failed"
140147
}
141148

142149
function clean() {
143150
cd $BASE_DIR
144-
rm -rf dcgan/_cache_lsun_classroom_train_lmdb dcgan/fake_samples_epoch_000.png dcgan/lsun/ dcgan/netD_epoch_0.pth dcgan/netG_epoch_0.pth dcgan/real_samples.png fast_neural_style/saved_models.zip fast_neural_style/saved_models/ imagenet/checkpoint.pth.tar imagenet/lsun/ imagenet/model_best.pth.tar imagenet/sample/ snli/.data/ snli/.vector_cache/ snli/results/ super_resolution/dataset/ super_resolution/model_epoch_1.pth word_language_model/model.pt || error "couldn't clean up some files"
151+
rm -rf dcgan/_cache_lsun_classroom_train_lmdb \
152+
dcgan/fake_samples_epoch_000.png dcgan/lsun/ \
153+
dcgan/_cache_lsunclassroomtrainlmdb \
154+
dcgan/netD_epoch_0.pth dcgan/netG_epoch_0.pth \
155+
dcgan/real_samples.png \
156+
fast_neural_style/saved_models.zip \
157+
fast_neural_style/saved_models/ \
158+
imagenet/checkpoint.pth.tar \
159+
imagenet/lsun/ \
160+
imagenet/model_best.pth.tar \
161+
imagenet/sample/ \
162+
snli/.data/ \
163+
snli/.vector_cache/ \
164+
snli/results/ \
165+
super_resolution/dataset/ \
166+
super_resolution/model_epoch_1.pth \
167+
time_sequence_prediction/predict*.pdf \
168+
time_sequence_prediction/traindata.pt \
169+
word_language_model/model.pt || error "couldn't clean up some files"
145170

146171
git checkout fast_neural_style/images/output-images/amber-candy.jpg || error "couldn't clean up fast neural style image"
147172
}
148173

149174
function run_all() {
175+
# cpp
150176
dcgan
177+
# distributed
151178
fast_neural_style
152179
imagenet
153180
mnist

snli/train.py

+2
Original file line numberDiff line numberDiff line change
@@ -140,3 +140,5 @@
140140
print(log_template.format(time.time()-start,
141141
epoch, iterations, 1+batch_idx, len(train_iter),
142142
100. * (1+batch_idx) / len(train_iter), loss.item(), ' '*8, n_correct/n_total*100, ' '*12))
143+
if args.dry_run:
144+
break

snli/util.py

+2
Original file line numberDiff line numberDiff line change
@@ -65,5 +65,7 @@ def get_args():
6565
'glove.6B.50d glove.6B.100d glove.6B.200d glove.6B.300d')
6666
parser.add_argument('--resume_snapshot', type=str, default='',
6767
help='model snapshot to resume.')
68+
parser.add_argument('--dry-run', action='store_true',
69+
help='run only a few iterations')
6870
args = parser.parse_args()
6971
return args

time_sequence_prediction/train.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from __future__ import print_function
2+
import argparse
23
import torch
34
import torch.nn as nn
45
import torch.optim as optim
@@ -36,6 +37,9 @@ def forward(self, input, future = 0):
3637

3738

3839
if __name__ == '__main__':
40+
parser = argparse.ArgumentParser()
41+
parser.add_argument('--steps', type=int, default=15, help='steps to run')
42+
opt = parser.parse_args()
3943
# set random seed to 0
4044
np.random.seed(0)
4145
torch.manual_seed(0)
@@ -52,7 +56,7 @@ def forward(self, input, future = 0):
5256
# use LBFGS as optimizer since we can load the whole data to train
5357
optimizer = optim.LBFGS(seq.parameters(), lr=0.8)
5458
#begin to train
55-
for i in range(15):
59+
for i in range(opt.steps):
5660
print('STEP: ', i)
5761
def closure():
5862
optimizer.zero_grad()

0 commit comments

Comments
 (0)