-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Adding Tensor_layout for Tensor parallelism for Autosharding #21792
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Adding Tensor_layout for Tensor parallelism for Autosharding #21792
Conversation
Summary of ChangesHello @buildwithsuhana, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request establishes foundational components for tensor parallelism within the JAX backend, crucial for Autosharding. It provides core collective communication primitives like Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces foundational components for tensor parallelism in Keras, specifically for the JAX backend. It adds all_reduce and all_gather collective operations, which are essential for distributed computations. Additionally, it provides a split_tensor_for_parallelism utility for sharding tensors across devices. The changes are well-tested, covering both even and uneven tensor splitting. My review includes a few suggestions to improve documentation accuracy and code simplicity, and to align with the repository's style guide regarding docstring examples.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
…keras into tensor_parallel
This change introduces core building blocks for tensor parallelism by adding two key components.
First, it adds crucial collective operations, all_reduce and all_gather, to the JAX backend. These allow multiple devices to synchronize data by summing tensors (like gradients) or gathering individual slices back into a full tensor. Second, it adds the high-level tensor sharding logic (split_tensor_for_parallelism), which uses ops.array_split to intelligently slice large tensors, even unevenly, for distribution across devices. New tests confirm this new parallel logic, including the uneven splitting, works as expected.
The tests on this PR will pass after the PR #21697 gets merged