@@ -189,6 +189,15 @@ class OrbaxCheckpoint(MonitorCallback):
189189
190190 model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback])
191191
192+ # Or use a SaveDecisionPolicy for more control -
193+ from orbax.checkpoint import checkpoint_managers
194+ policy = checkpoint_managers.FixedIntervalPolicy(interval=5)
195+ orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint(
196+ directory=checkpoint_dir,
197+ save_decision_policy=policy) # Save every 5 epochs
198+
199+ model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback])
200+
192201 # JAX-specific features: Sharding and Multi-Host Checkpointing
193202 # Note: These features are only available with JAX backend
194203
@@ -336,7 +345,6 @@ def __init__(
336345 "sharding and multi_host parameters are only supported "
337346 "with JAX backend. Current backend: " + backend .backend ()
338347 )
339-
340348 self ._batches_seen_since_last_saving = 0
341349 self ._last_batch_seen = 0
342350 self ._current_epoch = 0 # Keep track of epoch
@@ -506,10 +514,7 @@ def _save_checkpoint(self, step, logs=None):
506514 composite_state ["data_iterator" ] = iterator_state
507515
508516 # --- Save Logic ---
509- # In multi-host setups, only the primary process (rank 0) initiates the
510- # save operation. Orbax internally coordinates distributed writing: each
511- # process writes its own data shards in parallel while the primary
512- # process manages metadata and coordination.
517+ # Only save on the primary process (rank 0) in distributed setups
513518 is_primary_host = backend .get_process_index () == 0
514519
515520 if is_primary_host :
0 commit comments