Skip to content

Commit

Permalink
Merge pull request #263 from ami-iit/feature/tpu_simulation
Browse files Browse the repository at this point in the history
Disable host callbacks and enforce 32bit precision when running on TPU
  • Loading branch information
flferretti authored Oct 11, 2024
2 parents b804caf + 779d37b commit 3ac2978
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 6 deletions.
26 changes: 20 additions & 6 deletions src/jaxsim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,35 @@ def _jnp_options() -> None:

import jax

# Enable by default 64bit precision in JAX.
if os.environ.get("JAX_ENABLE_X64", "1") != "0":

logging.info("Enabling JAX to use 64bit precision")
# Check if running on TPU
is_tpu = jax.devices()[0].platform == "tpu"

# Enable by default 64-bit precision to get accurate physics.
# Users can enforce 32-bit precision by setting the following variable to 0.
use_x64 = os.environ.get("JAX_ENABLE_X64", "1") != "0"

# Notify the user if unsupported 64-bit precision was enforced on TPU.
if is_tpu and use_x64:
msg = "64-bit precision is not allowed on TPU. Enforcing 32bit precision."
logging.warning(msg)
use_x64 = False

# Enable 64-bit precision in JAX.
if use_x64:
logging.info("Enabling JAX to use 64-bit precision")
jax.config.update("jax_enable_x64", True)

import jax.numpy as jnp
import numpy as np

# Verify that 64-bit precision is correctly set.
if jnp.empty(0, dtype=float).dtype != jnp.empty(0, dtype=np.float64).dtype:
logging.warning("Failed to enable 64bit precision in JAX")
logging.warning("Failed to enable 64-bit precision in JAX")

# Warn about experimental usage of 32-bit precision.
else:
logging.warning(
"Using 32bit precision in JaxSim is still experimental, please avoid to use variable step integrators."
"Using 32-bit precision in JaxSim is still experimental, please avoid to use variable step integrators."
)


Expand Down
4 changes: 4 additions & 0 deletions src/jaxsim/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ def raise_if(
format string (fmt), whose fields are filled with the args and kwargs.
"""

# Disable host callback if running on TPU.
if jax.devices()[0].platform == "tpu":
return

# Check early that the format string is well-formed.
try:
_ = msg.format(*args, **kwargs)
Expand Down

0 comments on commit 3ac2978

Please sign in to comment.