diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index e6bed6f9d..2a1f6d569 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -144,7 +144,8 @@ def collidable_point_dynamics( The joint force references to apply to the joints. Returns: - The 6D force applied to each collidable point and additional data based on the contact model configured: + The 6D force applied to each collidable point and additional data based + on the contact model configured: - Soft: the material deformation rate. - Rigid: no additional data. - QuasiRigid: no additional data. @@ -156,21 +157,13 @@ def collidable_point_dynamics( """ # Import privately the contacts classes. - from jaxsim.rbda.contacts import ( - RelaxedRigidContacts, - RelaxedRigidContactsState, - RigidContacts, - RigidContactsState, - SoftContacts, - SoftContactsState, - ) + from jaxsim.rbda.contacts import RelaxedRigidContacts, RigidContacts, SoftContacts # Build the soft contact model. match model.contact_model: case SoftContacts(): assert isinstance(model.contact_model, SoftContacts) - assert isinstance(data.state.contact, SoftContactsState) # Compute the 6D force expressed in the inertial frame and applied to each # collidable point, and the corresponding material deformation rate. @@ -187,7 +180,6 @@ def collidable_point_dynamics( case RigidContacts(): assert isinstance(model.contact_model, RigidContacts) - assert isinstance(data.state.contact, RigidContactsState) # Compute the 6D force expressed in the inertial frame and applied to each # collidable point. @@ -203,7 +195,6 @@ def collidable_point_dynamics( case RelaxedRigidContacts(): assert isinstance(model.contact_model, RelaxedRigidContacts) - assert isinstance(data.state.contact, RelaxedRigidContactsState) # Compute the 6D force expressed in the inertial frame and applied to each # collidable point. diff --git a/src/jaxsim/api/data.py b/src/jaxsim/api/data.py index 5d1f21bea..b1f2230de 100644 --- a/src/jaxsim/api/data.py +++ b/src/jaxsim/api/data.py @@ -13,7 +13,6 @@ import jaxsim.math import jaxsim.rbda import jaxsim.typing as jtp -from jaxsim.rbda.contacts import SoftContacts from jaxsim.utils import Mutability from jaxsim.utils.tracing import not_tracing @@ -107,17 +106,17 @@ def zero( @staticmethod def build( model: js.model.JaxSimModel, - base_position: jtp.Vector | None = None, - base_quaternion: jtp.Vector | None = None, - joint_positions: jtp.Vector | None = None, - base_linear_velocity: jtp.Vector | None = None, - base_angular_velocity: jtp.Vector | None = None, - joint_velocities: jtp.Vector | None = None, + base_position: jtp.VectorLike | None = None, + base_quaternion: jtp.VectorLike | None = None, + joint_positions: jtp.VectorLike | None = None, + base_linear_velocity: jtp.VectorLike | None = None, + base_angular_velocity: jtp.VectorLike | None = None, + joint_velocities: jtp.VectorLike | None = None, standard_gravity: jtp.FloatLike = jaxsim.math.StandardGravity, - contact: jaxsim.rbda.contacts.ContactsState | None = None, contacts_params: jaxsim.rbda.contacts.ContactsParams | None = None, velocity_representation: VelRepr = VelRepr.Inertial, time: jtp.FloatLike | None = None, + extended_ode_state: dict[str, jtp.PyTree] | None = None, ) -> JaxSimModelData: """ Create a `JaxSimModelData` object with the given state. @@ -133,56 +132,73 @@ def build( The base angular velocity in the selected representation. joint_velocities: The joint velocities. standard_gravity: The standard gravity constant. - contact: The state of the soft contacts. contacts_params: The parameters of the soft contacts. velocity_representation: The velocity representation to use. time: The time at which the state is created. + extended_ode_state: + Additional user-defined state variables that are not part of the + standard `ODEState` object. Useful to extend the system dynamics + considered by default in JaxSim. Returns: - A `JaxSimModelData` object with the given state. + A `JaxSimModelData` initialized with the given state. """ base_position = jnp.array( - base_position if base_position is not None else jnp.zeros(3) + base_position if base_position is not None else jnp.zeros(3), + dtype=float, ).squeeze() base_quaternion = jnp.array( - base_quaternion - if base_quaternion is not None - else jnp.array([1.0, 0, 0, 0]) + ( + base_quaternion + if base_quaternion is not None + else jnp.array([1.0, 0, 0, 0]) + ), + dtype=float, ).squeeze() base_linear_velocity = jnp.array( - base_linear_velocity if base_linear_velocity is not None else jnp.zeros(3) + base_linear_velocity if base_linear_velocity is not None else jnp.zeros(3), + dtype=float, ).squeeze() base_angular_velocity = jnp.array( - base_angular_velocity if base_angular_velocity is not None else jnp.zeros(3) + ( + base_angular_velocity + if base_angular_velocity is not None + else jnp.zeros(3) + ), + dtype=float, ).squeeze() gravity = jnp.zeros(3).at[2].set(-standard_gravity) joint_positions = jnp.atleast_1d( - joint_positions.squeeze() - if joint_positions is not None - else jnp.zeros(model.dofs()) + jnp.array( + ( + joint_positions + if joint_positions is not None + else jnp.zeros(model.dofs()) + ), + dtype=float, + ).squeeze() ) joint_velocities = jnp.atleast_1d( - joint_velocities.squeeze() - if joint_velocities is not None - else jnp.zeros(model.dofs()) + jnp.array( + ( + joint_velocities + if joint_velocities is not None + else jnp.zeros(model.dofs()) + ), + dtype=float, + ).squeeze() ) - time_ns = ( - jnp.array( - time * 1e9, - dtype=jnp.uint64 if jax.config.read("jax_enable_x64") else jnp.uint32, - ) - if time is not None - else jnp.array( - 0, dtype=jnp.uint64 if jax.config.read("jax_enable_x64") else jnp.uint32 - ) + time_ns = jnp.array( + time * 1e9 if time is not None else 0.0, + dtype=jnp.uint64 if jax.config.read("jax_enable_x64") else jnp.uint32, ) W_H_B = jaxsim.math.Transform.from_quaternion_and_translation( @@ -194,21 +210,22 @@ def build( other_representation=velocity_representation, transform=W_H_B, is_force=False, - ) + ).astype(float) ode_state = ODEState.build_from_jaxsim_model( model=model, - base_position=base_position.astype(float), - base_quaternion=base_quaternion.astype(float), - joint_positions=joint_positions.astype(float), - base_linear_velocity=v_WB[0:3].astype(float), - base_angular_velocity=v_WB[3:6].astype(float), - joint_velocities=joint_velocities.astype(float), - tangential_deformation=( - contact.tangential_deformation - if contact is not None and isinstance(model.contact_model, SoftContacts) - else None - ), + base_position=base_position, + base_quaternion=base_quaternion, + joint_positions=joint_positions, + base_linear_velocity=v_WB[0:3], + base_angular_velocity=v_WB[3:6], + joint_velocities=joint_velocities, + # Unpack all the additional ODE states. If the contact model requires an + # additional state that is not explicitly passed to this builder, ODEState + # automatically populates that state with zeroed variables. + # This is not true for any other custom state that the user might want to + # pass to the integrator. + **(extended_ode_state if extended_ode_state else {}), ) if not ode_state.valid(model=model): @@ -220,13 +237,14 @@ def build( contacts_params = js.contact.estimate_good_soft_contacts_parameters( model=model, standard_gravity=standard_gravity ) + else: contacts_params = model.contact_model.parameters return JaxSimModelData( time_ns=time_ns, state=ode_state, - gravity=gravity.astype(float), + gravity=gravity, contacts_params=contacts_params, velocity_representation=velocity_representation, ) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 291e1cc97..989420684 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -33,7 +33,7 @@ class JaxSimModel(JaxsimDataclass): model_name: Static[str] terrain: Static[jaxsim.terrain.Terrain] = dataclasses.field( - default=jaxsim.terrain.FlatTerrain(), repr=False + default_factory=jaxsim.terrain.FlatTerrain.build, repr=False ) contact_model: jaxsim.rbda.contacts.ContactModel | None = dataclasses.field( @@ -101,13 +101,14 @@ def build_from_model_description( A path to an SDF/URDF file, a string containing its content, or a pre-parsed/pre-built rod model. model_name: - The optional name of the model that overrides the one in - the description. - terrain: - The optional terrain to consider. + The name of the model. If not specified, it is read from the description. + terrain: The terrain to consider (the default is a flat infinite plane). + contact_model: + The contact model to consider. + If not specified, a soft contacts model is used. is_urdf: - The optional flag to force the model description to be parsed as a - URDF or a SDF. This is otherwise automatically inferred. + The optional flag to force the model description to be parsed as a URDF. + This is usually automatically inferred. considered_joints: The list of joints to consider. If None, all joints are considered. @@ -120,7 +121,7 @@ def build_from_model_description( # Parse the input resource (either a path to file or a string with the URDF/SDF) # and build the -intermediate- model description. intermediate_description = jaxsim.parsers.rod.build_model_description( - model_description=model_description + model_description=model_description, is_urdf=is_urdf ) # Lump links together if not all joints are considered. @@ -160,11 +161,11 @@ def build( The intermediate model description defining the kinematics and dynamics of the model. model_name: - The optional name of the model overriding the physics model name. - terrain: - The optional terrain to consider. + The name of the model. If not specified, it is read from the description. + terrain: The terrain to consider (the default is a flat infinite plane). contact_model: - The optional contact model to consider. If None, the soft contact model is used. + The contact model to consider. + If not specified, a soft contacts model is used. Returns: The built Model object. @@ -173,21 +174,31 @@ def build( # Set the model name (if not provided, use the one from the model description). model_name = model_name if model_name is not None else model_description.name - # Set the terrain (if not provided, use the default flat terrain). - terrain = terrain or JaxSimModel.__dataclass_fields__["terrain"].default - contact_model = contact_model or jaxsim.rbda.contacts.SoftContacts( - terrain=terrain + # Consider the default terrain (a flat infinite plane) if not specified. + terrain = ( + terrain or JaxSimModel.__dataclass_fields__["terrain"].default_factory() + ) + + # Create the default contact model. + # It will be populated with an initial estimation of good parameters. + # While these might not be the best, they are a good starting point. + contact_model = contact_model or jaxsim.rbda.contacts.SoftContacts.build( + terrain=terrain, parameters=None ) # Build the model. model = JaxSimModel( model_name=model_name, - _description=wrappers.HashlessObject(obj=model_description), kin_dyn_parameters=js.kin_dyn_parameters.KynDynParameters.build( model_description=model_description ), terrain=terrain, contact_model=contact_model, + # The following is wrapped as hashless since it's a static argument, and we + # don't want to trigger recompilation if it changes. All relevant parameters + # needed to compute kinematics and dynamics quantities are stored in the + # kin_dyn_parameters attribute. + _description=wrappers.HashlessObject(obj=model_description), ) return model diff --git a/src/jaxsim/api/ode.py b/src/jaxsim/api/ode.py index d065ac712..408650b28 100644 --- a/src/jaxsim/api/ode.py +++ b/src/jaxsim/api/ode.py @@ -370,9 +370,8 @@ def system_dynamics( corresponding derivative, and the dictionary of auxiliary data returned by the system dynamics evaluation. """ - from jaxsim.rbda.contacts.relaxed_rigid import RelaxedRigidContacts - from jaxsim.rbda.contacts.rigid import RigidContacts - from jaxsim.rbda.contacts.soft import SoftContacts + + from jaxsim.rbda.contacts import RelaxedRigidContacts, RigidContacts, SoftContacts # Compute the accelerations and the material deformation rate. W_v̇_WB, s̈, aux_dict = system_velocity_dynamics( @@ -382,17 +381,20 @@ def system_dynamics( link_forces=link_forces, ) - ode_state_kwargs = {} + # Initialize the dictionary storing the derivative of the additional state variables + # that extend the state vector of the integrated ODE system. + extended_ode_state = {} match model.contact_model: + case SoftContacts(): - ode_state_kwargs["tangential_deformation"] = aux_dict["m_dot"] + extended_ode_state["tangential_deformation"] = aux_dict["m_dot"] case RigidContacts() | RelaxedRigidContacts(): pass case _: - raise ValueError("Unable to determine contact state class prefix.") + raise ValueError(f"Invalid contact model {model.contact_model}") # Extract the velocities. W_ṗ_B, W_Q̇_B, ṡ = system_position_dynamics( @@ -412,7 +414,7 @@ def system_dynamics( base_linear_velocity=W_v̇_WB[0:3], base_angular_velocity=W_v̇_WB[3:6], joint_velocities=s̈, - **ode_state_kwargs, + **extended_ode_state, ) return ode_state_derivative, aux_dict diff --git a/src/jaxsim/api/ode_data.py b/src/jaxsim/api/ode_data.py index 622e0e196..d9e6ae3f1 100644 --- a/src/jaxsim/api/ode_data.py +++ b/src/jaxsim/api/ode_data.py @@ -1,19 +1,13 @@ from __future__ import annotations +import dataclasses + +import jax import jax.numpy as jnp import jax_dataclasses import jaxsim.api as js import jaxsim.typing as jtp -from jaxsim.rbda.contacts import ( - ContactsState, - RelaxedRigidContacts, - RelaxedRigidContactsState, - RigidContacts, - RigidContactsState, - SoftContacts, - SoftContactsState, -) from jaxsim.utils import JaxsimDataclass # ============================================================================= @@ -125,15 +119,18 @@ class ODEState(JaxsimDataclass): Attributes: physics_model: The state of the physics model. - contact: The state of the contacts model. + extended: + Additional state variables extending the state vector corresponding to + equations of motion. These extended variables are passed to the integrator. """ physics_model: PhysicsModelState - contact: ContactsState + + extended: dict[str, jtp.PyTree] = dataclasses.field(default_factory=dict) @staticmethod def build_from_jaxsim_model( - model: js.model.JaxSimModel | None = None, + model: js.model.JaxSimModel, joint_positions: jtp.Vector | None = None, joint_velocities: jtp.Vector | None = None, base_position: jtp.Vector | None = None, @@ -155,7 +152,15 @@ def build_from_jaxsim_model( The linear velocity of the base link in inertial-fixed representation. base_angular_velocity: The angular velocity of the base link in inertial-fixed representation. - kwargs: Additional arguments needed to build the contact state. + kwargs: + Additional arguments corresponding variables extending the default + state vector of the physics model. + + Note: + Kwargs can be used to supply any additional state variables that are passed + to the integrator. This is useful to extend the default system dynamics, + for example if the contact model requires additional state variables or to + simulate additional dynamics like actuators or muscoloskeletal models. Returns: The `ODEState` built from the `JaxSimModel`. @@ -165,29 +170,11 @@ def build_from_jaxsim_model( `JaxSimModel` and initialized to zero. """ - # Get the contact model from the `JaxSimModel`. - match model.contact_model: - - case SoftContacts(): - - tangential_deformation = kwargs.get("tangential_deformation", None) - - contact = SoftContactsState.build_from_jaxsim_model( - model=model, - **( - dict(tangential_deformation=tangential_deformation) - if tangential_deformation is not None - else dict() - ), - ) - case RigidContacts(): - contact = RigidContactsState.build() + # Initialize the extended state with the optional contact state. + extended_state = model.contact_model.zero_state_variables(model=model) - case RelaxedRigidContacts(): - contact = RelaxedRigidContactsState.build() - - case _: - raise ValueError("Unsupported contact model.") + # Override the default extended state with optional kwargs. + extended_state |= kwargs return ODEState.build( model=model, @@ -200,13 +187,13 @@ def build_from_jaxsim_model( base_linear_velocity=base_linear_velocity, base_angular_velocity=base_angular_velocity, ), - contact=contact, + extended_state=extended_state, ) @staticmethod def build( physics_model_state: PhysicsModelState | None = None, - contact: ContactsState | None = None, + extended_state: dict[str, jtp.PyTree] | None = None, model: js.model.JaxSimModel | None = None, ) -> ODEState: """ @@ -214,62 +201,60 @@ def build( Args: physics_model_state: The state of the physics model. - contact: The state of the contacts model. + extended_state: Additional state variables extending the state vector. model: The `JaxSimModel` associated with the ODE state. Returns: A `ODEState` instance. """ + # Build a zero state for the physics model if not provided. physics_model_state = ( physics_model_state if physics_model_state is not None else PhysicsModelState.zero(model=model) ) - # Get the contact model from the `JaxSimModel`. - match contact: - case ( - SoftContactsState() | RigidContactsState() | RelaxedRigidContactsState() - ): - pass - case None: - contact = SoftContactsState.zero(model=model) - case _: - raise ValueError("Unable to determine contact state class prefix.") - - return ODEState(physics_model=physics_model_state, contact=contact) + return ODEState( + physics_model=physics_model_state, + extended=extended_state, + ) @staticmethod def zero(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> ODEState: """ - Build a zero `ODEState` from a `JaxSimModel`. + Build a zero `ODEState` corresponding to a `JaxSimModel`. Args: - model: The `JaxSimModel` associated with the ODE state. + model: The model to consider. + data: The data of the considered model. Returns: A zero `ODEState` instance. """ - model_state = ODEState.build( - model=model, contact=data.state.contact.zero(model=model) + ode_state = ODEState.build( + model=model, + extended_state=jax.tree.map( + lambda x: jnp.zeros_like(x), data.state.extended + ), ) - return model_state + return ode_state def valid(self, model: js.model.JaxSimModel) -> bool: """ Check if the `ODEState` is valid for a given `JaxSimModel`. Args: - model: The `JaxSimModel` to validate the `ODEState` against. + model: The model to validate this `ODEState` against. Returns: `True` if the ODE state is valid for the given model, `False` otherwise. """ - return self.physics_model.valid(model=model) and self.contact.valid(model=model) + # TODO: should we validate the extended state? + return self.physics_model.valid(model=model) # ================================================== diff --git a/src/jaxsim/rbda/contacts/__init__.py b/src/jaxsim/rbda/contacts/__init__.py index 1c6806304..d99014818 100644 --- a/src/jaxsim/rbda/contacts/__init__.py +++ b/src/jaxsim/rbda/contacts/__init__.py @@ -1,9 +1,5 @@ from . import relaxed_rigid, rigid, soft -from .common import ContactModel, ContactsParams, ContactsState -from .relaxed_rigid import ( - RelaxedRigidContacts, - RelaxedRigidContactsParams, - RelaxedRigidContactsState, -) -from .rigid import RigidContacts, RigidContactsParams, RigidContactsState -from .soft import SoftContacts, SoftContactsParams, SoftContactsState +from .common import ContactModel, ContactsParams +from .relaxed_rigid import RelaxedRigidContacts, RelaxedRigidContactsParams +from .rigid import RigidContacts, RigidContactsParams +from .soft import SoftContacts, SoftContactsParams diff --git a/src/jaxsim/rbda/contacts/common.py b/src/jaxsim/rbda/contacts/common.py index 728319c27..cad3fbf32 100644 --- a/src/jaxsim/rbda/contacts/common.py +++ b/src/jaxsim/rbda/contacts/common.py @@ -14,41 +14,6 @@ from typing_extensions import Self -class ContactsState(JaxsimDataclass): - """ - Abstract class storing the state of the contacts model. - """ - - @classmethod - @abc.abstractmethod - def build(cls: type[Self], **kwargs) -> Self: - """ - Build the contact state object. - - Returns: - The contact state object. - """ - pass - - @classmethod - @abc.abstractmethod - def zero(cls: type[Self], **kwargs) -> Self: - """ - Build a zero contact state. - - Returns: - The zero contact state. - """ - pass - - @abc.abstractmethod - def valid(self, **kwargs) -> jtp.BoolLike: - """ - Check if the contacts state is valid. - """ - pass - - class ContactsParams(JaxsimDataclass): """ Abstract class representing the parameters of a contact model. @@ -88,6 +53,27 @@ class ContactModel(JaxsimDataclass): parameters: ContactsParams terrain: jaxsim.terrain.Terrain + @classmethod + @abc.abstractmethod + def build( + cls: type[Self], + parameters: ContactsParams, + terrain: jaxsim.terrain.Terrain, + **kwargs, + ) -> Self: + """ + Create a `ContactModel` instance with specified parameters. + + Args: + parameters: The parameters of the contact model. + terrain: The considered terrain. + + Returns: + The `ContactModel` instance. + """ + + pass + @abc.abstractmethod def compute_contact_forces( self, @@ -109,6 +95,27 @@ def compute_contact_forces( pass + @classmethod + def zero_state_variables(cls, model: js.model.JaxSimModel) -> dict[str, jtp.Array]: + """ + Build zero state variables of the contact model. + + Args: + model: The robot model considered by the contact model. + + Note: + There are contact models that require to extend the state vector of the + integrated ODE system with additional variables. Our integrators are + capable of operating on a generic state, as long as it is a PyTree. + This method builds the zero state variables of the contact model as a + dictionary of JAX arrays. + + Returns: + A dictionary storing the zero state variables of the contact model. + """ + + return {} + def initialize_model_and_data( self, model: js.model.JaxSimModel, diff --git a/src/jaxsim/rbda/contacts/relaxed_rigid.py b/src/jaxsim/rbda/contacts/relaxed_rigid.py index 022295abc..9699b97a6 100644 --- a/src/jaxsim/rbda/contacts/relaxed_rigid.py +++ b/src/jaxsim/rbda/contacts/relaxed_rigid.py @@ -11,11 +11,12 @@ import jaxsim.api as js import jaxsim.typing as jtp +from jaxsim import logging from jaxsim.api.common import VelRepr from jaxsim.math import Adjoint from jaxsim.terrain.terrain import FlatTerrain, Terrain -from .common import ContactModel, ContactsParams, ContactsState +from .common import ContactModel, ContactsParams try: from typing import Self @@ -156,41 +157,46 @@ def valid(self) -> jtp.BoolLike: ) -@jax_dataclasses.pytree_dataclass -class RelaxedRigidContactsState(ContactsState): - """Class storing the state of the relaxed rigid contacts model.""" - - def __eq__(self, other: RelaxedRigidContactsState) -> bool: - return hash(self) == hash(other) - - @classmethod - def build(cls: type[Self]) -> Self: - """Create a `RelaxedRigidContactsState` instance""" - - return cls() - - @classmethod - def zero(cls: type[Self], **kwargs) -> Self: - """Build a zero `RelaxedRigidContactsState` instance from a `JaxSimModel`.""" - - return cls.build() - - def valid(self, **kwargs) -> jtp.BoolLike: - return True - - @jax_dataclasses.pytree_dataclass class RelaxedRigidContacts(ContactModel): """Relaxed rigid contacts model.""" parameters: RelaxedRigidContactsParams = dataclasses.field( - default_factory=RelaxedRigidContactsParams + default_factory=RelaxedRigidContactsParams.build ) terrain: jax_dataclasses.Static[Terrain] = dataclasses.field( - default_factory=FlatTerrain + default_factory=FlatTerrain.build ) + @classmethod + def build( + cls: type[Self], + parameters: RelaxedRigidContactsParams | None = None, + terrain: Terrain | None = None, + **kwargs, + ) -> Self: + """ + Create a `RelaxedRigidContacts` instance with specified parameters. + + Args: + parameters: The parameters of the rigid contacts model. + terrain: The considered terrain. + + Returns: + The `RelaxedRigidContacts` instance. + """ + + if len(kwargs) != 0: + logging.debug(msg=f"Ignoring extra arguments: {kwargs}") + + return cls( + parameters=( + parameters or cls.__dataclass_fields__["parameters"].default_factory() + ), + terrain=terrain or cls.__dataclass_fields__["terrain"].default_factory(), + ) + @jax.jit def compute_contact_forces( self, diff --git a/src/jaxsim/rbda/contacts/rigid.py b/src/jaxsim/rbda/contacts/rigid.py index 2a3db3f58..7c6d1f6ac 100644 --- a/src/jaxsim/rbda/contacts/rigid.py +++ b/src/jaxsim/rbda/contacts/rigid.py @@ -9,10 +9,11 @@ import jaxsim.api as js import jaxsim.typing as jtp +from jaxsim import logging from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr from jaxsim.terrain import FlatTerrain, Terrain -from .common import ContactModel, ContactsParams, ContactsState +from .common import ContactModel, ContactsParams try: from typing import Self @@ -78,29 +79,6 @@ def valid(self) -> jtp.BoolLike: ) -@jax_dataclasses.pytree_dataclass -class RigidContactsState(ContactsState): - """Class storing the state of the rigid contacts model.""" - - def __eq__(self, other: RigidContactsState) -> bool: - return hash(self) == hash(other) - - @classmethod - def build(cls: type[Self]) -> Self: - """Create a `RigidContactsState` instance""" - - return cls() - - @classmethod - def zero(cls: type[Self], **kwargs) -> Self: - """Build a zero `RigidContactsState` instance from a `JaxSimModel`.""" - - return cls.build() - - def valid(self, **kwargs) -> jtp.BoolLike: - return True - - @jax_dataclasses.pytree_dataclass class RigidContacts(ContactModel): """Rigid contacts model.""" @@ -110,9 +88,37 @@ class RigidContacts(ContactModel): ) terrain: jax_dataclasses.Static[Terrain] = dataclasses.field( - default_factory=FlatTerrain + default_factory=FlatTerrain.build ) + @classmethod + def build( + cls: type[Self], + parameters: RigidContactsParams | None = None, + terrain: Terrain | None = None, + **kwargs, + ) -> Self: + """ + Create a `RigidContacts` instance with specified parameters. + + Args: + parameters: The parameters of the rigid contacts model. + terrain: The considered terrain. + + Returns: + The `RigidContacts` instance. + """ + + if len(kwargs) != 0: + logging.debug(msg=f"Ignoring extra arguments: {kwargs}") + + return cls( + parameters=( + parameters or cls.__dataclass_fields__["parameters"].default_factory() + ), + terrain=terrain or cls.__dataclass_fields__["terrain"].default_factory(), + ) + @staticmethod def detect_contacts( W_p_C: jtp.ArrayLike, diff --git a/src/jaxsim/rbda/contacts/soft.py b/src/jaxsim/rbda/contacts/soft.py index 18c1cf401..7019344d1 100644 --- a/src/jaxsim/rbda/contacts/soft.py +++ b/src/jaxsim/rbda/contacts/soft.py @@ -10,10 +10,11 @@ import jaxsim.api as js import jaxsim.math import jaxsim.typing as jtp +from jaxsim import logging from jaxsim.math import StandardGravity from jaxsim.terrain import FlatTerrain, Terrain -from .common import ContactModel, ContactsParams, ContactsState +from .common import ContactModel, ContactsParams try: from typing import Self @@ -192,13 +193,67 @@ class SoftContacts(ContactModel): """Soft contacts model.""" parameters: SoftContactsParams = dataclasses.field( - default_factory=SoftContactsParams + default_factory=SoftContactsParams.build ) terrain: jax_dataclasses.Static[Terrain] = dataclasses.field( - default_factory=FlatTerrain + default_factory=FlatTerrain.build ) + @classmethod + def build( + cls: type[Self], + parameters: SoftContactsParams | None = None, + terrain: Terrain | None = None, + model: js.model.JaxSimModel | None = None, + **kwargs, + ) -> Self: + """ + Create a `SoftContacts` instance with specified parameters. + + Args: + parameters: The parameters of the soft contacts model. + terrain: The considered terrain. + model: + The robot model considered by the contact model. + If passed, it is used to estimate good default parameters. + + Returns: + The `SoftContacts` instance. + """ + + if len(kwargs) != 0: + logging.debug(msg=f"Ignoring extra arguments: {kwargs}") + + # Build the contact parameters if not provided. Use the model to estimate + # good default parameters, if passed. Users can later override these default + # parameters with their own values -- possibly tuned better. + if parameters is None: + parameters = ( + SoftContactsParams.build_default_from_jaxsim_model(model=model) + if model is not None + else cls.__dataclass_fields__["parameters"].default_factory() + ) + + return SoftContacts( + parameters=parameters, + terrain=terrain or cls.__dataclass_fields__["terrain"].default_factory(), + ) + + @classmethod + def zero_state_variables(cls, model: js.model.JaxSimModel) -> dict[str, jtp.Array]: + """ + Build zero state variables of the contact model. + """ + + # Initialize the material deformation to zero. + tangential_deformation = jnp.zeros( + shape=(len(model.kin_dyn_parameters.contact_parameters.body), 3), + dtype=float, + ) + + return {"tangential_deformation": tangential_deformation} + @staticmethod @functools.partial(jax.jit, static_argnames=("terrain",)) def hunt_crossley_contact_model( @@ -380,8 +435,7 @@ def compute_contact_forces( W_p_C, W_ṗ_C = js.contact.collidable_point_kinematics(model=model, data=data) # Extract the material deformation corresponding to the collidable points. - assert isinstance(data.state.contact, SoftContactsState) - m = data.state.contact.tangential_deformation + m = data.state.extended["tangential_deformation"] # Compute the contact forces for all collidable points. # Since we treat them as independent, we can vmap the computation. @@ -423,131 +477,3 @@ def compute_penetration_data( δ̇ = jnp.where(δ > 0, δ̇, 0.0) return δ, δ̇, n̂ - - -@jax_dataclasses.pytree_dataclass -class SoftContactsState(ContactsState): - """ - Class storing the state of the soft contacts model. - - Attributes: - tangential_deformation: - The matrix of 3D tangential material deformations corresponding to - each collidable point. - """ - - tangential_deformation: jtp.Matrix - - def __hash__(self) -> int: - - return hash( - tuple(jnp.atleast_1d(self.tangential_deformation.flatten()).tolist()) - ) - - def __eq__(self: Self, other: Self) -> bool: - - if not isinstance(other, type(self)): - return False - - return hash(self) == hash(other) - - @classmethod - def build_from_jaxsim_model( - cls: type[Self], - model: js.model.JaxSimModel | None = None, - tangential_deformation: jtp.MatrixLike | None = None, - ) -> Self: - """ - Build a `SoftContactsState` from a `JaxSimModel`. - - Args: - model: The `JaxSimModel` associated with the soft contacts state. - tangential_deformation: The matrix of 3D tangential material deformations. - - Returns: - The `SoftContactsState` built from the `JaxSimModel`. - - Note: - If any of the state components are not provided, they are built from the - `JaxSimModel` and initialized to zero. - """ - - return cls.build( - tangential_deformation=tangential_deformation, - number_of_collidable_points=len( - model.kin_dyn_parameters.contact_parameters.body - ), - ) - - @classmethod - def build( - cls: type[Self], - *, - tangential_deformation: jtp.MatrixLike | None = None, - number_of_collidable_points: int | None = None, - ) -> Self: - """ - Create a `SoftContactsState`. - - Args: - tangential_deformation: - The matrix of 3D tangential material deformations corresponding to - each collidable point. - number_of_collidable_points: The number of collidable points. - - Returns: - A `SoftContactsState` instance. - """ - - tangential_deformation = ( - jnp.atleast_2d(tangential_deformation) - if tangential_deformation is not None - else jnp.zeros(shape=(number_of_collidable_points, 3)) - ).astype(float) - - if tangential_deformation.shape[1] != 3: - raise RuntimeError("The tangential deformation matrix must have 3 columns.") - - if ( - number_of_collidable_points is not None - and tangential_deformation.shape[0] != number_of_collidable_points - ): - msg = "The number of collidable points must match the number of rows " - msg += "in the tangential deformation matrix." - raise RuntimeError(msg) - - return cls(tangential_deformation=tangential_deformation) - - @classmethod - def zero(cls: type[Self], *, model: js.model.JaxSimModel) -> Self: - """ - Build a zero `SoftContactsState` from a `JaxSimModel`. - - Args: - model: The `JaxSimModel` associated with the soft contacts state. - - Returns: - A zero `SoftContactsState` instance. - """ - - return cls.build_from_jaxsim_model(model=model) - - def valid(self, *, model: js.model.JaxSimModel) -> jtp.BoolLike: - """ - Check if the `SoftContactsState` is valid for a given `JaxSimModel`. - - Args: - model: The `JaxSimModel` to validate the `SoftContactsState` against. - - Returns: - `True` if the soft contacts state is valid for the given `JaxSimModel`, - `False` otherwise. - """ - - shape = self.tangential_deformation.shape - expected = (len(model.kin_dyn_parameters.contact_parameters.body), 3) - - if shape != expected: - return False - - return True diff --git a/src/jaxsim/terrain/terrain.py b/src/jaxsim/terrain/terrain.py index e4c6d31dc..770e4a876 100644 --- a/src/jaxsim/terrain/terrain.py +++ b/src/jaxsim/terrain/terrain.py @@ -49,7 +49,7 @@ class FlatTerrain(Terrain): _height: float = dataclasses.field(default=0.0, kw_only=True) @staticmethod - def build(height: jtp.FloatLike) -> FlatTerrain: + def build(height: jtp.FloatLike = 0.0) -> FlatTerrain: return FlatTerrain(_height=float(height)) diff --git a/tests/test_automatic_differentiation.py b/tests/test_automatic_differentiation.py index 024ee6b87..84b1e498e 100644 --- a/tests/test_automatic_differentiation.py +++ b/tests/test_automatic_differentiation.py @@ -8,7 +8,7 @@ import jaxsim.rbda import jaxsim.typing as jtp from jaxsim import VelRepr -from jaxsim.rbda.contacts import SoftContacts, SoftContactsParams, SoftContactsState +from jaxsim.rbda.contacts import SoftContacts, SoftContactsParams # All JaxSim algorithms, excluding the variable-step integrators, should support # being automatically differentiated until second order, both in FWD and REV modes. @@ -343,16 +343,13 @@ def test_ad_integration( model=model, velocity_representation=VelRepr.Inertial, key=subkey ) - # Make sure that the active contact model is SoctContacts. - assert isinstance(data.state.contact, SoftContactsState) - # State in VelRepr.Inertial representation. W_p_B = data.base_position() W_Q_B = data.base_orientation(dcm=False) s = data.joint_positions(model=model) W_v_WB = data.base_velocity() ṡ = data.joint_velocities(model=model) - m = data.state.contact.tangential_deformation + m = data.state.extended["tangential_deformation"] # Inputs. W_f_L = references.link_forces(model=model) @@ -406,7 +403,7 @@ def step( base_angular_velocity=W_v_WB[3:6], joint_velocities=ṡ, ), - contact=js.ode_data.SoftContactsState.build(tangential_deformation=m), + extended_state={"tangential_deformation": m}, ), ) @@ -425,7 +422,7 @@ def step( xf_s = data_xf.joint_positions(model=model) xf_W_v_WB = data_xf.base_velocity() xf_ṡ = data_xf.joint_velocities(model=model) - xf_m = data_xf.state.contact.tangential_deformation + xf_m = data_xf.state.extended["tangential_deformation"] return xf_W_p_B, xf_W_Q_B, xf_s, xf_W_v_WB, xf_ṡ, xf_m