17
17
from __future__ import division
18
18
from __future__ import print_function
19
19
20
+ import os
21
+
20
22
from absl .testing import parameterized
21
23
import neural_structured_learning .configs as configs
22
24
from neural_structured_learning .keras import graph_regularization
23
-
24
25
import numpy as np
25
26
import tensorflow as tf
26
27
@@ -88,10 +89,12 @@ def build_linear_functional_model(input_shape, weights, num_output=1):
88
89
def build_linear_subclass_model (input_shape , weights , num_output = 1 ):
89
90
del input_shape
90
91
91
- class LinearModel (tf .keras .Model ):
92
+ class CustomLinearModel (tf .keras .Model ):
92
93
93
- def __init__ (self ):
94
- super (LinearModel , self ).__init__ ()
94
+ def __init__ (self , weights , num_output , name = None ):
95
+ super (CustomLinearModel , self ).__init__ (name = name )
96
+ self .init_weights = weights
97
+ self .num_output = num_output
95
98
self .dense = tf .keras .layers .Dense (
96
99
num_output ,
97
100
use_bias = False ,
@@ -101,7 +104,14 @@ def __init__(self):
101
104
def call (self , inputs ):
102
105
return self .dense (inputs [FEATURE_NAME ])
103
106
104
- return LinearModel ()
107
+ def get_config (self ):
108
+ return {
109
+ 'name' : self .name ,
110
+ 'weights' : self .init_weights ,
111
+ 'num_output' : self .num_output
112
+ }
113
+
114
+ return CustomLinearModel (weights , num_output )
105
115
106
116
107
117
def make_dataset (example_proto , input_shape , training , max_neighbors ):
@@ -481,6 +491,47 @@ def test_graph_reg_model_evaluate(self, model_fn):
481
491
weight = w ,
482
492
distributed_strategy = None )
483
493
494
+ def _test_graph_reg_model_save (self , model_fn ):
495
+ """Template for testing model saving and loading."""
496
+ w = np .array ([[4.0 ], [- 3.0 ]])
497
+ base_model = model_fn ((2 ,), w )
498
+ graph_reg_config = configs .make_graph_reg_config (
499
+ max_neighbors = 1 , multiplier = 1 )
500
+ graph_reg_model = graph_regularization .GraphRegularization (
501
+ base_model , graph_reg_config )
502
+ graph_reg_model .compile (
503
+ optimizer = tf .keras .optimizers .SGD (LEARNING_RATE ),
504
+ loss = 'MSE' ,
505
+ metrics = ['accuracy' ])
506
+
507
+ # Run the model before saving it. This is necessary for subclassed models.
508
+ inputs = {FEATURE_NAME : tf .constant ([[5.0 , 3.0 ]])}
509
+ graph_reg_model .predict (inputs , steps = 1 , batch_size = 1 )
510
+ saved_model_dir = os .path .join (self .get_temp_dir (), 'saved_model' )
511
+ graph_reg_model .save (saved_model_dir )
512
+
513
+ loaded_model = tf .keras .models .load_model (saved_model_dir )
514
+ self .assertEqual (
515
+ len (loaded_model .trainable_weights ),
516
+ len (graph_reg_model .trainable_weights ))
517
+ for w_loaded , w_graph_reg in zip (loaded_model .trainable_weights ,
518
+ graph_reg_model .trainable_weights ):
519
+ self .assertAllClose (
520
+ tf .keras .backend .get_value (w_loaded ),
521
+ tf .keras .backend .get_value (w_graph_reg ))
522
+
523
+ @parameterized .named_parameters ([
524
+ ('_sequential' , build_linear_sequential_model ),
525
+ ('_functional' , build_linear_functional_model ),
526
+ ])
527
+ def test_graph_reg_model_save (self , model_fn ):
528
+ self ._test_graph_reg_model_save (model_fn )
529
+
530
+ # Saving subclassed models are only supported in TF v2.
531
+ @test_util .run_v2_only
532
+ def test_graph_reg_model_save_subclass (self ):
533
+ self ._test_graph_reg_model_save (build_linear_subclass_model )
534
+
484
535
485
536
if __name__ == '__main__' :
486
537
tf .test .main ()
0 commit comments