Skip to content

Commit

Permalink
Make static methods functions in RigidContacts
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Feb 26, 2025
1 parent 9df22e2 commit 6e04314
Showing 1 changed file with 83 additions and 77 deletions.
160 changes: 83 additions & 77 deletions src/jaxsim/rbda/contacts/rigid.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import jax
import jax.numpy as jnp
import jax_dataclasses
import qpax

import jaxsim.api as js
import jaxsim.typing as jtp
Expand Down Expand Up @@ -239,9 +240,6 @@ def compute_contact_forces(
A tuple containing as first element the computed contact forces.
"""

# Import qpax privately just in this method.
import qpax

# Get the indices of the enabled collidable points.
indices_of_enabled_collidable_points = (
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
Expand Down Expand Up @@ -306,15 +304,15 @@ def compute_contact_forces(

# Compute the free linear acceleration of the collidable points.
# Since we use doubly-mixed jacobian, this corresponds to W_p̈_C.
free_contact_acc = RigidContacts._linear_acceleration_of_collidable_points(
free_contact_acc = _linear_acceleration_of_collidable_points(
BW_nu=BW_ν,
BW_nu_dot=BW_ν̇_free,
CW_J_WC_BW=J_WC,
CW_J_dot_WC_BW=J̇_WC,
).flatten()

# Compute stabilization term.
baumgarte_term = RigidContacts._compute_baumgarte_stabilization_term(
baumgarte_term = _compute_baumgarte_stabilization_term(
inactive_collidable_points=(δ <= 0),
δ=δ,
δ_dot=δ_dot,
Expand All @@ -324,7 +322,7 @@ def compute_contact_forces(
).flatten()

# Compute the Delassus matrix.
delassus_matrix = RigidContacts._delassus_matrix(M=M, J_WC=J_WC)
delassus_matrix = _compute_delassus_matrix(M=M, J_WC=J_WC)

# Initialize regularization term of the Delassus matrix for
# better numerical conditioning.
Expand All @@ -335,12 +333,10 @@ def compute_contact_forces(
q = free_contact_acc - baumgarte_term

# Construct the inequality constraints.
G = RigidContacts._compute_ineq_constraint_matrix(
G = _compute_ineq_constraint_matrix(
inactive_collidable_points=(δ <= 0), mu=model.contact_params.mu
)
h_bounds = RigidContacts._compute_ineq_bounds(
n_collidable_points=n_collidable_points
)
h_bounds = _compute_ineq_bounds(n_collidable_points=n_collidable_points)

# Construct the equality constraints.
A = jnp.zeros((0, 3 * n_collidable_points))
Expand Down Expand Up @@ -375,82 +371,92 @@ def compute_contact_forces(

return W_f_C, {}

@staticmethod
def _delassus_matrix(
M: jtp.MatrixLike,
J_WC: jtp.MatrixLike,
) -> jtp.Matrix:

sl = jnp.s_[:, 0:3, :]
J_WC_lin = jnp.vstack(J_WC[sl])
@jax.jit
@js.common.named_scope
def _compute_delassus_matrix(
M: jtp.MatrixLike,
J_WC: jtp.MatrixLike,
) -> jtp.Matrix:

sl = jnp.s_[:, 0:3, :]
J_WC_lin = jnp.vstack(J_WC[sl])

delassus_matrix = J_WC_lin @ jnp.linalg.pinv(M) @ J_WC_lin.T
return delassus_matrix


@jax.jit
@js.common.named_scope
def _compute_ineq_constraint_matrix(
inactive_collidable_points: jtp.Vector, mu: jtp.FloatLike
) -> jtp.Matrix:
"""
Compute the inequality constraint matrix for a single collidable point.
Rows 0-3: enforce the friction pyramid constraint,
Row 4: last one is for the non negativity of the vertical force
Row 5: contact complementarity condition
"""
G_single_point = jnp.array(
[
[1, 0, -mu],
[0, 1, -mu],
[-1, 0, -mu],
[0, -1, -mu],
[0, 0, -1],
[0, 0, 0],
]
)
G = jnp.tile(G_single_point, (len(inactive_collidable_points), 1, 1))
G = G.at[:, 5, 2].set(inactive_collidable_points)

delassus_matrix = J_WC_lin @ jnp.linalg.pinv(M) @ J_WC_lin.T
return delassus_matrix
G = jax.scipy.linalg.block_diag(*G)
return G

@staticmethod
def _compute_ineq_constraint_matrix(
inactive_collidable_points: jtp.Vector, mu: jtp.FloatLike
) -> jtp.Matrix:
"""
Compute the inequality constraint matrix for a single collidable point.

Rows 0-3: enforce the friction pyramid constraint,
Row 4: last one is for the non negativity of the vertical force
Row 5: contact complementarity condition
"""
G_single_point = jnp.array(
[
[1, 0, -mu],
[0, 1, -mu],
[-1, 0, -mu],
[0, -1, -mu],
[0, 0, -1],
[0, 0, 0],
]
)
G = jnp.tile(G_single_point, (len(inactive_collidable_points), 1, 1))
G = G.at[:, 5, 2].set(inactive_collidable_points)
@jax.jit
@js.common.named_scope
def _compute_ineq_bounds(n_collidable_points: int) -> jtp.Vector:

G = jax.scipy.linalg.block_diag(*G)
return G
n_constraints = 6 * n_collidable_points
return jnp.zeros(shape=(n_constraints,))

@staticmethod
def _compute_ineq_bounds(n_collidable_points: int) -> jtp.Vector:

n_constraints = 6 * n_collidable_points
return jnp.zeros(shape=(n_constraints,))
@jax.jit
@js.common.named_scope
def _linear_acceleration_of_collidable_points(
BW_nu: jtp.ArrayLike,
BW_nu_dot: jtp.ArrayLike,
CW_J_WC_BW: jtp.MatrixLike,
CW_J_dot_WC_BW: jtp.MatrixLike,
) -> jtp.Matrix:

@staticmethod
def _linear_acceleration_of_collidable_points(
BW_nu: jtp.ArrayLike,
BW_nu_dot: jtp.ArrayLike,
CW_J_WC_BW: jtp.MatrixLike,
CW_J_dot_WC_BW: jtp.MatrixLike,
) -> jtp.Matrix:
BW_ν = BW_nu
BW_ν̇ = BW_nu_dot
CW_J̇_WC_BW = CW_J_dot_WC_BW

BW_ν = BW_nu
BW_ν̇ = BW_nu_dot
CW_J̇_WC_BW = CW_J_dot_WC_BW
# Compute the linear acceleration of the collidable points.
# Since we use doubly-mixed jacobians, this corresponds to W_p̈_C.
CW_a_WC = jnp.vstack(CW_J̇_WC_BW) @ BW_ν + jnp.vstack(CW_J_WC_BW) @ BW_ν̇

# Compute the linear acceleration of the collidable points.
# Since we use doubly-mixed jacobians, this corresponds to W_p̈_C.
CW_a_WC = jnp.vstack(CW_J̇_WC_BW) @ BW_ν + jnp.vstack(CW_J_WC_BW) @ BW_ν̇
CW_a_WC = CW_a_WC.reshape(-1, 6)
return CW_a_WC[:, 0:3].squeeze()

CW_a_WC = CW_a_WC.reshape(-1, 6)
return CW_a_WC[:, 0:3].squeeze()

@staticmethod
def _compute_baumgarte_stabilization_term(
inactive_collidable_points: jtp.ArrayLike,
δ: jtp.ArrayLike,
δ_dot: jtp.ArrayLike,
n: jtp.ArrayLike,
K: jtp.FloatLike,
D: jtp.FloatLike,
) -> jtp.Array:

return jnp.where(
inactive_collidable_points[:, jnp.newaxis],
jnp.zeros_like(n),
(K * δ + D * δ_dot)[:, jnp.newaxis] * n,
)
@jax.jit
@js.common.named_scope
def _compute_baumgarte_stabilization_term(
inactive_collidable_points: jtp.ArrayLike,
δ: jtp.ArrayLike,
δ_dot: jtp.ArrayLike,
n: jtp.ArrayLike,
K: jtp.FloatLike,
D: jtp.FloatLike,
) -> jtp.Array:

return jnp.where(
inactive_collidable_points[:, jnp.newaxis],
jnp.zeros_like(n),
(K * δ + D * δ_dot)[:, jnp.newaxis] * n,
)

0 comments on commit 6e04314

Please sign in to comment.