diff --git a/src/jaxsim/api/data.py b/src/jaxsim/api/data.py index 8a878d382..b1f2230de 100644 --- a/src/jaxsim/api/data.py +++ b/src/jaxsim/api/data.py @@ -172,7 +172,7 @@ def build( dtype=float, ).squeeze() - gravity = jnp.zeros(3, dtype=float).at[2].set(-standard_gravity) + gravity = jnp.zeros(3).at[2].set(-standard_gravity) joint_positions = jnp.atleast_1d( jnp.array( diff --git a/src/jaxsim/rbda/contacts/rigid.py b/src/jaxsim/rbda/contacts/rigid.py index 52dc0a5b3..7c6d1f6ac 100644 --- a/src/jaxsim/rbda/contacts/rigid.py +++ b/src/jaxsim/rbda/contacts/rigid.py @@ -88,7 +88,7 @@ class RigidContacts(ContactModel): ) terrain: jax_dataclasses.Static[Terrain] = dataclasses.field( - default_factory=FlatTerrain + default_factory=FlatTerrain.build ) @classmethod