@@ -95,7 +95,7 @@ def build(self, input_shape):
95
95
96
96
def call (self , inputs , training = False ):
97
97
if training :
98
- self .normalize_weights ()
98
+ self ._update_weights ()
99
99
100
100
output = self .layer (inputs )
101
101
return output
@@ -105,35 +105,42 @@ def compute_output_shape(self, input_shape):
105
105
self .layer .compute_output_shape (input_shape ).as_list ()
106
106
)
107
107
108
+ def _update_weights (self ):
109
+ weights = self .kernel
110
+ vector_u = self .vector_u
111
+
112
+ kernel_weights , vector_u = tf .cond (
113
+ tf .reduce_all (tf .equal (weights , 0 )),
114
+ lambda : (weights , vector_u ),
115
+ lambda : self .normalize_weights (),
116
+ )
117
+ self .kernel .assign (kernel_weights )
118
+ self .vector_u .assign (vector_u )
119
+
108
120
def normalize_weights (self ):
109
121
"""Generate spectral normalized weights.
110
122
111
123
This method will update the value of `self.kernel` with the
112
124
spectral normalized value, so that the layer is ready for `call()`.
113
125
"""
114
-
115
- weights = tf .reshape (self .kernel , [- 1 , self .kernel_shape [- 1 ]])
126
+ # Initialize vector_v to hint the compiler it always exist.
116
127
vector_u = self .vector_u
117
-
118
- # check for zeroes weights
119
- if not tf .reduce_all (tf .equal (weights , 0.0 )):
120
- for _ in range (self .power_iterations ):
121
- vector_v = tf .math .l2_normalize (
122
- tf .matmul (vector_u , weights , transpose_b = True )
123
- )
124
- vector_u = tf .math .l2_normalize (tf .matmul (vector_v , weights ))
125
- vector_u = tf .stop_gradient (vector_u )
126
- vector_v = tf .stop_gradient (vector_v )
127
- sigma = tf .matmul (
128
- tf .matmul (vector_v , weights ), vector_u , transpose_b = True
129
- )
130
- self .vector_u .assign (tf .cast (vector_u , self .vector_u .dtype ))
131
- self .kernel .assign (
132
- tf .cast (
133
- tf .reshape (self .kernel / sigma , self .kernel_shape ),
134
- self .kernel .dtype ,
135
- )
128
+ vector_v = self .vector_u
129
+ weights = tf .reshape (self .kernel , [- 1 , self .kernel_shape [- 1 ]])
130
+ for _ in range (self .power_iterations ):
131
+ vector_v = tf .math .l2_normalize (
132
+ tf .matmul (vector_u , weights , transpose_b = True )
136
133
)
134
+ vector_u = tf .math .l2_normalize (tf .matmul (vector_v , weights ))
135
+ vector_u = tf .stop_gradient (vector_u )
136
+ vector_v = tf .stop_gradient (vector_v )
137
+ sigma = tf .matmul (
138
+ tf .matmul (vector_v , weights ),
139
+ vector_u ,
140
+ transpose_b = True ,
141
+ )
142
+ weights_normalized = tf .reshape (weights / sigma , self .kernel_shape )
143
+ return weights_normalized , vector_u
137
144
138
145
def get_config (self ):
139
146
config = {"power_iterations" : self .power_iterations }
0 commit comments