Skip to content

Commit

Permalink
Merge pull request #373 from ami-iit/feature/rk4
Browse files Browse the repository at this point in the history
Implement Runge-Kutta 4 integrator
  • Loading branch information
flferretti authored Feb 18, 2025
2 parents a4a7153 + 7a2f1b9 commit 1b31e66
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 11 deletions.
80 changes: 80 additions & 0 deletions src/jaxsim/api/integrators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import dataclasses
from collections.abc import Callable

import jax
import jax.numpy as jnp

import jaxsim
Expand Down Expand Up @@ -74,3 +76,81 @@ def semi_implicit_euler_integration(
data = data.replace(model=model) # update cache

return data


def rk4_integration(
model: js.model.JaxSimModel,
data: JaxSimModelData,
base_acceleration_inertial: jtp.Vector,
joint_accelerations: jtp.Vector,
link_forces: jtp.Vector,
joint_torques: jtp.Vector,
) -> JaxSimModelData:
"""Integrate the system state using the Runge-Kutta 4 method."""

dt = model.time_step

def f(x) -> dict[str, jtp.Matrix]:

with data.switch_velocity_representation(jaxsim.VelRepr.Inertial):

data_ti = data.replace(model=model, **x)

return js.ode.system_dynamics(
model=model,
data=data_ti,
link_forces=link_forces,
joint_torques=joint_torques,
)

base_quaternion_norm = jaxsim.math.safe_norm(data._base_quaternion)
base_quaternion = data._base_quaternion / jnp.where(
base_quaternion_norm == 0, 1.0, base_quaternion_norm
)

x_t0 = dict(
base_position=data._base_position,
base_quaternion=base_quaternion,
joint_positions=data._joint_positions,
base_linear_velocity=data._base_linear_velocity,
base_angular_velocity=data._base_angular_velocity,
joint_velocities=data._joint_velocities,
)

euler_mid = lambda x, dxdt: x + (0.5 * dt) * dxdt
euler_fin = lambda x, dxdt: x + dt * dxdt

k1 = f(x_t0)
k2 = f(jax.tree.map(euler_mid, x_t0, k1))
k3 = f(jax.tree.map(euler_mid, x_t0, k2))
k4 = f(jax.tree.map(euler_fin, x_t0, k3))

# Average the slopes and compute the RK4 state derivative.
average = lambda k1, k2, k3, k4: (k1 + 2 * k2 + 2 * k3 + k4) / 6

dxdt = jax.tree_util.tree_map(average, k1, k2, k3, k4)

# Integrate the dynamics
x_tf = jax.tree_util.tree_map(euler_fin, x_t0, dxdt)

data_tf = dataclasses.replace(
data,
**{
"_base_position": x_tf["base_position"],
"_base_quaternion": x_tf["base_quaternion"],
"_joint_positions": x_tf["joint_positions"],
"_base_linear_velocity": x_tf["base_linear_velocity"],
"_base_angular_velocity": x_tf["base_angular_velocity"],
"_joint_velocities": x_tf["joint_velocities"],
},
)

return data_tf.replace(model=model)


_INTEGRATORS_MAP: dict[
js.model.IntegratorType, Callable[..., js.data.JaxSimModelData]
] = {
js.model.IntegratorType.SemiImplicitEuler: semi_implicit_euler_integration,
js.model.IntegratorType.RungeKutta4: rk4_integration,
}
37 changes: 36 additions & 1 deletion src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import copy
import dataclasses
import enum
import functools
import pathlib
from collections.abc import Sequence
Expand All @@ -23,6 +24,13 @@
from .common import VelRepr


class IntegratorType(enum.IntEnum):
"""The integrators available for the simulation."""

SemiImplicitEuler = enum.auto()
RungeKutta4 = enum.auto()


@jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False)
class JaxSimModel(JaxsimDataclass):
"""
Expand Down Expand Up @@ -55,6 +63,10 @@ class JaxSimModel(JaxsimDataclass):
dataclasses.field(default=None, repr=False)
)

integrator: Static[IntegratorType] = dataclasses.field(
default=IntegratorType.SemiImplicitEuler, repr=False
)

built_from: Static[str | pathlib.Path | rod.Model | None] = dataclasses.field(
default=None, repr=False
)
Expand Down Expand Up @@ -111,6 +123,7 @@ def build_from_model_description(
terrain: jaxsim.terrain.Terrain | None = None,
contact_model: jaxsim.rbda.contacts.ContactModel | None = None,
contact_params: jaxsim.rbda.contacts.ContactsParams | None = None,
integrator: IntegratorType | None = None,
is_urdf: bool | None = None,
considered_joints: Sequence[str] | None = None,
) -> JaxSimModel:
Expand All @@ -131,6 +144,7 @@ def build_from_model_description(
The contact model to consider.
If not specified, a soft contacts model is used.
contact_params: The parameters of the contact model.
integrator: The integrator to use for the simulation.
is_urdf:
The optional flag to force the model description to be parsed as a URDF.
This is usually automatically inferred.
Expand Down Expand Up @@ -164,6 +178,7 @@ def build_from_model_description(
terrain=terrain,
contact_model=contact_model,
contacts_params=contact_params,
integrator=integrator,
)

# Store the origin of the model, in case downstream logic needs it.
Expand All @@ -182,6 +197,7 @@ def build(
terrain: jaxsim.terrain.Terrain | None = None,
contact_model: jaxsim.rbda.contacts.ContactModel | None = None,
contacts_params: jaxsim.rbda.contacts.ContactsParams | None = None,
integrator: IntegratorType | None = None,
gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY,
) -> JaxSimModel:
"""
Expand All @@ -202,6 +218,7 @@ def build(
The contact model to consider.
If not specified, a soft contacts model is used.
contacts_params: The parameters of the soft contacts.
integrator: The integrator to use for the simulation.
gravity: The gravity constant.
Returns:
Expand Down Expand Up @@ -237,6 +254,13 @@ def build(
if contacts_params is None:
contacts_params = contact_model._parameters_class()

# Consider the default integrator if not specified.
integrator = (
integrator
if integrator is not None
else JaxSimModel.__dataclass_fields__["integrator"].default
)

# Build the model.
model = cls(
model_name=model_name,
Expand All @@ -247,6 +271,7 @@ def build(
terrain=terrain,
contact_model=contact_model,
contacts_params=contacts_params,
integrator=integrator,
gravity=gravity,
# The following is wrapped as hashless since it's a static argument, and we
# don't want to trigger recompilation if it changes. All relevant parameters
Expand Down Expand Up @@ -449,6 +474,7 @@ def reduce(
contact_model=model.contact_model,
contacts_params=model.contacts_params,
gravity=model.gravity,
integrator=model.integrator,
)

# Store the origin of the model, in case downstream logic needs it.
Expand Down Expand Up @@ -2075,12 +2101,21 @@ def step(
# =============================
# Advance the simulation state
# =============================
from .integrators import _INTEGRATORS_MAP

integrator_fn = _INTEGRATORS_MAP[model.integrator]

data_tf = js.integrators.semi_implicit_euler_integration(
data_tf = integrator_fn(
model=model,
data=data,
base_acceleration_inertial=W_v̇_WB,
joint_accelerations=,
# Pass link_forces and joint_torques if the integrator is rk4
**(
{"link_forces": W_f_L_total, "joint_torques": τ_total}
if model.integrator == IntegratorType.RungeKutta4
else {}
),
)

return data_tf
14 changes: 5 additions & 9 deletions src/jaxsim/api/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim.api.data import JaxSimModelData
from jaxsim.math import Quaternion, Skew

from .common import VelRepr
Expand Down Expand Up @@ -123,7 +122,7 @@ def system_dynamics(
link_forces: jtp.Vector | None = None,
joint_torques: jtp.Vector | None = None,
baumgarte_quaternion_regularization: jtp.FloatLike = 1.0,
) -> JaxSimModelData:
) -> dict[str, jtp.Vector]:
"""
Compute the dynamics of the system.
Expand All @@ -139,9 +138,9 @@ def system_dynamics(
quaternion (only used in integrators not operating on the SO(3) manifold).
Returns:
A tuple with an `JaxSimModelData` object storing in each of its attributes the
corresponding derivative, and the dictionary of auxiliary data returned
by the system dynamics evaluation.
A dictionary containing the derivatives of the base position, the base quaternion,
the joint positions, the base linear velocity, the base angular velocity, and the
joint velocities.
"""

with data.switch_velocity_representation(velocity_representation=VelRepr.Inertial):
Expand All @@ -158,14 +157,11 @@ def system_dynamics(
baumgarte_quaternion_regularization=baumgarte_quaternion_regularization,
)

ode_state_derivative = JaxSimModelData.build(
model=model,
return dict(
base_position=W_ṗ_B,
base_quaternion=W_Q̇_B,
joint_positions=,
base_linear_velocity=W_v̇_WB[0:3],
base_angular_velocity=W_v̇_WB[3:6],
joint_velocities=,
)

return ode_state_derivative
6 changes: 6 additions & 0 deletions src/jaxsim/rbda/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,12 @@ def process_inputs(
if W_Q_B.shape != (4,):
raise ValueError(W_Q_B.shape, (4,))

# Check that the quaternion does not contain NaNs.
exceptions.raise_value_error_if(
condition=jnp.isnan(W_Q_B).any(),
msg="A RBDA received a quaternion that contains NaNs.",
)

# Check that the quaternion is unary since our RBDAs make this assumption in order
# to prevent introducing additional normalizations that would affect AD.
exceptions.raise_value_error_if(
Expand Down
19 changes: 19 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import jaxsim
import jaxsim.api as js
from jaxsim.api.model import IntegratorType


def pytest_addoption(parser):
Expand Down Expand Up @@ -127,6 +128,24 @@ def velocity_representation(request) -> jaxsim.VelRepr:
return request.param


@pytest.fixture(
scope="function",
params=[
pytest.param(IntegratorType.SemiImplicitEuler, id="semi_implicit_euler"),
pytest.param(IntegratorType.RungeKutta4, id="runge_kutta_4"),
],
)
def integrator(request) -> str:
"""
Fixture providing the integrator to use in the simulation.
Returns:
The integrator to use in the simulation.
"""

return request.param


@pytest.fixture(scope="session")
def batch_size(request) -> int:
"""
Expand Down
3 changes: 2 additions & 1 deletion tests/test_simulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def run_simulation(


def test_simulation_with_relaxed_rigid_contacts(
jaxsim_model_box: js.model.JaxSimModel,
jaxsim_model_box: js.model.JaxSimModel, integrator
):

model = jaxsim_model_box
Expand All @@ -206,6 +206,7 @@ def test_simulation_with_relaxed_rigid_contacts(
model.kin_dyn_parameters.contact_parameters.enabled = tuple(
enabled_collidable_points_mask.tolist()
)
model.integrator = integrator

assert np.sum(model.kin_dyn_parameters.contact_parameters.enabled) == 4

Expand Down

0 comments on commit 1b31e66

Please sign in to comment.