Skip to content

Conversation

@amitsrivastava78
Copy link
Collaborator

Features:
- Sharding support: Enable distributed arrays across JAX devices (both features are only for jax backend)
- Multi-host support: Coordinate checkpointing across multiple processes
- Interoperability: Load sharded checkpoints to unsharded models and vice versa
- Error handling: Proper validation and backend-specific restrictions
- Comprehensive testing: 5 new test methods covering all scenarios

…re Supports following feature - Asynchronous Checkpointing - Composite Checkpointing - Preservation Policies - Save Decision Policies - Transformations - Custom Handlers
- Remove conditional export decorator to ensure OrbaxCheckpoint is always available
- Remove unnecessary exception handling in state tree operations
- Update process index check comment for clarity
- Format code to comply with 80-character line limit
- Add distribution_lib modules for backend-specific distributed training support
- Remove unused 'result' variable in _reconstruct_state_tree_with_values
- Fix long comment line in test file
- Apply code formatting changes
…st handling

- Implement OrbaxCheckpoint callback for async checkpointing with state tree handling
- Add conditional exports for optional orbax-checkpoint dependency
- Use pytest.importorskip for clean optional dependency testing
- Ensure graceful handling when orbax-checkpoint is not installed
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @amitsrivastava78, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances Keras 3.0's checkpointing functionality by integrating the OrbaxCheckpoint callback. This new callback leverages the Orbax library to provide asynchronous saving, flexible checkpoint retention policies, and critical support for distributed training environments, particularly for JAX. It introduces mechanisms for sharding and multi-host coordination, ensuring that models can be efficiently saved and restored in complex distributed setups while maintaining interoperability between sharded and unsharded model states. The changes also include a unified way to retrieve the process index across different backends, laying a foundation for more robust distributed operations.

Highlights

  • OrbaxCheckpoint Callback: Introduces a new OrbaxCheckpoint callback for Keras 3.0, providing advanced and asynchronous checkpointing capabilities leveraging the Orbax library.
  • Distributed Training Support: Adds sharding and multi-host checkpointing features, specifically designed for the JAX backend, to enable robust distributed training for large-scale models.
  • Checkpoint Interoperability: Allows for flexible loading of checkpoints, supporting the restoration of sharded checkpoints into unsharded models and vice versa, enhancing model deployment and experimentation.
  • Backend Process Indexing: Implements a unified get_process_index() function across all Keras backends (JAX, TensorFlow, Torch, NumPy, OpenVINO) to accurately identify the current process in a distributed setup.
  • Comprehensive Testing: Includes a new dedicated test file with extensive unit and integration tests, covering basic functionality, retention policies, error handling, metadata, data iterator state, and JAX-specific distributed features.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces OrbaxCheckpoint for multi-host and sharded checkpointing, primarily for the JAX backend, and adds a backend-agnostic get_process_index utility. The implementation is well-structured and includes a comprehensive suite of tests. However, I've identified a few issues that need attention. The core sharding functionality in OrbaxCheckpoint appears to be non-functional due to incorrect argument passing to the underlying Orbax library. Additionally, the get_process_index function returns an incorrect type for the TensorFlow backend, and there's a minor robustness concern in the error handling within OrbaxCheckpoint. Addressing these points will significantly improve the reliability of this new feature.

try:
import tensorflow as tf

return tf.distribute.get_replica_context().replica_id_in_sync_group
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The TensorFlow implementation of get_process_index returns a tf.Tensor, but the function's docstring specifies an int return type. This can lead to unexpected behavior, for instance when comparing with an integer. The tensor should be converted to a Python integer to match the documented return type and ensure consistent behavior across backends.

Suggested change
return tf.distribute.get_replica_context().replica_id_in_sync_group
return int(tf.distribute.get_replica_context().replica_id_in_sync_group)

@codecov-commenter
Copy link

codecov-commenter commented Oct 27, 2025

Codecov Report

❌ Patch coverage is 65.98837% with 117 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.54%. Comparing base (47fcb39) to head (c6b3753).
⚠️ Report is 11 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/callbacks/orbax_checkpoint.py 68.70% 61 Missing and 31 partials ⚠️
keras/src/backend/__init__.py 56.52% 8 Missing and 2 partials ⚠️
keras/src/backend/torch/distribution_lib.py 12.50% 7 Missing ⚠️
keras/src/backend/tensorflow/distribution_lib.py 16.66% 5 Missing ⚠️
keras/api/_tf_keras/keras/callbacks/__init__.py 0.00% 1 Missing ⚠️
keras/src/backend/numpy/distribution_lib.py 50.00% 1 Missing ⚠️
keras/src/backend/openvino/distribution_lib.py 50.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21786      +/-   ##
==========================================
- Coverage   82.69%   82.54%   -0.15%     
==========================================
  Files         573      581       +8     
  Lines       58888    59658     +770     
  Branches     9218     9388     +170     
==========================================
+ Hits        48696    49245     +549     
- Misses       7845     7992     +147     
- Partials     2347     2421      +74     
Flag Coverage Δ
keras 82.36% <65.11%> (-0.14%) ⬇️
keras-jax 63.31% <60.75%> (+0.06%) ⬆️
keras-numpy 57.30% <13.37%> (-0.42%) ⬇️
keras-openvino 34.16% <13.08%> (-0.24%) ⬇️
keras-tensorflow 64.08% <58.13%> (+0.06%) ⬆️
keras-torch 63.59% <57.55%> (+0.03%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@amitsrivastava78 amitsrivastava78 force-pushed the multi-host-sharding-support branch from 248a35a to 6d886d7 Compare October 27, 2025 05:56
- Sharding support: Enable distributed arrays across JAX devices
- Multi-host support: Coordinate checkpointing across multiple processes
- Interoperability: Load sharded checkpoints to unsharded models and vice versa
- Error handling: Proper validation and backend-specific restrictions
@amitsrivastava78 amitsrivastava78 force-pushed the multi-host-sharding-support branch from 6d886d7 to ece595d Compare October 27, 2025 06:01
- Fix sharding parameter passing in save/restore operations by passing as kwargs instead of setting attributes on StandardSave/StandardRestore objects
- Add robust error handling for distribution initialization with multiple error message patterns
- Add proper test skipping for JAX-only features when distribution module unavailable
- Add sharding parameter validation in constructor to prevent invalid types
- Update test expectations to match corrected sharding validation behavior

These changes ensure proper sharding support for JAX multi-host checkpointing while maintaining backward compatibility.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants