From 6ea7c18ae5dab3213a2ee1c6d0f5018617f82c3e Mon Sep 17 00:00:00 2001 From: Alessandro Croci Date: Wed, 5 Mar 2025 12:13:09 +0100 Subject: [PATCH] Refactor RelaxedRigidContacts to improve kinematic constraint handling and adding Baumgarte terms --- src/jaxsim/rbda/contacts/relaxed_rigid.py | 55 +++++++++++------------ 1 file changed, 25 insertions(+), 30 deletions(-) diff --git a/src/jaxsim/rbda/contacts/relaxed_rigid.py b/src/jaxsim/rbda/contacts/relaxed_rigid.py index cddf6fff6..b9fa96631 100644 --- a/src/jaxsim/rbda/contacts/relaxed_rigid.py +++ b/src/jaxsim/rbda/contacts/relaxed_rigid.py @@ -349,14 +349,14 @@ def compute_contact_forces( W_R_F1 = W_H_F1[:3, :3] W_R_F2 = W_H_F2[:3, :3] - J_WF1 = js.frame.jacobian(model=model, data=data, frame_index=F1_idx)[:3, :] - J_WF2 = js.frame.jacobian(model=model, data=data, frame_index=F2_idx)[:3, :] + J_WF1 = js.frame.jacobian(model=model, data=data, frame_index=F1_idx) + J_WF2 = js.frame.jacobian(model=model, data=data, frame_index=F2_idx) J̇_WF1 = js.frame.jacobian_derivative( model=model, data=data, frame_index=F1_idx - )[:3, :] + ) J̇_WF2 = js.frame.jacobian_derivative( model=model, data=data, frame_index=F2_idx - )[:3, :] + ) # Add the kinematic constraint terms to the Jacobians Jl_WC = jnp.vstack([Jl_WC, J_WF1 - J_WF2]) @@ -373,26 +373,26 @@ def compute_contact_forces( # jax.debug.print("a_ref.shape: \n{}", a_ref.shape) # jax.debug.print("R.shape: \n{}", R.shape) - num_zeros_kin_constr = J_WF1.shape[0] + dim_kin_constr = J_WF1.shape[0] # Compute Baumgarte stabilization terms for the kinematic constraints - # K_P = 0 # 1e1 - # K_D = 0 # 2 * jnp.sqrt(K_P) - # vel_error = (J_WF1 - J_WF2) @ BW_ν - # position_error = 0 * (W_p_F1 - W_p_F2) - # # R_error = W_R_F1.T @ W_R_F2 - # R_error = W_R_F2.T @ W_R_F1 - # orientation_error = jaxsim.math.rotation.Rotation.log_SO3(R_error) - # baumgarte_term = ( - # K_P * jnp.hstack([position_error, orientation_error]) + K_D * vel_error - # ) - - R_ext = jnp.pad( - R, ((0, num_zeros_kin_constr), (0, num_zeros_kin_constr)), mode="constant" + K_P = 100 # 1e1 + K_D = 2 * jnp.sqrt(K_P) + K_D = jnp.diag(jnp.array([K_D, K_D, K_D, K_D, K_D, K_D])) + vel_error = (J_WF1 - J_WF2) @ BW_ν + position_error = W_p_F1 - W_p_F2 + R_error = W_R_F2.T @ W_R_F1 + orientation_error = jaxsim.math.rotation.Rotation.log_vee(R_error) + baumgarte_term = ( + K_P * jnp.hstack([position_error, orientation_error]) + K_D @ vel_error ) - # a_ref_ext = jnp.hstack([a_ref, -baumgarte_term]) - a_ref_ext = jnp.hstack([a_ref, jnp.zeros((3,))]) + + # baumgarte_term = K_P * position_error + K_D * vel_error + + R_ext = jnp.pad(R, ((0, dim_kin_constr), (0, dim_kin_constr)), mode="constant") + a_ref_ext = jnp.hstack([a_ref, -baumgarte_term]) + # a_ref_ext = jnp.hstack([a_ref, jnp.zeros((3,))]) # Compute the Delassus matrix and the free mixed linear acceleration of # the collidable points. @@ -492,7 +492,7 @@ def continuing_criterion(carry: OptimizationCarry) -> jtp.Bool: init_params = jnp.hstack([ init_params, jnp.zeros( - 3, + dim_kin_constr, ), ]) @@ -514,15 +514,10 @@ def continuing_criterion(carry: OptimizationCarry) -> jtp.Bool: ) # Extract the last 6 values from the solution - kin_constr_force_mixed_linear = solution[-3:] - jax.debug.print("f_loop_mixed: \n{}", kin_constr_force_mixed_linear) + kin_constr_force_mixed_F1 = solution[-dim_kin_constr:] + # jax.debug.print("f_loop_mixed: \n{}", kin_constr_force_mixed_linear) - kin_constr_force_mixed_F1 = ( - jnp.zeros(6).at[0:3].set(kin_constr_force_mixed_linear) - ) - kin_constr_force_mixed_F2 = ( - jnp.zeros(6).at[0:3].set(-kin_constr_force_mixed_linear) - ) + kin_constr_force_mixed_F2 = -kin_constr_force_mixed_F1 # Transform the wrench in inertial representation kin_constr_force_inertial_F1 = ( @@ -546,7 +541,7 @@ def continuing_criterion(carry: OptimizationCarry) -> jtp.Bool: jax.debug.print("f_loop_inertial_F2: \n{}", kin_constr_force_inertial_F2) # Reshape the optimized solution to be a matrix of 3D contact forces. - CW_fl_C = solution[0:-3].reshape(-1, 3) + CW_fl_C = solution[0:-dim_kin_constr].reshape(-1, 3) # Convert the contact forces from mixed to inertial-fixed representation. W_f_C = jax.vmap(