-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
InconclusiveDimensionOperation: Symbolic dimension comparison 'b' < '2147483647' is inconclusive. #24730
Comments
- `deepmd/jax/descriptor/__init__.py` imports SeT and DPA-2 to let them found by the plugin; - `deepmd/dpmodel/descriptor/dpa1.py` fixes the jit issue regarding to the shape generated by `jnp.prod`. The shape should be static by using `math.prod`. - `deepmd/jax/model/ener_model.py` and `deepmd/jax/model/dp_zbl_model.py` stop the graident of coordinates when rebuilding the neighbor list. The graient of sort causes an error due to jax-ml/jax#24730. Signed-off-by: Jinzhe Zeng <[email protected]>
Assigning @gnecula who is most familiar with shape polymorphism and TF model exporting. |
- `deepmd/jax/descriptor/__init__.py` imports SeT and DPA-2 to let them found by the plugin; - `deepmd/dpmodel/descriptor/dpa1.py` fixes the jit issue regarding to the shape generated by `jnp.prod`. The shape should be static by using `math.prod`. - `deepmd/jax/model/ener_model.py` and `deepmd/jax/model/dp_zbl_model.py` stop the graident of coordinates when rebuilding the neighbor list. The graient of sort causes an error due to jax-ml/jax#24730. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced new methods `format_nlist` in `DPZBLModel` and `EnergyModel` classes for improved neighbor list formatting. - Added new descriptors `DescrptDPA2` and `DescrptSeTTebd` to the public API. - **Bug Fixes** - Enhanced attribute handling in `DPZBLModel` and `EnergyModel` to ensure proper serialization and deserialization of `atomic_model`. - **Documentation** - Updated the public API to reflect new additions and maintain existing documentation accuracy. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Jinzhe Zeng <[email protected]>
In the immediate term, you can unblock by adding an explicit constraint 'b < '2147483647', as explained in the documentation link from the error message. The issue is that JAX lowering for I will be thinking how to handle this more nicely. E.g., we could always use |
This is probably a reasonable solution. The reason for the shape-dependent dtype was because we were exploring the possibility of getting rid of the X64 flag and making APIs default to 32-bit unless 64-bit is explicitly requested or required – that approach turned out not to be viable, but some vestiges of it (like this one) are still around. |
Description
A simple code to reproduce:
Output:
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: