Skip to content

Commit 6d886d7

Browse files
Features: Add sharding and multi-host support for JAX backend
- Sharding support: Enable distributed arrays across JAX devices - Multi-host support: Coordinate checkpointing across multiple processes - Interoperability: Load sharded checkpoints to unsharded models and vice versa - Error handling: Proper validation and backend-specific restrictions - Comprehensive testing: 11 new test methods covering all scenarios
1 parent 1dee062 commit 6d886d7

File tree

2 files changed

+316
-298
lines changed

2 files changed

+316
-298
lines changed

keras/src/callbacks/orbax_checkpoint.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,15 @@ class OrbaxCheckpoint(MonitorCallback):
189189
190190
model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback])
191191
192+
# Or use a SaveDecisionPolicy for more control -
193+
from orbax.checkpoint import checkpoint_managers
194+
policy = checkpoint_managers.FixedIntervalPolicy(interval=5)
195+
orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint(
196+
directory=checkpoint_dir,
197+
save_decision_policy=policy) # Save every 5 epochs
198+
199+
model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback])
200+
192201
# JAX-specific features: Sharding and Multi-Host Checkpointing
193202
# Note: These features are only available with JAX backend
194203
@@ -336,7 +345,6 @@ def __init__(
336345
"sharding and multi_host parameters are only supported "
337346
"with JAX backend. Current backend: " + backend.backend()
338347
)
339-
340348
self._batches_seen_since_last_saving = 0
341349
self._last_batch_seen = 0
342350
self._current_epoch = 0 # Keep track of epoch
@@ -506,10 +514,7 @@ def _save_checkpoint(self, step, logs=None):
506514
composite_state["data_iterator"] = iterator_state
507515

508516
# --- Save Logic ---
509-
# In multi-host setups, only the primary process (rank 0) initiates the
510-
# save operation. Orbax internally coordinates distributed writing: each
511-
# process writes its own data shards in parallel while the primary
512-
# process manages metadata and coordination.
517+
# Only save on the primary process (rank 0) in distributed setups
513518
is_primary_host = backend.get_process_index() == 0
514519

515520
if is_primary_host:

0 commit comments

Comments
 (0)