Skip to content

Commit 4826e7e

Browse files
committed
Add custom variable updater.
Allows customization for how variables are updated by the optimizer. The base optimizer simply defers to the update handler to do the update, allowing full customization. Can replace the existing `overwrite_with_gradient` attribute on variables, which currently is very application-specific.
1 parent 37eacb0 commit 4826e7e

File tree

7 files changed

+178
-53
lines changed

7 files changed

+178
-53
lines changed

keras/src/backend/common/variables.py

+28
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,8 @@ def __init__(
150150
self._autocast = bool(autocast)
151151
self._aggregation = aggregation
152152
self._synchronization = synchronization
153+
# Custom variable updater.
154+
self._updater = None
153155
# `self._overwrite_with_gradient` is an internal property to determine
154156
# whether this variable should be overwritten by the computed gradient.
155157
# Ref: https://github.com/google/flax/blob/main/flax/linen/fp8_ops.py
@@ -334,6 +336,29 @@ def path(self):
334336
"""The path of the variable within the Keras model or layer."""
335337
return self._path
336338

339+
@property
340+
def updater(self):
341+
"""Custom variable updater.
342+
343+
This property is designed for special-casing variable updates during
344+
training, such as quantized float8 `scale` and `amax_history`, where
345+
the gradients represent updated scale factors, or for updating large
346+
embedding tables, where we need to handle sparse updates to a dense
347+
table.
348+
"""
349+
return self._updater
350+
351+
@updater.setter
352+
def updater(self, updater):
353+
from keras.src import optimizers
354+
355+
if not isinstance(updater, optimizers.VariableUpdater):
356+
raise TypeError(
357+
"`updater` must be a `keras.optimizers.VariableUpdater`. "
358+
f"Received: {updater.__class__.__name__}."
359+
)
360+
self._updater = updater
361+
337362
@property
338363
def overwrite_with_gradient(self):
339364
"""Whether this variable should be overwritten by the gradient.
@@ -355,6 +380,9 @@ def overwrite_with_gradient(self, value):
355380
f"Received: {value}"
356381
)
357382
self._overwrite_with_gradient = value
383+
from keras.src import optimizers
384+
385+
self._updater = optimizers.OverwriteScaleWithGradientUpdater()
358386

359387
@property
360388
def regularizer(self):

keras/src/optimizers/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from keras.src.optimizers.muon import Muon
1212
from keras.src.optimizers.nadam import Nadam
1313
from keras.src.optimizers.optimizer import Optimizer
14+
from keras.src.optimizers.optimizer import OverwriteScaleWithGradientUpdater
15+
from keras.src.optimizers.optimizer import VariableUpdater
1416
from keras.src.optimizers.rmsprop import RMSprop
1517
from keras.src.optimizers.sgd import SGD
1618
from keras.src.saving import serialization_lib

keras/src/optimizers/base_optimizer.py

+21-42
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,9 @@ def iterations(self):
204204
def _track_variable(self, variable):
205205
self._tracker.add_to_store("variables", variable)
206206

207+
def _get_variable_updater(self, variable):
208+
return getattr(variable, "updater", None)
209+
207210
@tracking.no_automatic_dependency_tracking
208211
def build(self, variables):
209212
if self.use_ema:
@@ -212,6 +215,11 @@ def build(self, variables):
212215
self._accumulated_gradients = []
213216
for i, variable in enumerate(variables):
214217
self._trainable_variables_indices[self._var_key(variable)] = i
218+
custom_updater = self._get_variable_updater(variable)
219+
if custom_updater is not None:
220+
# Build the updater.
221+
custom_updater.build(self, variable)
222+
215223
if self.use_ema:
216224
self._model_variables_moving_average.append(
217225
self.add_variable_from_reference(
@@ -431,10 +439,8 @@ def apply(self, grads, trainable_variables=None):
431439

432440
# Overwrite targeted variables directly with their gradients if
433441
# their `overwrite_with_gradient` is set.
434-
grads, trainable_variables = (
435-
self._overwrite_variables_directly_with_gradients(
436-
grads, trainable_variables
437-
)
442+
grads, trainable_variables = self.__handle_custom_updaters(
443+
grads, trainable_variables
438444
)
439445

440446
if len(list(grads)) == 0:
@@ -698,21 +704,14 @@ def _get_current_learning_rate(self):
698704
return self._learning_rate()
699705
return self._learning_rate
700706

701-
def _overwrite_variables_directly_with_gradients(self, grads, vars):
702-
"""Overwrite the variables directly by their gradients.
703-
704-
This method is designed for a special case where we want to overwrite
705-
the variable directly with its computed gradient. For example, in float8
706-
training, new `scale` and `amax_history` are computed as gradients, and
707-
we want to overwrite them directly instead of following the typical
708-
procedure such as gradient descent with a learning rate, gradient
709-
clipping and weight decaying.
707+
def __handle_custom_updaters(self, grads, vars):
708+
"""Update any variable that has a custom updater.
710709
711710
After the update, the processed pairs will be filtered out.
712711
"""
713712
# Shortcut for `tf.Variable` because it doesn't have a
714-
# `overwrite_with_gradient` attr
715-
if any(not hasattr(v, "overwrite_with_gradient") for v in vars):
713+
# `updater` attr.
714+
if not any(self._get_variable_updater(v) is not None for v in vars):
716715
return grads, vars
717716

718717
# Shallow copies
@@ -722,33 +721,8 @@ def _overwrite_variables_directly_with_gradients(self, grads, vars):
722721
# Iterate from right to left for safe popping
723722
for i in range(len(filtered_grads) - 1, -1, -1):
724723
g, v = filtered_grads[i], filtered_vars[i]
725-
if v.overwrite_with_gradient:
726-
if self.gradient_accumulation_steps:
727-
# Utilize a stateless manner for JAX compatibility
728-
steps = self.gradient_accumulation_steps
729-
is_update_step = (self._iterations + 1) % steps == 0
730-
acc_g = self._accumulated_gradients[
731-
self._get_variable_index(v)
732-
]
733-
# `ops.maximum` is utilized for gradient accumulation for
734-
# `overwrite_with_gradient=True` variables
735-
new_g_acc = ops.cond(
736-
is_update_step,
737-
lambda: ops.zeros(g.shape, dtype=g.dtype),
738-
lambda: ops.maximum(g, acc_g),
739-
)
740-
new_g = ops.cond(
741-
is_update_step,
742-
lambda: ops.maximum(g, acc_g),
743-
lambda: g,
744-
)
745-
new_v = ops.cond(
746-
is_update_step, lambda: new_g, lambda: v.value
747-
)
748-
v.assign(new_v)
749-
acc_g.assign(new_g_acc)
750-
else:
751-
v.assign(g)
724+
if v.updater:
725+
v.updater.update_step(g, v)
752726
filtered_grads.pop(i)
753727
filtered_vars.pop(i)
754728
return filtered_grads, filtered_vars
@@ -926,6 +900,11 @@ def finalize_variable_values(self, var_list):
926900
# optimizer.
927901
self._overwrite_model_variables_with_average_value(var_list)
928902

903+
for var in var_list:
904+
updater = self._get_variable_updater(var)
905+
if updater is not None:
906+
updater.finalize_variable_value(var)
907+
929908
def _obj_type(self):
930909
return "Optimizer"
931910

keras/src/optimizers/loss_scale_optimizer.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ def __init__(
6060
self.inner_optimizer = inner_optimizer
6161
self.initial_scale = initial_scale
6262
self.dynamic_growth_steps = dynamic_growth_steps
63+
# Disable the inner optimizer's loss scaling, otherwise
64+
# gradients will be scaled twice.
65+
self.inner_optimizer.loss_scale_factor = None
6366

6467
@tracking.no_automatic_dependency_tracking
6568
def build(self, var_list):
@@ -102,12 +105,6 @@ def stateless_apply(self, optimizer_variables, grads, trainable_variables):
102105
),
103106
)
104107

