-
Notifications
You must be signed in to change notification settings - Fork 12
api_gen
now excludes backend specific code.
#103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This allows not exploring some directories for the purpose of finding symbols. Reason: - keras-team/keras#21321 - keras-team/keras-rs#103
This allows not exploring some directories for the purpose of finding symbols. Reason: - keras-team/keras#21321 - keras-team/keras-rs#103
In draft for now, requires fchollet/namex#7 |
This: - Allows development (`api_gen` / git presubmit hooks) without all backends and backend specific dependencies installed and working. For instance, jax_tpu_embedding currently doesn't import on MacOS Sequoia, this allows running `api_gen` regardless. - Makes sure we don't accidentally create and honor exports that are backend specific.
669d203
to
1287f4e
Compare
"keras_rs", | ||
code_directory="src", | ||
exclude_directories=[ | ||
os.path.join("src", "layers", "embedding", "jax"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we create a backend/
subfolder somewhere and move these files there? Might make it more explicit, and easier to exclude future backend-specific code.
I'm okay with it either way - just thought I'd make the suggestion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I was thinking about this option too. On the one hand, it will be cleaner and more consistent with Keras, especially if we end up having more backend dependent code (but I'm hoping we won't). On the other hand, it moves the DistributedEmbedding
subclasses far away from the location of the super class, which is not intuitive and harder to navigate.
In the case of keras, it doesn't actually make it easier to exclude backends: https://github.com/keras-team/keras/pull/21321/files because we need to keep the backend
folder.
* Ignore shard_map attr error in mypy. (#97) * Added TF specific documentation to `DistributedEmbedding`. (#94) * Fix symbolic calls for `EmbedReduce`. (#98) `EmbedReduce` was inheriting the behavior from `Embedding` and not correctly applying the reduction. * Move `DistributedEmbedding` declaration to its own file. (#99) Having it in `__init__.py` doesn't play nice with pytype. * Remove dependency on `tree` and use `keras.tree`. (#100) Keras can already depend on either `dmtree` or `optree` and use whichever is best or available on the current platform. * Only enable JAX on linux_x86_64. (#101) * Add out_sharding argument WrappedKerasInitializer. (#102) This is for forward-compatibility. Latest versions of JAX introduce the `out_sharding` argument. * Use Python 3.10 style type annotations. (#104) Now that we require Python 3.10, we can use the shorter annotation style, which should improve the readability of the documentation. * Do not bundle test utils in wheel. (#105) * Update version number to 0.2.1 (#106) As 0.2.0 was just released. * Fix invalid escape sequence in unit test. (#108) * Replace leftover `unflatten_as` to `pack_sequence_as`. (#109) This instance was missed as it is only run on TPU. * Make the declaration of `Nested` compatible with pytype. (#110) Which doesn't support `|` between forward declarations using a string. * Add ragged support for default_device placement on JAX. (#107) Requires calling `preprocess`. Internally, we currently convert ragged inputs to dense before passing to the embedding call(...) function. * Add documentation for using DistributedEmbedding with JAX. (#111) * `api_gen` now excludes backend specific code. (#103) This: - Allows development (`api_gen` / git presubmit hooks) without all backends and backend specific dependencies installed and working. For instance, jax_tpu_embedding currently doesn't import on MacOS Sequoia, this allows running `api_gen` regardless. - Makes sure we don't accidentally create and honor exports that are backend specific. * Enable preprocess calls with symbolic input tensors. (#113) This allows us to more-easily create functional models via: ```python preprocessed_inputs = distributed_embedding.preprocess(symbolic_inputs, symbolic_weights) outputs = distributed_embedding(preprocessed_inputs) model = keras.Model(inputs=preprocessed_inputs, outputs=outputs) ``` * Check for jax_tpu_embedding on JAX backend. (#114) This is to allow users to potentially run Keras RS _without_ the dependency. If a user doesn't have `jax-tpu-embedding` installed, but are on `linux_x86_64` and has a sparsecore-capable TPU available, and if they try to use `auto` or `sparsecore` placement with distributed embedding, will raise an error informing them to install the dependency. --------- Co-authored-by: C. Antonio Sánchez <[email protected]> Co-authored-by: hertschuh <[email protected]>
This:
api_gen
/ git presubmit hooks) without all backends and backend specific dependencies installed and working. For instance, jax_tpu_embedding currently doesn't import on MacOS Sequoia, this allows runningapi_gen
regardless.