Skip to content

Commit 8a3317d

Browse files
csferngtensorflow-copybara
authored andcommitted
Implement GraphRegularization.save().
PiperOrigin-RevId: 458086350
1 parent df810dd commit 8a3317d

File tree

3 files changed

+112
-9
lines changed

3 files changed

+112
-9
lines changed

neural_structured_learning/keras/adversarial_regularization_test.py

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from __future__ import print_function
1919

2020
import collections
21+
import os
2122

2223
from absl.testing import parameterized
2324
import neural_structured_learning.configs as configs
@@ -64,10 +65,12 @@ def build_linear_keras_functional_model(input_shape,
6465
def build_linear_keras_subclassed_model(input_shape, weights, dynamic=False):
6566
del input_shape
6667

67-
class LinearModel(tf.keras.Model):
68+
class CustomLinearModel(tf.keras.Model):
6869

69-
def __init__(self):
70-
super(LinearModel, self).__init__(dynamic=dynamic)
70+
def __init__(self, weights, name=None, dynamic=False):
71+
super(CustomLinearModel, self).__init__(name=name, dynamic=dynamic)
72+
self.init_weights = weights
73+
self.init_dynamic = dynamic
7174
self.dense = tf.keras.layers.Dense(
7275
weights.shape[-1],
7376
use_bias=False,
@@ -77,7 +80,14 @@ def __init__(self):
7780
def call(self, inputs):
7881
return self.dense(inputs['feature'])
7982

80-
return LinearModel()
83+
def get_config(self):
84+
return {
85+
'name': self.name,
86+
'weights': self.init_weights,
87+
'dynamic': self.init_dynamic
88+
}
89+
90+
return CustomLinearModel(weights, dynamic=dynamic)
8191

8292

8393
def build_linear_keras_dynamic_model(input_shape, weights):
@@ -728,6 +738,42 @@ def test_perturb_on_batch_pgd(self, model_fn):
728738
self.assertAllClose(x_adv, adv_inputs['feature'])
729739
self.assertAllClose(y0, adv_inputs['label'])
730740

741+
def _test_adv_model_save(self, model_fn):
742+
"""Template for testing model saving and loading."""
743+
w, x0, y0, lr, adv_config, _ = self._set_up_linear_regression()
744+
model = model_fn(input_shape=(2,), weights=w)
745+
adv_model = adversarial_regularization.AdversarialRegularization(
746+
model, label_keys=['label'], adv_config=adv_config)
747+
adv_model.compile(optimizer=tf.keras.optimizers.SGD(lr), loss=['MAE'])
748+
749+
# Run the model before saving it. This is necessary for subclassed models.
750+
inputs = {'feature': x0, 'label': y0}
751+
adv_model.evaluate(inputs, steps=1)
752+
753+
saved_model_dir = os.path.join(self.get_temp_dir(), 'saved_model')
754+
adv_model.save(saved_model_dir)
755+
756+
loaded_model = tf.keras.models.load_model(saved_model_dir)
757+
self.assertEqual(
758+
len(loaded_model.trainable_weights), len(adv_model.trainable_weights))
759+
for w_loaded, w_adv in zip(loaded_model.trainable_weights,
760+
adv_model.trainable_weights):
761+
self.assertAllClose(
762+
tf.keras.backend.get_value(w_loaded),
763+
tf.keras.backend.get_value(w_adv))
764+
765+
@parameterized.named_parameters([
766+
('sequential', build_linear_keras_sequential_model),
767+
('functional', build_linear_keras_functional_model),
768+
])
769+
def test_adv_model_save(self, model_fn):
770+
self._test_adv_model_save(model_fn)
771+
772+
# Saving subclassed models are only supported in TF v2.
773+
@test_util.run_v2_only
774+
def test_adv_model_save_subclassed(self):
775+
self._test_adv_model_save(build_linear_keras_subclassed_model)
776+
731777

732778
if __name__ == '__main__':
733779
tf.test.main()

neural_structured_learning/keras/graph_regularization.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,3 +139,9 @@ def call(self, inputs, training=False, **kwargs):
139139
self.add_loss(scaled_graph_loss)
140140

141141
return base_output
142+
143+
def save(self, *args, **kwargs):
144+
"""Saves the base model. See base class for details of the interface."""
145+
# Graph regularization doesn't introduce new model variables, so saving the
146+
# base model can capture all variables in the model.
147+
self.base_model.save(*args, **kwargs)

neural_structured_learning/keras/graph_regularization_test.py

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20+
import os
21+
2022
from absl.testing import parameterized
2123
import neural_structured_learning.configs as configs
2224
from neural_structured_learning.keras import graph_regularization
23-
2425
import numpy as np
2526
import tensorflow as tf
2627

@@ -88,10 +89,12 @@ def build_linear_functional_model(input_shape, weights, num_output=1):
8889
def build_linear_subclass_model(input_shape, weights, num_output=1):
8990
del input_shape
9091

91-
class LinearModel(tf.keras.Model):
92+
class CustomLinearModel(tf.keras.Model):
9293

93-
def __init__(self):
94-
super(LinearModel, self).__init__()
94+
def __init__(self, weights, num_output, name=None):
95+
super(CustomLinearModel, self).__init__(name=name)
96+
self.init_weights = weights
97+
self.num_output = num_output
9598
self.dense = tf.keras.layers.Dense(
9699
num_output,
97100
use_bias=False,
@@ -101,7 +104,14 @@ def __init__(self):
101104
def call(self, inputs):
102105
return self.dense(inputs[FEATURE_NAME])
103106

104-
return LinearModel()
107+
def get_config(self):
108+
return {
109+
'name': self.name,
110+
'weights': self.init_weights,
111+
'num_output': self.num_output
112+
}
113+
114+
return CustomLinearModel(weights, num_output)
105115

106116

107117
def make_dataset(example_proto, input_shape, training, max_neighbors):
@@ -481,6 +491,47 @@ def test_graph_reg_model_evaluate(self, model_fn):
481491
weight=w,
482492
distributed_strategy=None)
483493

494+
def _test_graph_reg_model_save(self, model_fn):
495+
"""Template for testing model saving and loading."""
496+
w = np.array([[4.0], [-3.0]])
497+
base_model = model_fn((2,), w)
498+
graph_reg_config = configs.make_graph_reg_config(
499+
max_neighbors=1, multiplier=1)
500+
graph_reg_model = graph_regularization.GraphRegularization(
501+
base_model, graph_reg_config)
502+
graph_reg_model.compile(
503+
optimizer=tf.keras.optimizers.SGD(LEARNING_RATE),
504+
loss='MSE',
505+
metrics=['accuracy'])
506+
507+
# Run the model before saving it. This is necessary for subclassed models.
508+
inputs = {FEATURE_NAME: tf.constant([[5.0, 3.0]])}
509+
graph_reg_model.predict(inputs, steps=1, batch_size=1)
510+
saved_model_dir = os.path.join(self.get_temp_dir(), 'saved_model')
511+
graph_reg_model.save(saved_model_dir)
512+
513+
loaded_model = tf.keras.models.load_model(saved_model_dir)
514+
self.assertEqual(
515+
len(loaded_model.trainable_weights),
516+
len(graph_reg_model.trainable_weights))
517+
for w_loaded, w_graph_reg in zip(loaded_model.trainable_weights,
518+
graph_reg_model.trainable_weights):
519+
self.assertAllClose(
520+
tf.keras.backend.get_value(w_loaded),
521+
tf.keras.backend.get_value(w_graph_reg))
522+
523+
@parameterized.named_parameters([
524+
('_sequential', build_linear_sequential_model),
525+
('_functional', build_linear_functional_model),
526+
])
527+
def test_graph_reg_model_save(self, model_fn):
528+
self._test_graph_reg_model_save(model_fn)
529+
530+
# Saving subclassed models are only supported in TF v2.
531+
@test_util.run_v2_only
532+
def test_graph_reg_model_save_subclass(self):
533+
self._test_graph_reg_model_save(build_linear_subclass_model)
534+
484535

485536
if __name__ == '__main__':
486537
tf.test.main()

0 commit comments

Comments
 (0)