diff --git a/tests/conftest.py b/tests/conftest.py index 3ad751e85..19cb36599 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,6 +13,7 @@ import jaxsim import jaxsim.api as js +from jaxsim.api.model import Integrator def pytest_addoption(parser): @@ -127,6 +128,24 @@ def velocity_representation(request) -> jaxsim.VelRepr: return request.param +@pytest.fixture( + scope="function", + params=[ + pytest.param(Integrator.SemiImplicitEuler, id="semi_implicit_euler"), + pytest.param(Integrator.RungeKutta4, id="runge_kutta_4"), + ], +) +def integrator(request) -> str: + """ + Fixture providing the integrator to use in the simulation. + + Returns: + The integrator to use in the simulation. + """ + + return request.param + + @pytest.fixture(scope="session") def batch_size(request) -> int: """ diff --git a/tests/test_simulations.py b/tests/test_simulations.py index 50fa4beb9..56a65ef0a 100644 --- a/tests/test_simulations.py +++ b/tests/test_simulations.py @@ -188,7 +188,7 @@ def run_simulation( def test_simulation_with_relaxed_rigid_contacts( - jaxsim_model_box: js.model.JaxSimModel, + jaxsim_model_box: js.model.JaxSimModel, integrator ): model = jaxsim_model_box @@ -206,6 +206,7 @@ def test_simulation_with_relaxed_rigid_contacts( model.kin_dyn_parameters.contact_parameters.enabled = tuple( enabled_collidable_points_mask.tolist() ) + model.integrator = integrator assert np.sum(model.kin_dyn_parameters.contact_parameters.enabled) == 4