diff --git a/src/jaxsim/api/integrators.py b/src/jaxsim/api/integrators.py index 3d6dce113..ef50b5be6 100644 --- a/src/jaxsim/api/integrators.py +++ b/src/jaxsim/api/integrators.py @@ -1,5 +1,7 @@ import dataclasses +from collections.abc import Callable +import jax import jax.numpy as jnp import jaxsim @@ -74,3 +76,81 @@ def semi_implicit_euler_integration( data = data.replace(model=model) # update cache return data + + +def rk4_integration( + model: js.model.JaxSimModel, + data: JaxSimModelData, + base_acceleration_inertial: jtp.Vector, + joint_accelerations: jtp.Vector, + link_forces: jtp.Vector, + joint_torques: jtp.Vector, +) -> JaxSimModelData: + """Integrate the system state using the Runge-Kutta 4 method.""" + + dt = model.time_step + + def f(x) -> dict[str, jtp.Matrix]: + + with data.switch_velocity_representation(jaxsim.VelRepr.Inertial): + + data_ti = data.replace(model=model, **x) + + return js.ode.system_dynamics( + model=model, + data=data_ti, + link_forces=link_forces, + joint_torques=joint_torques, + ) + + base_quaternion_norm = jaxsim.math.safe_norm(data._base_quaternion) + base_quaternion = data._base_quaternion / jnp.where( + base_quaternion_norm == 0, 1.0, base_quaternion_norm + ) + + x_t0 = dict( + base_position=data._base_position, + base_quaternion=base_quaternion, + joint_positions=data._joint_positions, + base_linear_velocity=data._base_linear_velocity, + base_angular_velocity=data._base_angular_velocity, + joint_velocities=data._joint_velocities, + ) + + euler_mid = lambda x, dxdt: x + (0.5 * dt) * dxdt + euler_fin = lambda x, dxdt: x + dt * dxdt + + k1 = f(x_t0) + k2 = f(jax.tree.map(euler_mid, x_t0, k1)) + k3 = f(jax.tree.map(euler_mid, x_t0, k2)) + k4 = f(jax.tree.map(euler_fin, x_t0, k3)) + + # Average the slopes and compute the RK4 state derivative. + average = lambda k1, k2, k3, k4: (k1 + 2 * k2 + 2 * k3 + k4) / 6 + + dxdt = jax.tree_util.tree_map(average, k1, k2, k3, k4) + + # Integrate the dynamics + x_tf = jax.tree_util.tree_map(euler_fin, x_t0, dxdt) + + data_tf = dataclasses.replace( + data, + **{ + "_base_position": x_tf["base_position"], + "_base_quaternion": x_tf["base_quaternion"], + "_joint_positions": x_tf["joint_positions"], + "_base_linear_velocity": x_tf["base_linear_velocity"], + "_base_angular_velocity": x_tf["base_angular_velocity"], + "_joint_velocities": x_tf["joint_velocities"], + }, + ) + + return data_tf.replace(model=model) + + +_INTEGRATORS_MAP: dict[ + js.model.IntegratorType, Callable[..., js.data.JaxSimModelData] +] = { + js.model.IntegratorType.SemiImplicitEuler: semi_implicit_euler_integration, + js.model.IntegratorType.RungeKutta4: rk4_integration, +} diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index afce17e73..f9bb48fc7 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -2,6 +2,7 @@ import copy import dataclasses +import enum import functools import pathlib from collections.abc import Sequence @@ -23,6 +24,13 @@ from .common import VelRepr +class IntegratorType(enum.IntEnum): + """The integrators available for the simulation.""" + + SemiImplicitEuler = enum.auto() + RungeKutta4 = enum.auto() + + @jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False) class JaxSimModel(JaxsimDataclass): """ @@ -55,6 +63,10 @@ class JaxSimModel(JaxsimDataclass): dataclasses.field(default=None, repr=False) ) + integrator: Static[IntegratorType] = dataclasses.field( + default=IntegratorType.SemiImplicitEuler, repr=False + ) + built_from: Static[str | pathlib.Path | rod.Model | None] = dataclasses.field( default=None, repr=False ) @@ -111,6 +123,7 @@ def build_from_model_description( terrain: jaxsim.terrain.Terrain | None = None, contact_model: jaxsim.rbda.contacts.ContactModel | None = None, contact_params: jaxsim.rbda.contacts.ContactsParams | None = None, + integrator: IntegratorType | None = None, is_urdf: bool | None = None, considered_joints: Sequence[str] | None = None, ) -> JaxSimModel: @@ -131,6 +144,7 @@ def build_from_model_description( The contact model to consider. If not specified, a soft contacts model is used. contact_params: The parameters of the contact model. + integrator: The integrator to use for the simulation. is_urdf: The optional flag to force the model description to be parsed as a URDF. This is usually automatically inferred. @@ -164,6 +178,7 @@ def build_from_model_description( terrain=terrain, contact_model=contact_model, contacts_params=contact_params, + integrator=integrator, ) # Store the origin of the model, in case downstream logic needs it. @@ -182,6 +197,7 @@ def build( terrain: jaxsim.terrain.Terrain | None = None, contact_model: jaxsim.rbda.contacts.ContactModel | None = None, contacts_params: jaxsim.rbda.contacts.ContactsParams | None = None, + integrator: IntegratorType | None = None, gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY, ) -> JaxSimModel: """ @@ -202,6 +218,7 @@ def build( The contact model to consider. If not specified, a soft contacts model is used. contacts_params: The parameters of the soft contacts. + integrator: The integrator to use for the simulation. gravity: The gravity constant. Returns: @@ -237,6 +254,13 @@ def build( if contacts_params is None: contacts_params = contact_model._parameters_class() + # Consider the default integrator if not specified. + integrator = ( + integrator + if integrator is not None + else JaxSimModel.__dataclass_fields__["integrator"].default + ) + # Build the model. model = cls( model_name=model_name, @@ -247,6 +271,7 @@ def build( terrain=terrain, contact_model=contact_model, contacts_params=contacts_params, + integrator=integrator, gravity=gravity, # The following is wrapped as hashless since it's a static argument, and we # don't want to trigger recompilation if it changes. All relevant parameters @@ -449,6 +474,7 @@ def reduce( contact_model=model.contact_model, contacts_params=model.contacts_params, gravity=model.gravity, + integrator=model.integrator, ) # Store the origin of the model, in case downstream logic needs it. @@ -2075,12 +2101,21 @@ def step( # ============================= # Advance the simulation state # ============================= + from .integrators import _INTEGRATORS_MAP + + integrator_fn = _INTEGRATORS_MAP[model.integrator] - data_tf = js.integrators.semi_implicit_euler_integration( + data_tf = integrator_fn( model=model, data=data, base_acceleration_inertial=W_v̇_WB, joint_accelerations=s̈, + # Pass link_forces and joint_torques if the integrator is rk4 + **( + {"link_forces": W_f_L_total, "joint_torques": τ_total} + if model.integrator == IntegratorType.RungeKutta4 + else {} + ), ) return data_tf diff --git a/src/jaxsim/api/ode.py b/src/jaxsim/api/ode.py index a92dd890b..2f61e1d43 100644 --- a/src/jaxsim/api/ode.py +++ b/src/jaxsim/api/ode.py @@ -3,7 +3,6 @@ import jaxsim.api as js import jaxsim.typing as jtp -from jaxsim.api.data import JaxSimModelData from jaxsim.math import Quaternion, Skew from .common import VelRepr @@ -123,7 +122,7 @@ def system_dynamics( link_forces: jtp.Vector | None = None, joint_torques: jtp.Vector | None = None, baumgarte_quaternion_regularization: jtp.FloatLike = 1.0, -) -> JaxSimModelData: +) -> dict[str, jtp.Vector]: """ Compute the dynamics of the system. @@ -139,9 +138,9 @@ def system_dynamics( quaternion (only used in integrators not operating on the SO(3) manifold). Returns: - A tuple with an `JaxSimModelData` object storing in each of its attributes the - corresponding derivative, and the dictionary of auxiliary data returned - by the system dynamics evaluation. + A dictionary containing the derivatives of the base position, the base quaternion, + the joint positions, the base linear velocity, the base angular velocity, and the + joint velocities. """ with data.switch_velocity_representation(velocity_representation=VelRepr.Inertial): @@ -158,8 +157,7 @@ def system_dynamics( baumgarte_quaternion_regularization=baumgarte_quaternion_regularization, ) - ode_state_derivative = JaxSimModelData.build( - model=model, + return dict( base_position=W_ṗ_B, base_quaternion=W_Q̇_B, joint_positions=ṡ, @@ -167,5 +165,3 @@ def system_dynamics( base_angular_velocity=W_v̇_WB[3:6], joint_velocities=s̈, ) - - return ode_state_derivative diff --git a/src/jaxsim/rbda/utils.py b/src/jaxsim/rbda/utils.py index 9b2614a52..2ef119862 100644 --- a/src/jaxsim/rbda/utils.py +++ b/src/jaxsim/rbda/utils.py @@ -132,6 +132,12 @@ def process_inputs( if W_Q_B.shape != (4,): raise ValueError(W_Q_B.shape, (4,)) + # Check that the quaternion does not contain NaNs. + exceptions.raise_value_error_if( + condition=jnp.isnan(W_Q_B).any(), + msg="A RBDA received a quaternion that contains NaNs.", + ) + # Check that the quaternion is unary since our RBDAs make this assumption in order # to prevent introducing additional normalizations that would affect AD. exceptions.raise_value_error_if( diff --git a/tests/conftest.py b/tests/conftest.py index 3ad751e85..7c97a7f2d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,6 +13,7 @@ import jaxsim import jaxsim.api as js +from jaxsim.api.model import IntegratorType def pytest_addoption(parser): @@ -127,6 +128,24 @@ def velocity_representation(request) -> jaxsim.VelRepr: return request.param +@pytest.fixture( + scope="function", + params=[ + pytest.param(IntegratorType.SemiImplicitEuler, id="semi_implicit_euler"), + pytest.param(IntegratorType.RungeKutta4, id="runge_kutta_4"), + ], +) +def integrator(request) -> str: + """ + Fixture providing the integrator to use in the simulation. + + Returns: + The integrator to use in the simulation. + """ + + return request.param + + @pytest.fixture(scope="session") def batch_size(request) -> int: """ diff --git a/tests/test_simulations.py b/tests/test_simulations.py index 50fa4beb9..56a65ef0a 100644 --- a/tests/test_simulations.py +++ b/tests/test_simulations.py @@ -188,7 +188,7 @@ def run_simulation( def test_simulation_with_relaxed_rigid_contacts( - jaxsim_model_box: js.model.JaxSimModel, + jaxsim_model_box: js.model.JaxSimModel, integrator ): model = jaxsim_model_box @@ -206,6 +206,7 @@ def test_simulation_with_relaxed_rigid_contacts( model.kin_dyn_parameters.contact_parameters.enabled = tuple( enabled_collidable_points_mask.tolist() ) + model.integrator = integrator assert np.sum(model.kin_dyn_parameters.contact_parameters.enabled) == 4