@@ -396,7 +396,7 @@ def _custom_getter(getter, *args, **kwargs):
396396 yield varscope
397397
398398
399- def set_precision_policy (policy_name = None , loss_scale = False ):
399+ def set_precision_policy (policy_name = None ):
400400 """Set precision policy according to the name.
401401
402402 Args:
@@ -410,14 +410,8 @@ def set_precision_policy(policy_name=None, loss_scale=False):
410410 assert policy_name in ('mixed_float16' , 'mixed_bfloat16' , 'float32' )
411411 logging .info ('use mixed precision policy name %s' , policy_name )
412412 tf .compat .v1 .keras .layers .enable_v2_dtype_behavior ()
413- # mixed_float16 training is not supported for now, so disable loss_scale.
414- # float32 and mixed_bfloat16 do not need loss scale for training.
415- if loss_scale :
416- policy = tf .keras .mixed_precision .experimental .Policy (policy_name )
417- else :
418- policy = tf .keras .mixed_precision .experimental .Policy (
419- policy_name , loss_scale = None )
420- tf .keras .mixed_precision .experimental .set_policy (policy )
413+ policy = tf .keras .mixed_precision .Policy (policy_name )
414+ tf .keras .mixed_precision .set_policy (policy )
421415
422416
423417def build_model_with_precision (pp , mm , ii , tt , * args , ** kwargs ):
@@ -438,14 +432,15 @@ def build_model_with_precision(pp, mm, ii, tt, *args, **kwargs):
438432 Returns:
439433 the output of mm model.
440434 """
435+ del tt
441436 if pp == 'mixed_bfloat16' :
442437 set_precision_policy (pp )
443438 inputs = tf .cast (ii , tf .bfloat16 )
444439 with tf .compat .v1 .tpu .bfloat16_scope ():
445440 outputs = mm (inputs , * args , ** kwargs )
446441 set_precision_policy ('float32' )
447442 elif pp == 'mixed_float16' :
448- set_precision_policy (pp , loss_scale = tt )
443+ set_precision_policy (pp )
449444 inputs = tf .cast (ii , tf .float16 )
450445 with float16_scope ():
451446 outputs = mm (inputs , * args , ** kwargs )
0 commit comments