Skip to content

Commit 7cfb6bb

Browse files
authored
Glint everything (tensorflow#3654)
* Glint everything * Adding rcfile and pylinting * Extra newline * Few last lints
1 parent adfd5a3 commit 7cfb6bb

27 files changed

+382
-162
lines changed

official/__init__.py

-14
Original file line numberDiff line numberDiff line change
@@ -1,14 +0,0 @@
1-
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2-
#
3-
# Licensed under the Apache License, Version 2.0 (the "License");
4-
# you may not use this file except in compliance with the License.
5-
# You may obtain a copy of the License at
6-
#
7-
# http://www.apache.org/licenses/LICENSE-2.0
8-
#
9-
# Unless required by applicable law or agreed to in writing, software
10-
# distributed under the License is distributed on an "AS IS" BASIS,
11-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-
# See the License for the specific language governing permissions and
13-
# limitations under the License.
14-
# ==============================================================================

official/mnist/dataset.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20+
import gzip
2021
import os
2122
import shutil
22-
import gzip
2323

2424
import numpy as np
2525
from six.moves import urllib
@@ -36,7 +36,7 @@ def check_image_file_header(filename):
3636
"""Validate that filename corresponds to images for the MNIST dataset."""
3737
with tf.gfile.Open(filename, 'rb') as f:
3838
magic = read32(f)
39-
num_images = read32(f)
39+
read32(f) # num_images, unused
4040
rows = read32(f)
4141
cols = read32(f)
4242
if magic != 2051:
@@ -52,7 +52,7 @@ def check_labels_file_header(filename):
5252
"""Validate that filename corresponds to labels for the MNIST dataset."""
5353
with tf.gfile.Open(filename, 'rb') as f:
5454
magic = read32(f)
55-
num_items = read32(f)
55+
read32(f) # num_items, unused
5656
if magic != 2049:
5757
raise ValueError('Invalid magic number %d in MNIST file %s' % (magic,
5858
f.name))
@@ -77,6 +77,8 @@ def download(directory, filename):
7777

7878

7979
def dataset(directory, images_file, labels_file):
80+
"""Download and parse MNIST dataset."""
81+
8082
images_file = download(directory, images_file)
8183
labels_file = download(directory, labels_file)
8284

official/mnist/mnist.py

+21-12
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,15 @@
2020
import argparse
2121
import sys
2222

23-
import tensorflow as tf
23+
import tensorflow as tf # pylint: disable=g-bad-import-order
2424

2525
from official.mnist import dataset
2626
from official.utils.arg_parsers import parsers
2727
from official.utils.logging import hooks_helper
2828

2929
LEARNING_RATE = 1e-4
3030

31+
3132
class Model(tf.keras.Model):
3233
"""Model to recognize digits in the MNIST dataset.
3334
@@ -145,31 +146,36 @@ def model_fn(features, labels, mode, params):
145146

146147

147148
def validate_batch_size_for_multi_gpu(batch_size):
148-
"""For multi-gpu, batch-size must be a multiple of the number of
149-
available GPUs.
149+
"""For multi-gpu, batch-size must be a multiple of the number of GPUs.
150150
151151
Note that this should eventually be handled by replicate_model_fn
152152
directly. Multi-GPU support is currently experimental, however,
153153
so doing the work here until that feature is in place.
154+
155+
Args:
156+
batch_size: the number of examples processed in each training batch.
157+
158+
Raises:
159+
ValueError: if no GPUs are found, or selected batch_size is invalid.
154160
"""
155-
from tensorflow.python.client import device_lib
161+
from tensorflow.python.client import device_lib # pylint: disable=g-import-not-at-top
156162

157163
local_device_protos = device_lib.list_local_devices()
158164
num_gpus = sum([1 for d in local_device_protos if d.device_type == 'GPU'])
159165
if not num_gpus:
160166
raise ValueError('Multi-GPU mode was specified, but no GPUs '
161-
'were found. To use CPU, run without --multi_gpu.')
167+
'were found. To use CPU, run without --multi_gpu.')
162168

