Skip to content

Commit adfd5a3

Browse files
authored
Use util functions hooks_helper and parser in mnist and wide_deep, and rename epochs_between_eval (from epochs_per_eval) (tensorflow#3650)
1 parent 875fcb3 commit adfd5a3

File tree

11 files changed

+171
-158
lines changed

11 files changed

+171
-158
lines changed

official/mnist/README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@ APIs.
1111

1212
## Setup
1313

14-
To begin, you'll simply need the latest version of TensorFlow installed.
14+
To begin, you'll simply need the latest version of TensorFlow installed,
15+
and make sure to run the command to export the `/models` folder to the
16+
python path: https://github.com/tensorflow/models/tree/master/official#running-the-models
17+
1518
Then to train the model, run the following:
1619

1720
```

official/mnist/mnist.py

Lines changed: 41 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,15 @@
1818
from __future__ import print_function
1919

2020
import argparse
21-
import os
2221
import sys
2322

2423
import tensorflow as tf
24+
2525
from official.mnist import dataset
26+
from official.utils.arg_parsers import parsers
27+
from official.utils.logging import hooks_helper
2628

29+
LEARNING_RATE = 1e-4
2730

2831
class Model(tf.keras.Model):
2932
"""Model to recognize digits in the MNIST dataset.
@@ -104,7 +107,7 @@ def model_fn(features, labels, mode, params):
104107
'classify': tf.estimator.export.PredictOutput(predictions)
105108
})
106109
if mode == tf.estimator.ModeKeys.TRAIN:
107-
optimizer = tf.train.AdamOptimizer(learning_rate=1e-4)
110+
optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE)
108111

109112
# If we are running multi-GPU, we need to wrap the optimizer.
110113
if params.get('multi_gpu'):
@@ -114,10 +117,15 @@ def model_fn(features, labels, mode, params):
114117
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
115118
accuracy = tf.metrics.accuracy(
116119
labels=labels, predictions=tf.argmax(logits, axis=1))
117-
# Name the accuracy tensor 'train_accuracy' to demonstrate the
118-
# LoggingTensorHook.
120+
121+
# Name tensors to be logged with LoggingTensorHook.
122+
tf.identity(LEARNING_RATE, 'learning_rate')
123+
tf.identity(loss, 'cross_entropy')
119124
tf.identity(accuracy[1], name='train_accuracy')
125+
126+
# Save accuracy scalar to Tensorboard output.
120127
tf.summary.scalar('train_accuracy', accuracy[1])
128+
121129
return tf.estimator.EstimatorSpec(
122130
mode=tf.estimator.ModeKeys.TRAIN,
123131
loss=loss,
@@ -185,30 +193,32 @@ def main(unused_argv):
185193
'multi_gpu': FLAGS.multi_gpu
186194
})
187195

188-
# Train the model
196+
# Set up training and evaluation input functions.
189197
def train_input_fn():
190198
# When choosing shuffle buffer sizes, larger sizes result in better
191199
# randomness, while smaller sizes use less memory. MNIST is a small
192200
# enough dataset that we can easily shuffle the full epoch.
193201
ds = dataset.train(FLAGS.data_dir)
194-
ds = ds.cache().shuffle(buffer_size=50000).batch(FLAGS.batch_size).repeat(
195-
FLAGS.train_epochs)
196-
return ds
202+
ds = ds.cache().shuffle(buffer_size=50000).batch(FLAGS.batch_size)
197203

198-
# Set up training hook that logs the training accuracy every 100 steps.
199-
tensors_to_log = {'train_accuracy': 'train_accuracy'}
200-
logging_hook = tf.train.LoggingTensorHook(
201-
tensors=tensors_to_log, every_n_iter=100)
202-
mnist_classifier.train(input_fn=train_input_fn, hooks=[logging_hook])
204+
# Iterate through the dataset a set number (`epochs_between_evals`) of times
205+
# during each training session.
206+
ds = ds.repeat(FLAGS.epochs_between_evals)
207+
return ds
203208

204-
# Evaluate the model and print results
205209
def eval_input_fn():
206210
return dataset.test(FLAGS.data_dir).batch(
207211
FLAGS.batch_size).make_one_shot_iterator().get_next()
208212

209-
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
210-
print()
211-
print('Evaluation results:\n\t%s' % eval_results)
213+
# Set up hook that outputs training logs every 100 steps.
214+
train_hooks = hooks_helper.get_train_hooks(
215+
FLAGS.hooks, batch_size=FLAGS.batch_size)
216+
217+
# Train and evaluate model.
218+
for n in range(FLAGS.train_epochs // FLAGS.epochs_between_evals):
219+
mnist_classifier.train(input_fn=train_input_fn, hooks=train_hooks)
220+
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
221+
print('\nEvaluation results:\n\t%s\n' % eval_results)
212222

213223
# Export the model
214224
if FLAGS.export_dir is not None:
@@ -220,51 +230,28 @@ def eval_input_fn():
220230

221231

222232
class MNISTArgParser(argparse.ArgumentParser):
223-
233+
"""Argument parser for running MNIST model."""
224234
def __init__(self):
225-
super(MNISTArgParser, self).__init__()
235+
super(MNISTArgParser, self).__init__(parents=[
236+
parsers.BaseParser(),
237+
parsers.ImageModelParser()])
226238

227-
self.add_argument(
228-
'--multi_gpu', action='store_true',
229-
help='If set, run across all available GPUs.')
230-
self.add_argument(
231-
'--batch_size',
232-
type=int,
233-
default=100,
234-
help='Number of images to process in a batch')
235-
self.add_argument(
236-
'--data_dir',
237-
type=str,
238-
default='/tmp/mnist_data',
239-
help='Path to directory containing the MNIST dataset')
240-
self.add_argument(
241-
'--model_dir',
242-
type=str,
243-
default='/tmp/mnist_model',
244-
help='The directory where the model will be stored.')
245-
self.add_argument(
246-
'--train_epochs',
247-
type=int,
248-
default=40,
249-
help='Number of epochs to train.')
250-
self.add_argument(
251-
'--data_format',
252-
type=str,
253-
default=None,
254-
choices=['channels_first', 'channels_last'],
255-
help='A flag to override the data format used in the model. '
256-
'channels_first provides a performance boost on GPU but is not always '
257-
'compatible with CPU. If left unspecified, the data format will be '
258-
'chosen automatically based on whether TensorFlow was built for CPU or '
259-
'GPU.')
260239
self.add_argument(
261240
'--export_dir',
262241
type=str,
263-
help='The directory where the exported SavedModel will be stored.')
242+
help='[default: %(default)s] If set, a SavedModel serialization of the '
243+
'model will be exported to this directory at the end of training. '
244+
'See the README for more details and relevant links.')
245+
246+
self.set_defaults(
247+
data_dir='/tmp/mnist_data',
248+
model_dir='/tmp/mnist_model',
249+
batch_size=100,
250+
train_epochs=40)
264251

265252

266253
if __name__ == '__main__':
267-
parser = MNISTArgParser()
268254
tf.logging.set_verbosity(tf.logging.INFO)
255+
parser = MNISTArgParser()
269256
FLAGS, unparsed = parser.parse_known_args()
270257
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

official/mnist/mnist_eager.py

Lines changed: 60 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,10 @@
3333

3434
import tensorflow as tf
3535
import tensorflow.contrib.eager as tfe
36+
3637
from official.mnist import mnist
3738
from official.mnist import dataset
39+
from official.utils.arg_parsers import parsers
3840

3941
FLAGS = None
4042

@@ -98,9 +100,13 @@ def test(model, dataset):
98100
def main(_):
99101
tfe.enable_eager_execution()
100102

103+
# Automatically determine device and data_format
101104
(device, data_format) = ('/gpu:0', 'channels_first')
102105
if FLAGS.no_gpu or tfe.num_gpus() <= 0:
103106
(device, data_format) = ('/cpu:0', 'channels_last')
107+
# If data_format is defined in FLAGS, overwrite automatically set value.
108+
if FLAGS.data_format is not None:
109+
data_format = data_format
104110
print('Using device %s, and data format %s.' % (device, data_format))
105111

106112
# Load the datasets
@@ -112,6 +118,7 @@ def main(_):
112118
model = mnist.Model(data_format)
113119
optimizer = tf.train.MomentumOptimizer(FLAGS.lr, FLAGS.momentum)
114120

121+
# Create file writers for writing TensorBoard summaries.
115122
if FLAGS.output_dir:
116123
# Create directories to which summaries will be written
117124
# tensorboard --logdir=<output_dir>
@@ -126,15 +133,18 @@ def main(_):
126133
train_dir, flush_millis=10000)
127134
test_summary_writer = tf.contrib.summary.create_file_writer(
128135
test_dir, flush_millis=10000, name='test')
129-
checkpoint_prefix = os.path.join(FLAGS.checkpoint_dir, 'ckpt')
136+
137+
# Create and restore checkpoint (if one exists on the path)
138+
checkpoint_prefix = os.path.join(FLAGS.model_dir, 'ckpt')
130139
step_counter = tf.train.get_or_create_global_step()
131140
checkpoint = tfe.Checkpoint(
132141
model=model, optimizer=optimizer, step_counter=step_counter)
133142
# Restore variables on creation if a checkpoint exists.
134-
checkpoint.restore(tf.train.latest_checkpoint(FLAGS.checkpoint_dir))
135-
# Train and evaluate for 10 epochs.
143+
checkpoint.restore(tf.train.latest_checkpoint(FLAGS.model_dir))
144+
145+
# Train and evaluate for a set number of epochs.
136146
with tf.device(device):
137-
for _ in range(10):
147+
for _ in range(FLAGS.train_epochs):
138148
start = time.time()
139149
with summary_writer.as_default():
140150
train(model, optimizer, train_ds, step_counter, FLAGS.log_interval)
@@ -148,54 +158,52 @@ def main(_):
148158
checkpoint.save(checkpoint_prefix)
149159

150160

151-
if __name__ == '__main__':
152-
parser = argparse.ArgumentParser()
153-
parser.add_argument(
154-
'--data_dir',
155-
type=str,
156-
default='/tmp/tensorflow/mnist/input_data',
157-
help='Directory for storing input data')
158-
parser.add_argument(
159-
'--batch_size',
160-
type=int,
161-
default=100,
162-
metavar='N',
163-
help='input batch size for training (default: 100)')
164-
parser.add_argument(
165-
'--log_interval',
166-
type=int,
167-
default=10,
168-
metavar='N',
169-
help='how many batches to wait before logging training status')
170-
parser.add_argument(
171-
'--output_dir',
172-
type=str,
173-
default=None,
174-
metavar='N',
175-
help='Directory to write TensorBoard summaries')
176-
parser.add_argument(
177-
'--checkpoint_dir',
178-
type=str,
179-
default='/tmp/tensorflow/mnist/checkpoints/',
180-
metavar='N',
181-
help='Directory to save checkpoints in (once per epoch)')
182-
parser.add_argument(
183-
'--lr',
184-
type=float,
185-
default=0.01,
186-
metavar='LR',
187-
help='learning rate (default: 0.01)')
188-
parser.add_argument(
189-
'--momentum',
190-
type=float,
191-
default=0.5,
192-
metavar='M',
193-
help='SGD momentum (default: 0.5)')
194-
parser.add_argument(
195-
'--no_gpu',
196-
action='store_true',
197-
default=False,
198-
help='disables GPU usage even if a GPU is available')
161+
class MNISTEagerArgParser(argparse.ArgumentParser):
162+
"""Argument parser for running MNIST model with eager trainng loop."""
163+
def __init__(self):
164+
super(MNISTEagerArgParser, self).__init__(parents=[
165+
parsers.BaseParser(epochs_between_evals=False, multi_gpu=False,
166+
hooks=False),
167+
parsers.ImageModelParser()])
168+
169+
self.add_argument(
170+
'--log_interval', '-li',
171+
type=int,
172+
default=10,
173+
metavar='N',
174+
help='[default: %(default)s] batches between logging training status')
175+
self.add_argument(
176+
'--output_dir', '-od',
177+
type=str,
178+
default=None,
179+
metavar='<OD>',
180+
help='[default: %(default)s] Directory to write TensorBoard summaries')
181+
self.add_argument(
182+
'--lr', '-lr',
183+
type=float,
184+
default=0.01,
185+
metavar='<LR>',
186+
help='[default: %(default)s] learning rate')
187+
self.add_argument(
188+
'--momentum', '-m',
189+
type=float,
190+
default=0.5,
191+
metavar='<M>',
192+
help='[default: %(default)s] SGD momentum')
193+
self.add_argument(
194+
'--no_gpu', '-nogpu',
195+
action='store_true',
196+
default=False,
197+
help='disables GPU usage even if a GPU is available')
198+
199+
self.set_defaults(
200+
data_dir='/tmp/tensorflow/mnist/input_data',
201+
model_dir='/tmp/tensorflow/mnist/checkpoints/',
202+
batch_size=100,
203+
train_epochs=10,
204+
)
199205

206+
if __name__ == '__main__':
207+
parser = MNISTEagerArgParser()
200208
FLAGS, unparsed = parser.parse_known_args()
201209
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

official/resnet/cifar10_main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def main(argv):
216216
model_dir='/tmp/cifar10_model',
217217
resnet_size=32,
218218
train_epochs=250,
219-
epochs_per_eval=10,
219+
epochs_between_evals=10,
220220
batch_size=128)
221221

222222
flags = parser.parse_args(args=argv[1:])

official/resnet/resnet_run_loop.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -339,15 +339,16 @@ def resnet_main(flags, model_function, input_function):
339339
'version': flags.version,
340340
})
341341

342-
for _ in range(flags.train_epochs // flags.epochs_per_eval):
343-
train_hooks = hooks_helper.get_train_hooks(flags.hooks, batch_size=flags.batch_size)
342+
for _ in range(flags.train_epochs // flags.epochs_between_evals):
343+
train_hooks = hooks_helper.get_train_hooks(flags.hooks,
344+
batch_size=flags.batch_size)
344345

345346
print('Starting a training cycle.')
346347

347348
def input_fn_train():
348349
return input_function(True, flags.data_dir, flags.batch_size,
349-
flags.epochs_per_eval, flags.num_parallel_calls,
350-
flags.multi_gpu)
350+
flags.epochs_between_evals,
351+
flags.num_parallel_calls, flags.multi_gpu)
351352

352353
classifier.train(input_fn=input_fn_train, hooks=train_hooks,
353354
max_steps=flags.max_train_steps)

0 commit comments

Comments
 (0)