Skip to content

Commit

Permalink
Merge pull request #378 from ami-iit/fix/integrators_usage
Browse files Browse the repository at this point in the history
Uniformize usage of integrators API
  • Loading branch information
flferretti authored Feb 26, 2025
2 parents eab4885 + 7e6f942 commit f5e291f
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 61 deletions.
75 changes: 38 additions & 37 deletions src/jaxsim/api/integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,81 +8,82 @@
import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim.api.data import JaxSimModelData
from jaxsim.math import Adjoint, Transform
from jaxsim.math import Skew


def semi_implicit_euler_integration(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
base_acceleration_inertial: jtp.Vector,
joint_accelerations: jtp.Vector,
link_forces: jtp.Vector,
joint_torques: jtp.Vector,
) -> JaxSimModelData:
"""Integrate the system state using the semi-implicit Euler method."""
# Step the dynamics forward.

with data.switch_velocity_representation(jaxsim.VelRepr.Inertial):

dt = model.time_step
W_v̇_WB = base_acceleration_inertial
= joint_accelerations
# Compute the system acceleration
W_v̇_WB, = js.ode.system_acceleration(
model=model,
data=data,
link_forces=link_forces,
joint_torques=joint_torques,
)

B_H_W = Transform.inverse(data._base_transform).at[:3, :3].set(jnp.eye(3))
BW_X_W = Adjoint.from_transform(B_H_W)
dt = model.time_step

# Compute the new generalized velocity.
new_generalized_acceleration = jnp.hstack([W_v̇_WB, ])

new_generalized_velocity = (
data.generalized_velocity + dt * new_generalized_acceleration
)

new_base_velocity_inertial = new_generalized_velocity[0:6]
new_joint_velocities = new_generalized_velocity[6:]
# Extract the new base and joint velocities.
W_v_B = new_generalized_velocity[0:6]
= new_generalized_velocity[6:]

base_lin_velocity_inertial = new_base_velocity_inertial[0:3]
# Compute the new base position and orientation.
W_ω_WB = new_generalized_velocity[3:6]

new_base_velocity_mixed = BW_X_W @ new_generalized_velocity[0:6]
base_lin_velocity_mixed = new_base_velocity_mixed[0:3]
base_ang_velocity_mixed = new_base_velocity_mixed[3:6]
# To obtain the derivative of the base position, we need to subtract
# the skew-symmetric matrix of the base angular velocity times the base position.
# See: S. Traversaro and A. Saccon, “Multibody Dynamics Notation (Version 2), pg.9
W_ṗ_B = new_generalized_velocity[0:3] + Skew.wedge(W_ω_WB) @ data.base_position

base_quaternion_derivative = jaxsim.math.Quaternion.derivative(
W_Q̇_B = jaxsim.math.Quaternion.derivative(
quaternion=data.base_orientation,
omega=base_ang_velocity_mixed,
omega=W_ω_WB,
omega_in_body_fixed=False,
).squeeze()

new_base_position = data.base_position + dt * base_lin_velocity_mixed
new_base_quaternion = data.base_orientation + dt * base_quaternion_derivative
W_p_B = data.base_position + dt * W_ṗ_B
W_Q_B = data.base_orientation + dt * W_Q̇_B

base_quaternion_norm = jaxsim.math.safe_norm(new_base_quaternion)
base_quaternion_norm = jaxsim.math.safe_norm(W_Q_B)

new_base_quaternion = new_base_quaternion / jnp.where(
base_quaternion_norm == 0, 1.0, base_quaternion_norm
)
W_Q_B = W_Q_B / jnp.where(base_quaternion_norm == 0, 1.0, base_quaternion_norm)

new_joint_position = data.joint_positions + dt * new_joint_velocities
s = data.joint_positions + dt *

# TODO: Avoid double replace, e.g. by computing cached value here
data = dataclasses.replace(
data,
_base_quaternion=new_base_quaternion,
_base_position=new_base_position,
_joint_positions=new_joint_position,
_joint_velocities=new_joint_velocities,
_base_linear_velocity=base_lin_velocity_inertial,
# Here we use the base angular velocity in mixed representation since
# it's equivalent to the one in inertial representation
# See: S. Traversaro and A. Saccon, “Multibody Dynamics Notation (Version 2), pg.9
_base_angular_velocity=base_ang_velocity_mixed,
_base_quaternion=W_Q_B,
_base_position=W_p_B,
_joint_positions=s,
_joint_velocities=,
_base_linear_velocity=W_v_B[0:3],
_base_angular_velocity=W_ω_WB,
)
data = data.replace(model=model) # update cache

# Update the cached computations.
data = data.replace(model=model)

return data


def rk4_integration(
model: js.model.JaxSimModel,
data: JaxSimModelData,
base_acceleration_inertial: jtp.Vector,
joint_accelerations: jtp.Vector,
link_forces: jtp.Vector,
joint_torques: jtp.Vector,
) -> JaxSimModelData:
Expand Down
26 changes: 2 additions & 24 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ class JaxSimModel(JaxsimDataclass):
The JaxSim model defining the kinematics and dynamics of a robot.
"""

# link_spatial_inertial_matrices, motion_subspaces

model_name: Static[str]

time_step: float = dataclasses.field(
Expand Down Expand Up @@ -2086,36 +2084,16 @@ def step(

W_f_L_total = W_f_L_external + W_f_L_terrain

# ===============================
# Compute the system acceleration
# ===============================

with data.switch_velocity_representation(jaxsim.VelRepr.Inertial):
W_v̇_WB, = js.ode.system_acceleration(
model=model,
data=data,
link_forces=W_f_L_total,
joint_torques=τ_total,
)

# =============================
# Advance the simulation state
# =============================

from .integrators import _INTEGRATORS_MAP

integrator_fn = _INTEGRATORS_MAP[model.integrator]

data_tf = integrator_fn(
model=model,
data=data,
base_acceleration_inertial=W_v̇_WB,
joint_accelerations=,
# Pass link_forces and joint_torques if the integrator is rk4
**(
{"link_forces": W_f_L_total, "joint_torques": τ_total}
if model.integrator == IntegratorType.RungeKutta4
else {}
),
model=model, data=data, link_forces=W_f_L_total, joint_torques=τ_total
)

return data_tf

0 comments on commit f5e291f

Please sign in to comment.