163169
remainder = batch_size % num_gpus
164170
if remainder:
165171
err = ('When running with multiple GPUs, batch size '
166-
'must be a multiple of the number of available GPUs. '
167-
'Found {} GPUs with a batch size of {}; try --batch_size={} instead.'
168-
).format(num_gpus, batch_size, batch_size - remainder)
172+
'must be a multiple of the number of available GPUs. '
173+
'Found {} GPUs with a batch size of {}; try --batch_size={} instead.'
174+
).format(num_gpus, batch_size, batch_size - remainder)
169175
raise ValueError(err)
170176

171177

172-
def main(unused_argv):
178+
def main(_):
173179
model_function = model_fn
174180

175181
if FLAGS.multi_gpu:
@@ -195,6 +201,8 @@ def main(unused_argv):
195201

196202
# Set up training and evaluation input functions.
197203
def train_input_fn():
204+
"""Prepare data for training."""
205+
198206
# When choosing shuffle buffer sizes, larger sizes result in better
199207
# randomness, while smaller sizes use less memory. MNIST is a small
200208
# enough dataset that we can easily shuffle the full epoch.
@@ -215,7 +223,7 @@ def eval_input_fn():
215223
FLAGS.hooks, batch_size=FLAGS.batch_size)
216224

217225
# Train and evaluate model.
218-
for n in range(FLAGS.train_epochs // FLAGS.epochs_between_evals):
226+
for _ in range(FLAGS.train_epochs // FLAGS.epochs_between_evals):
219227
mnist_classifier.train(input_fn=train_input_fn, hooks=train_hooks)
220228
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
221229
print('\nEvaluation results:\n\t%s\n' % eval_results)
@@ -231,10 +239,11 @@ def eval_input_fn():
231239

232240
class MNISTArgParser(argparse.ArgumentParser):
233241
"""Argument parser for running MNIST model."""
242+
234243
def __init__(self):
235244
super(MNISTArgParser, self).__init__(parents=[
236-
parsers.BaseParser(),
237-
parsers.ImageModelParser()])
245+
parsers.BaseParser(),
246+
parsers.ImageModelParser()])
238247

239248
self.add_argument(
240249
'--export_dir',

official/mnist/mnist_eager.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@
3131
import sys
3232
import time
3333

34-
import tensorflow as tf
35-
import tensorflow.contrib.eager as tfe
34+
import tensorflow as tf # pylint: disable=g-bad-import-order
35+
import tensorflow.contrib.eager as tfe # pylint: disable=g-bad-import-order
3636

37+
from official.mnist import dataset as mnist_dataset
3738
from official.mnist import mnist
38-
from official.mnist import dataset
3939
from official.utils.arg_parsers import parsers
4040

4141
FLAGS = None
@@ -110,9 +110,9 @@ def main(_):
110110
print('Using device %s, and data format %s.' % (device, data_format))
111111

112112
# Load the datasets
113-
train_ds = dataset.train(FLAGS.data_dir).shuffle(60000).batch(
113+
train_ds = mnist_dataset.train(FLAGS.data_dir).shuffle(60000).batch(
114114
FLAGS.batch_size)
115-
test_ds = dataset.test(FLAGS.data_dir).batch(FLAGS.batch_size)
115+
test_ds = mnist_dataset.test(FLAGS.data_dir).batch(FLAGS.batch_size)
116116

117117
# Create the model and optimizer
118118
model = mnist.Model(data_format)
@@ -159,12 +159,13 @@ def main(_):
159159

160160

161161
class MNISTEagerArgParser(argparse.ArgumentParser):
162-
"""Argument parser for running MNIST model with eager trainng loop."""
162+
"""Argument parser for running MNIST model with eager training loop."""
163+
163164
def __init__(self):
164165
super(MNISTEagerArgParser, self).__init__(parents=[
165-
parsers.BaseParser(epochs_between_evals=False, multi_gpu=False,
166-
hooks=False),
167-
parsers.ImageModelParser()])
166+
parsers.BaseParser(
167+
epochs_between_evals=False, multi_gpu=False, hooks=False),
168+
parsers.ImageModelParser()])
168169

