-
Notifications
You must be signed in to change notification settings - Fork 383
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MrVI slowdown due to JAX compilation update #3179
Comments
@justjhong lmk if you want me to upper bound jax for now and to which version. |
Hi @ori-kron-wis, thanks for checking. I took some time this morning to try to debug it but was not able to find a solution. |
For now pinning jax<0.4.36. Potentially related to jax-ml/jax#26162. Check again when this is adressed. Leaving this open as pinning circumvents it but might create issues in the near future. |
@ori-kron-wis can you try with the fix suggested in jax-ml/jax#26162? We can also wait for the next jax release that contains it. |
The fix is about the XLA compiler used for JAX (i.e jaxlib), we can't just implement it out of the box. It is not merged yet AFAIK. I did try Jax nightly release (consists of jaxlib nightly), but the issue remains. We need to wait for the next version that contains it unfortunately. |
With recent updates to JAX, MrVI trains significantly slower than before. We suspect it is due to the new AOT compilation strategy (https://jax.readthedocs.io/en/latest/aot.html).
Any basic training with MrVI with a fresh install. Reproduced by @PierreBoyeau and myself.
The text was updated successfully, but these errors were encountered: