Skip to content

Commit 79d71b8

Browse files
nairbvsoumith
authored andcommitted
add a script to run all pytorch examples (pytorch#591)
1 parent 71a9207 commit 79d71b8

File tree

4 files changed

+193
-3
lines changed

4 files changed

+193
-3
lines changed

dcgan/main.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
parser.add_argument('--netD', default='', help="path to netD (to continue training)")
3232
parser.add_argument('--outf', default='.', help='folder to output images and model checkpoints')
3333
parser.add_argument('--manualSeed', type=int, help='manual seed')
34+
parser.add_argument('--classes', default='bedroom', help='comma separated list of classes for the lsun data set')
3435

3536
opt = parser.parse_args()
3637
print(opt)
@@ -62,7 +63,8 @@
6263
]))
6364
nc=3
6465
elif opt.dataset == 'lsun':
65-
dataset = dset.LSUN(root=opt.dataroot, classes=['bedroom_train'],
66+
classes = [ c + '_train' for c in opt.classes.split(',')]
67+
dataset = dset.LSUN(root=opt.dataroot, classes=classes,
6668
transform=transforms.Compose([
6769
transforms.Resize(opt.imageSize),
6870
transforms.CenterCrop(opt.imageSize),

run_python_examples.sh

+184
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
#!/bin/sh
2+
#
3+
# This script runs through the code in each of the python examples.
4+
# The purpose is just as an integrtion test, not to actually train
5+
# models in any meaningful way. For that reason, most of these set
6+
# epochs = 1.
7+
#
8+
# Optionally specify a comma separated list of examples to run.
9+
# can be run as:
10+
# ./run_python_examples.sh "install_deps,run_all,clean"
11+
# to pip install dependencies (other than pytorch), run all examples,
12+
# and remove temporary/changed data files.
13+
# Expects pytorch to be installed.
14+
15+
BASE_DIR=`pwd`"/"`dirname $0`
16+
EXAMPLES=`echo $1 | sed -e 's/ //g'`
17+
18+
if which nvcc ; then
19+
echo "using cuda"
20+
CUDA=1
21+
CUDA_FLAG="--cuda"
22+
else
23+
echo "not using cuda"
24+
CUDA=0
25+
CUDA_FLAG=""
26+
fi
27+
28+
ERRORS=""
29+
30+
function error() {
31+
ERR=$1
32+
ERRORS="$ERRORS\n$ERR"
33+
echo $ERR
34+
}
35+
36+
function install_deps() {
37+
echo "installing requirements"
38+
cat $BASE_DIR/*/requirements.txt | \
39+
sort -u | \
40+
# testing the installed version of torch, so don't pip install it.
41+
grep -vE '^torch$' | \
42+
pip install -r /dev/stdin || \
43+
{ error "failed to install dependencies"; exit 1; }
44+
}
45+
46+
function start() {
47+
EXAMPLE=${FUNCNAME[1]}
48+
cd $BASE_DIR/$EXAMPLE
49+
echo "Running example: $EXAMPLE"
50+
}
51+
52+
function dcgan() {
53+
start
54+
if [ ! -d "lsun" ]; then
55+
echo "cloning repo to get lsun dataset"
56+
git clone https://github.com/fyu/lsun || { error "couldn't clone lsun repo needed for dcgan"; return; }
57+
fi
58+
# 'classroom' much smaller than the default 'bedroom' dataset.
59+
DATACLASS="classroom"
60+
if [ ! -d "lsun/${DATACLASS}_train_lmdb" ]; then
61+
pushd lsun
62+
python download.py -c $DATACLASS || { error "couldn't download $DATACLASS for dcgan"; return; }
63+
unzip ${DATACLASS}_train_lmdb.zip || { error "couldn't unzip $DATACLASS"; return; }
64+
popd
65+
fi
66+
python main.py --dataset lsun --dataroot lsun --classes $DATACLASS --niter 1 $CUDA_FLAG || error "dcgan failed"
67+
}
68+
69+
function fast_neural_style() {
70+
start
71+
if [ ! -d "saved_models" ]; then
72+
echo "downloading saved models for fast neural style"
73+
python download_saved_models.py
74+
fi
75+
test -d "saved_models" || { error "saved models not found"; return; }
76+
77+
echo "running fast neural style model"
78+
python neural_style/neural_style.py eval --content-image images/content-images/amber.jpg --model saved_models/candy.pth --output-image images/output-images/amber-candy.jpg --cuda $CUDA || error "neural_style.py failed"
79+
}
80+
81+
function imagenet() {
82+
start
83+
if [[ ! -d "sample/val" || ! -d "sample/train" ]]; then
84+
mkdir -p sample/val/n
85+
mkdir -p sample/train/n
86+
wget "https://upload.wikimedia.org/wikipedia/commons/5/5a/Socks-clinton.jpg" || { error "couldn't download sample image for imagenet"; return; }
87+
mv Socks-clinton.jpg sample/train/n
88+
cp sample/train/n/* sample/val/n/
89+
fi
90+
python main.py --epochs 1 sample/ || error "imagenet example failed"
91+
}
92+
93+
function mnist() {
94+
start
95+
python main.py --epochs 1 || error "mnist example failed"
96+
}
97+
98+
function mnist_hogwild() {
99+
start
100+
python main.py --epochs 1 $CUDA_FLAG || error "mnist hogwild failed"
101+
}
102+
103+
function regression() {
104+
start
105+
python main.py --epochs 1 $CUDA_FLAG || error "regression failed"
106+
}
107+
108+
function reinforcement_learning() {
109+
start
110+
python reinforce.py || error "reinforcement learning failed"
111+
}
112+
113+
function snli() {
114+
start
115+
echo "installing 'en' model if not installed"
116+
python -m spacy download en || { error "couldn't download 'en' model needed for snli"; return; }
117+
echo "training..."
118+
python train.py --epochs 1 --no-bidirectional || error "couldn't train snli"
119+
}
120+
121+
function super_resolution() {
122+
start
123+
python main.py --upscale_factor 3 --batchSize 4 --testBatchSize 100 --nEpochs 1 --lr 0.001 || error "super resolution failed"
124+
}
125+
126+
function time_sequence_prediciton() {
127+
start
128+
python generate_sine_wave.py || { error "generate sine wave failed"; return; }
129+
python train.py || error "time sequence prediction training failed"
130+
}
131+
132+
function vae() {
133+
start
134+
python main.py --epochs 1 || error "vae failed"
135+
}
136+
137+
function word_language_model() {
138+
start
139+
python main.py --epochs 1 $CUDA_FLAG || error "word_language_model failed"
140+
}
141+
142+
function clean() {
143+
cd $BASE_DIR
144+
rm -rf dcgan/_cache_lsun_classroom_train_lmdb dcgan/fake_samples_epoch_000.png dcgan/lsun/ dcgan/netD_epoch_0.pth dcgan/netG_epoch_0.pth dcgan/real_samples.png fast_neural_style/saved_models.zip fast_neural_style/saved_models/ imagenet/checkpoint.pth.tar imagenet/lsun/ imagenet/model_best.pth.tar imagenet/sample/ snli/.data/ snli/.vector_cache/ snli/results/ super_resolution/dataset/ super_resolution/model_epoch_1.pth word_language_model/model.pt || error "couldn't clean up some files"
145+
146+
git checkout fast_neural_style/images/output-images/amber-candy.jpg || error "couldn't clean up fast neural style image"
147+
}
148+
149+
function run_all() {
150+
dcgan
151+
fast_neural_style
152+
imagenet
153+
mnist
154+
mnist_hogwild
155+
regression
156+
reinforcement_learning
157+
snli
158+
super_resolution
159+
time_sequence_prediction
160+
vae
161+
word_language_model
162+
}
163+
164+
# by default, run all examples
165+
if [ "" == "$EXAMPLES" ]; then
166+
run_all
167+
else
168+
for i in $(echo $EXAMPLES | sed "s/,/ /g")
169+
do
170+
$i
171+
done
172+
fi
173+
174+
if [ "" == "$ERRORS" ]; then
175+
tput setaf 2
176+
echo "Completed successfully"
177+
else
178+
tput setaf 1
179+
echo "Some examples failed:"
180+
printf "$ERRORS"
181+
fi
182+
183+
tput sgr0
184+

snli/requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
torch
22
torchtext
3+
spacy

snli/train.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,11 @@
1414

1515

1616
args = get_args()
17-
torch.cuda.set_device(args.gpu)
18-
device = torch.device('cuda:{}'.format(args.gpu))
17+
if torch.cuda.is_available():
18+
torch.cuda.set_device(args.gpu)
19+
device = torch.device('cuda:{}'.format(args.gpu))
20+
else:
21+
device = torch.device('cpu')
1922

2023
inputs = data.Field(lower=args.lower, tokenize='spacy')
2124
answers = data.Field(sequential=False)

0 commit comments

Comments
 (0)