|
| 1 | +#!/bin/sh |
| 2 | +# |
| 3 | +# This script runs through the code in each of the python examples. |
| 4 | +# The purpose is just as an integrtion test, not to actually train |
| 5 | +# models in any meaningful way. For that reason, most of these set |
| 6 | +# epochs = 1. |
| 7 | +# |
| 8 | +# Optionally specify a comma separated list of examples to run. |
| 9 | +# can be run as: |
| 10 | +# ./run_python_examples.sh "install_deps,run_all,clean" |
| 11 | +# to pip install dependencies (other than pytorch), run all examples, |
| 12 | +# and remove temporary/changed data files. |
| 13 | +# Expects pytorch to be installed. |
| 14 | + |
| 15 | +BASE_DIR=`pwd`"/"`dirname $0` |
| 16 | +EXAMPLES=`echo $1 | sed -e 's/ //g'` |
| 17 | + |
| 18 | +if which nvcc ; then |
| 19 | + echo "using cuda" |
| 20 | + CUDA=1 |
| 21 | + CUDA_FLAG="--cuda" |
| 22 | +else |
| 23 | + echo "not using cuda" |
| 24 | + CUDA=0 |
| 25 | + CUDA_FLAG="" |
| 26 | +fi |
| 27 | + |
| 28 | +ERRORS="" |
| 29 | + |
| 30 | +function error() { |
| 31 | + ERR=$1 |
| 32 | + ERRORS="$ERRORS\n$ERR" |
| 33 | + echo $ERR |
| 34 | +} |
| 35 | + |
| 36 | +function install_deps() { |
| 37 | + echo "installing requirements" |
| 38 | + cat $BASE_DIR/*/requirements.txt | \ |
| 39 | + sort -u | \ |
| 40 | + # testing the installed version of torch, so don't pip install it. |
| 41 | + grep -vE '^torch$' | \ |
| 42 | + pip install -r /dev/stdin || \ |
| 43 | + { error "failed to install dependencies"; exit 1; } |
| 44 | +} |
| 45 | + |
| 46 | +function start() { |
| 47 | + EXAMPLE=${FUNCNAME[1]} |
| 48 | + cd $BASE_DIR/$EXAMPLE |
| 49 | + echo "Running example: $EXAMPLE" |
| 50 | +} |
| 51 | + |
| 52 | +function dcgan() { |
| 53 | + start |
| 54 | + if [ ! -d "lsun" ]; then |
| 55 | + echo "cloning repo to get lsun dataset" |
| 56 | + git clone https://github.com/fyu/lsun || { error "couldn't clone lsun repo needed for dcgan"; return; } |
| 57 | + fi |
| 58 | + # 'classroom' much smaller than the default 'bedroom' dataset. |
| 59 | + DATACLASS="classroom" |
| 60 | + if [ ! -d "lsun/${DATACLASS}_train_lmdb" ]; then |
| 61 | + pushd lsun |
| 62 | + python download.py -c $DATACLASS || { error "couldn't download $DATACLASS for dcgan"; return; } |
| 63 | + unzip ${DATACLASS}_train_lmdb.zip || { error "couldn't unzip $DATACLASS"; return; } |
| 64 | + popd |
| 65 | + fi |
| 66 | + python main.py --dataset lsun --dataroot lsun --classes $DATACLASS --niter 1 $CUDA_FLAG || error "dcgan failed" |
| 67 | +} |
| 68 | + |
| 69 | +function fast_neural_style() { |
| 70 | + start |
| 71 | + if [ ! -d "saved_models" ]; then |
| 72 | + echo "downloading saved models for fast neural style" |
| 73 | + python download_saved_models.py |
| 74 | + fi |
| 75 | + test -d "saved_models" || { error "saved models not found"; return; } |
| 76 | + |
| 77 | + echo "running fast neural style model" |
| 78 | + python neural_style/neural_style.py eval --content-image images/content-images/amber.jpg --model saved_models/candy.pth --output-image images/output-images/amber-candy.jpg --cuda $CUDA || error "neural_style.py failed" |
| 79 | +} |
| 80 | + |
| 81 | +function imagenet() { |
| 82 | + start |
| 83 | + if [[ ! -d "sample/val" || ! -d "sample/train" ]]; then |
| 84 | + mkdir -p sample/val/n |
| 85 | + mkdir -p sample/train/n |
| 86 | + wget "https://upload.wikimedia.org/wikipedia/commons/5/5a/Socks-clinton.jpg" || { error "couldn't download sample image for imagenet"; return; } |
| 87 | + mv Socks-clinton.jpg sample/train/n |
| 88 | + cp sample/train/n/* sample/val/n/ |
| 89 | + fi |
| 90 | + python main.py --epochs 1 sample/ || error "imagenet example failed" |
| 91 | +} |
| 92 | + |
| 93 | +function mnist() { |
| 94 | + start |
| 95 | + python main.py --epochs 1 || error "mnist example failed" |
| 96 | +} |
| 97 | + |
| 98 | +function mnist_hogwild() { |
| 99 | + start |
| 100 | + python main.py --epochs 1 $CUDA_FLAG || error "mnist hogwild failed" |
| 101 | +} |
| 102 | + |
| 103 | +function regression() { |
| 104 | + start |
| 105 | + python main.py --epochs 1 $CUDA_FLAG || error "regression failed" |
| 106 | +} |
| 107 | + |
| 108 | +function reinforcement_learning() { |
| 109 | + start |
| 110 | + python reinforce.py || error "reinforcement learning failed" |
| 111 | +} |
| 112 | + |
| 113 | +function snli() { |
| 114 | + start |
| 115 | + echo "installing 'en' model if not installed" |
| 116 | + python -m spacy download en || { error "couldn't download 'en' model needed for snli"; return; } |
| 117 | + echo "training..." |
| 118 | + python train.py --epochs 1 --no-bidirectional || error "couldn't train snli" |
| 119 | +} |
| 120 | + |
| 121 | +function super_resolution() { |
| 122 | + start |
| 123 | + python main.py --upscale_factor 3 --batchSize 4 --testBatchSize 100 --nEpochs 1 --lr 0.001 || error "super resolution failed" |
| 124 | +} |
| 125 | + |
| 126 | +function time_sequence_prediciton() { |
| 127 | + start |
| 128 | + python generate_sine_wave.py || { error "generate sine wave failed"; return; } |
| 129 | + python train.py || error "time sequence prediction training failed" |
| 130 | +} |
| 131 | + |
| 132 | +function vae() { |
| 133 | + start |
| 134 | + python main.py --epochs 1 || error "vae failed" |
| 135 | +} |
| 136 | + |
| 137 | +function word_language_model() { |
| 138 | + start |
| 139 | + python main.py --epochs 1 $CUDA_FLAG || error "word_language_model failed" |
| 140 | +} |
| 141 | + |
| 142 | +function clean() { |
| 143 | + cd $BASE_DIR |
| 144 | + rm -rf dcgan/_cache_lsun_classroom_train_lmdb dcgan/fake_samples_epoch_000.png dcgan/lsun/ dcgan/netD_epoch_0.pth dcgan/netG_epoch_0.pth dcgan/real_samples.png fast_neural_style/saved_models.zip fast_neural_style/saved_models/ imagenet/checkpoint.pth.tar imagenet/lsun/ imagenet/model_best.pth.tar imagenet/sample/ snli/.data/ snli/.vector_cache/ snli/results/ super_resolution/dataset/ super_resolution/model_epoch_1.pth word_language_model/model.pt || error "couldn't clean up some files" |
| 145 | + |
| 146 | + git checkout fast_neural_style/images/output-images/amber-candy.jpg || error "couldn't clean up fast neural style image" |
| 147 | +} |
| 148 | + |
| 149 | +function run_all() { |
| 150 | + dcgan |
| 151 | + fast_neural_style |
| 152 | + imagenet |
| 153 | + mnist |
| 154 | + mnist_hogwild |
| 155 | + regression |
| 156 | + reinforcement_learning |
| 157 | + snli |
| 158 | + super_resolution |
| 159 | + time_sequence_prediction |
| 160 | + vae |
| 161 | + word_language_model |
| 162 | +} |
| 163 | + |
| 164 | +# by default, run all examples |
| 165 | +if [ "" == "$EXAMPLES" ]; then |
| 166 | + run_all |
| 167 | +else |
| 168 | + for i in $(echo $EXAMPLES | sed "s/,/ /g") |
| 169 | + do |
| 170 | + $i |
| 171 | + done |
| 172 | +fi |
| 173 | + |
| 174 | +if [ "" == "$ERRORS" ]; then |
| 175 | + tput setaf 2 |
| 176 | + echo "Completed successfully" |
| 177 | +else |
| 178 | + tput setaf 1 |
| 179 | + echo "Some examples failed:" |
| 180 | + printf "$ERRORS" |
| 181 | +fi |
| 182 | + |
| 183 | +tput sgr0 |
| 184 | + |
0 commit comments