@@ -196,6 +196,40 @@ class OrbaxCheckpoint(MonitorCallback):
196196        directory=checkpoint_dir, 
197197        save_decision_policy=policy)  # Save every 5 epochs 
198198
199+     model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback]) 
200+ 
201+     # JAX-specific features: Sharding and Multi-Host Checkpointing 
202+     # Note: These features are only available with JAX backend 
203+ 
204+     # Example with sharding support (JAX only): 
205+     from keras.distribution import DeviceMesh, TensorLayout 
206+     devices = keras.distribution.list_devices() 
207+     device_mesh = DeviceMesh(shape=(len(devices),), axis_names=('x',), 
208+                              devices=devices) 
209+     tensor_layout = TensorLayout(axes=(None,), device_mesh=device_mesh) 
210+     orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint( 
211+         directory=checkpoint_dir, 
212+         sharding=tensor_layout.backend_layout 
213+     )  # Enable sharding for distributed arrays 
214+ 
215+     # Example with multi-host checkpointing (JAX only): 
216+     # Enables distributed checkpointing where each host writes its data shards 
217+     # while the primary process coordinates metadata and finalization 
218+     orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint( 
219+         directory=checkpoint_dir, 
220+         multi_host=True)  # Enable multi-host checkpointing 
221+ 
222+     # Combined sharding and multi-host (JAX only): 
223+     from keras.distribution import DeviceMesh, TensorLayout 
224+     devices = keras.distribution.list_devices() 
225+     device_mesh = DeviceMesh(shape=(len(devices),), axis_names=('x',), 
226+                              devices=devices) 
227+     tensor_layout = TensorLayout(axes=(None,), device_mesh=device_mesh) 
228+     orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint( 
229+         directory=checkpoint_dir, 
230+         sharding=tensor_layout.backend_layout, 
231+         multi_host=True)  # Enable both features 
232+ 
199233    model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback]) 
200234    ``` 
201235
@@ -241,6 +275,16 @@ class OrbaxCheckpoint(MonitorCallback):
241275            overrides the default save frequency logic. Defaults to None. 
242276        save_interval: Integer, save checkpoints every N steps. If provided, 
243277            overrides save_freq. Defaults to None. 
278+         sharding: JAX sharding specification for distributed checkpointing. 
279+             Only supported with JAX backend. If provided with TensorFlow or 
280+             PyTorch backends, will raise an error. Defaults to None. 
281+         multi_host: Boolean, whether to enable multi-host checkpointing for 
282+             distributed training across multiple processes/hosts. When enabled, 
283+             the primary process (rank 0) coordinates the checkpoint operation 
284+             while all processes write their data shards in parallel to create a 
285+             complete distributed checkpoint. Only supported with JAX backend. 
286+             If enabled with TensorFlow or PyTorch backends, will raise an error. 
287+             Defaults to False. 
244288    """ 
245289
246290    def  __init__ (
@@ -265,6 +309,8 @@ def __init__(
265309        save_transforms = None ,
266310        save_decision_policy = None ,
267311        save_interval = None ,
312+         sharding = None ,
313+         multi_host = False ,
268314    ):
269315        # Ensure orbax is available 
270316        ocp .initialize ()
@@ -287,6 +333,18 @@ def __init__(
287333        self .save_transforms  =  save_transforms 
288334        self .save_decision_policy  =  save_decision_policy 
289335        self .save_interval  =  save_interval 
336+ 
337+         # JAX-specific features validation 
338+         self .sharding  =  sharding 
339+         self .multi_host  =  multi_host 
340+ 
341+         # Validate JAX-only features 
342+         if  sharding  is  not None  or  multi_host :
343+             if  backend .backend () !=  "jax" :
344+                 raise  ValueError (
345+                     "sharding and multi_host parameters are only supported " 
346+                     "with JAX backend. Current backend: "  +  backend .backend ()
347+                 )
290348        self ._batches_seen_since_last_saving  =  0 
291349        self ._last_batch_seen  =  0 
292350        self ._current_epoch  =  0   # Keep track of epoch 
@@ -326,6 +384,28 @@ def __init__(
326384            should_save_fn = should_save_fn ,
327385            save_decision_policy = save_decision_policy ,
328386        )
387+ 
388+         # Multi-host setup for JAX 
389+         if  self .multi_host  and  backend .backend () ==  "jax" :
390+             try :
391+                 # Enable multi-host checkpointing using Keras distribution API 
392+                 from  keras .src  import  distribution 
393+ 
394+                 distribution .initialize ()
395+             except  RuntimeError  as  e :
396+                 # If distributed cannot be initialized (e.g., JAX already 
397+                 # initialized), continue anyway - the multi_host flag is mainly 
398+                 # a hint to Orbax 
399+                 if  "must be called before"  in  str (e ):
400+                     pass   # This is expected in test environments 
401+                 else :
402+                     raise 
403+             # Orbax will automatically handle multi-host coordination: 
404+             # - Primary process (rank 0) coordinates and writes 
405+             #   metadata/manifest 
406+             # - All processes write their data shards in parallel to the 
407+             #   checkpoint directory 
408+ 
329409        # Ensure directory exists (only needed on one process in multi-host) 
330410        if  backend .get_process_index () ==  0 :
331411            os .makedirs (directory , exist_ok = True )
@@ -447,6 +527,16 @@ def _save_checkpoint(self, step, logs=None):
447527            save_args  =  ocp .args .StandardSave (
448528                composite_state , save_args = self .save_transforms 
449529            )
530+ 
531+             # Apply sharding if specified (JAX only) 
532+             if  self .sharding  is  not None  and  backend .backend () ==  "jax" :
533+                 # For JAX sharding, we need to ensure the data is properly 
534+                 # sharded 
535+                 # This is typically handled automatically by Orbax when JAX 
536+                 # arrays with sharding metadata are saved 
537+                 if  hasattr (save_args , "sharding" ):
538+                     save_args .sharding  =  self .sharding 
539+ 
450540            self .manager .save (step , args = save_args )
451541
452542    def  on_train_batch_end (self , batch , logs = None ):
@@ -539,8 +629,15 @@ def load_checkpoint(self, step, model=None):
539629            was successful, False otherwise, and iterator_state is the saved 
540630            data iterator state dict if available, None otherwise. 
541631        """ 
542-         # In distributed training, only load on primary process 
543-         if  backend .get_process_index () !=  0 :
632+         # In multi-host distributed training, all processes participate in 
633+         # loading to read their respective data shards in parallel. Only the 
634+         # primary process coordinates the metadata reading and broadcasting. 
635+         if  self .multi_host  and  backend .backend () ==  "jax" :
636+             # Multi-host loading: all processes participate 
637+             pass   # Continue with loading on all processes 
638+         elif  backend .get_process_index () !=  0 :
639+             # Single-host or non-multi-host distributed: only primary 
640+             # process loads 
544641            return  True   # Return True to indicate no error, but no loading 
545642
546643        if  self .verbose  >  0 :
@@ -552,6 +649,13 @@ def load_checkpoint(self, step, model=None):
552649        # template 
553650        restore_args  =  ocp .args .StandardRestore ()
554651
652+         # Apply sharding if specified (JAX only) 
653+         if  self .sharding  is  not None  and  backend .backend () ==  "jax" :
654+             # For JAX sharding, we need to ensure the data is properly restored 
655+             # with the same sharding specification used during save 
656+             if  hasattr (restore_args , "sharding" ):
657+                 restore_args .sharding  =  self .sharding 
658+ 
555659        # Load the checkpoint 
556660        checkpoint_data  =  self .manager .restore (step , args = restore_args )
557661
0 commit comments