169170
self.add_argument(
170171
'--log_interval', '-li',

official/mnist/mnist_eager_test.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20-
import tensorflow as tf
21-
import tensorflow.contrib.eager as tfe
20+
import tensorflow as tf # pylint: disable=g-bad-import-order
21+
import tensorflow.contrib.eager as tfe # pylint: disable=g-bad-import-order
2222

2323
from official.mnist import mnist
2424
from official.mnist import mnist_eager
@@ -60,6 +60,7 @@ def evaluate(defun=False):
6060

6161

6262
class MNISTTest(tf.test.TestCase):
63+
"""Run tests for MNIST eager loop."""
6364

6465
def test_train(self):
6566
train(defun=False)

official/mnist/mnist_test.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20-
import tensorflow as tf
2120
import time
2221

22+
import tensorflow as tf # pylint: disable=g-bad-import-order
23+
2324
from official.mnist import mnist
2425

2526
BATCH_SIZE = 100
@@ -42,6 +43,7 @@ def make_estimator():
4243

4344

4445
class Tests(tf.test.TestCase):
46+
"""Run tests for MNIST model."""
4547

4648
def test_mnist(self):
4749
classifier = make_estimator()
@@ -57,7 +59,7 @@ def test_mnist(self):
5759

5860
input_fn = lambda: tf.random_uniform([3, 784])
5961
predictions_generator = classifier.predict(input_fn)
60-
for i in range(3):
62+
for _ in range(3):
6163
predictions = next(predictions_generator)
6264
self.assertEqual(predictions['probabilities'].shape, (10,))
6365
self.assertEqual(predictions['classes'].shape, ())
@@ -103,6 +105,7 @@ def test_mnist_model_fn_predict_mode(self):
103105

104106

105107
class Benchmarks(tf.test.Benchmark):
108+
"""Simple speed benchmarking for MNIST."""
106109

107110
def benchmark_train_step_time(self):
108111
classifier = make_estimator()

official/mnist/mnist_tpu.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
from __future__ import division
2424
from __future__ import print_function
2525

26-
import tensorflow as tf
26+
import tensorflow as tf # pylint: disable=g-bad-import-order
27+
2728
from official.mnist import dataset
2829
from official.mnist import mnist
2930

@@ -132,7 +133,7 @@ def main(argv):
132133
tf.logging.set_verbosity(tf.logging.INFO)
133134

134135
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
135-
FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
136+
FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
136137

137138
run_config = tf.contrib.tpu.RunConfig(
138139
cluster=tpu_cluster_resolver,

official/resnet/cifar10_download_and_extract.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
help='Directory to download data and extract the tarball')
3737

3838

39-
def main(unused_argv):
39+
def main(_):
4040
"""Download and extract the tarball from Alex's website."""
4141
if not os.path.exists(FLAGS.data_dir):
4242
os.makedirs(FLAGS.data_dir)

official/resnet/cifar10_main.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import os
2222
import sys
2323

24-
import tensorflow as tf
24+
import tensorflow as tf # pylint: disable=g-bad-import-order
2525

2626
from official.resnet import resnet_model
2727
from official.resnet import resnet_run_loop
@@ -127,22 +127,25 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
127127

128128
num_images = is_training and _NUM_IMAGES['train'] or _NUM_IMAGES['validation']
129129

130-
return resnet_run_loop.process_record_dataset(dataset, is_training, batch_size,
131-
_NUM_IMAGES['train'], parse_record, num_epochs, num_parallel_calls,
132-
examples_per_epoch=num_images, multi_gpu=multi_gpu)
130+
return resnet_run_loop.process_record_dataset(
131+
dataset, is_training, batch_size, _NUM_IMAGES['train'],
132+
parse_record, num_epochs, num_parallel_calls,
133+
examples_per_epoch=num_images, multi_gpu=multi_gpu)
133134

134135

135136
def get_synth_input_fn():
136-
return resnet_run_loop.get_synth_input_fn(_HEIGHT, _WIDTH, _NUM_CHANNELS, _NUM_CLASSES)
137+
return resnet_run_loop.get_synth_input_fn(
138+
_HEIGHT, _WIDTH, _NUM_CHANNELS, _NUM_CLASSES)
137139

138140

139141
###############################################################################
140142
# Running the model
141143
###############################################################################
142144
class Cifar10Model(resnet_model.Model):
145+
"""Model class with appropriate defaults for CIFAR-10 data."""
143146

144147
def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES,
145-
version=resnet_model.DEFAULT_VERSION):
148+
version=resnet_model.DEFAULT_VERSION):
146149
"""These are the parameters that work for CIFAR-10 data.
147150
148151
Args:
@@ -153,6 +156,9 @@ def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES,
153156
enables users to extend the same model to their own datasets.
154157
version: Integer representing which version of the ResNet network to use.
155158
See README for details. Valid values: [1, 2]
159+
160+
Raises:
161+
ValueError: if invalid resnet_size is chosen
156162
"""
157163
if resnet_size % 6 != 2:
158164
raise ValueError('resnet_size must be 6n + 2:', resnet_size)
@@ -195,7 +201,7 @@ def cifar10_model_fn(features, labels, mode, params):
195201
# for the CIFAR-10 dataset, perhaps because the regularization prevents
196202
# overfitting on the small data set. We therefore include all vars when
197203
# regularizing and computing loss during training.
198-
def loss_filter_fn(name):
204+
def loss_filter_fn(_):
199205
return True
200206

