@@ -204,6 +204,9 @@ def iterations(self):
204
204
def _track_variable (self , variable ):
205
205
self ._tracker .add_to_store ("variables" , variable )
206
206
207
+ def _get_variable_updater (self , variable ):
208
+ return getattr (variable , "updater" , None )
209
+
207
210
@tracking .no_automatic_dependency_tracking
208
211
def build (self , variables ):
209
212
if self .use_ema :
@@ -212,6 +215,11 @@ def build(self, variables):
212
215
self ._accumulated_gradients = []
213
216
for i , variable in enumerate (variables ):
214
217
self ._trainable_variables_indices [self ._var_key (variable )] = i
218
+ custom_updater = self ._get_variable_updater (variable )
219
+ if custom_updater is not None :
220
+ # Build the updater.
221
+ custom_updater .build (self , variable )
222
+
215
223
if self .use_ema :
216
224
self ._model_variables_moving_average .append (
217
225
self .add_variable_from_reference (
@@ -431,10 +439,8 @@ def apply(self, grads, trainable_variables=None):
431
439
432
440
# Overwrite targeted variables directly with their gradients if
433
441
# their `overwrite_with_gradient` is set.
434
- grads , trainable_variables = (
435
- self ._overwrite_variables_directly_with_gradients (
436
- grads , trainable_variables
437
- )
442
+ grads , trainable_variables = self .__handle_custom_updaters (
443
+ grads , trainable_variables
438
444
)
439
445
440
446
if len (list (grads )) == 0 :
@@ -698,21 +704,14 @@ def _get_current_learning_rate(self):
698
704
return self ._learning_rate ()
699
705
return self ._learning_rate
700
706
701
- def _overwrite_variables_directly_with_gradients (self , grads , vars ):
702
- """Overwrite the variables directly by their gradients.
703
-
704
- This method is designed for a special case where we want to overwrite
705
- the variable directly with its computed gradient. For example, in float8
706
- training, new `scale` and `amax_history` are computed as gradients, and
707
- we want to overwrite them directly instead of following the typical
708
- procedure such as gradient descent with a learning rate, gradient
709
- clipping and weight decaying.
707
+ def __handle_custom_updaters (self , grads , vars ):
708
+ """Update any variable that has a custom updater.
710
709
711
710
After the update, the processed pairs will be filtered out.
712
711
"""
713
712
# Shortcut for `tf.Variable` because it doesn't have a
714
- # `overwrite_with_gradient ` attr
715
- if any ( not hasattr ( v , "overwrite_with_gradient" ) for v in vars ):
713
+ # `updater ` attr.
714
+ if not any ( self . _get_variable_updater ( v ) is not None for v in vars ):
716
715
return grads , vars
717
716
718
717
# Shallow copies
@@ -722,33 +721,8 @@ def _overwrite_variables_directly_with_gradients(self, grads, vars):
722
721
# Iterate from right to left for safe popping
723
722
for i in range (len (filtered_grads ) - 1 , - 1 , - 1 ):
724
723
g , v = filtered_grads [i ], filtered_vars [i ]
725
- if v .overwrite_with_gradient :
726
- if self .gradient_accumulation_steps :
727
- # Utilize a stateless manner for JAX compatibility
728
- steps = self .gradient_accumulation_steps
729
- is_update_step = (self ._iterations + 1 ) % steps == 0
730
- acc_g = self ._accumulated_gradients [
731
- self ._get_variable_index (v )
732
- ]
733
- # `ops.maximum` is utilized for gradient accumulation for
734
- # `overwrite_with_gradient=True` variables
735
- new_g_acc = ops .cond (
736
- is_update_step ,
737
- lambda : ops .zeros (g .shape , dtype = g .dtype ),
738
- lambda : ops .maximum (g , acc_g ),
739
- )
740
- new_g = ops .cond (
741
- is_update_step ,
742
- lambda : ops .maximum (g , acc_g ),
743
- lambda : g ,
744
- )
745
- new_v = ops .cond (
746
- is_update_step , lambda : new_g , lambda : v .value
747
- )
748
- v .assign (new_v )
749
- acc_g .assign (new_g_acc )
750
- else :
751
- v .assign (g )
724
+ if v .updater :
725
+ v .updater .update_step (g , v )
752
726
filtered_grads .pop (i )
753
727
filtered_vars .pop (i )
754
728
return filtered_grads , filtered_vars
@@ -926,6 +900,11 @@ def finalize_variable_values(self, var_list):
926
900
# optimizer.
927
901
self ._overwrite_model_variables_with_average_value (var_list )
928
902
903
+ for var in var_list :
904
+ updater = self ._get_variable_updater (var )
905
+ if updater is not None :
906
+ updater .finalize_variable_value (var )
907
+
929
908
def _obj_type (self ):
930
909
return "Optimizer"
931
910
0 commit comments