105-
def _overwrite_variable_with_gradient(self, variable):
106-
return (
107-
hasattr(variable, "overwrite_with_gradient")
108-
and variable.overwrite_with_gradient
109-
)
110-
111108
def _stateless_handle_finite_grads(
112109
self, optimizer_variables, grads, trainable_variables
113110
):
@@ -137,7 +134,7 @@ def increment():
137134
scale = self.dynamic_scale
138135
unscaled_grads = [
139136
g
140-
if g is None or self._overwrite_variable_with_gradient(v)
137+
if g is None or self._get_variable_updater(v) is not None
141138
else ops.divide(g, scale)
142139
for g, v in zip(grads, trainable_variables)
143140
]
@@ -183,7 +180,7 @@ def _stateful_handle_finite_grads(self, grads, trainable_variables):
183180
tvs = trainable_variables or self._trainable_variables
184181
unscaled_grads = [
185182
g
186-
if g is None or self._overwrite_variable_with_gradient(v)
183+
if g is None or self._get_variable_updater(v) is not None
187184
else ops.divide(g, scale)
188185
for g, v in zip(grads, tvs)
189186
]

keras/src/optimizers/loss_scale_optimizer_test.py

+20
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,26 @@ def test_finite_step(self, stateless):
4848
vars, [[0.5, -1.0, -0.5, 3.0]], rtol=1e-4, atol=1e-4
4949
)
5050

