Skip to content

Commit d98aa3d

Browse files
csferngtensorflow-copybara
authored andcommitted
Only pass defined parameters to base Estimator's model_fn.
PiperOrigin-RevId: 315571871
1 parent 14bf59e commit d98aa3d

File tree

2 files changed

+34
-17
lines changed

2 files changed

+34
-17
lines changed

neural_structured_learning/estimator/adversarial_regularization.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21+
import functools
22+
import inspect
23+
2124
import neural_structured_learning.configs as nsl_configs
2225
import neural_structured_learning.lib as nsl_lib
23-
2426
import tensorflow as tf
2527

2628

@@ -55,6 +57,10 @@ def add_adversarial_regularization(estimator,
5557
adv_config = nsl_configs.AdvRegConfig()
5658

5759
base_model_fn = estimator._model_fn # pylint: disable=protected-access
60+
try:
61+
base_model_fn_args = inspect.signature(base_model_fn).parameters.keys()
62+
except AttributeError: # For Python 2 compatibility
63+
base_model_fn_args = inspect.getargspec(base_model_fn).args # pylint: disable=deprecated-method
5864

5965
def adv_model_fn(features, labels, mode, params=None, config=None):
6066
"""The adversarial-regularized model_fn.
@@ -82,19 +88,22 @@ def adv_model_fn(features, labels, mode, params=None, config=None):
8288
Returns:
8389
A `tf.estimator.EstimatorSpec` with adversarial regularization.
8490
"""
91+
# Parameters 'params' and 'config' are optional. If they are not passed,
92+
# then it is possible for base_model_fn not to accept these arguments.
93+
# See documentation for tf.estimator.Estimator for additional context.
94+
kwargs = {'mode': mode}
95+
if 'params' in base_model_fn_args:
96+
kwargs['params'] = params
97+
if 'config' in base_model_fn_args:
98+
kwargs['config'] = config
99+
base_fn = functools.partial(base_model_fn, **kwargs)
85100

86101
# Uses the same variable scope for calculating the original objective and
87102
# adversarial regularization.
88103
with tf.compat.v1.variable_scope(tf.compat.v1.get_variable_scope(),
89104
reuse=tf.compat.v1.AUTO_REUSE,
90105
auxiliary_name_scope=False):
91-
# If no 'params' is passed, then it is possible for base_model_fn not to
92-
# accept a 'params' argument. See documentation for tf.estimator.Estimator
93-
# for additional context.
94-
base_args = [mode, params, config] if params else [mode, config]
95-
spec_fn = lambda feature, label: base_model_fn(feature, label, *base_args)
96-
97-
original_spec = spec_fn(features, labels)
106+
original_spec = base_fn(features, labels)
98107

99108
# Adversarial regularization only happens in training.
100109
if mode != tf.estimator.ModeKeys.TRAIN:
@@ -107,11 +116,11 @@ def adv_model_fn(features, labels, mode, params=None, config=None):
107116
# The pgd_model_fn is a dummy identity function since loss is
108117
# directly available from spec_fn.
109118
pgd_model_fn=lambda features: features,
110-
pgd_loss_fn=lambda labels, features: spec_fn(features, labels).loss,
119+
pgd_loss_fn=lambda labels, features: base_fn(features, labels).loss,
111120
pgd_labels=labels)
112121

113122
# Runs the base model again to compute loss on adv_neighbor.
114-
adv_spec = spec_fn(adv_neighbor, labels)
123+
adv_spec = base_fn(adv_neighbor, labels)
115124

116125
final_loss = original_spec.loss + adv_config.multiplier * adv_spec.loss
117126

neural_structured_learning/estimator/graph_regularization.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20+
import inspect
21+
2022
import neural_structured_learning.configs as configs
2123
from neural_structured_learning.lib import distances
2224
from neural_structured_learning.lib import utils
@@ -51,6 +53,10 @@ def add_graph_regularization(estimator,
5153
graph_reg_config = configs.GraphRegConfig()
5254

5355
base_model_fn = estimator._model_fn # pylint: disable=protected-access
56+
try:
57+
base_model_fn_args = inspect.signature(base_model_fn).parameters.keys()
58+
except AttributeError: # For Python 2 compatibility
59+
base_model_fn_args = inspect.getargspec(base_model_fn).args # pylint: disable=deprecated-method
5460

5561
def graph_reg_model_fn(features, labels, mode, params=None, config=None):
5662
"""The graph-regularized model function.
@@ -79,6 +85,14 @@ def graph_reg_model_fn(features, labels, mode, params=None, config=None):
7985
Returns:
8086
A `tf.estimator.EstimatorSpec` with graph regularization.
8187
"""
88+
# Parameters 'params' and 'config' are optional. If they are not passed,
89+
# then it is possible for base_model_fn not to accept these arguments.
90+
# See documentation for tf.estimator.Estimator for additional context.
91+
kwargs = {'mode': mode}
92+
if 'params' in base_model_fn_args:
93+
kwargs['params'] = params
94+
if 'config' in base_model_fn_args:
95+
kwargs['config'] = config
8296

8397
# Uses the same variable scope for calculating the original objective and
8498
# the graph regularization loss term.
@@ -100,13 +114,7 @@ def graph_reg_model_fn(features, labels, mode, params=None, config=None):
100114
sample_features = utils.strip_neighbor_features(
101115
features, graph_reg_config.neighbor_config)
102116

103-
# If no 'params' is passed, then it is possible for base_model_fn not to
104-
# accept a 'params' argument. See documentation for tf.estimator.Estimator
105-
# for additional context.
106-
if params:
107-
base_spec = base_model_fn(sample_features, labels, mode, params, config)
108-
else:
109-
base_spec = base_model_fn(sample_features, labels, mode, config)
117+
base_spec = base_model_fn(sample_features, labels, **kwargs)
110118

111119
has_nbr_inputs = nbr_weights is not None and nbr_features
112120

0 commit comments

Comments
 (0)