20
20
import argparse
21
21
import sys
22
22
23
- import tensorflow as tf
23
+ import tensorflow as tf # pylint: disable=g-bad-import-order
24
24
25
25
from official .mnist import dataset
26
26
from official .utils .arg_parsers import parsers
27
27
from official .utils .logging import hooks_helper
28
28
29
29
LEARNING_RATE = 1e-4
30
30
31
+
31
32
class Model (tf .keras .Model ):
32
33
"""Model to recognize digits in the MNIST dataset.
33
34
@@ -145,31 +146,36 @@ def model_fn(features, labels, mode, params):
145
146
146
147
147
148
def validate_batch_size_for_multi_gpu (batch_size ):
148
- """For multi-gpu, batch-size must be a multiple of the number of
149
- available GPUs.
149
+ """For multi-gpu, batch-size must be a multiple of the number of GPUs.
150
150
151
151
Note that this should eventually be handled by replicate_model_fn
152
152
directly. Multi-GPU support is currently experimental, however,
153
153
so doing the work here until that feature is in place.
154
+
155
+ Args:
156
+ batch_size: the number of examples processed in each training batch.
157
+
158
+ Raises:
159
+ ValueError: if no GPUs are found, or selected batch_size is invalid.
154
160
"""
155
- from tensorflow .python .client import device_lib
161
+ from tensorflow .python .client import device_lib # pylint: disable=g-import-not-at-top
156
162
157
163
local_device_protos = device_lib .list_local_devices ()
158
164
num_gpus = sum ([1 for d in local_device_protos if d .device_type == 'GPU' ])
159
165
if not num_gpus :
160
166
raise ValueError ('Multi-GPU mode was specified, but no GPUs '
161
- 'were found. To use CPU, run without --multi_gpu.' )
167
+ 'were found. To use CPU, run without --multi_gpu.' )
162
168
163
169
remainder = batch_size % num_gpus
164
170
if remainder :
165
171
err = ('When running with multiple GPUs, batch size '
166
- 'must be a multiple of the number of available GPUs. '
167
- 'Found {} GPUs with a batch size of {}; try --batch_size={} instead.'
168
- ).format (num_gpus , batch_size , batch_size - remainder )
172
+ 'must be a multiple of the number of available GPUs. '
173
+ 'Found {} GPUs with a batch size of {}; try --batch_size={} instead.'
174
+ ).format (num_gpus , batch_size , batch_size - remainder )
169
175
raise ValueError (err )
170
176
171
177
172
- def main (unused_argv ):
178
+ def main (_ ):
173
179
model_function = model_fn
174
180
175
181
if FLAGS .multi_gpu :
@@ -195,6 +201,8 @@ def main(unused_argv):
195
201
196
202
# Set up training and evaluation input functions.
197
203
def train_input_fn ():
204
+ """Prepare data for training."""
205
+
198
206
# When choosing shuffle buffer sizes, larger sizes result in better
199
207
# randomness, while smaller sizes use less memory. MNIST is a small
200
208
# enough dataset that we can easily shuffle the full epoch.
@@ -215,7 +223,7 @@ def eval_input_fn():
215
223
FLAGS .hooks , batch_size = FLAGS .batch_size )
216
224
217
225
# Train and evaluate model.
218
- for n in range (FLAGS .train_epochs // FLAGS .epochs_between_evals ):
226
+ for _ in range (FLAGS .train_epochs // FLAGS .epochs_between_evals ):
219
227
mnist_classifier .train (input_fn = train_input_fn , hooks = train_hooks )
220
228
eval_results = mnist_classifier .evaluate (input_fn = eval_input_fn )
221
229
print ('\n Evaluation results:\n \t %s\n ' % eval_results )
@@ -231,10 +239,11 @@ def eval_input_fn():
231
239
232
240
class MNISTArgParser (argparse .ArgumentParser ):
233
241
"""Argument parser for running MNIST model."""
242
+
234
243
def __init__ (self ):
235
244
super (MNISTArgParser , self ).__init__ (parents = [
236
- parsers .BaseParser (),
237
- parsers .ImageModelParser ()])
245
+ parsers .BaseParser (),
246
+ parsers .ImageModelParser ()])
238
247
239
248
self .add_argument (
240
249
'--export_dir' ,
0 commit comments