From a7db0e984dd9ce32e46498d861be6a1baa4107bd Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Thu, 27 Feb 2025 15:32:14 +0100 Subject: [PATCH] Fix link forces representation in rigid contacts --- src/jaxsim/rbda/contacts/rigid.py | 43 ++++++++++++------------------- 1 file changed, 17 insertions(+), 26 deletions(-) diff --git a/src/jaxsim/rbda/contacts/rigid.py b/src/jaxsim/rbda/contacts/rigid.py index b6e4bc7df..540017277 100644 --- a/src/jaxsim/rbda/contacts/rigid.py +++ b/src/jaxsim/rbda/contacts/rigid.py @@ -212,6 +212,7 @@ def compute_impact_velocity( return BW_ν_post_impact[0 : M.shape[0]] @jax.jit + @js.common.named_scope def compute_contact_forces( self, model: js.model.JaxSimModel, @@ -264,12 +265,27 @@ def compute_contact_forces( joint_force_references=joint_force_references, ) + # Compute the position and linear velocities (mixed representation) of + # all enabled collidable points belonging to the robot. + position, velocity = js.contact.collidable_point_kinematics( + model=model, data=data + ) + + # Compute the penetration depth and velocity of the collidable points. + # Note that this function considers the penetration in the normal direction. + δ, δ_dot, n̂ = jax.vmap(common.compute_penetration_data, in_axes=(0, 0, None))( + position, velocity, model.terrain + ) + # Compute the transforms of the implicit frames corresponding to the # collidable points. W_H_C = js.contact.transforms(model=model, data=data) # Compute kin-dyn quantities used in the contact model. - with data.switch_velocity_representation(VelRepr.Mixed): + with ( + data.switch_velocity_representation(VelRepr.Mixed), + references.switch_velocity_representation(VelRepr.Mixed), + ): BW_ν = data.generalized_velocity M = js.model.free_floating_mass_matrix(model=model, data=data) @@ -289,18 +305,6 @@ def compute_contact_forces( ) ) - # Compute the position and linear velocities (mixed representation) of - # all enabled collidable points belonging to the robot. - position, velocity = js.contact.collidable_point_kinematics( - model=model, data=data - ) - - # Compute the penetration depth and velocity of the collidable points. - # Note that this function considers the penetration in the normal direction. - δ, δ_dot, n̂ = jax.vmap(common.compute_penetration_data, in_axes=(0, 0, None))( - position, velocity, model.terrain - ) - # Compute the free linear acceleration of the collidable points. # Since we use doubly-mixed jacobian, this corresponds to W_p̈_C. CW_al_free_WC = Jl_WC @ BW_ν̇_free + J̇l_WC @ BW_ν @@ -406,19 +410,6 @@ def _compute_ineq_constraint_matrix( return G -def _compute_ineq_bounds(n_collidable_points: int) -> jtp.Vector: - """ - Compute the inequality bounds for the contact forces. - - Note: - Do not JIT this function as the output shape depends on the number of collidable points. - """ - - n_constraints = 6 * n_collidable_points - - return jnp.zeros(shape=(n_constraints,)) - - @jax.jit @js.common.named_scope def _compute_baumgarte_stabilization_term(