Skip to content

Allow per-variable optimizer, add DispatchOptimizer. #21196

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions keras/src/backend/common/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ def __init__(
# whether this variable should be overwritten by the computed gradient.
# Ref: https://github.com/google/flax/blob/main/flax/linen/fp8_ops.py
self._overwrite_with_gradient = False
# Per-variable optimizer.
self._optimizer = None
if isinstance(initializer, str):
from keras.src import initializers

Expand Down Expand Up @@ -372,6 +374,26 @@ def regularizer(self, value):
)
self._regularizer = value

@property
def optimizer(self):
"""Per-variable custom optimizer."""
return self._optimizer

@optimizer.setter
def optimizer(self, value):
from keras.src import optimizers

if isinstance(value, str):
value = optimizers.get(value)

if value is not None and not isinstance(value, optimizers.Optimizer):
raise ValueError(
"Invalid value for attribute `optimizer`. Expected an "
"instance of `keras.optimizers.Optimizer`, or `None`. "
f"Received: regularizer={value}"
)
self._optimizer = value

@property
def constraint(self):
return self._constraint
Expand Down
13 changes: 13 additions & 0 deletions keras/src/backend/common/variables_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from keras.src import backend
from keras.src import initializers
from keras.src import ops
from keras.src import optimizers
from keras.src.backend.common import dtypes
from keras.src.backend.common.variables import AutocastScope
from keras.src.backend.common.variables import shape_equal
Expand Down Expand Up @@ -419,6 +420,18 @@ def test_deferred_initialize_within_stateless_scope(self):
):
v._deferred_initialize()

def test_optimizer_setter(self):
v = backend.Variable(
initializer=initializers.RandomNormal(),
shape=(2, 2),
)
self.assertIsNone(v.optimizer)
v.optimizer = "sgd"
self.assertTrue(isinstance(v.optimizer, optimizers.Optimizer))

with self.assertRaisesRegex(ValueError, "Invalid value"):
v.optimizer = True


class VariableDtypeShapeNdimRepr(test_case.TestCase):
"""tests for dtype, shape, ndim, __repr__"""
Expand Down
23 changes: 19 additions & 4 deletions keras/src/backend/tensorflow/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,25 @@ def assign_sub(self, variable, value):
else:
variable.assign_sub(value)

def _var_key(self, variable):
def _convert_to_tf_variable(self, variable):
if isinstance(variable, backend.Variable):
variable = variable.value # Convert to tf.Variable
tf_variable = variable.value
# Copy additional properties.
if getattr(variable, "optimizer", None) is not None:
tf_variable.optimizer = variable.optimizer
if getattr(variable, "overwrite_with_gradient", False):
tf_variable.overwrite_with_gradient = True
return tf_variable
elif isinstance(variable, tf.Variable):
return variable
else:
raise ValueError(
f"Variable {variable} must be of type keras.Variable or "
f"tf.Variable, received {value.__class__.__name__}."
)

def _var_key(self, variable):
variable = self._convert_to_tf_variable(variable)
if hasattr(variable, "_distributed_container"):
variable = variable._distributed_container()
elif (
Expand All @@ -98,8 +114,7 @@ def weight_decay_fn(variable):
variable.assign_sub(variable * wd * lr)

for variable in variables:
if isinstance(variable, backend.Variable):
variable = variable.value # Convert to tf.Variable
variable = self._convert_to_tf_variable(variable)
distribution.extended.update(
variable, weight_decay_fn, group=False
)
Expand Down
1 change: 1 addition & 0 deletions keras/src/optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from keras.src.optimizers.adam import Adam
from keras.src.optimizers.adamax import Adamax
from keras.src.optimizers.adamw import AdamW
from keras.src.optimizers.dispatch_optimizer import DispatchOptimizer
from keras.src.optimizers.ftrl import Ftrl
from keras.src.optimizers.lion import Lion
from keras.src.optimizers.loss_scale_optimizer import LossScaleOptimizer
Expand Down
16 changes: 15 additions & 1 deletion keras/src/optimizers/base_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,13 +204,27 @@ def iterations(self):
def _track_variable(self, variable):
self._tracker.add_to_store("variables", variable)

def _has_custom_optimizer(self, variable):
return (
hasattr(variable, "optimizer")
and variable.optimizer is not None
and variable.optimizer != self
)

@tracking.no_automatic_dependency_tracking
def build(self, variables):
if self.use_ema:
self._model_variables_moving_average = []
if self.gradient_accumulation_steps:
self._accumulated_gradients = []
for i, variable in enumerate(variables):
if self._has_custom_optimizer(variable):
warnings.warn(
f"Variable {variable} has a custom optimizer "
f"{variable.optimizer} that is being ignored. "
"See `keras.optimizers.DispatchOptimizer` to allow "
"dispatching to the correct per-variable optimizer."
)
self._trainable_variables_indices[self._var_key(variable)] = i
if self.use_ema:
self._model_variables_moving_average.append(
Expand Down Expand Up @@ -568,7 +582,7 @@ def stateless_apply(self, optimizer_variables, grads, trainable_variables):
)
if len(trainable_variables) != len(self._trainable_variables):
raise ValueError(
"Argument `optimizer_variables` must be a list of tensors "
"Argument `trainable_variables` must be a list of tensors "
"corresponding 1:1 to the trainable variables list that "
"the optimizer was built with. Received "
f"len(trainable_variables) == {len(trainable_variables)} "
Expand Down
Loading