From a7b71503155746effb39ba898439a6deca925d65 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 28 Jan 2025 15:54:38 +0100 Subject: [PATCH] Streamline new API changes to alternative contact models --- src/jaxsim/api/contact.py | 95 +++++++++++- src/jaxsim/api/contact_model.py | 14 +- src/jaxsim/api/data.py | 21 +++ src/jaxsim/api/model.py | 76 +++++++++- src/jaxsim/api/ode.py | 5 +- src/jaxsim/math/__init__.py | 2 +- src/jaxsim/rbda/contacts/rigid.py | 16 +- src/jaxsim/rbda/contacts/soft.py | 12 +- src/jaxsim/rbda/contacts/visco_elastic.py | 147 +++++++------------ src/jaxsim/rbda/utils.py | 6 + tests/test_automatic_differentiation.py | 15 +- tests/test_simulations.py | 170 +++++++++++++++++++++- 12 files changed, 443 insertions(+), 136 deletions(-) diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 2859a3fdc..848d88e15 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -37,14 +37,11 @@ def collidable_point_kinematics( the linear component of the mixed 6D frame velocity. """ - # Switch to inertial-fixed since the RBDAs expect velocities in this representation. - with data.switch_velocity_representation(VelRepr.Inertial): - - W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_points.collidable_points_pos_vel( - model=model, - link_transforms=data.link_transforms, - link_velocities=data.link_velocities, - ) + W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_points.collidable_points_pos_vel( + model=model, + link_transforms=data.link_transforms, + link_velocities=data.link_velocities, + ) return W_p_Ci, W_ṗ_Ci @@ -164,7 +161,11 @@ def estimate_good_soft_contacts_parameters( def estimate_good_contact_parameters( model: js.model.JaxSimModel, *, + standard_gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY, static_friction_coefficient: jtp.FloatLike = 0.5, + number_of_active_collidable_points_steady_state: jtp.IntLike = 1, + damping_ratio: jtp.FloatLike = 1.0, + max_penetration: jtp.FloatLike | None = None, **kwargs, ) -> jaxsim.rbda.contacts.ContactParamsTypes: """ @@ -172,7 +173,12 @@ def estimate_good_contact_parameters( Args: model: The model to consider. + standard_gravity: The standard gravity acceleration. static_friction_coefficient: The static friction coefficient. + number_of_active_collidable_points_steady_state: + The number of active collidable points in steady state. + damping_ratio: The damping ratio. + max_penetration: The maximum penetration allowed. kwargs: Additional model-specific parameters passed to the builder method of the parameters class. @@ -190,8 +196,81 @@ def estimate_good_contact_parameters( specific application. """ + def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float: + """ + Displacement between the CoM and the lowest collidable point using zero + joint positions. + """ + + zero_data = js.data.JaxSimModelData.build( + model=model, + ) + + W_pz_CoM = js.com.com_position(model=model, data=zero_data)[2] + + if model.floating_base(): + W_pz_C = collidable_point_positions(model=model, data=zero_data)[:, -1] + return 2 * (W_pz_CoM - W_pz_C.min()) + + return 2 * W_pz_CoM + + max_δ = ( + max_penetration + if max_penetration is not None + # Consider as default a 0.5% of the model height. + else 0.005 * estimate_model_height(model=model) + ) + + nc = number_of_active_collidable_points_steady_state + match model.contact_model: + case contacts.SoftContacts(): + assert isinstance(model.contact_model, contacts.SoftContacts) + + parameters = contacts.SoftContactsParams.build_default_from_jaxsim_model( + model=model, + standard_gravity=standard_gravity, + static_friction_coefficient=static_friction_coefficient, + max_penetration=max_δ, + number_of_active_collidable_points_steady_state=nc, + damping_ratio=damping_ratio, + **kwargs, + ) + + case contacts.ViscoElasticContacts(): + assert isinstance(model.contact_model, contacts.ViscoElasticContacts) + + parameters = ( + contacts.ViscoElasticContactsParams.build_default_from_jaxsim_model( + model=model, + standard_gravity=standard_gravity, + static_friction_coefficient=static_friction_coefficient, + max_penetration=max_δ, + number_of_active_collidable_points_steady_state=nc, + damping_ratio=damping_ratio, + **kwargs, + ) + ) + + case contacts.RigidContacts(): + assert isinstance(model.contact_model, contacts.RigidContacts) + + # Disable Baumgarte stabilization by default since it does not play + # well with the forward Euler integrator. + K = kwargs.get("K", 0.0) + + parameters = contacts.RigidContactsParams.build( + mu=static_friction_coefficient, + **( + dict( + K=K, + D=2 * jnp.sqrt(K), + ) + | kwargs + ), + ) + case contacts.RelaxedRigidContacts(): assert isinstance(model.contact_model, contacts.RelaxedRigidContacts) diff --git a/src/jaxsim/api/contact_model.py b/src/jaxsim/api/contact_model.py index 1eeca4c5d..cda5222ba 100644 --- a/src/jaxsim/api/contact_model.py +++ b/src/jaxsim/api/contact_model.py @@ -5,6 +5,7 @@ import jaxsim.api as js import jaxsim.typing as jtp +from jaxsim.rbda.contacts import SoftContacts @jax.jit @@ -15,7 +16,7 @@ def link_contact_forces( *, link_forces: jtp.MatrixLike | None = None, joint_torques: jtp.VectorLike | None = None, -) -> jtp.Matrix: +) -> tuple[jtp.Matrix, dict[str, jtp.Matrix]]: """ Compute the 6D contact forces of all links of the model in inertial representation. @@ -33,11 +34,14 @@ def link_contact_forces( """ # Compute the contact forces for each collidable point with the active contact model. - W_f_C, _ = model.contact_model.compute_contact_forces( + W_f_C, aux_dict = model.contact_model.compute_contact_forces( model=model, data=data, - link_forces=link_forces, - joint_force_references=joint_torques, + **( + dict(link_forces=link_forces, joint_force_references=joint_torques) + if not isinstance(model.contact_model, SoftContacts) + else {} + ), ) # Compute the 6D forces applied to the links equivalent to the forces applied @@ -46,7 +50,7 @@ def link_contact_forces( model=model, data=data, contact_forces=W_f_C ) - return W_f_L + return W_f_L, aux_dict @staticmethod diff --git a/src/jaxsim/api/data.py b/src/jaxsim/api/data.py index cf64af93a..b543c758b 100644 --- a/src/jaxsim/api/data.py +++ b/src/jaxsim/api/data.py @@ -60,6 +60,9 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation): link_transforms: jtp.Matrix = dataclasses.field(repr=False, default=None) link_velocities: jtp.Matrix = dataclasses.field(repr=False, default=None) + # Extended state for soft and rigid contact models. + contact_state: dict[str, jtp.Array] = dataclasses.field(default=None) + @staticmethod def build( model: js.model.JaxSimModel, @@ -70,6 +73,8 @@ def build( base_angular_velocity: jtp.VectorLike | None = None, joint_velocities: jtp.VectorLike | None = None, velocity_representation: VelRepr = VelRepr.Inertial, + *, + contact_state: dict[str, jtp.Array] | None = None, ) -> JaxSimModelData: """ Create a `JaxSimModelData` object with the given state. @@ -85,6 +90,7 @@ def build( The base angular velocity in the selected representation. joint_velocities: The joint velocities. velocity_representation: The velocity representation to use. + contact_state: The optional contact state. Returns: A `JaxSimModelData` initialized with the given state. @@ -165,6 +171,20 @@ def build( joint_velocities=joint_velocities, ) + contact_state = ( + { + "tangential_deformation": jnp.zeros_like( + model.kin_dyn_parameters.contact_parameters.point + ) + } + if isinstance( + model.contact_model, + jaxsim.rbda.contacts.SoftContacts + | jaxsim.rbda.contacts.ViscoElasticContacts, + ) + else contact_state or {} + ) + model_data = JaxSimModelData( base_quaternion=base_quaternion, base_position=base_position, @@ -177,6 +197,7 @@ def build( joint_transforms=joint_transforms, link_transforms=link_transforms, link_velocities=link_velocities, + contact_state=contact_state or {}, ) if not model_data.valid(model=model): diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 3610e26f6..630dfc271 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -41,7 +41,7 @@ class JaxSimModel(JaxsimDataclass): default_factory=jaxsim.terrain.FlatTerrain.build, repr=False ) - gravity: Static[float] = jaxsim.math.STANDARD_GRAVITY + gravity: Static[float] = -jaxsim.math.STANDARD_GRAVITY contact_model: Static[jaxsim.rbda.contacts.ContactModel | None] = dataclasses.field( default=None, repr=False @@ -111,6 +111,7 @@ def build_from_model_description( terrain: jaxsim.terrain.Terrain | None = None, contact_model: jaxsim.rbda.contacts.ContactModel | None = None, contact_params: jaxsim.rbda.contacts.ContactsParams | None = None, + gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY, is_urdf: bool | None = None, considered_joints: Sequence[str] | None = None, ) -> JaxSimModel: @@ -131,6 +132,7 @@ def build_from_model_description( The contact model to consider. If not specified, a soft contacts model is used. contact_params: The parameters of the contact model. + gravity: The gravity constant. is_urdf: The optional flag to force the model description to be parsed as a URDF. This is usually automatically inferred. @@ -164,6 +166,7 @@ def build_from_model_description( terrain=terrain, contact_model=contact_model, contacts_params=contact_params, + gravity=gravity, ) # Store the origin of the model, in case downstream logic needs it. @@ -247,7 +250,7 @@ def build( terrain=terrain, contact_model=contact_model, contacts_params=contacts_params, - gravity=gravity, + gravity=-gravity, # The following is wrapped as hashless since it's a static argument, and we # don't want to trigger recompilation if it changes. All relevant parameters # needed to compute kinematics and dynamics quantities are stored in the @@ -447,6 +450,8 @@ def reduce( time_step=model.time_step, terrain=model.terrain, contact_model=model.contact_model, + contacts_params=model.contacts_params, + gravity=model.gravity, ) # Store the origin of the model, in case downstream logic needs it. @@ -2045,7 +2050,7 @@ def step( # Compute the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact # with the terrain. - W_f_L_terrain = js.contact_model.link_contact_forces( + W_f_L_terrain, aux_dict = js.contact_model.link_contact_forces( model=model, data=data, link_forces=W_f_L_external, @@ -2058,6 +2063,33 @@ def step( W_f_L_total = W_f_L_external + W_f_L_terrain + # ============================= + # Update the contact state data + # ============================= + + contact_state = {} + + match model.contact_model: + + case jaxsim.rbda.contacts.SoftContacts(): + contact_state["tangential_deformation"] = aux_dict["m_dot"] + data = data.replace(contact_state=contact_state) + + case jaxsim.rbda.contacts.ViscoElasticContacts(): + contact_state["tangential_deformation"] = jnp.zeros_like( + jnp.array(model.kin_dyn_parameters.contact_parameters.point) + ) + data = data.replace(contact_state=contact_state) + + case ( + jaxsim.rbda.contacts.RigidContacts() + | jaxsim.rbda.contacts.RelaxedRigidContacts() + ): + pass + + case _: + raise ValueError(f"Invalid contact model: {model.contact_model}") + # =============================== # Compute the system acceleration # =============================== @@ -2081,6 +2113,44 @@ def step( joint_accelerations=s̈, ) + if isinstance(model.contact_model, jaxsim.rbda.contacts.RigidContacts): + # Extract the indices corresponding to the enabled collidable points. + indices_of_enabled_collidable_points = ( + model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points + ) + + W_p_C = js.contact.collidable_point_positions(model, data_tf)[ + indices_of_enabled_collidable_points + ] + + # Compute the penetration depth of the collidable points. + δ, *_ = jax.vmap( + jaxsim.rbda.contacts.common.compute_penetration_data, + in_axes=(0, 0, None), + )(W_p_C, jnp.zeros_like(W_p_C), model.terrain) + + with data_tf.switch_velocity_representation(VelRepr.Mixed): + J_WC = js.contact.jacobian(model, data_tf)[ + indices_of_enabled_collidable_points + ] + M = js.model.free_floating_mass_matrix(model, data_tf) + BW_ν_pre_impact = data_tf.generalized_velocity() + + # Compute the impact velocity. + # It may be discontinuous in case new contacts are made. + BW_ν_post_impact = ( + jaxsim.rbda.contacts.RigidContacts.compute_impact_velocity( + generalized_velocity=BW_ν_pre_impact, + inactive_collidable_points=(δ <= 0), + M=M, + J_WC=J_WC, + ) + ) + + # Reset the generalized velocity. + data_tf = data_tf.reset_base_velocity(BW_ν_post_impact[0:6]) + data_tf = data_tf.reset_joint_velocities(BW_ν_post_impact[6:]) + # ne parliamo dopo # Restore the input velocity representation data_tf = data_tf.replace( diff --git a/src/jaxsim/api/ode.py b/src/jaxsim/api/ode.py index a7b96ec72..bfd06d758 100644 --- a/src/jaxsim/api/ode.py +++ b/src/jaxsim/api/ode.py @@ -62,7 +62,10 @@ def system_velocity_dynamics( link_forces=W_f_L, ) - return W_v̇_WB, s̈ + return ( + W_v̇_WB, + s̈, + ) def system_acceleration( diff --git a/src/jaxsim/math/__init__.py b/src/jaxsim/math/__init__.py index e7c221742..cf0bcb107 100644 --- a/src/jaxsim/math/__init__.py +++ b/src/jaxsim/math/__init__.py @@ -11,4 +11,4 @@ # Define the default standard gravity constant. -STANDARD_GRAVITY = -9.81 +STANDARD_GRAVITY = 9.81 diff --git a/src/jaxsim/rbda/contacts/rigid.py b/src/jaxsim/rbda/contacts/rigid.py index d04a7b895..6f3ffbdde 100644 --- a/src/jaxsim/rbda/contacts/rigid.py +++ b/src/jaxsim/rbda/contacts/rigid.py @@ -294,19 +294,13 @@ def compute_contact_forces( ) # Compute the generalized free acceleration. - with ( - references.switch_velocity_representation(VelRepr.Mixed), - data.switch_velocity_representation(VelRepr.Mixed), - ): - + with data.switch_velocity_representation(VelRepr.Mixed): BW_ν̇_free = jnp.hstack( js.ode.system_acceleration( model=model, data=data, link_forces=references.link_forces(model=model, data=data), - joint_force_references=references.joint_force_references( - model=model - ), + joint_torques=references.joint_force_references(model=model), ) ) @@ -325,8 +319,8 @@ def compute_contact_forces( δ=δ, δ_dot=δ_dot, n=n̂, - K=data.contacts_params.K, - D=data.contacts_params.D, + K=model.contacts_params.K, + D=model.contacts_params.D, ).flatten() # Compute the Delassus matrix. @@ -342,7 +336,7 @@ def compute_contact_forces( # Construct the inequality constraints. G = RigidContacts._compute_ineq_constraint_matrix( - inactive_collidable_points=(δ <= 0), mu=data.contacts_params.mu + inactive_collidable_points=(δ <= 0), mu=model.contacts_params.mu ) h_bounds = RigidContacts._compute_ineq_bounds( n_collidable_points=n_collidable_points diff --git a/src/jaxsim/rbda/contacts/soft.py b/src/jaxsim/rbda/contacts/soft.py index dde16cfb2..d7402d670 100644 --- a/src/jaxsim/rbda/contacts/soft.py +++ b/src/jaxsim/rbda/contacts/soft.py @@ -11,7 +11,7 @@ import jaxsim.math import jaxsim.typing as jtp from jaxsim import logging -from jaxsim.math import StandardGravity +from jaxsim.math import STANDARD_GRAVITY from jaxsim.terrain import Terrain from . import common @@ -108,7 +108,7 @@ def build_default_from_jaxsim_model( cls: type[Self], model: js.model.JaxSimModel, *, - standard_gravity: jtp.FloatLike = StandardGravity, + standard_gravity: jtp.FloatLike = STANDARD_GRAVITY, static_friction_coefficient: jtp.FloatLike = 0.5, max_penetration: jtp.FloatLike = 0.001, number_of_active_collidable_points_steady_state: jtp.IntLike = 1, @@ -456,7 +456,11 @@ def compute_contact_forces( W_p_C, W_ṗ_C = js.contact.collidable_point_kinematics(model=model, data=data) # Extract the material deformation corresponding to the collidable points. - m = data.state.extended["tangential_deformation"] + m = ( + data.contact_state["tangential_deformation"] + if "tangential_deformation" in data.contact_state + else jnp.zeros_like(W_p_C) + ) m_enabled = m[indices_of_enabled_collidable_points] @@ -470,7 +474,7 @@ def compute_contact_forces( position=p, velocity=v, tangential_deformation=m, - parameters=data.contacts_params, + parameters=model.contacts_params, terrain=model.terrain, ) )(W_p_C, W_ṗ_C, m_enabled) diff --git a/src/jaxsim/rbda/contacts/visco_elastic.py b/src/jaxsim/rbda/contacts/visco_elastic.py index 40ad4ab61..51e90e978 100644 --- a/src/jaxsim/rbda/contacts/visco_elastic.py +++ b/src/jaxsim/rbda/contacts/visco_elastic.py @@ -14,7 +14,7 @@ import jaxsim.typing as jtp from jaxsim import logging from jaxsim.api.common import ModelDataWithVelocityRepresentation -from jaxsim.math import StandardGravity +from jaxsim.math import STANDARD_GRAVITY from jaxsim.terrain import Terrain from . import common @@ -90,7 +90,7 @@ def build_default_from_jaxsim_model( cls: type[Self], model: js.model.JaxSimModel, *, - standard_gravity: jtp.FloatLike = StandardGravity, + standard_gravity: jtp.FloatLike = STANDARD_GRAVITY, static_friction_coefficient: jtp.FloatLike = 0.5, max_penetration: jtp.FloatLike = 0.001, number_of_active_collidable_points_steady_state: jtp.IntLike = 1, @@ -375,7 +375,7 @@ def _compute_contact_forces_with_exponential_integration( # ================================== p_t0, v_t0 = js.contact.collidable_point_kinematics(model, data) - m_t0 = data.state.extended["tangential_deformation"][indices, :] + m_t0 = data.contact_state["tangential_deformation"][indices, :] p_t0 = p_t0[indices, :] v_t0 = v_t0[indices, :] @@ -525,7 +525,7 @@ def _contact_points_dynamics( m_t0 = jnp.atleast_2d( m_t0 if m_t0 is not None - else data.state.extended["tangential_deformation"][ + else data.contact_state["tangential_deformation"][ indices_of_enabled_collidable_points, : ] ) @@ -551,7 +551,7 @@ def _contact_points_dynamics( position=p, velocity=v, tangential_deformation=m, - parameters=data.contacts_params, + parameters=model.contacts_params, terrain=model.terrain, ) )(p_t0, v_t0, m_t0) @@ -599,12 +599,11 @@ def _contact_points_dynamics( CW_Jl_WC = js.contact.jacobian( model=model, data=data, - output_vel_repr=jaxsim.VelRepr.Mixed, )[indices_of_enabled_collidable_points, 0:3, :] - CW_J̇l_WC = js.contact.jacobian_derivative( - model=model, data=data, output_vel_repr=jaxsim.VelRepr.Mixed - )[indices_of_enabled_collidable_points, 0:3, :] + CW_J̇l_WC = js.contact.jacobian_derivative(model=model, data=data)[ + indices_of_enabled_collidable_points, 0:3, : + ] # Compute the Delassus matrix. ψ = jnp.vstack(CW_Jl_WC) @ jnp.linalg.lstsq(M, jnp.vstack(CW_Jl_WC).T)[0] @@ -627,17 +626,12 @@ def _contact_points_dynamics( J̇ = jnp.vstack(CW_J̇l_WC) # Compute the free system acceleration components. - with ( - data.switch_velocity_representation(jaxsim.VelRepr.Mixed), - references.switch_velocity_representation(jaxsim.VelRepr.Mixed), - ): - - BW_v̇_free_WB, s̈_free = js.ode.system_acceleration( - model=model, - data=data, - link_forces=references.link_forces(model=model, data=data), - joint_force_references=references.joint_force_references(model=model), - ) + BW_v̇_free_WB, s̈_free = js.ode.system_acceleration( + model=model, + data=data, + link_forces=references.link_forces(model=model, data=data), + joint_torques=references.joint_force_references(model=model), + ) # Pack the free system acceleration in mixed representation. ν̇_free = jnp.hstack([BW_v̇_free_WB, s̈_free]) @@ -802,8 +796,8 @@ def integrate_data_with_average_contact_forces( dt: jtp.FloatLike, link_forces: jtp.MatrixLike | None = None, joint_force_references: jtp.VectorLike | None = None, - average_link_contact_forces_inertial: jtp.MatrixLike | None = None, - average_of_average_link_contact_forces_mixed: jtp.MatrixLike | None = None, + average_link_contact_forces: jtp.MatrixLike | None = None, + average_of_average_link_contact_forces: jtp.MatrixLike | None = None, ) -> js.data.JaxSimModelData: """ Advance the system state by integrating the dynamics. @@ -816,22 +810,21 @@ def integrate_data_with_average_contact_forces( The 6D forces to apply to the links expressed in the frame corresponding to the velocity representation of `data`. joint_force_references: The joint force references to apply. - average_link_contact_forces_inertial: - The average contact forces computed with the exponential integrator and - expressed in the inertial-fixed frame. - average_of_average_link_contact_forces_mixed: + average_link_contact_forces: + The average contact forces computed with the exponential integrator. + average_of_average_link_contact_forces: The average of the average contact forces computed with the exponential - integrator and expressed in the mixed frame. + integrator. Returns: The data object storing the system state at the final time. """ - s_t0 = data.joint_positions() - W_p_B_t0 = data.base_position() - W_Q_B_t0 = data.base_orientation(dcm=False) + s_t0 = data.joint_positions + W_p_B_t0 = data.base_position + W_Q_B_t0 = data.base_quaternion - ṡ_t0 = data.joint_velocities() + ṡ_t0 = data.joint_velocities with data.switch_velocity_representation(jaxsim.VelRepr.Mixed): W_ṗ_B_t0 = data.base_velocity()[0:3] W_ω_WB_t0 = data.base_velocity()[3:6] @@ -850,53 +843,39 @@ def integrate_data_with_average_contact_forces( ) W_f̅_L = ( - jnp.array(average_link_contact_forces_inertial) - if average_link_contact_forces_inertial is not None + jnp.array(average_link_contact_forces) + if average_link_contact_forces is not None else jnp.zeros_like(references._link_forces) ).astype(float) LW_f̿_L = ( - jnp.array(average_of_average_link_contact_forces_mixed) - if average_of_average_link_contact_forces_mixed is not None + jnp.array(average_of_average_link_contact_forces) + if average_of_average_link_contact_forces is not None else W_f̅_L ).astype(float) # Compute the system inertial acceleration, used to integrate the system velocity. # It considers the average contact forces computed with the exponential integrator. - with ( - data.switch_velocity_representation(jaxsim.VelRepr.Inertial), - references.switch_velocity_representation(jaxsim.VelRepr.Inertial), - ): - - W_ν̇_pr = jnp.hstack( - js.ode.system_acceleration( - model=model, - data=data, - joint_force_references=references.joint_force_references( - model=model - ), - link_forces=W_f̅_L + references.link_forces(model=model, data=data), - ) + W_ν̇_pr = jnp.hstack( + js.ode.system_acceleration( + model=model, + data=data, + joint_torques=references.joint_force_references(model=model), + link_forces=W_f̅_L + references.link_forces(model=model, data=data), ) + ) # Compute the system mixed acceleration, used to integrate the system position. # It considers the average of the average contact forces computed with the # exponential integrator. - with ( - data.switch_velocity_representation(jaxsim.VelRepr.Mixed), - references.switch_velocity_representation(jaxsim.VelRepr.Mixed), - ): - - BW_ν̇_pr2 = jnp.hstack( - js.ode.system_acceleration( - model=model, - data=data, - joint_force_references=references.joint_force_references( - model=model - ), - link_forces=LW_f̿_L + references.link_forces(model=model, data=data), - ) + BW_ν̇_pr2 = jnp.hstack( + js.ode.system_acceleration( + model=model, + data=data, + joint_torques=references.joint_force_references(model=model), + link_forces=LW_f̿_L + references.link_forces(model=model, data=data), ) + ) # Integrate the system velocity using the inertial-fixed acceleration. W_ν_plus = W_ν_t0 + dt * W_ν̇_pr @@ -917,8 +896,7 @@ def integrate_data_with_average_contact_forces( ) # Create the data at the final time. - data_tf = data.copy() - data_tf = data_tf.reset_joint_positions(q_plus[7:]) + data_tf = data.reset_joint_positions(q_plus[7:]) data_tf = data_tf.reset_base_position(q_plus[0:3]) data_tf = data_tf.reset_base_quaternion(q_plus[3:7]) data_tf = data_tf.reset_joint_velocities(W_ν_plus[6:]) @@ -926,6 +904,8 @@ def integrate_data_with_average_contact_forces( W_ν_plus[0:6], velocity_representation=jaxsim.VelRepr.Inertial ) + data_tf = data_tf.update_cached(model=model) + return data_tf.replace( velocity_representation=data.velocity_representation, validate=False ) @@ -958,13 +938,7 @@ def step( """ assert isinstance(model.contact_model, ViscoElasticContacts) - assert isinstance(data.contacts_params, ViscoElasticContactsParams) - - # Compute the contact forces in inertial-fixed representation. - # TODO: understand what's wrong in other representations. - data_inertial_fixed = data.replace( - velocity_representation=jaxsim.VelRepr.Inertial, validate=False - ) + assert isinstance(model.contacts_params, ViscoElasticContactsParams) # Create the references object. references = js.references.JaxSimModelReferences.build( @@ -981,7 +955,7 @@ def step( # Compute the contact forces with the exponential integrator. W_f̅_C, aux_data = model.contact_model.compute_contact_forces( model=model, - data=data_inertial_fixed, + data=data, dt=jnp.array(dt).astype(float), link_forces=references.link_forces(model=model, data=data), joint_force_references=references.joint_force_references(model=model), @@ -999,17 +973,13 @@ def step( # Get the link contact forces by summing the forces of contact points belonging # to the same link. W_f̅_L, W_f̿_L = jax.vmap( - lambda W_f_C: model.contact_model.link_forces_from_contact_forces( - model=model, data=data_inertial_fixed, contact_forces=W_f_C + lambda W_f_C: js.contact_model.link_forces_from_contact_forces( + model=model, data=data, contact_forces=W_f_C ) )(jnp.stack([W_f̅_C, W_f̿_C])) # Compute the link transforms. - W_H_L = ( - js.model.forward_kinematics(model=model, data=data) - if data.velocity_representation is not jaxsim.VelRepr.Inertial - else jnp.zeros(shape=(model.number_of_links(), 4, 4)) - ) + W_H_L = data.link_transforms # For integration purpose, we need the average of average forces expressed in # mixed representation. @@ -1032,12 +1002,12 @@ def step( data_tf: js.data.JaxSimModelData = ( model.contact_model.integrate_data_with_average_contact_forces( model=model, - data=data_inertial_fixed, + data=data, dt=dt, link_forces=references.link_forces(model=model, data=data), joint_force_references=references.joint_force_references(model=model), - average_link_contact_forces_inertial=W_f̅_L, - average_of_average_link_contact_forces_mixed=LW_f̿_L, + average_link_contact_forces=W_f̅_L, + average_of_average_link_contact_forces=LW_f̿_L, ) ) @@ -1052,15 +1022,10 @@ def step( model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points ) - data_tf.state.extended |= { - "tangential_deformation": data_tf.state.extended["tangential_deformation"] + data_tf.contact_state |= { + "tangential_deformation": data_tf.contact_state["tangential_deformation"] .at[indices_of_enabled_collidable_points] .set(m_tf) } - # Restore the original velocity representation. - data_tf = data_tf.replace( - velocity_representation=data.velocity_representation, validate=False - ) - - return data_tf, {} + return data_tf diff --git a/src/jaxsim/rbda/utils.py b/src/jaxsim/rbda/utils.py index 9b2614a52..b2435dbf4 100644 --- a/src/jaxsim/rbda/utils.py +++ b/src/jaxsim/rbda/utils.py @@ -132,6 +132,12 @@ def process_inputs( if W_Q_B.shape != (4,): raise ValueError(W_Q_B.shape, (4,)) + # Check that the quaternion does not contain NaN values. + exceptions.raise_value_error_if( + condition=jnp.isnan(W_Q_B).any(), + msg="A RBDA received a quaternion that contains NaN values.", + ) + # Check that the quaternion is unary since our RBDAs make this assumption in order # to prevent introducing additional normalizations that would affect AD. exceptions.raise_value_error_if( diff --git a/tests/test_automatic_differentiation.py b/tests/test_automatic_differentiation.py index 83418e20a..aea446a67 100644 --- a/tests/test_automatic_differentiation.py +++ b/tests/test_automatic_differentiation.py @@ -296,6 +296,9 @@ def test_ad_soft_contacts( model = jaxsim_models_types + with model.editable(validate=False) as model: + model.contact_model = jaxsim.rbda.contacts.SoftContacts.build(model=model) + _, subkey1, subkey2, subkey3 = jax.random.split(prng_key, num=4) p = jax.random.uniform(subkey1, shape=(3,), minval=-1) v = jax.random.uniform(subkey2, shape=(3,), minval=-1) @@ -355,8 +358,7 @@ def test_ad_integration( W_Q_B = data.base_orientation(dcm=False) s = data.joint_positions W_v_WB = data.base_velocity() - ṡ = data.joint_velocities(model=model) - m = data.extended_state["tangential_deformation"] + ṡ = data.joint_velocities # Inputs. W_f_L = references.link_forces(model=model) @@ -373,10 +375,9 @@ def step( s: jtp.Vector, W_v_WB: jtp.Vector, ṡ: jtp.Vector, - m: jtp.Vector, τ: jtp.Vector, W_f_L: jtp.Matrix, - ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: + ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: # When JAX tests against finite differences, the injected ε will make the # quaternion non-unitary, which will cause the AD check to fail. @@ -389,7 +390,6 @@ def step( base_linear_velocity=W_v_WB[0:3], base_angular_velocity=W_v_WB[3:6], joint_velocities=ṡ, - extended_state={"tangential_deformation": m}, ) data.update_cached(model) @@ -406,16 +406,15 @@ def step( xf_s = data_xf.joint_positions xf_W_v_WB = data_xf.base_velocity() xf_ṡ = data_xf.joint_velocities - xf_m = data_xf.extended_state["tangential_deformation"] - return xf_W_p_B, xf_W_Q_B, xf_s, xf_W_v_WB, xf_ṡ, xf_m + return xf_W_p_B, xf_W_Q_B, xf_s, xf_W_v_WB, xf_ṡ # Check derivatives against finite differences. # We set forward mode only because the backward mode is not supported by the # current implementation of `optax` optimizers in the relaxed rigid contact model. check_grads( f=step, - args=(W_p_B, W_Q_B, s, W_v_WB, ṡ, m, τ, W_f_L), + args=(W_p_B, W_Q_B, s, W_v_WB, ṡ, τ, W_f_L), order=AD_ORDER, modes=["fwd"], eps=ε, diff --git a/tests/test_simulations.py b/tests/test_simulations.py index e56c2c1cf..0b79a3500 100644 --- a/tests/test_simulations.py +++ b/tests/test_simulations.py @@ -61,7 +61,7 @@ def test_box_with_external_forces( additive=False, ) - # Initialize the integrator. + # Initialize the simulation horizon. tf = 0.5 T_ns = jnp.arange(start=0, stop=tf * 1e9, step=model.time_step * 1e9, dtype=int) @@ -175,12 +175,172 @@ def run_simulation( for _ in T_ns: - data = js.model.step( + match model.contact_model: + + case jaxsim.rbda.contacts.ViscoElasticContacts(): + + data = jaxsim.rbda.contacts.visco_elastic.step( + model=model, + data=data, + ) + + case _: + + data = js.model.step( + model=model, + data=data, + ) + return data + + +def test_simulation_with_soft_contacts( + jaxsim_model_box: js.model.JaxSimModel, +): + + model = jaxsim_model_box + + # Define the maximum penetration of each collidable point at steady state. + max_penetration = 0.001 + + with model.editable(validate=False) as model: + + model.contact_model = jaxsim.rbda.contacts.SoftContacts.build() + model.contacts_params = js.contact.estimate_good_contact_parameters( model=model, - data=data, + number_of_active_collidable_points_steady_state=4, + static_friction_coefficient=1.0, + damping_ratio=1.0, + max_penetration=max_penetration, ) - return data + # Enable a subset of the collidable points. + enabled_collidable_points_mask = np.zeros( + len(model.kin_dyn_parameters.contact_parameters.body), dtype=bool + ) + enabled_collidable_points_mask[[0, 1, 2, 3]] = True + model.kin_dyn_parameters.contact_parameters.enabled = tuple( + enabled_collidable_points_mask.tolist() + ) + + assert np.sum(model.kin_dyn_parameters.contact_parameters.enabled) == 4 + + # Check jaxsim_model_box@conftest.py. + box_height = 0.1 + + # Build the data of the model. + data_t0 = js.data.JaxSimModelData.build( + model=model, + base_position=jnp.array([0.0, 0.0, box_height * 2]), + velocity_representation=VelRepr.Inertial, + ) + + # =========================================== + # Run the simulation and test the final state + # =========================================== + + data_tf = run_simulation(model=model, data_t0=data_t0, tf=1.0) + + assert data_tf.base_position[0:2] == pytest.approx(data_t0.base_position[0:2]) + assert data_tf.base_position[2] + max_penetration == pytest.approx(box_height / 2) + + +def test_simulation_with_visco_elastic_contacts( + jaxsim_model_box: js.model.JaxSimModel, +): + + model = jaxsim_model_box + + # Define the maximum penetration of each collidable point at steady state. + max_penetration = 0.001 + + with model.editable(validate=False) as model: + + model.contact_model = jaxsim.rbda.contacts.ViscoElasticContacts.build() + model.contacts_params = js.contact.estimate_good_contact_parameters( + model=model, + number_of_active_collidable_points_steady_state=4, + static_friction_coefficient=1.0, + damping_ratio=1.0, + max_penetration=max_penetration, + ) + + # Enable a subset of the collidable points. + enabled_collidable_points_mask = np.zeros( + len(model.kin_dyn_parameters.contact_parameters.body), dtype=bool + ) + enabled_collidable_points_mask[[0, 1, 2, 3]] = True + model.kin_dyn_parameters.contact_parameters.enabled = tuple( + enabled_collidable_points_mask.tolist() + ) + + # Check jaxsim_model_box@conftest.py. + box_height = 0.1 + + # Build the data of the model. + data_t0 = js.data.JaxSimModelData.build( + model=model, + base_position=jnp.array([0.0, 0.0, box_height * 2]), + velocity_representation=VelRepr.Inertial, + ) + + # =========================================== + # Run the simulation and test the final state + # =========================================== + + data_tf = run_simulation(model=model, data_t0=data_t0, tf=1.0) + + assert data_tf.base_position[0:2] == pytest.approx(data_t0.base_position[0:2]) + assert data_tf.base_position[2] + max_penetration == pytest.approx(box_height / 2) + + +def test_simulation_with_rigid_contacts( + jaxsim_model_box: js.model.JaxSimModel, +): + + model = jaxsim_model_box + + with model.editable(validate=False) as model: + + # In order to achieve almost no penetration, we need to use a fairly large + # Baumgarte stabilization term. + model.contact_model = jaxsim.rbda.contacts.RigidContacts.build( + solver_options={"solver_tol": 1e-3} + ) + model.contacts_params = model.contact_model._parameters_class(K=1e5) + + # Enable a subset of the collidable points. + enabled_collidable_points_mask = np.zeros( + len(model.kin_dyn_parameters.contact_parameters.body), dtype=bool + ) + enabled_collidable_points_mask[[0, 1, 2, 3]] = True + model.kin_dyn_parameters.contact_parameters.enabled = tuple( + enabled_collidable_points_mask.tolist() + ) + + assert np.sum(model.kin_dyn_parameters.contact_parameters.enabled) == 4 + + # Initialize the maximum penetration of each collidable point at steady state. + # This model is rigid, so we expect (almost) no penetration. + max_penetration = 0.000 + + # Check jaxsim_model_box@conftest.py. + box_height = 0.1 + + # Build the data of the model. + data_t0 = js.data.JaxSimModelData.build( + model=model, + base_position=jnp.array([0.0, 0.0, box_height * 2]), + velocity_representation=VelRepr.Inertial, + ) + + # =========================================== + # Run the simulation and test the final state + # =========================================== + + data_tf = run_simulation(model=model, data_t0=data_t0, tf=1.0) + + assert data_tf.base_position[0:2] == pytest.approx(data_t0.base_position[0:2]) + assert data_tf.base_position[2] + max_penetration == pytest.approx(box_height / 2) def test_simulation_with_relaxed_rigid_contacts( @@ -194,6 +354,8 @@ def test_simulation_with_relaxed_rigid_contacts( model.contact_model = jaxsim.rbda.contacts.RelaxedRigidContacts.build( solver_options={"tol": 1e-3}, ) + model.contacts_params = model.contact_model._parameters_class() + # Enable a subset of the collidable points. enabled_collidable_points_mask = np.zeros( len(model.kin_dyn_parameters.contact_parameters.body), dtype=bool