From 5af6b8471c2570c9ea936feab50f59b45ebf0222 Mon Sep 17 00:00:00 2001 From: "Omar G. Younis" Date: Thu, 23 Jan 2025 16:01:58 +0100 Subject: [PATCH 1/6] fix device transfers --- src/jaxsim/api/model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 4af9469be..6088b449f 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -91,7 +91,7 @@ def __hash__(self) -> int: return hash( ( hash(self.model_name), - hash(float(self.time_step)), + hash(self.time_step), hash(self.kin_dyn_parameters), hash(self.contact_model), ) @@ -317,7 +317,7 @@ def floating_base(self) -> bool: True if the model is floating-base, False otherwise. """ - return bool(self.kin_dyn_parameters.joint_model.joint_dofs[0] == 6) + return self.kin_dyn_parameters.joint_model.joint_dofs[0] == 6 def base_link(self) -> str: """ @@ -348,7 +348,7 @@ def dofs(self) -> int: the number of joints. In the future, this could be different. """ - return int(sum(self.kin_dyn_parameters.joint_model.joint_dofs[1:])) + return sum(self.kin_dyn_parameters.joint_model.joint_dofs[1:]) def joint_names(self) -> tuple[str, ...]: """ @@ -431,7 +431,7 @@ def reduce( for joint_name in set(model.joint_names()) - set(considered_joints): j = intermediate_description.joints_dict[joint_name] with j.mutable_context(): - j.initial_position = float(locked_joint_positions.get(joint_name, 0.0)) + j.initial_position = locked_joint_positions.get(joint_name, 0.0) # Reduce the model description. # If `considered_joints` contains joints not existing in the model, From 15e31c75e3058bdb0a94e7b5f6baab4f1d85fa2d Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Fri, 24 Jan 2025 10:59:40 +0100 Subject: [PATCH 2/6] Disable exceptions by default --- docs/guide/configuration.rst | 4 ++-- src/jaxsim/exceptions.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/guide/configuration.rst b/docs/guide/configuration.rst index 4e4c87e16..a1b72ab1c 100644 --- a/docs/guide/configuration.rst +++ b/docs/guide/configuration.rst @@ -61,9 +61,9 @@ The logging and exceptions configurations is controlled by the following environ *Default:* ``DEBUG`` for development, ``WARNING`` for production. -- ``JAXSIM_DISABLE_EXCEPTIONS``: Disables the runtime checks and exceptions. +- ``JAXSIM_DISABLE_EXCEPTIONS``: Disables the runtime checks and exceptions. Note that enabling exceptions might lead to device-to-host transfer of data, increasing the computational time required. - *Default:* ``False``. + *Default:* ``True``. .. note:: Runtime exceptions are disabled by default on TPU. diff --git a/src/jaxsim/exceptions.py b/src/jaxsim/exceptions.py index 16590d051..4cfd1a221 100644 --- a/src/jaxsim/exceptions.py +++ b/src/jaxsim/exceptions.py @@ -24,7 +24,7 @@ def raise_if( # Disable host callback if running on unsupported hardware or if the user # explicitly disabled it. if jax.devices()[0].platform in {"tpu", "METAL"} or os.environ.get( - "JAXSIM_DISABLE_EXCEPTIONS", 0 + "JAXSIM_DISABLE_EXCEPTIONS", 1 ): return From e974a303755d12824d1b7447b95957d8516d1972 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Fri, 24 Jan 2025 11:01:51 +0100 Subject: [PATCH 3/6] Change flag name to enable exceptions --- docs/guide/configuration.rst | 4 ++-- src/jaxsim/exceptions.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/guide/configuration.rst b/docs/guide/configuration.rst index a1b72ab1c..4061b30ff 100644 --- a/docs/guide/configuration.rst +++ b/docs/guide/configuration.rst @@ -61,9 +61,9 @@ The logging and exceptions configurations is controlled by the following environ *Default:* ``DEBUG`` for development, ``WARNING`` for production. -- ``JAXSIM_DISABLE_EXCEPTIONS``: Disables the runtime checks and exceptions. Note that enabling exceptions might lead to device-to-host transfer of data, increasing the computational time required. +- ``JAXSIM_ENABLE_EXCEPTIONS``: Enables the runtime checks and exceptions. Note that enabling exceptions might lead to device-to-host transfer of data, increasing the computational time required. - *Default:* ``True``. + *Default:* ``False``. .. note:: Runtime exceptions are disabled by default on TPU. diff --git a/src/jaxsim/exceptions.py b/src/jaxsim/exceptions.py index 4cfd1a221..ecebbbe9a 100644 --- a/src/jaxsim/exceptions.py +++ b/src/jaxsim/exceptions.py @@ -23,8 +23,8 @@ def raise_if( # Disable host callback if running on unsupported hardware or if the user # explicitly disabled it. - if jax.devices()[0].platform in {"tpu", "METAL"} or os.environ.get( - "JAXSIM_DISABLE_EXCEPTIONS", 1 + if jax.devices()[0].platform in {"tpu", "METAL"} or not os.environ.get( + "JAXSIM_ENABLE_EXCEPTIONS", 0 ): return From 56d080eed2db141d21e0fffc7e156b212ab2933b Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Fri, 24 Jan 2025 11:15:35 +0100 Subject: [PATCH 4/6] Enable exceptions in tests --- tests/conftest.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index 9fcb9be4b..3ad751e85 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,7 @@ import os + +os.environ["JAXSIM_ENABLE_EXCEPTIONS"] = "1" + import pathlib import subprocess From 478660bfb3acef75884e8badff17860ee2f94106 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Fri, 24 Jan 2025 11:15:51 +0100 Subject: [PATCH 5/6] Fix argument name in AD integration test --- tests/test_automatic_differentiation.py | 2 +- tests/test_simulations.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_automatic_differentiation.py b/tests/test_automatic_differentiation.py index 84a50d6c5..d3153f2ed 100644 --- a/tests/test_automatic_differentiation.py +++ b/tests/test_automatic_differentiation.py @@ -344,7 +344,7 @@ def step( model=model, data=data_x0, joint_force_references=τ, - link_forces=W_f_L, + link_forces_inertial=W_f_L, ) xf_W_p_B = data_xf.base_position diff --git a/tests/test_simulations.py b/tests/test_simulations.py index 86e3696e3..e56c2c1cf 100644 --- a/tests/test_simulations.py +++ b/tests/test_simulations.py @@ -74,7 +74,7 @@ def test_box_with_external_forces( data = js.model.step( model=model, data=data, - link_forces=references.link_forces(model=model, data=data), + link_forces_inertial=references._link_forces, ) # Check that the box didn't move. @@ -148,7 +148,7 @@ def test_box_with_zero_gravity( data = js.model.step( model=model, data=data, - link_forces=references.link_forces(model=model, data=data), + link_forces_inertial=references.link_forces(model=model, data=data), ) # Check that the box moved as expected. From 7aff90138ecfc71337f4dea483ffb854075bebc6 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Fri, 24 Jan 2025 17:23:11 +0100 Subject: [PATCH 6/6] Save `time_step` and reduced joint positions as float --- src/jaxsim/api/model.py | 6 +++--- tests/test_api_model.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 6088b449f..3610e26f6 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -33,8 +33,8 @@ class JaxSimModel(JaxsimDataclass): model_name: Static[str] - time_step: jtp.FloatLike = dataclasses.field( - default_factory=lambda: jnp.array(0.001, dtype=float), + time_step: float = dataclasses.field( + default=0.001, ) terrain: Static[jaxsim.terrain.Terrain] = dataclasses.field( @@ -222,7 +222,7 @@ def build( time_step = ( time_step if time_step is not None - else JaxSimModel.__dataclass_fields__["time_step"].default_factory() + else JaxSimModel.__dataclass_fields__["time_step"].default ) # Create the default contact model. diff --git a/tests/test_api_model.py b/tests/test_api_model.py index bfd6eb0b9..ac97d616e 100644 --- a/tests/test_api_model.py +++ b/tests/test_api_model.py @@ -81,7 +81,7 @@ def test_model_creation_and_reduction( locked_joint_positions=dict( zip( model_full.joint_names(), - data_full.joint_positions, + data_full.joint_positions.tolist(), strict=True, ) ),