51+
@parameterized.named_parameters(("stateless", True), ("stateful", False))
52+
def test_finite_step_with_inner_loss_scale(self, stateless):
53+
self._skip_test_for_stateless(stateless)
54+
55+
# Ensure that the inner loss scale does not interfere with the update.
56+
inner_optimizer = SGD(learning_rate=0.5, loss_scale_factor=100)
57+
optimizer = LossScaleOptimizer(inner_optimizer)
58+
grads = [ops.array([1.0, 6.0, 7.0, 2.0]) * optimizer.initial_scale]
59+
vars = [backend.Variable([1.0, 2.0, 3.0, 4.0])]
60+
if stateless:
61+
optimizer.build(vars)
62+
vars, _ = optimizer.stateless_apply(
63+
optimizer.variables, grads, vars
64+
)
65+
else:
66+
optimizer.apply(grads, vars)
67+
self.assertAllClose(
68+
vars, [[0.5, -1.0, -0.5, 3.0]], rtol=1e-4, atol=1e-4
69+
)
70+
5171
@parameterized.named_parameters(("stateless", True), ("stateful", False))
5272
def test_infinite_step(self, stateless):
5373
self._skip_test_for_stateless(stateless)

keras/src/optimizers/optimizer.py

+86
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from keras.src import backend
2+
from keras.src import ops
23
from keras.src.api_export import keras_export
34
from keras.src.optimizers import base_optimizer
45

@@ -23,5 +24,90 @@ class Optimizer(BackendOptimizer, base_optimizer.BaseOptimizer):
2324
pass
2425

2526

27+
@keras_export("keras.optimizers.VariableUpdater")
28+
class VariableUpdater:
29+
"""Allows special handling of variable updates."""
30+
31+
def build(self, optimizer, variable):
32+
"""Set up any state that might depend on the optimizer.
33+
34+
This may add variables directly to the optimizer for updating state.
35+
36+
Args:
37+
optimizer: The optimizer used to update the variables during training.
38+
variable: Variable to update.
39+
"""
40+
pass
41+
42+
def update_step(self, gradient, variable):
43+
"""Update the variable state using the supplied gradient.
44+
45+
Args:
46+
gradient: Gradient for the variable.
47+
variable: Variable to update.
48+
"""
49+
pass
50+
51+
def finalize_variable_value(self, variable):
52+
"""Set the final value of the trainable variable.
53+
54+
Sometimes there are some extra steps before ending the variable updates,
55+
such as overriding the model variables with its average value.
56+
57+
Args:
58+
variable: Variable to finalize.
59+
"""
60+
pass
61+
62+
63+
class OverwriteScaleWithGradientUpdater(VariableUpdater):
64+
"""Special variable update handler for float8 quantization scales.
65+
66+
The "gradient" of the scale factor (scale, amax_history) is actually the
67+
updated scale to assign to the variable. Supports gradient accumulation
68+
steps, in which the maximum scale factor between intermediate gradient
69+
steps is recorded.
70+
"""
71+
72+
def build(self, optimizer, variable):
73+
# Keep reference copy of iterations so we can update gradient
74+
# accumulators appropriately.
75+
self._iterations = optimizer._iterations
76+
# Support gradient accumulation by adding an accumulator directly
77+
# to the optimizer.
78+
self._gradient_accumulation_steps = (
79+
optimizer.gradient_accumulation_steps
80+
)
81+
if self._gradient_accumulation_steps:
82+
self.gradient_accumulator = optimizer.add_variable_from_reference(
83+
reference_variable=variable, name="gradient_accumulation"
84+
)
85+
86+
def update_step(self, gradient, variable):
87+
if self._gradient_accumulation_steps:
88+
# Utilize a stateless manner for JAX compatibility
89+
steps = self._gradient_accumulation_steps
90+
is_update_step = (self._iterations + 1) % steps == 0
91+
# Keep track of the maximum scale factor encountered.
92+
new_g_acc = ops.cond(
93+
is_update_step,
94+
lambda: ops.zeros(gradient.shape, dtype=gradient.dtype),
95+
lambda: ops.maximum(gradient, self.gradient_accumulator),
96+
)
97+
new_g = ops.cond(
98+
is_update_step,
99+
lambda: ops.maximum(gradient, self.gradient_accumulator),
100+
lambda: gradient,
101+
)
102+
new_v = ops.cond(
103+
is_update_step, lambda: new_g, lambda: variable.value
104+
)
105+
variable.assign(new_v)
106+
self.gradient_accumulator.assign(new_g_acc)
107+
else:
108+
# Assign scale "gradient" directly to variable.
109+
variable.assign(gradient)
110+
111+
26112
Optimizer.__doc__ = base_optimizer.BaseOptimizer.__doc__
27113
base_optimizer_keyword_args = base_optimizer.base_optimizer_keyword_args

