- 
                Notifications
    You must be signed in to change notification settings 
- Fork 19.6k
Multi host & sharding support for keras 3.0 #21786
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Multi host & sharding support for keras 3.0 #21786
Conversation
         amitsrivastava78
  
      
      
      commented
      
            amitsrivastava78
  
      
      
      commented
        Oct 27, 2025 
      
    
  
…re Supports following feature - Asynchronous Checkpointing - Composite Checkpointing - Preservation Policies - Save Decision Policies - Transformations - Custom Handlers
- Remove conditional export decorator to ensure OrbaxCheckpoint is always available - Remove unnecessary exception handling in state tree operations - Update process index check comment for clarity - Format code to comply with 80-character line limit - Add distribution_lib modules for backend-specific distributed training support
- Remove unused 'result' variable in _reconstruct_state_tree_with_values - Fix long comment line in test file - Apply code formatting changes
…st handling - Implement OrbaxCheckpoint callback for async checkpointing with state tree handling - Add conditional exports for optional orbax-checkpoint dependency - Use pytest.importorskip for clean optional dependency testing - Ensure graceful handling when orbax-checkpoint is not installed
| Summary of ChangesHello @amitsrivastava78, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances Keras 3.0's checkpointing functionality by integrating the  Highlights
 Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either  
 Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a  Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
 | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces OrbaxCheckpoint for multi-host and sharded checkpointing, primarily for the JAX backend, and adds a backend-agnostic get_process_index utility. The implementation is well-structured and includes a comprehensive suite of tests. However, I've identified a few issues that need attention. The core sharding functionality in OrbaxCheckpoint appears to be non-functional due to incorrect argument passing to the underlying Orbax library. Additionally, the get_process_index function returns an incorrect type for the TensorFlow backend, and there's a minor robustness concern in the error handling within OrbaxCheckpoint. Addressing these points will significantly improve the reliability of this new feature.
| try: | ||
| import tensorflow as tf | ||
|  | ||
| return tf.distribute.get_replica_context().replica_id_in_sync_group | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The TensorFlow implementation of get_process_index returns a tf.Tensor, but the function's docstring specifies an int return type. This can lead to unexpected behavior, for instance when comparing with an integer. The tensor should be converted to a Python integer to match the documented return type and ensure consistent behavior across backends.
| return tf.distribute.get_replica_context().replica_id_in_sync_group | |
| return int(tf.distribute.get_replica_context().replica_id_in_sync_group) | 
| Codecov Report❌ Patch coverage is  Additional details and impacted files@@            Coverage Diff             @@
##           master   #21786      +/-   ##
==========================================
- Coverage   82.69%   82.54%   -0.15%     
==========================================
  Files         573      581       +8     
  Lines       58888    59658     +770     
  Branches     9218     9388     +170     
==========================================
+ Hits        48696    49245     +549     
- Misses       7845     7992     +147     
- Partials     2347     2421      +74     
 Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
 | 
248a35a    to
    6d886d7      
    Compare
  
    - Sharding support: Enable distributed arrays across JAX devices - Multi-host support: Coordinate checkpointing across multiple processes - Interoperability: Load sharded checkpoints to unsharded models and vice versa - Error handling: Proper validation and backend-specific restrictions
6d886d7    to
    ece595d      
    Compare
  
    - Fix sharding parameter passing in save/restore operations by passing as kwargs instead of setting attributes on StandardSave/StandardRestore objects - Add robust error handling for distribution initialization with multiple error message patterns - Add proper test skipping for JAX-only features when distribution module unavailable - Add sharding parameter validation in constructor to prevent invalid types - Update test expectations to match corrected sharding validation behavior These changes ensure proper sharding support for JAX multi-host checkpointing while maintaining backward compatibility.
7caff51    to
    c6b3753      
    Compare