Skip to content

Commit ece595d

Browse files
Features: Add sharding and multi-host support for JAX backend
- Sharding support: Enable distributed arrays across JAX devices - Multi-host support: Coordinate checkpointing across multiple processes - Interoperability: Load sharded checkpoints to unsharded models and vice versa - Error handling: Proper validation and backend-specific restrictions
1 parent 19d2495 commit ece595d

File tree

2 files changed

+635
-2
lines changed

2 files changed

+635
-2
lines changed

keras/src/callbacks/orbax_checkpoint.py

Lines changed: 106 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,40 @@ class OrbaxCheckpoint(MonitorCallback):
196196
directory=checkpoint_dir,
197197
save_decision_policy=policy) # Save every 5 epochs
198198
199+
model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback])
200+
201+
# JAX-specific features: Sharding and Multi-Host Checkpointing
202+
# Note: These features are only available with JAX backend
203+
204+
# Example with sharding support (JAX only):
205+
from keras.distribution import DeviceMesh, TensorLayout
206+
devices = keras.distribution.list_devices()
207+
device_mesh = DeviceMesh(shape=(len(devices),), axis_names=('x',),
208+
devices=devices)
209+
tensor_layout = TensorLayout(axes=(None,), device_mesh=device_mesh)
210+
orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint(
211+
directory=checkpoint_dir,
212+
sharding=tensor_layout.backend_layout
213+
) # Enable sharding for distributed arrays
214+
215+
# Example with multi-host checkpointing (JAX only):
216+
# Enables distributed checkpointing where each host writes its data shards
217+
# while the primary process coordinates metadata and finalization
218+
orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint(
219+
directory=checkpoint_dir,
220+
multi_host=True) # Enable multi-host checkpointing
221+
222+
# Combined sharding and multi-host (JAX only):
223+
from keras.distribution import DeviceMesh, TensorLayout
224+
devices = keras.distribution.list_devices()
225+
device_mesh = DeviceMesh(shape=(len(devices),), axis_names=('x',),
226+
devices=devices)
227+
tensor_layout = TensorLayout(axes=(None,), device_mesh=device_mesh)
228+
orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint(
229+
directory=checkpoint_dir,
230+
sharding=tensor_layout.backend_layout,
231+
multi_host=True) # Enable both features
232+
199233
model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback])
200234
```
201235
@@ -241,6 +275,16 @@ class OrbaxCheckpoint(MonitorCallback):
241275
overrides the default save frequency logic. Defaults to None.
242276
save_interval: Integer, save checkpoints every N steps. If provided,
243277
overrides save_freq. Defaults to None.
278+
sharding: JAX sharding specification for distributed checkpointing.
279+
Only supported with JAX backend. If provided with TensorFlow or
280+
PyTorch backends, will raise an error. Defaults to None.
281+
multi_host: Boolean, whether to enable multi-host checkpointing for
282+
distributed training across multiple processes/hosts. When enabled,
283+
the primary process (rank 0) coordinates the checkpoint operation
284+
while all processes write their data shards in parallel to create a
285+
complete distributed checkpoint. Only supported with JAX backend.
286+
If enabled with TensorFlow or PyTorch backends, will raise an error.
287+
Defaults to False.
244288
"""
245289

246290
def __init__(
@@ -265,6 +309,8 @@ def __init__(
265309
save_transforms=None,
266310
save_decision_policy=None,
267311
save_interval=None,
312+
sharding=None,
313+
multi_host=False,
268314
):
269315
# Ensure orbax is available
270316
ocp.initialize()
@@ -287,6 +333,18 @@ def __init__(
287333
self.save_transforms = save_transforms
288334
self.save_decision_policy = save_decision_policy
289335
self.save_interval = save_interval
336+
337+
# JAX-specific features validation
338+
self.sharding = sharding
339+
self.multi_host = multi_host
340+
341+
# Validate JAX-only features
342+
if sharding is not None or multi_host:
343+
if backend.backend() != "jax":
344+
raise ValueError(
345+
"sharding and multi_host parameters are only supported "
346+
"with JAX backend. Current backend: " + backend.backend()
347+
)
290348
self._batches_seen_since_last_saving = 0
291349
self._last_batch_seen = 0
292350
self._current_epoch = 0 # Keep track of epoch
@@ -326,6 +384,28 @@ def __init__(
326384
should_save_fn=should_save_fn,
327385
save_decision_policy=save_decision_policy,
328386
)
387+
388+
# Multi-host setup for JAX
389+
if self.multi_host and backend.backend() == "jax":
390+
try:
391+
# Enable multi-host checkpointing using Keras distribution API
392+
from keras.src import distribution
393+
394+
distribution.initialize()
395+
except RuntimeError as e:
396+
# If distributed cannot be initialized (e.g., JAX already
397+
# initialized), continue anyway - the multi_host flag is mainly
398+
# a hint to Orbax
399+
if "must be called before" in str(e):
400+
pass # This is expected in test environments
401+
else:
402+
raise
403+
# Orbax will automatically handle multi-host coordination:
404+
# - Primary process (rank 0) coordinates and writes
405+
# metadata/manifest
406+
# - All processes write their data shards in parallel to the
407+
# checkpoint directory
408+
329409
# Ensure directory exists (only needed on one process in multi-host)
330410
if backend.get_process_index() == 0:
331411
os.makedirs(directory, exist_ok=True)
@@ -447,6 +527,16 @@ def _save_checkpoint(self, step, logs=None):
447527
save_args = ocp.args.StandardSave(
448528
composite_state, save_args=self.save_transforms
449529
)
530+
531+
# Apply sharding if specified (JAX only)
532+
if self.sharding is not None and backend.backend() == "jax":
533+
# For JAX sharding, we need to ensure the data is properly
534+
# sharded
535+
# This is typically handled automatically by Orbax when JAX
536+
# arrays with sharding metadata are saved
537+
if hasattr(save_args, "sharding"):
538+
save_args.sharding = self.sharding
539+
450540
self.manager.save(step, args=save_args)
451541

452542
def on_train_batch_end(self, batch, logs=None):
@@ -539,8 +629,15 @@ def load_checkpoint(self, step, model=None):
539629
was successful, False otherwise, and iterator_state is the saved
540630
data iterator state dict if available, None otherwise.
541631
"""
542-
# In distributed training, only load on primary process
543-
if backend.get_process_index() != 0:
632+
# In multi-host distributed training, all processes participate in
633+
# loading to read their respective data shards in parallel. Only the
634+
# primary process coordinates the metadata reading and broadcasting.
635+
if self.multi_host and backend.backend() == "jax":
636+
# Multi-host loading: all processes participate
637+
pass # Continue with loading on all processes
638+
elif backend.get_process_index() != 0:
639+
# Single-host or non-multi-host distributed: only primary
640+
# process loads
544641
return True # Return True to indicate no error, but no loading
545642

546643
if self.verbose > 0:
@@ -552,6 +649,13 @@ def load_checkpoint(self, step, model=None):
552649
# template
553650
restore_args = ocp.args.StandardRestore()
554651

652+
# Apply sharding if specified (JAX only)
653+
if self.sharding is not None and backend.backend() == "jax":
654+
# For JAX sharding, we need to ensure the data is properly restored
655+
# with the same sharding specification used during save
656+
if hasattr(restore_args, "sharding"):
657+
restore_args.sharding = self.sharding
658+
555659
# Load the checkpoint
556660
checkpoint_data = self.manager.restore(step, args=restore_args)
557661

0 commit comments

Comments
 (0)