Skip to content

Commit cb79194

Browse files
committed
Add custom variable updater.
Allows customization for how variables are updated by the optimizer. The base optimizer simply defers to the update handler to do the update, allowing full customization. Can replace the existing `overwrite_with_gradient` attribute on variables, which currently is very application-specific. Eliminates creation of optimizer variables that have custom updaters (including `overwrite_with_gradient`), since those variables are never used and may be wasteful.
1 parent 37eacb0 commit cb79194

19 files changed

+305
-213
lines changed

keras/src/backend/common/variables.py

+28
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,8 @@ def __init__(
150150
self._autocast = bool(autocast)
151151
self._aggregation = aggregation
152152
self._synchronization = synchronization
153+
# Custom variable updater.
154+
self._updater = None
153155
# `self._overwrite_with_gradient` is an internal property to determine
154156
# whether this variable should be overwritten by the computed gradient.
155157
# Ref: https://github.com/google/flax/blob/main/flax/linen/fp8_ops.py
@@ -334,6 +336,29 @@ def path(self):
334336
"""The path of the variable within the Keras model or layer."""
335337
return self._path
336338

339+
@property
340+
def updater(self):
341+
"""Custom variable updater.
342+
343+
This property is designed for special-casing variable updates during
344+
training, such as quantized float8 `scale` and `amax_history`, where
345+
the gradients represent updated scale factors, or for updating large
346+
embedding tables, where we need to handle sparse updates to a dense
347+
table.
348+
"""
349+
return self._updater
350+
351+
@updater.setter
352+
def updater(self, updater):
353+
from keras.src import optimizers
354+
355+
if not isinstance(updater, optimizers.VariableUpdater):
356+
raise TypeError(
357+
"`updater` must be a `keras.optimizers.VariableUpdater`. "
358+
f"Received: {updater.__class__.__name__}."
359+
)
360+
self._updater = updater
361+
337362
@property
338363
def overwrite_with_gradient(self):
339364
"""Whether this variable should be overwritten by the gradient.
@@ -355,6 +380,9 @@ def overwrite_with_gradient(self, value):
355380
f"Received: {value}"
356381
)
357382
self._overwrite_with_gradient = value
383+
from keras.src import optimizers
384+
385+
self._updater = optimizers.OverwriteScaleWithGradientUpdater()
358386

359387
@property
360388
def regularizer(self):

keras/src/optimizers/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from keras.src.optimizers.muon import Muon
1212
from keras.src.optimizers.nadam import Nadam
1313
from keras.src.optimizers.optimizer import Optimizer
14+
from keras.src.optimizers.optimizer import OverwriteScaleWithGradientUpdater
15+
from keras.src.optimizers.optimizer import VariableUpdater
1416
from keras.src.optimizers.rmsprop import RMSprop
1517
from keras.src.optimizers.sgd import SGD
1618
from keras.src.saving import serialization_lib

keras/src/optimizers/adadelta.py

+6-9
Original file line numberDiff line numberDiff line change
@@ -75,15 +75,12 @@ def build(self, var_list):
7575
if self.built:
7676
return
7777
super().build(var_list)
78-
self._accumulated_grads = []
79-
self._accumulated_delta_vars = []
80-
for var in var_list:
81-
self._accumulated_grads.append(
82-
self.add_variable_from_reference(var, "accumulated_grad")
83-
)
84-
self._accumulated_delta_vars.append(
85-
self.add_variable_from_reference(var, "accumulated_delta_var")
86-
)
78+
self._accumulated_grads = self.add_optimizer_variables(
79+
var_list, "accumulated_grad"
80+
)
81+
self._accumulated_delta_vars = self.add_optimizer_variables(
82+
var_list, "accumulated_delta_var"
83+
)
8784

8885
def update_step(self, grad, variable, learning_rate):
8986
"""Update step given gradient and the associated model variable."""

keras/src/optimizers/adafactor.py

+13-15
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from keras.src import backend
21
from keras.src import ops
32
from keras.src.api_export import keras_export
43
from keras.src.optimizers import optimizer
@@ -97,16 +96,11 @@ def build(self, var_list):
9796
self._c = []
9897
self._v = []
9998
for var in var_list:
100-
if len(var.shape) < 2:
101-
# Don't factor if variable is of dimension < 2, but we still
102-
# need to create dummy variables as placeholder.
103-
with backend.name_scope(self.name, caller=self):
104-
self._r.append(
105-
backend.Variable(0, name=var.name, trainable=False)
106-
)
107-
self._c.append(
108-
backend.Variable(0, name=var.name, trainable=False)
109-
)
99+
variable_updater = self._get_variable_updater(var)
100+
if len(var.shape) < 2 or variable_updater is not None:
101+
# Don't factor if variable is of dimension < 2.
102+
self._r.append(None)
103+
self._c.append(None)
110104
else:
111105
# Always factor the last 2 dimensions.
112106
r_shape = var.shape[:-1]
@@ -125,11 +119,15 @@ def build(self, var_list):
125119
name=var.name,
126120
)
127121
)
128-
self._v.append(
129-
self.add_variable_from_reference(
130-
reference_variable=var, name="velocity"
122+
123+
if variable_updater is not None:
124+
self._v.append(None)
125+
else:
126+
self._v.append(
127+
self.add_variable_from_reference(
128+
reference_variable=var, name="velocity"
129+
)
131130
)
132-
)
133131

134132
def _rms(self, x):
135133
return ops.sqrt(ops.mean(ops.square(x)))

keras/src/optimizers/adagrad.py

+3-10
Original file line numberDiff line numberDiff line change
@@ -70,17 +70,10 @@ def build(self, var_list):
7070
if self.built:
7171
return
7272
super().build(var_list)
73-
self._accumulators = []
7473
initializer = initializers.Constant(self.initial_accumulator_value)
75-
for var in var_list:
76-
self._accumulators.append(
77-
self.add_variable(
78-
shape=var.shape,
79-
initializer=initializer,
80-
dtype=var.dtype,
81-
name="accumulator",
82-
)
83-
)
74+
self._accumulators = self.add_optimizer_variables(
75+
var_list, "accumulator", initializer=initializer
76+
)
8477

8578
def update_step(self, gradient, variable, learning_rate):
8679
"""Update step given gradient and the associated model variable."""

keras/src/optimizers/adam.py

+6-20
Original file line numberDiff line numberDiff line change
@@ -90,27 +90,13 @@ def build(self, var_list):
9090
if self.built:
9191
return
9292
super().build(var_list)
93-
self._momentums = []
94-
self._velocities = []
95-
for var in var_list:
96-
self._momentums.append(
97-
self.add_variable_from_reference(
98-
reference_variable=var, name="momentum"
99-
)
100-
)
101-
self._velocities.append(
102-
self.add_variable_from_reference(
103-
reference_variable=var, name="velocity"
104-
)
105-
)
93+
self._momentums = self.add_optimizer_variables(var_list, "momentum")
94+
self._velocities = self.add_optimizer_variables(var_list, "velocity")
95+
10696
if self.amsgrad:
107-
self._velocity_hats = []
108-
for var in var_list:
109-
self._velocity_hats.append(
110-
self.add_variable_from_reference(
111-
reference_variable=var, name="velocity_hat"
112-
)
113-
)
97+
self._velocity_hats = self.add_optimizer_variables(
98+
var_list, "velocity_hat"
99+
)
114100

115101
def update_step(self, gradient, variable, learning_rate):
116102
"""Update step given gradient and the associated model variable."""

keras/src/optimizers/adamax.py

+2-13
Original file line numberDiff line numberDiff line change
@@ -98,19 +98,8 @@ def build(self, var_list):
9898
if self.built:
9999
return
100100
super().build(var_list)
101-
self._m = []
102-
self._u = []
103-
for var in var_list:
104-
self._m.append(
105-
self.add_variable_from_reference(
106-
reference_variable=var, name="momentum"
107-
)
108-
)
109-
self._u.append(
110-
self.add_variable_from_reference(
111-
reference_variable=var, name="norm"
112-
)
113-
)
101+
self._m = self.add_optimizer_variables(var_list, "momentum")
102+
self._u = self.add_optimizer_variables(var_list, "norm")
114103

115104
def update_step(self, gradient, variable, learning_rate):
116105
"""Update step given gradient and the associated model variable."""

0 commit comments

Comments
 (0)