Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable 64 bit codegen to HLO for Neuron backend #24682

Open
ptoulme-aws opened this issue Nov 4, 2024 · 7 comments
Open

Disable 64 bit codegen to HLO for Neuron backend #24682

ptoulme-aws opened this issue Nov 4, 2024 · 7 comments
Labels
enhancement New feature or request

Comments

@ptoulme-aws
Copy link

ptoulme-aws commented Nov 4, 2024

Hi, I want to make a PR to disable any codegen to FP64 or S64 or U64 if jax_backend=="Neuron". Neuron hardware does not support 64 bit types, so I want to codegen any 64 bit types in Jax to 32 bit types for Neuron backend only.

We made the same PR to torch-xla - pytorch/xla@7c7ad4e

My question: Where in the Jax library is the HLO type chosen in the codegen? I tried searching but cannot find it.

@ptoulme-aws ptoulme-aws added the enhancement New feature or request label Nov 4, 2024
@dfm
Copy link
Collaborator

dfm commented Nov 4, 2024

I don't think there's any global place where you can do this check in JAX (maybe @hawkinsp can correct me if I'm wrong), but I expect that a JAX version of this would require registering custom lowering rules for at least some primitives (for example, this is where the data types are chosen for the lax.dot_general operation). It's possible that it would be possible to add an XLA pass for this, but I'm less familiar with where exactly that would sit. If @hawkinsp doesn't have better suggestions, it might be worth opening a similar issue on https://github.com/openxla/xla. What do you think about that?

@ptoulme-aws
Copy link
Author

@dfm Yes, I have explored the pass option. I can write a pass to handle any intermediate tensors in FP64/S64/U64, but I do not think the pass can handle changing the types of graph inputs. I would prefer I fix it in the Jax->StableHlo codegen, as we did for torch-xla.

I mainly encounter 64 bit values in the Jax parameter init graphs. I do not encounter them in train step graphs.

@dfm
Copy link
Collaborator

dfm commented Nov 6, 2024

I would prefer I fix it in the Jax->StableHlo codegen, as we did for torch-xla.

That makes sense! In that case, I would point you in the direction of individual lowering rules for JAX primitives.

As a side note, by default JAX typically doesn't produce many 64 bit types unless the enable_x64 config flag is set, so it's possible that you would get most of that way by ensuring that enable_x64 is set to false, and then fixing any primitives that still emit higher precision types (I could believe that there are still some, although I don't know which ones off the top of my head!).

@hawkinsp
Copy link
Collaborator

hawkinsp commented Nov 6, 2024

Indeed, the equivalent big hammer in the JAX world is enable_x64, and that disables the generation of all 64-bit types. Does that solve your problem for now?

(We don't much love that config option, but it's very useful on hardware that doesn't have or want 64-bit types.)

@ptoulme-aws
Copy link
Author

I have seen that config flag. I cannot be sure - but it did not always work for me in Jax 0.4.21, as I still saw 64 bit in parameter init HLO graphs. I have not tested on newer Jax versions.
I will retest the flag.

@hawkinsp
Copy link
Collaborator

hawkinsp commented Nov 7, 2024

If you have an example of where that flag didn't work, I'd be interested to look at it.

One thing to note is that it doesn't work to set that flag locally at the moment: it must be true for the entire program.

@ptoulme-aws
Copy link
Author

@hawkinsp that is very interesting. I was setting it in my trainer file, but my codebase is probably 100 files.
I also tried using the env var approach and I still saw 64 bit generated.

I will try with the latest jax.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants