-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Disable 64 bit codegen to HLO for Neuron backend #24682
Comments
I don't think there's any global place where you can do this check in JAX (maybe @hawkinsp can correct me if I'm wrong), but I expect that a JAX version of this would require registering custom lowering rules for at least some primitives (for example, this is where the data types are chosen for the |
@dfm Yes, I have explored the pass option. I can write a pass to handle any intermediate tensors in FP64/S64/U64, but I do not think the pass can handle changing the types of graph inputs. I would prefer I fix it in the Jax->StableHlo codegen, as we did for torch-xla. I mainly encounter 64 bit values in the Jax parameter init graphs. I do not encounter them in train step graphs. |
That makes sense! In that case, I would point you in the direction of individual lowering rules for JAX primitives. As a side note, by default JAX typically doesn't produce many 64 bit types unless the |
Indeed, the equivalent big hammer in the JAX world is (We don't much love that config option, but it's very useful on hardware that doesn't have or want 64-bit types.) |
I have seen that config flag. I cannot be sure - but it did not always work for me in Jax 0.4.21, as I still saw 64 bit in parameter init HLO graphs. I have not tested on newer Jax versions. |
If you have an example of where that flag didn't work, I'd be interested to look at it. One thing to note is that it doesn't work to set that flag locally at the moment: it must be true for the entire program. |
@hawkinsp that is very interesting. I was setting it in my trainer file, but my codebase is probably 100 files. I will try with the latest jax. |
Hi, I want to make a PR to disable any codegen to FP64 or S64 or U64 if jax_backend=="Neuron". Neuron hardware does not support 64 bit types, so I want to codegen any 64 bit types in Jax to 32 bit types for Neuron backend only.
We made the same PR to torch-xla - pytorch/xla@7c7ad4e
My question: Where in the Jax library is the HLO type chosen in the codegen? I tried searching but cannot find it.
The text was updated successfully, but these errors were encountered: