From 340e2b48c2ffd3840e432b1f81fd2403d7b7e174 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Fri, 14 Feb 2025 17:23:51 +0100 Subject: [PATCH 01/12] Make `js.ode.system_dynamics` return a PyTree --- src/jaxsim/api/ode.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/jaxsim/api/ode.py b/src/jaxsim/api/ode.py index a92dd890b..e4cf2658a 100644 --- a/src/jaxsim/api/ode.py +++ b/src/jaxsim/api/ode.py @@ -158,8 +158,7 @@ def system_dynamics( baumgarte_quaternion_regularization=baumgarte_quaternion_regularization, ) - ode_state_derivative = JaxSimModelData.build( - model=model, + return dict( base_position=W_ṗ_B, base_quaternion=W_Q̇_B, joint_positions=ṡ, @@ -167,5 +166,3 @@ def system_dynamics( base_angular_velocity=W_v̇_WB[3:6], joint_velocities=s̈, ) - - return ode_state_derivative From 79ac4544c56d410c795214ca67a66db9e819330c Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Fri, 14 Feb 2025 17:34:15 +0100 Subject: [PATCH 02/12] Implement Runge-Kutta 4 integration --- src/jaxsim/api/integrators.py | 69 +++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/src/jaxsim/api/integrators.py b/src/jaxsim/api/integrators.py index 3d6dce113..300c1dd6b 100644 --- a/src/jaxsim/api/integrators.py +++ b/src/jaxsim/api/integrators.py @@ -1,5 +1,6 @@ import dataclasses +import jax import jax.numpy as jnp import jaxsim @@ -74,3 +75,71 @@ def semi_implicit_euler_integration( data = data.replace(model=model) # update cache return data + + +def rk4_integration( + model: js.model.JaxSimModel, + data: JaxSimModelData, + base_acceleration_inertial: jtp.Vector, + joint_accelerations: jtp.Vector, + link_forces: jtp.Vector, + joint_torques: jtp.Vector, +) -> JaxSimModelData: + """Integrate the system state using the Runge-Kutta 4 method.""" + + dt = model.time_step + + def get_state_derivative(data_ode: JaxSimModelData) -> dict: + + # Safe normalize the quaternion. + base_quaternion_norm = jaxsim.math.safe_norm(data_ode.base_quaternion) + base_quaternion = data_ode.base_quaternion / jnp.where( + base_quaternion_norm == 0, 1.0, base_quaternion_norm + ) + + return dict( + base_position=data_ode.base_position, + base_quaternion=base_quaternion, + joint_positions=data_ode.joint_positions, + base_linear_velocity=data_ode.base_velocity[0:3], + base_angular_velocity=data_ode.base_velocity[3:6], + joint_velocities=data_ode.joint_velocities, + ) + + def f(x) -> dict[str, jtp.Matrix]: + + with data.switch_velocity_representation(jaxsim.VelRepr.Inertial): + + data_ti = data.replace( + model=model, **{k.lstrip("_"): v for k, v in x.items()} + ) + + return js.ode.system_dynamics( + model=model, + data=data_ti, + link_forces=link_forces, + joint_torques=joint_torques, + ) + + with data.switch_velocity_representation(jaxsim.VelRepr.Inertial): + x_t0 = get_state_derivative(data) + + euler_mid = lambda x, dxdt: x + (0.5 * dt) * dxdt + euler_fin = lambda x, dxdt: x + dt * dxdt + + k1 = f(x_t0) + k2 = f(jax.tree.map(euler_mid, x_t0, k1)) + k3 = f(jax.tree.map(euler_mid, x_t0, k2)) + k4 = f(jax.tree.map(euler_fin, x_t0, k3)) + + # Average the slopes and compute the RK4 state derivative. + average = lambda k1, k2, k3, k4: (k1 + 2 * k2 + 2 * k3 + k4) / 6 + + dxdt = jax.tree_util.tree_map(average, k1, k2, k3, k4) + + # Integrate the dynamics + x_tf = jax.tree_util.tree_map(euler_fin, x_t0, dxdt) + + data_tf = dataclasses.replace(data, **{"_" + k: v for k, v in x_tf.items()}) + + return data_tf.replace(model=model) From 476579b3ef9ed236693b1c5b1ede738e793b6a97 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Fri, 14 Feb 2025 17:48:52 +0100 Subject: [PATCH 03/12] Add `Integrator` enum to define available simulation integrators --- src/jaxsim/api/integrators.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/jaxsim/api/integrators.py b/src/jaxsim/api/integrators.py index 300c1dd6b..3c7f05bc7 100644 --- a/src/jaxsim/api/integrators.py +++ b/src/jaxsim/api/integrators.py @@ -1,4 +1,5 @@ import dataclasses +from collections.abc import Callable import jax import jax.numpy as jnp @@ -143,3 +144,9 @@ def f(x) -> dict[str, jtp.Matrix]: data_tf = dataclasses.replace(data, **{"_" + k: v for k, v in x_tf.items()}) return data_tf.replace(model=model) + + +_INTEGRATORS_MAP: dict[js.model.Integrator, Callable[..., js.data.JaxSimModelData]] = { + js.model.Integrator.SemiImplicitEuler: semi_implicit_euler_integration, + js.model.Integrator.RungeKutta4: rk4_integration, +} From 3f7ba0131965cea9f1d9fa4d9d48abdc853e7d68 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Fri, 14 Feb 2025 17:50:04 +0100 Subject: [PATCH 04/12] Add NaN check for quaternions --- src/jaxsim/rbda/utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/jaxsim/rbda/utils.py b/src/jaxsim/rbda/utils.py index 9b2614a52..2ef119862 100644 --- a/src/jaxsim/rbda/utils.py +++ b/src/jaxsim/rbda/utils.py @@ -132,6 +132,12 @@ def process_inputs( if W_Q_B.shape != (4,): raise ValueError(W_Q_B.shape, (4,)) + # Check that the quaternion does not contain NaNs. + exceptions.raise_value_error_if( + condition=jnp.isnan(W_Q_B).any(), + msg="A RBDA received a quaternion that contains NaNs.", + ) + # Check that the quaternion is unary since our RBDAs make this assumption in order # to prevent introducing additional normalizations that would affect AD. exceptions.raise_value_error_if( From 4eb684f47bfce77623cb6127a2114cbeacfc5d61 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Fri, 14 Feb 2025 17:59:24 +0100 Subject: [PATCH 05/12] Add `Integrator` option in `JaxSimModel` --- src/jaxsim/api/model.py | 37 ++++++++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index afce17e73..d587dc96e 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -2,6 +2,7 @@ import copy import dataclasses +import enum import functools import pathlib from collections.abc import Sequence @@ -23,6 +24,13 @@ from .common import VelRepr +class Integrator(enum.IntEnum): + """The integrators available for the simulation.""" + + SemiImplicitEuler = enum.auto() + RungeKutta4 = enum.auto() + + @jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False) class JaxSimModel(JaxsimDataclass): """ @@ -55,6 +63,10 @@ class JaxSimModel(JaxsimDataclass): dataclasses.field(default=None, repr=False) ) + integrator: Static[Integrator] = dataclasses.field( + default=Integrator.SemiImplicitEuler, repr=False + ) + built_from: Static[str | pathlib.Path | rod.Model | None] = dataclasses.field( default=None, repr=False ) @@ -111,6 +123,7 @@ def build_from_model_description( terrain: jaxsim.terrain.Terrain | None = None, contact_model: jaxsim.rbda.contacts.ContactModel | None = None, contact_params: jaxsim.rbda.contacts.ContactsParams | None = None, + integrator: Integrator | None = None, is_urdf: bool | None = None, considered_joints: Sequence[str] | None = None, ) -> JaxSimModel: @@ -131,6 +144,7 @@ def build_from_model_description( The contact model to consider. If not specified, a soft contacts model is used. contact_params: The parameters of the contact model. + integrator: The integrator to use for the simulation. is_urdf: The optional flag to force the model description to be parsed as a URDF. This is usually automatically inferred. @@ -164,6 +178,7 @@ def build_from_model_description( terrain=terrain, contact_model=contact_model, contacts_params=contact_params, + integrator=integrator, ) # Store the origin of the model, in case downstream logic needs it. @@ -182,6 +197,7 @@ def build( terrain: jaxsim.terrain.Terrain | None = None, contact_model: jaxsim.rbda.contacts.ContactModel | None = None, contacts_params: jaxsim.rbda.contacts.ContactsParams | None = None, + integrator: Integrator | None = None, gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY, ) -> JaxSimModel: """ @@ -202,6 +218,7 @@ def build( The contact model to consider. If not specified, a soft contacts model is used. contacts_params: The parameters of the soft contacts. + integrator: The integrator to use for the simulation. gravity: The gravity constant. Returns: @@ -237,6 +254,13 @@ def build( if contacts_params is None: contacts_params = contact_model._parameters_class() + # Consider the default integrator if not specified. + integrator = ( + integrator + if integrator is not None + else JaxSimModel.__dataclass_fields__["integrator"].default + ) + # Build the model. model = cls( model_name=model_name, @@ -247,6 +271,7 @@ def build( terrain=terrain, contact_model=contact_model, contacts_params=contacts_params, + integrator=integrator, gravity=gravity, # The following is wrapped as hashless since it's a static argument, and we # don't want to trigger recompilation if it changes. All relevant parameters @@ -449,6 +474,7 @@ def reduce( contact_model=model.contact_model, contacts_params=model.contacts_params, gravity=model.gravity, + integrator=model.integrator, ) # Store the origin of the model, in case downstream logic needs it. @@ -2075,12 +2101,21 @@ def step( # ============================= # Advance the simulation state # ============================= + from .integrators import _INTEGRATORS_MAP + + integrator_fn = _INTEGRATORS_MAP[model.integrator] - data_tf = js.integrators.semi_implicit_euler_integration( + data_tf = integrator_fn( model=model, data=data, base_acceleration_inertial=W_v̇_WB, joint_accelerations=s̈, + # Pass link_forces and joint_torques if the integrator is rk4 + **( + {"link_forces": W_f_L_total, "joint_torques": τ_total} + if model.integrator == js.integrators.rk4_integration + else {} + ), ) return data_tf From 0ebd39e130f58b6550c9be6ae7fd406a05b1a109 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Fri, 14 Feb 2025 17:59:59 +0100 Subject: [PATCH 06/12] Add `pytest` fixture for simulation integrators --- tests/conftest.py | 19 +++++++++++++++++++ tests/test_simulations.py | 3 ++- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 3ad751e85..19cb36599 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,6 +13,7 @@ import jaxsim import jaxsim.api as js +from jaxsim.api.model import Integrator def pytest_addoption(parser): @@ -127,6 +128,24 @@ def velocity_representation(request) -> jaxsim.VelRepr: return request.param +@pytest.fixture( + scope="function", + params=[ + pytest.param(Integrator.SemiImplicitEuler, id="semi_implicit_euler"), + pytest.param(Integrator.RungeKutta4, id="runge_kutta_4"), + ], +) +def integrator(request) -> str: + """ + Fixture providing the integrator to use in the simulation. + + Returns: + The integrator to use in the simulation. + """ + + return request.param + + @pytest.fixture(scope="session") def batch_size(request) -> int: """ diff --git a/tests/test_simulations.py b/tests/test_simulations.py index 50fa4beb9..56a65ef0a 100644 --- a/tests/test_simulations.py +++ b/tests/test_simulations.py @@ -188,7 +188,7 @@ def run_simulation( def test_simulation_with_relaxed_rigid_contacts( - jaxsim_model_box: js.model.JaxSimModel, + jaxsim_model_box: js.model.JaxSimModel, integrator ): model = jaxsim_model_box @@ -206,6 +206,7 @@ def test_simulation_with_relaxed_rigid_contacts( model.kin_dyn_parameters.contact_parameters.enabled = tuple( enabled_collidable_points_mask.tolist() ) + model.integrator = integrator assert np.sum(model.kin_dyn_parameters.contact_parameters.enabled) == 4 From 8974f9a763424541cb17b4f8db1339b3033187cb Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Mon, 17 Feb 2025 11:39:45 +0100 Subject: [PATCH 07/12] Refactor state derivative function to use private attributes for data access --- src/jaxsim/api/integrators.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/jaxsim/api/integrators.py b/src/jaxsim/api/integrators.py index 3c7f05bc7..7d5f5b1c7 100644 --- a/src/jaxsim/api/integrators.py +++ b/src/jaxsim/api/integrators.py @@ -93,18 +93,18 @@ def rk4_integration( def get_state_derivative(data_ode: JaxSimModelData) -> dict: # Safe normalize the quaternion. - base_quaternion_norm = jaxsim.math.safe_norm(data_ode.base_quaternion) - base_quaternion = data_ode.base_quaternion / jnp.where( + base_quaternion_norm = jaxsim.math.safe_norm(data_ode._base_quaternion) + base_quaternion = data_ode._base_quaternion / jnp.where( base_quaternion_norm == 0, 1.0, base_quaternion_norm ) return dict( - base_position=data_ode.base_position, + base_position=data_ode._base_position, base_quaternion=base_quaternion, - joint_positions=data_ode.joint_positions, - base_linear_velocity=data_ode.base_velocity[0:3], - base_angular_velocity=data_ode.base_velocity[3:6], - joint_velocities=data_ode.joint_velocities, + joint_positions=data_ode._joint_positions, + base_linear_velocity=data_ode._base_linear_velocity, + base_angular_velocity=data_ode._base_angular_velocity, + joint_velocities=data_ode._joint_velocities, ) def f(x) -> dict[str, jtp.Matrix]: @@ -122,8 +122,7 @@ def f(x) -> dict[str, jtp.Matrix]: joint_torques=joint_torques, ) - with data.switch_velocity_representation(jaxsim.VelRepr.Inertial): - x_t0 = get_state_derivative(data) + x_t0 = get_state_derivative(data) euler_mid = lambda x, dxdt: x + (0.5 * dt) * dxdt euler_fin = lambda x, dxdt: x + dt * dxdt From 39146eb6f1bd1feb467bc3d3749080035826ff9f Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Mon, 17 Feb 2025 11:42:55 +0100 Subject: [PATCH 08/12] Replace `Integrator` with `IntegratorType` --- src/jaxsim/api/integrators.py | 8 +++++--- src/jaxsim/api/model.py | 10 +++++----- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/jaxsim/api/integrators.py b/src/jaxsim/api/integrators.py index 7d5f5b1c7..079923bbb 100644 --- a/src/jaxsim/api/integrators.py +++ b/src/jaxsim/api/integrators.py @@ -145,7 +145,9 @@ def f(x) -> dict[str, jtp.Matrix]: return data_tf.replace(model=model) -_INTEGRATORS_MAP: dict[js.model.Integrator, Callable[..., js.data.JaxSimModelData]] = { - js.model.Integrator.SemiImplicitEuler: semi_implicit_euler_integration, - js.model.Integrator.RungeKutta4: rk4_integration, +_INTEGRATORS_MAP: dict[ + js.model.IntegratorType, Callable[..., js.data.JaxSimModelData] +] = { + js.model.IntegratorType.SemiImplicitEuler: semi_implicit_euler_integration, + js.model.IntegratorType.RungeKutta4: rk4_integration, } diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index d587dc96e..39bd2da02 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -24,7 +24,7 @@ from .common import VelRepr -class Integrator(enum.IntEnum): +class IntegratorType(enum.IntEnum): """The integrators available for the simulation.""" SemiImplicitEuler = enum.auto() @@ -63,8 +63,8 @@ class JaxSimModel(JaxsimDataclass): dataclasses.field(default=None, repr=False) ) - integrator: Static[Integrator] = dataclasses.field( - default=Integrator.SemiImplicitEuler, repr=False + integrator: Static[IntegratorType] = dataclasses.field( + default=IntegratorType.SemiImplicitEuler, repr=False ) built_from: Static[str | pathlib.Path | rod.Model | None] = dataclasses.field( @@ -123,7 +123,7 @@ def build_from_model_description( terrain: jaxsim.terrain.Terrain | None = None, contact_model: jaxsim.rbda.contacts.ContactModel | None = None, contact_params: jaxsim.rbda.contacts.ContactsParams | None = None, - integrator: Integrator | None = None, + integrator: IntegratorType | None = None, is_urdf: bool | None = None, considered_joints: Sequence[str] | None = None, ) -> JaxSimModel: @@ -197,7 +197,7 @@ def build( terrain: jaxsim.terrain.Terrain | None = None, contact_model: jaxsim.rbda.contacts.ContactModel | None = None, contacts_params: jaxsim.rbda.contacts.ContactsParams | None = None, - integrator: Integrator | None = None, + integrator: IntegratorType | None = None, gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY, ) -> JaxSimModel: """ From 22c1dc179e6313c9078fae4f578dc8ca53f96f78 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Mon, 17 Feb 2025 14:04:53 +0100 Subject: [PATCH 09/12] Simplify handling of ODE state --- src/jaxsim/api/integrators.py | 47 +++++++++++++++++++---------------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/src/jaxsim/api/integrators.py b/src/jaxsim/api/integrators.py index 079923bbb..ef50b5be6 100644 --- a/src/jaxsim/api/integrators.py +++ b/src/jaxsim/api/integrators.py @@ -90,30 +90,11 @@ def rk4_integration( dt = model.time_step - def get_state_derivative(data_ode: JaxSimModelData) -> dict: - - # Safe normalize the quaternion. - base_quaternion_norm = jaxsim.math.safe_norm(data_ode._base_quaternion) - base_quaternion = data_ode._base_quaternion / jnp.where( - base_quaternion_norm == 0, 1.0, base_quaternion_norm - ) - - return dict( - base_position=data_ode._base_position, - base_quaternion=base_quaternion, - joint_positions=data_ode._joint_positions, - base_linear_velocity=data_ode._base_linear_velocity, - base_angular_velocity=data_ode._base_angular_velocity, - joint_velocities=data_ode._joint_velocities, - ) - def f(x) -> dict[str, jtp.Matrix]: with data.switch_velocity_representation(jaxsim.VelRepr.Inertial): - data_ti = data.replace( - model=model, **{k.lstrip("_"): v for k, v in x.items()} - ) + data_ti = data.replace(model=model, **x) return js.ode.system_dynamics( model=model, @@ -122,7 +103,19 @@ def f(x) -> dict[str, jtp.Matrix]: joint_torques=joint_torques, ) - x_t0 = get_state_derivative(data) + base_quaternion_norm = jaxsim.math.safe_norm(data._base_quaternion) + base_quaternion = data._base_quaternion / jnp.where( + base_quaternion_norm == 0, 1.0, base_quaternion_norm + ) + + x_t0 = dict( + base_position=data._base_position, + base_quaternion=base_quaternion, + joint_positions=data._joint_positions, + base_linear_velocity=data._base_linear_velocity, + base_angular_velocity=data._base_angular_velocity, + joint_velocities=data._joint_velocities, + ) euler_mid = lambda x, dxdt: x + (0.5 * dt) * dxdt euler_fin = lambda x, dxdt: x + dt * dxdt @@ -140,7 +133,17 @@ def f(x) -> dict[str, jtp.Matrix]: # Integrate the dynamics x_tf = jax.tree_util.tree_map(euler_fin, x_t0, dxdt) - data_tf = dataclasses.replace(data, **{"_" + k: v for k, v in x_tf.items()}) + data_tf = dataclasses.replace( + data, + **{ + "_base_position": x_tf["base_position"], + "_base_quaternion": x_tf["base_quaternion"], + "_joint_positions": x_tf["joint_positions"], + "_base_linear_velocity": x_tf["base_linear_velocity"], + "_base_angular_velocity": x_tf["base_angular_velocity"], + "_joint_velocities": x_tf["joint_velocities"], + }, + ) return data_tf.replace(model=model) From a6af3b9e009baefb281fe0f9fc7ceb72e980bd1b Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Mon, 17 Feb 2025 16:50:45 +0100 Subject: [PATCH 10/12] Fix integrator type identification --- src/jaxsim/api/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 39bd2da02..f9bb48fc7 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -2113,7 +2113,7 @@ def step( # Pass link_forces and joint_torques if the integrator is rk4 **( {"link_forces": W_f_L_total, "joint_torques": τ_total} - if model.integrator == js.integrators.rk4_integration + if model.integrator == IntegratorType.RungeKutta4 else {} ), ) From c515062efaf35235dfda170883e90b17fba6a916 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 18 Feb 2025 16:05:12 +0100 Subject: [PATCH 11/12] Update docstrings and return type of `js.ode.system_dynamics` --- src/jaxsim/api/ode.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/jaxsim/api/ode.py b/src/jaxsim/api/ode.py index e4cf2658a..2f61e1d43 100644 --- a/src/jaxsim/api/ode.py +++ b/src/jaxsim/api/ode.py @@ -3,7 +3,6 @@ import jaxsim.api as js import jaxsim.typing as jtp -from jaxsim.api.data import JaxSimModelData from jaxsim.math import Quaternion, Skew from .common import VelRepr @@ -123,7 +122,7 @@ def system_dynamics( link_forces: jtp.Vector | None = None, joint_torques: jtp.Vector | None = None, baumgarte_quaternion_regularization: jtp.FloatLike = 1.0, -) -> JaxSimModelData: +) -> dict[str, jtp.Vector]: """ Compute the dynamics of the system. @@ -139,9 +138,9 @@ def system_dynamics( quaternion (only used in integrators not operating on the SO(3) manifold). Returns: - A tuple with an `JaxSimModelData` object storing in each of its attributes the - corresponding derivative, and the dictionary of auxiliary data returned - by the system dynamics evaluation. + A dictionary containing the derivatives of the base position, the base quaternion, + the joint positions, the base linear velocity, the base angular velocity, and the + joint velocities. """ with data.switch_velocity_representation(velocity_representation=VelRepr.Inertial): From 7a2f1b9b56985140fe3a4d05b1171185697b72a3 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 18 Feb 2025 16:38:45 +0100 Subject: [PATCH 12/12] Fix import in tests --- tests/conftest.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 19cb36599..7c97a7f2d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,7 +13,7 @@ import jaxsim import jaxsim.api as js -from jaxsim.api.model import Integrator +from jaxsim.api.model import IntegratorType def pytest_addoption(parser): @@ -131,8 +131,8 @@ def velocity_representation(request) -> jaxsim.VelRepr: @pytest.fixture( scope="function", params=[ - pytest.param(Integrator.SemiImplicitEuler, id="semi_implicit_euler"), - pytest.param(Integrator.RungeKutta4, id="runge_kutta_4"), + pytest.param(IntegratorType.SemiImplicitEuler, id="semi_implicit_euler"), + pytest.param(IntegratorType.RungeKutta4, id="runge_kutta_4"), ], ) def integrator(request) -> str: