Skip to content

Commit

Permalink
Implement Baumgarte stabilization for kinematic constraints in Relaxe…
Browse files Browse the repository at this point in the history
…dRigidContacts
  • Loading branch information
xela-95 committed Feb 21, 2025
1 parent 1958676 commit 5342ff2
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions src/jaxsim/rbda/contacts/relaxed_rigid.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,10 @@ def compute_contact_forces(

W_H_F1 = js.frame.transform(model=model, data=data, frame_index=F1_idx)
W_H_F2 = js.frame.transform(model=model, data=data, frame_index=F2_idx)
W_p_F1 = W_H_F1[:3, 3]
W_p_F2 = W_H_F2[:3, 3]
W_R_F1 = W_H_F1[:3, :3]
W_R_F2 = W_H_F2[:3, :3]

J_WF1 = js.frame.jacobian(model=model, data=data, frame_index=F1_idx)
J_WF2 = js.frame.jacobian(model=model, data=data, frame_index=F2_idx)
Expand All @@ -370,12 +374,21 @@ def compute_contact_forces(
# jax.debug.print("R.shape: \n{}", R.shape)

num_zeros_kin_constr = J_WF1.shape[0]
zeros_array_a_ref = jnp.zeros((num_zeros_kin_constr,))

# Compute Baumgarte stabilization terms for the kinematic constraints

K_P = 1.0e2
K_D = 2 * jnp.sqrt(K_P)
vel_error = (J_WF1 - J_WF2) @ BW_ν
pos_error = W_p_F1 - W_p_F2
R_error = W_R_F1.T @ W_R_F2
omega_error = jaxsim.math.rotation.Rotation.log_SO3(R_error)
baumgarte_term = K_P * jnp.hstack([pos_error, omega_error]) + K_D * vel_error

R_ext = jnp.pad(
R, ((0, num_zeros_kin_constr), (0, num_zeros_kin_constr)), mode="constant"
)
a_ref_ext = jnp.hstack([a_ref, zeros_array_a_ref])
a_ref_ext = jnp.hstack([a_ref, -baumgarte_term])

# Compute the Delassus matrix and the free mixed linear acceleration of
# the collidable points.
Expand Down

0 comments on commit 5342ff2

Please sign in to comment.