Skip to content

Commit

Permalink
Speed up rigid contact model
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Feb 26, 2025
1 parent 645c4e8 commit 2ea67ea
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 71 deletions.
17 changes: 7 additions & 10 deletions src/jaxsim/rbda/contacts/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,6 @@ def update_velocity_after_impact(
in_axes=(0, 0, None),
)(W_p_C, jnp.zeros_like(W_p_C), model.terrain)

original_representation = data.velocity_representation

with data.switch_velocity_representation(VelRepr.Mixed):
J_WC = js.contact.jacobian(model, data)[
indices_of_enabled_collidable_points
Expand All @@ -344,13 +342,12 @@ def update_velocity_after_impact(
is_force=False,
)

# Reset the generalized velocity.
data = dataclasses.replace(
data,
velocity_representation=original_representation,
_base_linear_velocity=BW_ν_post_impact_inertial[0:3],
_base_angular_velocity=BW_ν_post_impact_inertial[3:6],
_joint_velocities=BW_ν_post_impact[6:],
)
# Reset the generalized velocity.
data = dataclasses.replace(
data,
_base_linear_velocity=BW_ν_post_impact_inertial[0:3],
_base_angular_velocity=BW_ν_post_impact_inertial[3:6],
_joint_velocities=BW_ν_post_impact[6:],
)

return data
92 changes: 31 additions & 61 deletions src/jaxsim/rbda/contacts/rigid.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,11 +196,7 @@ def compute_impact_velocity(

# Zero out the jacobian rows of inactive points.
Jl_WC = jnp.vstack(
jnp.where(
inactive_collidable_points[:, jnp.newaxis, jnp.newaxis],
jnp.zeros_like(Jl_WC),
Jl_WC,
)
jax.vmap(lambda J, δ: J * δ)(Jl_WC, inactive_collidable_points)
)

A = jnp.vstack(
Expand Down Expand Up @@ -259,16 +255,39 @@ def compute_contact_forces(
else jnp.zeros((model.number_of_joints(),))
)

# Build a references object to simplify converting link forces.
references = js.references.JaxSimModelReferences.build(
model=model,
data=data,
velocity_representation=data.velocity_representation,
link_forces=link_forces,
joint_force_references=joint_force_references,
)

# Compute the transforms of the implicit frames corresponding to the
# collidable points.
W_H_C = js.contact.transforms(model=model, data=data)

# Compute kin-dyn quantities used in the contact model.
with data.switch_velocity_representation(VelRepr.Mixed):
BW_ν = data.generalized_velocity

M = js.model.free_floating_mass_matrix(model=model, data=data)

J_WC = js.contact.jacobian(model=model, data=data)
J̇_WC = js.contact.jacobian_derivative(model=model, data=data)
Jl_WC = jnp.vstack(js.contact.jacobian(model=model, data=data)[:, :3, :])
J̇l_WC = jnp.vstack(
js.contact.jacobian_derivative(model=model, data=data)[:, :3, :]
)

W_H_C = js.contact.transforms(model=model, data=data)
# Compute the generalized free acceleration.
BW_ν̇_free = jnp.hstack(
js.model.forward_dynamics_aba(
model=model,
data=data,
link_forces=references.link_forces(model=model, data=data),
joint_forces=references.joint_force_references(model=model),
)
)

# Compute the position and linear velocities (mixed representation) of
# all enabled collidable points belonging to the robot.
Expand All @@ -282,34 +301,9 @@ def compute_contact_forces(
position, velocity, model.terrain
)

# Build a references object to simplify converting link forces.
references = js.references.JaxSimModelReferences.build(
model=model,
data=data,
velocity_representation=data.velocity_representation,
link_forces=link_forces,
joint_force_references=joint_force_references,
)

# Compute the generalized free acceleration.
with data.switch_velocity_representation(VelRepr.Mixed):
BW_ν̇_free = jnp.hstack(
js.model.forward_dynamics_aba(
model=model,
data=data,
link_forces=references.link_forces(model=model, data=data),
joint_forces=references.joint_force_references(model=model),
)
)

# Compute the free linear acceleration of the collidable points.
# Since we use doubly-mixed jacobian, this corresponds to W_p̈_C.
free_contact_acc = _linear_acceleration_of_collidable_points(
BW_nu=BW_ν,
BW_nu_dot=BW_ν̇_free,
CW_J_WC_BW=J_WC,
CW_J_dot_WC_BW=J̇_WC,
).flatten()
CW_al_free_WC = Jl_WC @ BW_ν̇_free + J̇l_WC @ BW_ν

# Compute stabilization term.
baumgarte_term = _compute_baumgarte_stabilization_term(
Expand All @@ -322,15 +316,15 @@ def compute_contact_forces(
).flatten()

# Compute the Delassus matrix.
delassus_matrix = _compute_delassus_matrix(M=M, J_WC=J_WC)
delassus_matrix = _compute_delassus_matrix(M=M, J_WC=Jl_WC)

# Initialize regularization term of the Delassus matrix for
# better numerical conditioning.
= self.regularization_delassus * jnp.eye(delassus_matrix.shape[0])

# Construct the quadratic cost function.
Q = delassus_matrix +
q = free_contact_acc - baumgarte_term
q = CW_al_free_WC - baumgarte_term

# Construct the inequality constraints.
G = _compute_ineq_constraint_matrix(
Expand Down Expand Up @@ -379,10 +373,7 @@ def _compute_delassus_matrix(
J_WC: jtp.MatrixLike,
) -> jtp.Matrix:

sl = jnp.s_[:, 0:3, :]
J_WC_lin = jnp.vstack(J_WC[sl])

delassus_matrix = J_WC_lin @ jnp.linalg.pinv(M) @ J_WC_lin.T
delassus_matrix = J_WC @ jnp.linalg.pinv(M) @ J_WC.T
return delassus_matrix


Expand Down Expand Up @@ -428,27 +419,6 @@ def _compute_ineq_bounds(n_collidable_points: int) -> jtp.Vector:
return jnp.zeros(shape=(n_constraints,))


@jax.jit
@js.common.named_scope
def _linear_acceleration_of_collidable_points(
BW_nu: jtp.ArrayLike,
BW_nu_dot: jtp.ArrayLike,
CW_J_WC_BW: jtp.MatrixLike,
CW_J_dot_WC_BW: jtp.MatrixLike,
) -> jtp.Matrix:

BW_ν = BW_nu
BW_ν̇ = BW_nu_dot
CW_J̇_WC_BW = CW_J_dot_WC_BW

# Compute the linear acceleration of the collidable points.
# Since we use doubly-mixed jacobians, this corresponds to W_p̈_C.
CW_a_WC = jnp.vstack(CW_J̇_WC_BW) @ BW_ν + jnp.vstack(CW_J_WC_BW) @ BW_ν̇

CW_a_WC = CW_a_WC.reshape(-1, 6)
return CW_a_WC[:, 0:3].squeeze()


@jax.jit
@js.common.named_scope
def _compute_baumgarte_stabilization_term(
Expand Down

0 comments on commit 2ea67ea

Please sign in to comment.