keras/src/optimizers/optimizer_test.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,19 @@ def test_constraints_are_applied(self):
152152
optimizer.apply_gradients([(grad, v)])
153153
self.assertAlmostEqual(np.min(v), 0.0)
154154

155+
def test_custom_variable_updater(self):
156+
class IncrementVariable(optimizers.VariableUpdater):
157+
def update_step(self, gradient, variable):
158+
variable.assign_add(1.0)
159+
160+
orig_value = np.random.random((2, 2)) - 1.0
161+
v = backend.Variable(orig_value)
162+
v.updater = IncrementVariable()
163+
optimizer = optimizers.SGD(learning_rate=0.0001)
164+
grad = backend.numpy.zeros((2, 2))
165+
optimizer.apply_gradients([(grad, v)])
166+
self.assertAllClose(v, orig_value + 1)
167+
155168
def test_get_method(self):
156169
obj = optimizers.get("sgd")
157170
self.assertIsInstance(obj, optimizers.SGD)
@@ -298,7 +311,7 @@ def test_overwrite_with_gradient_with_gradient_accumulation(self):
298311
self.assertAllClose(v, [[1.0, 2.0], [3.0, 4.0]])
299312
self.assertAllClose(v2, [[1.0, 2.0], [3.0, 4.0]])
300313
self.assertAllClose(
301-
optimizer._accumulated_gradients[0], [[1.0, 1.0], [1.0, 1.0]]
314+
v.updater.gradient_accumulator, [[1.0, 1.0], [1.0, 1.0]]
302315
)
303316
self.assertAllClose(
304317
optimizer._accumulated_gradients[1], [[1.0, 1.0], [1.0, 1.0]]
@@ -311,7 +324,7 @@ def test_overwrite_with_gradient_with_gradient_accumulation(self):
311324
self.assertAllClose(v, [[2.0, 2.0], [2.0, 2.0]])
312325
self.assertAllClose(v2, [[-0.5, 0.5], [1.5, 2.5]])
313326
self.assertAllClose(
314-
optimizer._accumulated_gradients[0], [[0.0, 0.0], [0.0, 0.0]]
327+
v.updater.gradient_accumulator, [[0.0, 0.0], [0.0, 0.0]]
315328
)
316329
self.assertAllClose(
317330
optimizer._accumulated_gradients[1], [[0.0, 0.0], [0.0, 0.0]]
@@ -324,7 +337,7 @@ def test_overwrite_with_gradient_with_gradient_accumulation(self):
324337
self.assertAllClose(v, [[2.0, 2.0], [2.0, 2.0]])
325338
self.assertAllClose(v2, [[-0.5, 0.5], [1.5, 2.5]])
326339
self.assertAllClose(
327-
optimizer._accumulated_gradients[0], [[1.0, 1.0], [1.0, 1.0]]
340+
v.updater.gradient_accumulator, [[1.0, 1.0], [1.0, 1.0]]
328341
)
329342
self.assertAllClose(
330343
optimizer._accumulated_gradients[1], [[1.0, 1.0], [1.0, 1.0]]

0 commit comments

Comments
 (0)