Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Runge-Kutta 4 integrator #373

Merged
merged 12 commits into from
Feb 18, 2025
80 changes: 80 additions & 0 deletions src/jaxsim/api/integrators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import dataclasses
from collections.abc import Callable

import jax
import jax.numpy as jnp

import jaxsim
Expand Down Expand Up @@ -74,3 +76,81 @@ def semi_implicit_euler_integration(
data = data.replace(model=model) # update cache

return data


def rk4_integration(
model: js.model.JaxSimModel,
data: JaxSimModelData,
base_acceleration_inertial: jtp.Vector,
joint_accelerations: jtp.Vector,
link_forces: jtp.Vector,
joint_torques: jtp.Vector,
) -> JaxSimModelData:
"""Integrate the system state using the Runge-Kutta 4 method."""

dt = model.time_step

def f(x) -> dict[str, jtp.Matrix]:

with data.switch_velocity_representation(jaxsim.VelRepr.Inertial):

data_ti = data.replace(model=model, **x)

return js.ode.system_dynamics(
model=model,
data=data_ti,
link_forces=link_forces,
joint_torques=joint_torques,
)

base_quaternion_norm = jaxsim.math.safe_norm(data._base_quaternion)
base_quaternion = data._base_quaternion / jnp.where(
base_quaternion_norm == 0, 1.0, base_quaternion_norm
)

x_t0 = dict(
base_position=data._base_position,
base_quaternion=base_quaternion,
joint_positions=data._joint_positions,
base_linear_velocity=data._base_linear_velocity,
base_angular_velocity=data._base_angular_velocity,
joint_velocities=data._joint_velocities,
)

euler_mid = lambda x, dxdt: x + (0.5 * dt) * dxdt
euler_fin = lambda x, dxdt: x + dt * dxdt

k1 = f(x_t0)
k2 = f(jax.tree.map(euler_mid, x_t0, k1))
k3 = f(jax.tree.map(euler_mid, x_t0, k2))
k4 = f(jax.tree.map(euler_fin, x_t0, k3))

# Average the slopes and compute the RK4 state derivative.
average = lambda k1, k2, k3, k4: (k1 + 2 * k2 + 2 * k3 + k4) / 6

dxdt = jax.tree_util.tree_map(average, k1, k2, k3, k4)

# Integrate the dynamics
x_tf = jax.tree_util.tree_map(euler_fin, x_t0, dxdt)

data_tf = dataclasses.replace(
data,
**{
"_base_position": x_tf["base_position"],
"_base_quaternion": x_tf["base_quaternion"],
"_joint_positions": x_tf["joint_positions"],
"_base_linear_velocity": x_tf["base_linear_velocity"],
"_base_angular_velocity": x_tf["base_angular_velocity"],
"_joint_velocities": x_tf["joint_velocities"],
},
)

return data_tf.replace(model=model)


_INTEGRATORS_MAP: dict[
js.model.IntegratorType, Callable[..., js.data.JaxSimModelData]
] = {
js.model.IntegratorType.SemiImplicitEuler: semi_implicit_euler_integration,
js.model.IntegratorType.RungeKutta4: rk4_integration,
}
37 changes: 36 additions & 1 deletion src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import copy
import dataclasses
import enum
import functools
import pathlib
from collections.abc import Sequence
Expand All @@ -23,6 +24,13 @@
from .common import VelRepr


class IntegratorType(enum.IntEnum):
"""The integrators available for the simulation."""

SemiImplicitEuler = enum.auto()
RungeKutta4 = enum.auto()


@jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False)
class JaxSimModel(JaxsimDataclass):
"""
Expand Down Expand Up @@ -55,6 +63,10 @@ class JaxSimModel(JaxsimDataclass):
dataclasses.field(default=None, repr=False)
)

integrator: Static[IntegratorType] = dataclasses.field(
default=IntegratorType.SemiImplicitEuler, repr=False
)

built_from: Static[str | pathlib.Path | rod.Model | None] = dataclasses.field(
default=None, repr=False
)
Expand Down Expand Up @@ -111,6 +123,7 @@ def build_from_model_description(
terrain: jaxsim.terrain.Terrain | None = None,
contact_model: jaxsim.rbda.contacts.ContactModel | None = None,
contact_params: jaxsim.rbda.contacts.ContactsParams | None = None,
integrator: IntegratorType | None = None,
is_urdf: bool | None = None,
considered_joints: Sequence[str] | None = None,
) -> JaxSimModel:
Expand All @@ -131,6 +144,7 @@ def build_from_model_description(
The contact model to consider.
If not specified, a soft contacts model is used.
contact_params: The parameters of the contact model.
integrator: The integrator to use for the simulation.
is_urdf:
The optional flag to force the model description to be parsed as a URDF.
This is usually automatically inferred.
Expand Down Expand Up @@ -164,6 +178,7 @@ def build_from_model_description(
terrain=terrain,
contact_model=contact_model,
contacts_params=contact_params,
integrator=integrator,
)

# Store the origin of the model, in case downstream logic needs it.
Expand All @@ -182,6 +197,7 @@ def build(
terrain: jaxsim.terrain.Terrain | None = None,
contact_model: jaxsim.rbda.contacts.ContactModel | None = None,
contacts_params: jaxsim.rbda.contacts.ContactsParams | None = None,
integrator: IntegratorType | None = None,
gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY,
) -> JaxSimModel:
"""
Expand All @@ -202,6 +218,7 @@ def build(
The contact model to consider.
If not specified, a soft contacts model is used.
contacts_params: The parameters of the soft contacts.
integrator: The integrator to use for the simulation.
gravity: The gravity constant.

Returns:
Expand Down Expand Up @@ -237,6 +254,13 @@ def build(
if contacts_params is None:
contacts_params = contact_model._parameters_class()

# Consider the default integrator if not specified.
integrator = (
integrator
if integrator is not None
else JaxSimModel.__dataclass_fields__["integrator"].default
)

# Build the model.
model = cls(
model_name=model_name,
Expand All @@ -247,6 +271,7 @@ def build(
terrain=terrain,
contact_model=contact_model,
contacts_params=contacts_params,
integrator=integrator,
gravity=gravity,
# The following is wrapped as hashless since it's a static argument, and we
# don't want to trigger recompilation if it changes. All relevant parameters
Expand Down Expand Up @@ -449,6 +474,7 @@ def reduce(
contact_model=model.contact_model,
contacts_params=model.contacts_params,
gravity=model.gravity,
integrator=model.integrator,
)

# Store the origin of the model, in case downstream logic needs it.
Expand Down Expand Up @@ -2075,12 +2101,21 @@ def step(
# =============================
# Advance the simulation state
# =============================
from .integrators import _INTEGRATORS_MAP

integrator_fn = _INTEGRATORS_MAP[model.integrator]

data_tf = js.integrators.semi_implicit_euler_integration(
data_tf = integrator_fn(
model=model,
data=data,
base_acceleration_inertial=W_v̇_WB,
joint_accelerations=s̈,
# Pass link_forces and joint_torques if the integrator is rk4
**(
{"link_forces": W_f_L_total, "joint_torques": τ_total}
if model.integrator == IntegratorType.RungeKutta4
else {}
),
)

return data_tf
14 changes: 5 additions & 9 deletions src/jaxsim/api/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim.api.data import JaxSimModelData
from jaxsim.math import Quaternion, Skew

from .common import VelRepr
Expand Down Expand Up @@ -123,7 +122,7 @@ def system_dynamics(
link_forces: jtp.Vector | None = None,
joint_torques: jtp.Vector | None = None,
baumgarte_quaternion_regularization: jtp.FloatLike = 1.0,
) -> JaxSimModelData:
) -> dict[str, jtp.Vector]:
"""
Compute the dynamics of the system.

Expand All @@ -139,9 +138,9 @@ def system_dynamics(
quaternion (only used in integrators not operating on the SO(3) manifold).

Returns:
A tuple with an `JaxSimModelData` object storing in each of its attributes the
corresponding derivative, and the dictionary of auxiliary data returned
by the system dynamics evaluation.
A dictionary containing the derivatives of the base position, the base quaternion,
the joint positions, the base linear velocity, the base angular velocity, and the
joint velocities.
"""

with data.switch_velocity_representation(velocity_representation=VelRepr.Inertial):
Expand All @@ -158,14 +157,11 @@ def system_dynamics(
baumgarte_quaternion_regularization=baumgarte_quaternion_regularization,
)

ode_state_derivative = JaxSimModelData.build(
model=model,
return dict(
base_position=W_ṗ_B,
base_quaternion=W_Q̇_B,
joint_positions=ṡ,
base_linear_velocity=W_v̇_WB[0:3],
base_angular_velocity=W_v̇_WB[3:6],
joint_velocities=s̈,
)

return ode_state_derivative
6 changes: 6 additions & 0 deletions src/jaxsim/rbda/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,12 @@ def process_inputs(
if W_Q_B.shape != (4,):
raise ValueError(W_Q_B.shape, (4,))

# Check that the quaternion does not contain NaNs.
exceptions.raise_value_error_if(
condition=jnp.isnan(W_Q_B).any(),
msg="A RBDA received a quaternion that contains NaNs.",
)

# Check that the quaternion is unary since our RBDAs make this assumption in order
# to prevent introducing additional normalizations that would affect AD.
exceptions.raise_value_error_if(
Expand Down
19 changes: 19 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import jaxsim
import jaxsim.api as js
from jaxsim.api.model import IntegratorType


def pytest_addoption(parser):
Expand Down Expand Up @@ -127,6 +128,24 @@ def velocity_representation(request) -> jaxsim.VelRepr:
return request.param


@pytest.fixture(
scope="function",
params=[
pytest.param(IntegratorType.SemiImplicitEuler, id="semi_implicit_euler"),
pytest.param(IntegratorType.RungeKutta4, id="runge_kutta_4"),
],
)
def integrator(request) -> str:
"""
Fixture providing the integrator to use in the simulation.

Returns:
The integrator to use in the simulation.
"""

return request.param


@pytest.fixture(scope="session")
def batch_size(request) -> int:
"""
Expand Down
3 changes: 2 additions & 1 deletion tests/test_simulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def run_simulation(


def test_simulation_with_relaxed_rigid_contacts(
jaxsim_model_box: js.model.JaxSimModel,
jaxsim_model_box: js.model.JaxSimModel, integrator
):

model = jaxsim_model_box
Expand All @@ -206,6 +206,7 @@ def test_simulation_with_relaxed_rigid_contacts(
model.kin_dyn_parameters.contact_parameters.enabled = tuple(
enabled_collidable_points_mask.tolist()
)
model.integrator = integrator

assert np.sum(model.kin_dyn_parameters.contact_parameters.enabled) == 4

Expand Down
Loading