Skip to content

Commit eaa4052

Browse files
Optimize spectral_normalization implementation.
PiperOrigin-RevId: 722757020
1 parent 3f00c13 commit eaa4052

File tree

1 file changed

+29
-22
lines changed

1 file changed

+29
-22
lines changed

tf_keras/layers/normalization/spectral_normalization.py

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def build(self, input_shape):
9595

9696
def call(self, inputs, training=False):
9797
if training:
98-
self.normalize_weights()
98+
self._update_weights()
9999

100100
output = self.layer(inputs)
101101
return output
@@ -105,35 +105,42 @@ def compute_output_shape(self, input_shape):
105105
self.layer.compute_output_shape(input_shape).as_list()
106106
)
107107

108+
def _update_weights(self):
109+
weights = self.kernel
110+
vector_u = self.vector_u
111+
112+
kernel_weights, vector_u = tf.cond(
113+
tf.reduce_all(tf.equal(weights, 0)),
114+
lambda: (weights, vector_u),
115+
lambda: self.normalize_weights(),
116+
)
117+
self.kernel.assign(kernel_weights)
118+
self.vector_u.assign(vector_u)
119+
108120
def normalize_weights(self):
109121
"""Generate spectral normalized weights.
110122
111123
This method will update the value of `self.kernel` with the
112124
spectral normalized value, so that the layer is ready for `call()`.
113125
"""
114-
115-
weights = tf.reshape(self.kernel, [-1, self.kernel_shape[-1]])
126+
# Initialize vector_v to hint the compiler it always exist.
116127
vector_u = self.vector_u
117-
118-
# check for zeroes weights
119-
if not tf.reduce_all(tf.equal(weights, 0.0)):
120-
for _ in range(self.power_iterations):
121-
vector_v = tf.math.l2_normalize(
122-
tf.matmul(vector_u, weights, transpose_b=True)
123-
)
124-
vector_u = tf.math.l2_normalize(tf.matmul(vector_v, weights))
125-
vector_u = tf.stop_gradient(vector_u)
126-
vector_v = tf.stop_gradient(vector_v)
127-
sigma = tf.matmul(
128-
tf.matmul(vector_v, weights), vector_u, transpose_b=True
129-
)
130-
self.vector_u.assign(tf.cast(vector_u, self.vector_u.dtype))
131-
self.kernel.assign(
132-
tf.cast(
133-
tf.reshape(self.kernel / sigma, self.kernel_shape),
134-
self.kernel.dtype,
135-
)
128+
vector_v = self.vector_u
129+
weights = tf.reshape(self.kernel, [-1, self.kernel_shape[-1]])
130+
for _ in range(self.power_iterations):
131+
vector_v = tf.math.l2_normalize(
132+
tf.matmul(vector_u, weights, transpose_b=True)
136133
)
134+
vector_u = tf.math.l2_normalize(tf.matmul(vector_v, weights))
135+
vector_u = tf.stop_gradient(vector_u)
136+
vector_v = tf.stop_gradient(vector_v)
137+
sigma = tf.matmul(
138+
tf.matmul(vector_v, weights),
139+
vector_u,
140+
transpose_b=True,
141+
)
142+
weights_normalized = tf.reshape(weights / sigma, self.kernel_shape)
143+
return weights_normalized, vector_u
137144

138145
def get_config(self):
139146
config = {"power_iterations": self.power_iterations}

0 commit comments

Comments
 (0)