201207
return resnet_run_loop.resnet_model_fn(features, labels, mode, Cifar10Model,

official/resnet/cifar10_test.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from tempfile import mkstemp
2121

2222
import numpy as np
23-
import tensorflow as tf
23+
import tensorflow as tf # pylint: disable=g-bad-import-order
2424

2525
from official.resnet import cifar10_main
2626
from official.utils.testing import integration
@@ -34,6 +34,8 @@
3434

3535

3636
class BaseTest(tf.test.TestCase):
37+
"""Tests for the Cifar10 version of Resnet.
38+
"""
3739

3840
def tearDown(self):
3941
super(BaseTest, self).tearDown()
@@ -52,7 +54,7 @@ def test_dataset_input_fn(self):
5254
data_file.close()
5355

5456
fake_dataset = tf.data.FixedLengthRecordDataset(
55-
filename, cifar10_main._RECORD_BYTES)
57+
filename, cifar10_main._RECORD_BYTES) # pylint: disable=protected-access
5658
fake_dataset = fake_dataset.map(
5759
lambda val: cifar10_main.parse_record(val, False))
5860
image, label = fake_dataset.make_one_shot_iterator().get_next()
@@ -133,9 +135,11 @@ def test_cifar10model_shape(self):
133135
num_classes = 246
134136

135137
for version in (1, 2):
136-
model = cifar10_main.Cifar10Model(32, data_format='channels_last',
137-
num_classes=num_classes, version=version)
138-
fake_input = tf.random_uniform([batch_size, _HEIGHT, _WIDTH, _NUM_CHANNELS])
138+
model = cifar10_main.Cifar10Model(
139+
32, data_format='channels_last', num_classes=num_classes,
140+
version=version)
141+
fake_input = tf.random_uniform(
142+
[batch_size, _HEIGHT, _WIDTH, _NUM_CHANNELS])
139143
output = model(fake_input, training=True)
140144

141145
self.assertAllEqual(output.shape, (batch_size, num_classes))

0 commit comments

Comments
 (0)