Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add access to native mesh and layout distribution objects. #20897

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

hertschuh
Copy link
Collaborator

  • Added backend_mesh property to keras.distribution.DeviceMesh to access the native mesh object.
  • Added backend_layout property to keras.distribution.TensorLayout to access the native layout or sharding object.

The values are cached. Changed the code to access these directly instead of calling the convertion functions every time.

Made the following renames so that these functions can be used in backend agnostic code:

  • _to_jax_device to _to_backend_device
  • _to_jax_mesh and _to_dtensor_mesh to _to_backend_mesh
  • _to_jax_layout and _to_dtensor_layout to _to_backend_layout

- Added `backend_mesh` property to `keras.distribution.DeviceMesh` to access the native mesh object.
- Added `backend_layout` property to `keras.distribution.TensorLayout` to access the native layout or sharding object.

The values are cached. Changed the code to access these directly instead of calling the convertion functions every time.

Made the following renames so that these functions can be used in backend agnostic code:
- `_to_jax_device` to `_to_backend_device`
- `_to_jax_mesh` and `_to_dtensor_mesh` to `_to_backend_mesh`
- `_to_jax_layout` and `_to_dtensor_layout` to `_to_backend_layout`
@codecov-commenter
Copy link

codecov-commenter commented Feb 13, 2025

Codecov Report

Attention: Patch coverage is 86.95652% with 3 lines in your changes missing coverage. Please review.

Project coverage is 82.26%. Comparing base (b5261a0) to head (74dda3f).

Files with missing lines Patch % Lines
keras/src/backend/jax/distribution_lib.py 75.00% 2 Missing ⚠️
keras/src/backend/tensorflow/distribution_lib.py 66.66% 1 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##           master   #20897   +/-   ##
=======================================
  Coverage   82.25%   82.26%           
=======================================
  Files         561      561           
  Lines       52680    52690   +10     
  Branches     8144     8146    +2     
=======================================
+ Hits        43334    43344   +10     
  Misses       7342     7342           
  Partials     2004     2004           
Flag Coverage Δ
keras 82.07% <86.95%> (+<0.01%) ⬆️
keras-jax 64.20% <86.95%> (+<0.01%) ⬆️
keras-numpy 58.99% <39.13%> (-0.01%) ⬇️
keras-openvino 32.55% <39.13%> (+<0.01%) ⬆️
keras-tensorflow 64.81% <39.13%> (-0.01%) ⬇️
keras-torch 64.25% <39.13%> (-0.01%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Why not mesh and layout? (not asking for a change, asking for clarification)

@hertschuh
Copy link
Collaborator Author

Thanks for the PR! Why not mesh and layout? (not asking for a change, asking for clarification)

First, I'm not a big fan of the names backend_mesh and backend_layout.

But when for instance you have a mesh called mesh, you'd have to write mesh.mesh, which is weird, and it's unclear how mesh is different from mesh.mesh (one is the keras mesh, the other one is the backend native mesh).

A common use case is when you want the native mesh of a layout, you'd write layout.mesh.mesh, or alternatively layout.layout.mesh. Both look weird...

I'm open to name suggestions:

  • mesh.mesh
  • mesh.backend_mesh
  • mesh.to_backend()
  • mesh.to_backend_mesh()
  • mesh.as_backend_mesh()
  • mesh.native_mesh
  • mesh.to_native()
  • mesh.to_native_mesh()
  • mesh.as_native_mesh()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants