Skip to content

Commit 4b8fe70

Browse files
author
Taylor Robie
authored
Forbid ResNet v1 from running with fp16 (tensorflow#4207)
* forbid resnet v1 fp16 * address PR comments
1 parent 911a0d2 commit 4b8fe70

File tree

4 files changed

+27
-2
lines changed

4 files changed

+27
-2
lines changed

official/resnet/cifar10_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,13 @@ def test_cifar10_end_to_end_synthetic_v2(self):
164164
extra_flags=['-resnet_version', '2']
165165
)
166166

167+
def test_flag_restriction(self):
168+
with self.assertRaises(SystemExit):
169+
integration.run_synthetic(
170+
main=cifar10_main.run_cifar, tmp_root=self.get_temp_dir(),
171+
extra_flags=['-resnet_version', '1', "-dtype", "fp16"]
172+
)
173+
167174

168175
if __name__ == '__main__':
169176
tf.test.main()

official/resnet/imagenet_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,13 @@ def test_imagenet_end_to_end_synthetic_v2_huge(self):
303303
extra_flags=['-resnet_version', '2', '-resnet_size', '200']
304304
)
305305

306+
def test_flag_restriction(self):
307+
with self.assertRaises(SystemExit):
308+
integration.run_synthetic(
309+
main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
310+
extra_flags=['-resnet_version', '1', '-dtype', 'fp16']
311+
)
312+
306313

307314
if __name__ == '__main__':
308315
tf.test.main()

official/resnet/resnet_run_loop.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,17 @@
2525

2626
import os
2727

28+
# pylint: disable=g-bad-import-order
2829
from absl import flags
29-
import tensorflow as tf # pylint: disable=g-bad-import-order
30+
import tensorflow as tf
3031

3132
from official.resnet import resnet_model
3233
from official.utils.flags import core as flags_core
3334
from official.utils.export import export
3435
from official.utils.logs import hooks_helper
3536
from official.utils.logs import logger
3637
from official.utils.misc import model_helpers
38+
# pylint: enable=g-bad-import-order
3739

3840

3941
################################################################################
@@ -462,7 +464,6 @@ def define_resnet_flags(resnet_size_choices=None):
462464
help=flags_core.help_wrap(
463465
'Version of ResNet. (1 or 2) See README.md for details.'))
464466

465-
466467
choice_kwargs = dict(
467468
name='resnet_size', short_name='rs', default='50',
468469
help=flags_core.help_wrap('The size of the ResNet model to use.'))
@@ -471,3 +472,12 @@ def define_resnet_flags(resnet_size_choices=None):
471472
flags.DEFINE_string(**choice_kwargs)
472473
else:
473474
flags.DEFINE_enum(enum_values=resnet_size_choices, **choice_kwargs)
475+
476+
# The current implementation of ResNet v1 is numerically unstable when run
477+
# with fp16 and will produce NaN errors soon after training begins.
478+
msg = ('ResNet version 1 is not currently supported with fp16. '
479+
'Please use version 2 instead.')
480+
@flags.multi_flags_validator(['dtype', 'resnet_version'], message=msg)
481+
def _forbid_v1_fp16(flag_values): # pylint: disable=unused-variable
482+
return (flags_core.DTYPE_MAP[flag_values['dtype']][0] != tf.float16 or
483+
flag_values['resnet_version'] != '1')

official/utils/flags/core.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,4 @@ def core_fn(*args, **kwargs):
8282
get_num_gpus = _base.get_num_gpus
8383
get_tf_dtype = _performance.get_tf_dtype
8484
get_loss_scale = _performance.get_loss_scale
85+
DTYPE_MAP = _performance.DTYPE_MAP

0 commit comments

Comments
 (0)