From 3499116754c01106a4be803c05292bb11ded5a5b Mon Sep 17 00:00:00 2001 From: Antonio Sanchez Date: Mon, 21 Apr 2025 16:07:55 -0700 Subject: [PATCH] Allow per-variable optimizer, add DispatchOptimizer. - Adds a property `variable.optimizer` that defaults to `None` - Adds a `DispatchOptimizer` that scans the list of trainable variables during build, collects all unique per-variable optimizers, then dispatches the apply/stateless_apply function to the correct optimizer if applicable. - Modifies `trainer` so that during the optimizer build stage, checks if any variables have a custom optimizer attached, and if so inserts a `DispatchOptimizer` to properly handle them. This allows usage to be hidden from the user. Context: for large embedding tables, we need special optimizers to be used so that the tables can be updated in-place, rather than returning large gradients. The layer will handle setting of the custom optimizers, but we need the trainer to be aware of them and dispatch the embedding tables to different optimizers appropriately. --- keras/src/backend/common/variables.py | 22 ++ keras/src/backend/common/variables_test.py | 13 + keras/src/backend/tensorflow/optimizer.py | 23 +- keras/src/optimizers/__init__.py | 1 + keras/src/optimizers/base_optimizer.py | 16 +- keras/src/optimizers/dispatch_optimizer.py | 294 ++++++++++++++++++ .../src/optimizers/dispatch_optimizer_test.py | 114 +++++++ keras/src/trainers/trainer.py | 38 ++- keras/src/trainers/trainer_test.py | 34 ++ 9 files changed, 548 insertions(+), 7 deletions(-) create mode 100644 keras/src/optimizers/dispatch_optimizer.py create mode 100644 keras/src/optimizers/dispatch_optimizer_test.py diff --git a/keras/src/backend/common/variables.py b/keras/src/backend/common/variables.py index 47ce553f1e98..ec64bfd5957f 100644 --- a/keras/src/backend/common/variables.py +++ b/keras/src/backend/common/variables.py @@ -154,6 +154,8 @@ def __init__( # whether this variable should be overwritten by the computed gradient. # Ref: https://github.com/google/flax/blob/main/flax/linen/fp8_ops.py self._overwrite_with_gradient = False + # Per-variable optimizer. + self._optimizer = None if isinstance(initializer, str): from keras.src import initializers @@ -372,6 +374,26 @@ def regularizer(self, value): ) self._regularizer = value + @property + def optimizer(self): + """Per-variable custom optimizer.""" + return self._optimizer + + @optimizer.setter + def optimizer(self, value): + from keras.src import optimizers + + if isinstance(value, str): + value = optimizers.get(value) + + if value is not None and not isinstance(value, optimizers.Optimizer): + raise ValueError( + "Invalid value for attribute `optimizer`. Expected an " + "instance of `keras.optimizers.Optimizer`, or `None`. " + f"Received: regularizer={value}" + ) + self._optimizer = value + @property def constraint(self): return self._constraint diff --git a/keras/src/backend/common/variables_test.py b/keras/src/backend/common/variables_test.py index fe8e0f48bcf7..6f81b7772573 100644 --- a/keras/src/backend/common/variables_test.py +++ b/keras/src/backend/common/variables_test.py @@ -8,6 +8,7 @@ from keras.src import backend from keras.src import initializers from keras.src import ops +from keras.src import optimizers from keras.src.backend.common import dtypes from keras.src.backend.common.variables import AutocastScope from keras.src.backend.common.variables import shape_equal @@ -419,6 +420,18 @@ def test_deferred_initialize_within_stateless_scope(self): ): v._deferred_initialize() + def test_optimizer_setter(self): + v = backend.Variable( + initializer=initializers.RandomNormal(), + shape=(2, 2), + ) + self.assertIsNone(v.optimizer) + v.optimizer = "sgd" + self.assertTrue(isinstance(v.optimizer, optimizers.Optimizer)) + + with self.assertRaisesRegex(ValueError, "Invalid value"): + v.optimizer = True + class VariableDtypeShapeNdimRepr(test_case.TestCase): """tests for dtype, shape, ndim, __repr__""" diff --git a/keras/src/backend/tensorflow/optimizer.py b/keras/src/backend/tensorflow/optimizer.py index f4497543d6ab..816e8dec8f21 100644 --- a/keras/src/backend/tensorflow/optimizer.py +++ b/keras/src/backend/tensorflow/optimizer.py @@ -71,9 +71,25 @@ def assign_sub(self, variable, value): else: variable.assign_sub(value) - def _var_key(self, variable): + def _convert_to_tf_variable(self, variable): if isinstance(variable, backend.Variable): - variable = variable.value # Convert to tf.Variable + tf_variable = variable.value + # Copy additional properties. + if getattr(variable, "optimizer", None) is not None: + tf_variable.optimizer = variable.optimizer + if getattr(variable, "overwrite_with_gradient", False): + tf_variable.overwrite_with_gradient = True + return tf_variable + elif isinstance(variable, tf.Variable): + return variable + else: + raise ValueError( + f"Variable {variable} must be of type keras.Variable or " + f"tf.Variable, received {value.__class__.__name__}." + ) + + def _var_key(self, variable): + variable = self._convert_to_tf_variable(variable) if hasattr(variable, "_distributed_container"): variable = variable._distributed_container() elif ( @@ -98,8 +114,7 @@ def weight_decay_fn(variable): variable.assign_sub(variable * wd * lr) for variable in variables: - if isinstance(variable, backend.Variable): - variable = variable.value # Convert to tf.Variable + variable = self._convert_to_tf_variable(variable) distribution.extended.update( variable, weight_decay_fn, group=False ) diff --git a/keras/src/optimizers/__init__.py b/keras/src/optimizers/__init__.py index 4db5319793ea..15417d980324 100644 --- a/keras/src/optimizers/__init__.py +++ b/keras/src/optimizers/__init__.py @@ -5,6 +5,7 @@ from keras.src.optimizers.adam import Adam from keras.src.optimizers.adamax import Adamax from keras.src.optimizers.adamw import AdamW +from keras.src.optimizers.dispatch_optimizer import DispatchOptimizer from keras.src.optimizers.ftrl import Ftrl from keras.src.optimizers.lion import Lion from keras.src.optimizers.loss_scale_optimizer import LossScaleOptimizer diff --git a/keras/src/optimizers/base_optimizer.py b/keras/src/optimizers/base_optimizer.py index 57833afadc7e..8c79d891e782 100644 --- a/keras/src/optimizers/base_optimizer.py +++ b/keras/src/optimizers/base_optimizer.py @@ -204,6 +204,13 @@ def iterations(self): def _track_variable(self, variable): self._tracker.add_to_store("variables", variable) + def _has_custom_optimizer(self, variable): + return ( + hasattr(variable, "optimizer") + and variable.optimizer is not None + and variable.optimizer != self + ) + @tracking.no_automatic_dependency_tracking def build(self, variables): if self.use_ema: @@ -211,6 +218,13 @@ def build(self, variables): if self.gradient_accumulation_steps: self._accumulated_gradients = [] for i, variable in enumerate(variables): + if self._has_custom_optimizer(variable): + warnings.warn( + f"Variable {variable} has a custom optimizer " + f"{variable.optimizer} that is being ignored. " + "See `keras.optimizers.DispatchOptimizer` to allow " + "dispatching to the correct per-variable optimizer." + ) self._trainable_variables_indices[self._var_key(variable)] = i if self.use_ema: self._model_variables_moving_average.append( @@ -568,7 +582,7 @@ def stateless_apply(self, optimizer_variables, grads, trainable_variables): ) if len(trainable_variables) != len(self._trainable_variables): raise ValueError( - "Argument `optimizer_variables` must be a list of tensors " + "Argument `trainable_variables` must be a list of tensors " "corresponding 1:1 to the trainable variables list that " "the optimizer was built with. Received " f"len(trainable_variables) == {len(trainable_variables)} " diff --git a/keras/src/optimizers/dispatch_optimizer.py b/keras/src/optimizers/dispatch_optimizer.py new file mode 100644 index 000000000000..3f344c88f166 --- /dev/null +++ b/keras/src/optimizers/dispatch_optimizer.py @@ -0,0 +1,294 @@ +from keras.src import backend +from keras.src.api_export import keras_export +from keras.src import optimizers as keras_optimizers +from keras.src.optimizers import optimizer as base_optimizer +from keras.src.saving import serialization_lib +from keras.src.utils import tracking + + +@keras_export("keras.optimizers.DispatchOptimizer") +class DispatchOptimizer(base_optimizer.Optimizer): + """Dispatches to a per-variable optimizers if applicable. + + Allows per-variable optimizer customizations by examining + each variable's `variable.optimizer` property, and dispatching + to the appropriate underlying optimizer if applicable. If multiple + variables share the same optimizer, the variables are grouped + together for a single dispatched call. Variables without a + `variable.optimizer` are dispatched to the default optimizer. + + Args: + default_optimizer: The `keras.optimizers.Optimizer` instance + to use by default. + {{base_optimizer_keyword_args}} + """ + + def __init__( + self, + default_optimizer="rmsprop", + **kwargs, + ): + if "learning_rate" in kwargs: + raise ValueError( + "DispatchOptimizer does not support a learning rate. " + "Instead, set the `learning_rate` directly on the " + "`default_optimizer`." + ) + super().__init__(learning_rate=0.0, **kwargs) + self._default_optimizer = default_optimizer + self._optimizers = [self._default_optimizer] + self.built = False + + def _has_custom_optimizer(self, variable): + return ( + hasattr(variable, "optimizer") + and variable.optimizer is not None + and variable.optimizer != self # Prevent infinite recursion. + and variable.optimizer != self._default_optimizer + ) + + def _separate_per_optimizer(self, var_list, value_list): + """Separate a list of values into per-optimizer lists. + + Args: + var_list: List of variables to use for determining the optimizer. + value_list: List of values to separate per-optimizer. + + Returns: + Nested lists of variables per optimizer. + """ + lists = [[] for _ in range(len(self._optimizers))] + for var, value in zip(var_list, value_list): + odx = self._variable_to_optimizer_index[self._var_key(var)] + lists[odx].append(value) + + return lists + + @tracking.no_automatic_dependency_tracking + def build(self, var_list): + self._default_optimizer = keras_optimizers.get(self._default_optimizer) + # Extract optimizers and separate variables into groups. + optimizers = [self._default_optimizer] + optimizer_index = {id(self._default_optimizer): 0} + # Map of training_variable -> optimizer, required for apply(). + variable_to_optimizer_index = {} + # Map of training_variable index -> optimizer, required for + # stateless_apply(). + variable_index_to_optimizer_index = [] + + self._optimizers = optimizers + self._trainable_variables = var_list[:] + self._variable_to_optimizer_index = variable_to_optimizer_index + self._variable_index_to_optimizer_index = ( + variable_index_to_optimizer_index + ) + + # First do a pass to check if we even need to dispatch + # to per-variable optimizers. If not, we can just build + # the default optimizer. + needs_dispatch = False + for var in var_list: + if self._has_custom_optimizer(var): + needs_dispatch = True + break + + if needs_dispatch: + for var in var_list: + optimizer_idx = 0 + if self._has_custom_optimizer(var): + optimizer = var.optimizer + optimizer_key = id(optimizer) + optimizer_idx = optimizer_index.get(optimizer_key, None) + if optimizer_idx is None: + optimizer_idx = len(optimizers) + optimizers.append(optimizer) + optimizer_index[optimizer_key] = optimizer_idx + + variable_to_optimizer_index[self._var_key(var)] = optimizer_idx + variable_index_to_optimizer_index.append(optimizer_idx) + + # Build all optimizers. + vars_per_optimizer_lists = self._separate_per_optimizer( + var_list, var_list + ) + for optimizer, optimizer_vars in zip( + optimizers, vars_per_optimizer_lists + ): + optimizer.build(optimizer_vars) + else: + self._default_optimizer.build(var_list) + + # Separate optimizer variables for stateless_call. + # Optimizer variables are simply stacked. See self.variables. + oidx = 0 + optimizer_variable_offsets = [] + for optimizer in optimizers: + optimizer_variable_offsets.append(oidx) + oidx += len(optimizer.variables) + optimizer_variable_offsets.append(oidx) + self._optimizer_variable_offsets = optimizer_variable_offsets + + self.built = True + + def set_weights(self, weights): + raise ValueError( + "DispatchOptimizer does not support adding weights. " + "All weights must be set in the underlying optimizers." + ) + + @property + def variables(self): + if not self.built: + return [] + + if len(self._optimizers) == 1: + return self._default_optimizer.variables + + # Stack all optimizer variables. + variables = [] + for optimizer in self._optimizers: + variables.extend(optimizer.variables) + return variables + + def stateless_apply(self, optimizer_variables, grads, trainable_variables): + if not self.built: + raise ValueError( + f"To call `stateless_apply`, {self.__class__.__name__} " + "must be built (i.e. its variables must have been created). " + "You can build it via `optimizer.build(trainable_variables)`." + ) + if len(optimizer_variables) != self._optimizer_variable_offsets[-1]: + raise ValueError( + "Argument `optimizer_variables` must be a list of tensors " + f"corresponding 1:1 to {self.__class__.__name__}().variables. " + f"Received list with length {len(optimizer_variables)}, but " + f"expected {self._optimizer_variable_offsets[-1]} variables." + ) + if len(self._optimizers) > 1 and len(trainable_variables) != len( + self._variable_index_to_optimizer_index + ): + raise ValueError( + "Argument `trainable_variables` must be a list of tensors " + "corresponding 1:1 to the trainable variables list that " + "the optimizer was built with. Received " + f"len(trainable_variables) == {len(trainable_variables)} " + "whereas the optimizer was built with " + f"{len(self._variable_index_to_optimizer_index)} variables." + ) + + num_optimizers = len(self._optimizers) + if num_optimizers == 1: + return self._default_optimizer.stateless_apply( + optimizer_variables, grads, trainable_variables + ) + + # Separate into per-optimizer lists. + optimizer_params = [] + for i in range(num_optimizers): + optimizer_params.append( + optimizer_variables[ + self._optimizer_variable_offsets[ + i + ] : self._optimizer_variable_offsets[i + 1] + ] + ) + + per_optimizer_grads = [[] for _ in range(num_optimizers)] + per_optimizer_variables = [[] for _ in range(num_optimizers)] + reverse_map = [[] for _ in range(num_optimizers)] + for i in range(len(trainable_variables)): + oidx = self._variable_index_to_optimizer_index[i] + per_optimizer_grads[oidx].append(grads[i]) + per_optimizer_variables[oidx].append(trainable_variables[i]) + reverse_map[oidx].append(i) + + # Apply and update lists. + updated_optimizer_variables = [] + updated_trainable_variables = [None] * len(trainable_variables) + for optimizer, ovars, tgrads, tvars, tidxs in zip( + self._optimizers, + optimizer_params, + per_optimizer_grads, + per_optimizer_variables, + reverse_map, + ): + tvs, ovs = optimizer.stateless_apply(ovars, tgrads, tvars) + updated_optimizer_variables.extend(ovs) + # Scatter training vars into correct positions. + for tv, idx in zip(tvs, tidxs): + updated_trainable_variables[idx] = tv + + return updated_trainable_variables, updated_optimizer_variables + + def apply(self, grads, trainable_variables=None): + # Optionally build optimizer. + if not self.built: + with backend.name_scope(self.name, caller=self): + self.build(trainable_variables) + + if len(self._optimizers) == 1: + return self._default_optimizer.apply(grads, trainable_variables) + + if trainable_variables is None: + params = self._separate_per_optimizer( + self._trainable_variables, grads + ) + for optimizer, grads in zip(self._optimizers, params): + optimizer.apply(grads) + else: + params = self._separate_per_optimizer( + trainable_variables, zip(grads, trainable_variables) + ) + for optimizer, apply_params in zip(self._optimizers, params): + ograds, ovars = zip(*apply_params) + optimizer.apply(ograds, ovars) + + @property + def learning_rate(self): + return self._default_optimizer.learning_rate + + @learning_rate.setter + def learning_rate(self, learning_rate): + self._default_optimizer.learning_rate = learning_rate + + @property + def iterations(self): + # It's possible the default optimizer has no trainable + # variables, so it's iteration count is never incremented. + # Take the maximum of all iteration counts. + return max([optimizer.iterations for optimizer in self._optimizers]) + + def finalize_variable_values(self, var_list): + if self.built: + if len(self._optimizers) == 1: + self._default_optimizer.finalize_variable_values(var_list) + else: + vars_per_optimizer_lists = self._separate_per_optimizer( + var_list, var_list + ) + for optimizer, optimizer_vars in zip( + self._optimizers, vars_per_optimizer_lists + ): + optimizer.finalize_variable_values(optimizer_vars) + + def get_config(self): + config = super().get_config() + config["default_optimizer"] = serialization_lib.serialize_keras_object( + self._default_optimizer + ) + del config["learning_rate"] + return config + + @classmethod + def from_config(cls, config, custom_objects=None): + default_optimizer = serialization_lib.deserialize_keras_object( + config.pop("default_optimizer"), + custom_objects=custom_objects, + ) + return cls(default_optimizer, **config) + + +DispatchOptimizer.__doc__ = DispatchOptimizer.__doc__.replace( + "{{base_optimizer_keyword_args}}", + base_optimizer.base_optimizer_keyword_args, +) diff --git a/keras/src/optimizers/dispatch_optimizer_test.py b/keras/src/optimizers/dispatch_optimizer_test.py new file mode 100644 index 000000000000..2a695aa019c9 --- /dev/null +++ b/keras/src/optimizers/dispatch_optimizer_test.py @@ -0,0 +1,114 @@ +from absl.testing import parameterized + +from keras.src import backend +from keras.src import ops +from keras.src import testing +from keras.src.optimizers.dispatch_optimizer import DispatchOptimizer +from keras.src.optimizers.sgd import SGD + + +class DispatchOptimizerTest(testing.TestCase): + def _skip_test_for_stateless(self, stateless): + if not stateless and backend.backend() == "jax": + self.skipTest( + "DispatchOptimizer must use stateless_apply with JAX." + ) + if stateless and backend.backend() == "tensorflow": + self.skipTest( + "stateless_apply is not supported with the TF backend." + ) + + def test_config(self): + default_optimizer = SGD( + learning_rate=0.5, + momentum=0.06, + nesterov=True, + weight_decay=0.004, + ) + optimizer = DispatchOptimizer(default_optimizer) + self.run_class_serialization_test(optimizer) + + @parameterized.named_parameters(("stateless", True), ("stateful", False)) + def test_step(self, stateless): + self._skip_test_for_stateless(stateless) + + default_optimizer = SGD(learning_rate=0.5) + optimizer = DispatchOptimizer(default_optimizer) + grads = [ops.array([1.0, 6.0, 7.0, 2.0])] + vars = [backend.Variable([1.0, 2.0, 3.0, 4.0])] + if stateless: + optimizer.build(vars) + vars, _ = optimizer.stateless_apply( + optimizer.variables, grads, vars + ) + else: + optimizer.apply(grads, vars) + self.assertAllClose( + vars, [[0.5, -1.0, -0.5, 3.0]], rtol=1e-4, atol=1e-4 + ) + + @parameterized.named_parameters(("stateless", True), ("stateful", False)) + def test_per_variable_optimizers(self, stateless): + self._skip_test_for_stateless(stateless) + + default_optimizer = SGD(learning_rate=0.5) + per_variable_optimizer = SGD(learning_rate=0.1) + + optimizer = DispatchOptimizer(default_optimizer) + grads = [ + ops.array([1.0, 6.0, 7.0, 2.0]), + ops.array([2.0, 12.0, 14.0, 4.0]), + ops.array([3.0, 18.0, 21.0, 6.0]), + ] + vars = [ + backend.Variable([1.0, 2.0, 3.0, 4.0]), + backend.Variable([5.0, 6.0, 7.0, 8.0]), + backend.Variable([9.0, 10.0, 11.0, 12.0]), + ] + # Two variables share the same optimizer. + vars[1].optimizer = per_variable_optimizer + vars[2].optimizer = per_variable_optimizer + + if stateless: + optimizer.build(vars) + vars, _ = optimizer.stateless_apply( + optimizer.variables, grads, vars + ) + else: + optimizer.apply(grads, vars) + + # Verify the per-variable optimizer was used. + self.assertEqual(len(default_optimizer._trainable_variables), 1) + self.assertEqual(len(per_variable_optimizer._trainable_variables), 2) + self.assertAllClose( + vars, + [ + [0.5, -1.0, -0.5, 3.0], + [4.8, 4.8, 5.6, 7.6], + [8.7, 8.2, 8.9, 11.4], + ], + rtol=1e-4, + atol=1e-4, + ) + + @parameterized.named_parameters(("stateless", True), ("stateful", False)) + def test_iterations_update(self, stateless): + self._skip_test_for_stateless(stateless) + + default_optimizer = SGD(learning_rate=0.5) + optimizer = DispatchOptimizer(default_optimizer) + vars = [backend.Variable([1.0, 2.0, 3.0, 4.0])] + optimizer.build(vars) + opt_vars = optimizer.variables + grads = [ops.array([1.0, 6.0, 7.0, 2.0])] + + self.assertEqual(optimizer.iterations.value, 0) + + for i in range(3): + if stateless: + _, opt_vars = optimizer.stateless_apply(opt_vars, grads, vars) + for ref_v, v in zip(optimizer.variables, opt_vars): + ref_v.assign(v) + else: + optimizer.apply(grads, vars) + self.assertEqual(optimizer.iterations.value, i + 1) diff --git a/keras/src/trainers/trainer.py b/keras/src/trainers/trainer.py index a0de303faf86..92e00292e5ce 100644 --- a/keras/src/trainers/trainer.py +++ b/keras/src/trainers/trainer.py @@ -7,6 +7,7 @@ from keras.src import ops from keras.src import optimizers from keras.src import tree +from keras.src.optimizers.dispatch_optimizer import DispatchOptimizer from keras.src.optimizers.loss_scale_optimizer import LossScaleOptimizer from keras.src.saving import serialization_lib from keras.src.trainers.compile_utils import CompileLoss @@ -946,6 +947,39 @@ def get_compile_config(self): if self.compiled and hasattr(self, "_compile_config"): return self._compile_config.serialize() + def _build_optimizer(self): + """Builds the trainer's optimizer. + + If any trainable variables have a custom per-variable optimizer + attribute `variable.optimizer`, this also inserts a dispatcher to handle + them. + """ + # Check if we should support per-variable optimizers. + has_per_variable_optimizers = False + for tv in self.trainable_variables: + if ( + hasattr(tv, "optimizer") + and tv.optimizer is not None + and tv.optimizer != self.optimizer + and not isinstance(tv.optimizer, DispatchOptimizer) + ): + has_per_variable_optimizers = True + break + + if has_per_variable_optimizers: + # To keep proper loss-scaling, we need to insert the + # dispatcher between the LossScaleOptimizer and + # its inner optimizer. + if isinstance(self.optimizer, LossScaleOptimizer): + inner_optimizer = self.optimizer.inner_optimizer + self.optimizer.inner_optimizer = DispatchOptimizer( + inner_optimizer + ) + else: + self.optimizer = DispatchOptimizer(self.optimizer) + + self.optimizer.build(self.trainable_variables) + def compile_from_config(self, config): """Compiles the model with the information given in config. @@ -972,7 +1006,7 @@ def compile_from_config(self, config): self.compile(**config) if hasattr(self, "optimizer") and self.built: # Create optimizer variables. - self.optimizer.build(self.trainable_variables) + self._build_optimizer() def _should_eval(self, epoch, validation_freq): epoch = epoch + 1 # one-index the user-facing epoch. @@ -1120,7 +1154,7 @@ def to_symbolic_input(v): ) if optimizer_unbuilt: # Build optimizer - self.optimizer.build(self.trainable_variables) + self._build_optimizer() self._post_build() diff --git a/keras/src/trainers/trainer_test.py b/keras/src/trainers/trainer_test.py index 3bd8fd5e4ef5..915dfd2246d4 100644 --- a/keras/src/trainers/trainer_test.py +++ b/keras/src/trainers/trainer_test.py @@ -1757,6 +1757,40 @@ def call(self, x): # With autoscaling, the first dense will update. self.assertNotEqual(first_kernel, np.ones_like(first_kernel)) + @pytest.mark.requires_trainable_backend + def test_adds_dispatch_optimizer(self): + # Basic Dense layer with custom optimizer for its variables. + class DenseWithCustomOptimizer(layers.Dense): + def __init__(self, optimizer, **kwargs): + super().__init__(**kwargs) + self.optimizer = optimizer + + def add_weight(self, **kwargs): + w = super().add_weight(**kwargs) + w.optimizer = self.optimizer + return w + + custom_optimizer = optimizers.SGD() + model = keras.Sequential( + [ + layers.Dense(2), + DenseWithCustomOptimizer(custom_optimizer, units=1), + ] + ) + loss = losses.MeanSquaredError() + model.compile(optimizer="adam", loss=loss) + x = np.ones((16, 2)) + y = np.zeros((16, 1)) + model.fit(x, y, batch_size=4) + + # After model compile, the model's optimizer is a dispatcher. + self.assertIsInstance(model.optimizer, optimizers.DispatchOptimizer) + self.assertEqual(len(model.optimizer._optimizers), 2) + # The custom optimizer is shared between the two variables, and was + # used for `apply`. + self.assertEqual(len(custom_optimizer._trainable_variables), 2) + self.assertNotEqual(custom_optimizer.iterations, 0) + @pytest.mark.requires_trainable_backend def test_training_arg(self): model = TrainingTestingLayer()