18
18
from __future__ import division
19
19
from __future__ import print_function
20
20
21
+ import functools
22
+ import inspect
23
+
21
24
import neural_structured_learning .configs as nsl_configs
22
25
import neural_structured_learning .lib as nsl_lib
23
-
24
26
import tensorflow as tf
25
27
26
28
@@ -55,6 +57,10 @@ def add_adversarial_regularization(estimator,
55
57
adv_config = nsl_configs .AdvRegConfig ()
56
58
57
59
base_model_fn = estimator ._model_fn # pylint: disable=protected-access
60
+ try :
61
+ base_model_fn_args = inspect .signature (base_model_fn ).parameters .keys ()
62
+ except AttributeError : # For Python 2 compatibility
63
+ base_model_fn_args = inspect .getargspec (base_model_fn ).args # pylint: disable=deprecated-method
58
64
59
65
def adv_model_fn (features , labels , mode , params = None , config = None ):
60
66
"""The adversarial-regularized model_fn.
@@ -82,19 +88,22 @@ def adv_model_fn(features, labels, mode, params=None, config=None):
82
88
Returns:
83
89
A `tf.estimator.EstimatorSpec` with adversarial regularization.
84
90
"""
91
+ # Parameters 'params' and 'config' are optional. If they are not passed,
92
+ # then it is possible for base_model_fn not to accept these arguments.
93
+ # See documentation for tf.estimator.Estimator for additional context.
94
+ kwargs = {'mode' : mode }
95
+ if 'params' in base_model_fn_args :
96
+ kwargs ['params' ] = params
97
+ if 'config' in base_model_fn_args :
98
+ kwargs ['config' ] = config
99
+ base_fn = functools .partial (base_model_fn , ** kwargs )
85
100
86
101
# Uses the same variable scope for calculating the original objective and
87
102
# adversarial regularization.
88
103
with tf .compat .v1 .variable_scope (tf .compat .v1 .get_variable_scope (),
89
104
reuse = tf .compat .v1 .AUTO_REUSE ,
90
105
auxiliary_name_scope = False ):
91
- # If no 'params' is passed, then it is possible for base_model_fn not to
92
- # accept a 'params' argument. See documentation for tf.estimator.Estimator
93
- # for additional context.
94
- base_args = [mode , params , config ] if params else [mode , config ]
95
- spec_fn = lambda feature , label : base_model_fn (feature , label , * base_args )
96
-
97
- original_spec = spec_fn (features , labels )
106
+ original_spec = base_fn (features , labels )
98
107
99
108
# Adversarial regularization only happens in training.
100
109
if mode != tf .estimator .ModeKeys .TRAIN :
@@ -107,11 +116,11 @@ def adv_model_fn(features, labels, mode, params=None, config=None):
107
116
# The pgd_model_fn is a dummy identity function since loss is
108
117
# directly available from spec_fn.
109
118
pgd_model_fn = lambda features : features ,
110
- pgd_loss_fn = lambda labels , features : spec_fn (features , labels ).loss ,
119
+ pgd_loss_fn = lambda labels , features : base_fn (features , labels ).loss ,
111
120
pgd_labels = labels )
112
121
113
122
# Runs the base model again to compute loss on adv_neighbor.
114
- adv_spec = spec_fn (adv_neighbor , labels )
123
+ adv_spec = base_fn (adv_neighbor , labels )
115
124
116
125
final_loss = original_spec .loss + adv_config .multiplier * adv_spec .loss
117
126
0 commit comments