Skip to content

Commit

Permalink
Compute contact forces inside ode.system_acceleration
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Feb 27, 2025
1 parent 1c219eb commit 8709aa1
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 57 deletions.
15 changes: 9 additions & 6 deletions src/jaxsim/api/integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,21 @@
def semi_implicit_euler_integration(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
base_acceleration_inertial: jtp.Vector,
joint_accelerations: jtp.Vector,
link_forces: jtp.Vector,
joint_torques: jtp.Vector,
) -> JaxSimModelData:
"""Integrate the system state using the semi-implicit Euler method."""
# Step the dynamics forward.
with data.switch_velocity_representation(jaxsim.VelRepr.Inertial):

W_v̇_WB, = js.ode.system_acceleration(
model=model,
data=data,
link_forces=link_forces,
joint_torques=joint_torques,
)

dt = model.time_step
W_v̇_WB = base_acceleration_inertial
= joint_accelerations

B_H_W = Transform.inverse(data._base_transform).at[:3, :3].set(jnp.eye(3))
BW_X_W = Adjoint.from_transform(B_H_W)
Expand Down Expand Up @@ -81,8 +86,6 @@ def semi_implicit_euler_integration(
def rk4_integration(
model: js.model.JaxSimModel,
data: JaxSimModelData,
base_acceleration_inertial: jtp.Vector,
joint_accelerations: jtp.Vector,
link_forces: jtp.Vector,
joint_torques: jtp.Vector,
) -> JaxSimModelData:
Expand Down
45 changes: 2 additions & 43 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2063,41 +2063,6 @@ def step(
model, data, joint_force_references=τ_references
)

# ======================
# Compute contact forces
# ======================

W_f_L_terrain = jnp.zeros_like(W_f_L_external)

if len(model.kin_dyn_parameters.contact_parameters.body) > 0:

# Compute the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact
# with the terrain.
W_f_L_terrain = js.contact_model.link_contact_forces(
model=model,
data=data,
link_forces=W_f_L_external,
joint_torques=τ_total,
)

# ==============================
# Compute the total link forces
# ==============================

W_f_L_total = W_f_L_external + W_f_L_terrain

# ===============================
# Compute the system acceleration
# ===============================

with data.switch_velocity_representation(jaxsim.VelRepr.Inertial):
W_v̇_WB, = js.ode.system_acceleration(
model=model,
data=data,
link_forces=W_f_L_total,
joint_torques=τ_total,
)

# =============================
# Advance the simulation state
# =============================
Expand All @@ -2108,14 +2073,8 @@ def step(
data_tf = integrator_fn(
model=model,
data=data,
base_acceleration_inertial=W_v̇_WB,
joint_accelerations=,
# Pass link_forces and joint_torques if the integrator is rk4
**(
{"link_forces": W_f_L_total, "joint_torques": τ_total}
if model.integrator == IntegratorType.RungeKutta4
else {}
),
link_forces=W_f_L_external,
joint_torques=τ_total,
)

return data_tf
31 changes: 25 additions & 6 deletions src/jaxsim/api/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,21 @@ def system_acceleration(
and the joint accelerations.
"""

# ======================
# Compute contact forces
# ======================

if len(model.kin_dyn_parameters.contact_parameters.body) > 0:

# Compute the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact
# with the terrain.
W_f_L_terrain = js.contact_model.link_contact_forces(
model=model,
data=data,
link_forces=link_forces,
joint_torques=joint_torques,
)

# ====================
# Validate input data
# ====================
Expand All @@ -54,19 +69,23 @@ def system_acceleration(
link_forces=f_L,
)

with references.switch_velocity_representation(VelRepr.Inertial):
W_f_L_total = W_f_L_terrain + references.link_forces(model=model, data=data)

# Compute forward dynamics.
#
# - Joint accelerations: s̈ ∈ ℝⁿ
# - Base acceleration: v̇_WB ∈ ℝ⁶
#
# Note that ABA returns the base acceleration in the velocity representation
# stored in the `data` object.
v̇_WB, = js.model.forward_dynamics_aba(
model=model,
data=data,
joint_forces=joint_torques,
link_forces=references.link_forces(model=model, data=data),
)
with data.switch_velocity_representation(VelRepr.Inertial):
v̇_WB, = js.model.forward_dynamics_aba(
model=model,
data=data,
joint_forces=joint_torques,
link_forces=W_f_L_total,
)

return v̇_WB,

Expand Down
4 changes: 2 additions & 2 deletions src/jaxsim/rbda/contacts/relaxed_rigid.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,11 +314,11 @@ def compute_contact_forces(
BW_ν = data.generalized_velocity

BW_ν̇_free = jnp.hstack(
js.ode.system_acceleration(
js.model.forward_dynamics_aba(
model=model,
data=data,
link_forces=references.link_forces(model=model, data=data),
joint_torques=references.joint_force_references(model=model),
joint_forces=references.joint_force_references(model=model),
)
)

Expand Down

0 comments on commit 8709aa1

Please sign in to comment.