From a80cbaf184f027e7c63b2967f3de5c132af2b0f6 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Fri, 11 Oct 2024 10:55:26 +0200 Subject: [PATCH] Add new tests for all contact models --- tests/test_simulations.py | 240 +++++++++++++++++++++++++++++++++++++- 1 file changed, 238 insertions(+), 2 deletions(-) diff --git a/tests/test_simulations.py b/tests/test_simulations.py index bc928a957..c513ae1c2 100644 --- a/tests/test_simulations.py +++ b/tests/test_simulations.py @@ -1,3 +1,5 @@ +import functools + import jax import jax.numpy as jnp import pytest @@ -5,8 +7,8 @@ import jaxsim.api as js import jaxsim.integrators import jaxsim.rbda +import jaxsim.typing as jtp from jaxsim import VelRepr -from jaxsim.utils import Mutability def test_box_with_external_forces( @@ -102,7 +104,7 @@ def test_box_with_zero_gravity( model = jaxsim_model_box # Move the terrain (almost) infinitely far away from the box. - with model.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): + with model.editable(validate=False) as model: model.terrain = jaxsim.terrain.FlatTerrain.build(height=-1e9) # Split the PRNG key. @@ -186,3 +188,237 @@ def test_box_with_zero_gravity( + 0.5 * LW_f[:, :3].squeeze() / js.model.total_mass(model=model) * tf**2, abs=1e-3, ) + + +def run_simulation( + model: js.model.JaxSimModel, + data_t0: js.data.JaxSimModelData, + dt: jtp.FloatLike, + tf: jtp.FloatLike, +) -> js.data.JaxSimModelData: + + @functools.cache + def get_integrator() -> tuple[jaxsim.integrators.Integrator, dict[str, jtp.PyTree]]: + + # Create the integrator. + integrator = jaxsim.integrators.fixed_step.Heun2.build( + fsal_enabled_if_supported=False, + dynamics=js.ode.wrap_system_dynamics_for_integration( + model=model, + data=data_t0, + system_dynamics=js.ode.system_dynamics, + ), + ) + + # Initialize the integrator state. + integrator_state_t0 = integrator.init(x0=data_t0.state, t0=0.0, dt=dt) + + return integrator, integrator_state_t0 + + # Initialize the integration horizon. + T_ns = jnp.arange(start=0.0, stop=int(tf * 1e9), step=int(dt * 1e9)).astype(int) + + # Initialize the simulation data. + integrator = None + integrator_state = None + data = data_t0.copy() + + for t_ns in T_ns: + + match model.contact_model: + + case jaxsim.rbda.contacts.ViscoElasticContacts(): + + data, _ = jaxsim.rbda.contacts.visco_elastic.step( + model=model, + data=data, + dt=dt, + ) + + case _: + + integrator, integrator_state = ( + get_integrator() if t_ns == 0 else (integrator, integrator_state) + ) + + data, integrator_state = js.model.step( + model=model, + data=data, + dt=dt, + integrator=integrator, + integrator_state=integrator_state, + ) + + return data + + +def test_simulation_with_soft_contacts( + jaxsim_model_box: js.model.JaxSimModel, +): + + model = jaxsim_model_box + + with model.editable(validate=False) as model: + + model.contact_model = jaxsim.rbda.contacts.SoftContacts.build( + terrain=model.terrain, + ) + + # Initialize the maximum penetration of each collidable point at steady state. + max_penetration = 0.001 + + # Check jaxsim_model_box@conftest.py. + box_height = 0.1 + + # Build the data of the model. + data_t0 = js.data.JaxSimModelData.build( + model=model, + base_position=jnp.array([0.0, 0.0, box_height * 2]), + velocity_representation=VelRepr.Inertial, + contacts_params=js.contact.estimate_good_contact_parameters( + model=model, + number_of_active_collidable_points_steady_state=4, + static_friction_coefficient=1.0, + damping_ratio=1.0, + max_penetration=0.001, + ), + ) + + # =========================================== + # Run the simulation and test the final state + # =========================================== + + data_tf = run_simulation(model=model, data_t0=data_t0, dt=0.001, tf=1.0) + + assert data_tf.base_position()[0:2] == pytest.approx(data_t0.base_position()[0:2]) + assert data_tf.base_position()[2] + max_penetration == pytest.approx(box_height / 2) + + +def test_simulation_with_visco_elastic_contacts( + jaxsim_model_box: js.model.JaxSimModel, +): + + model = jaxsim_model_box + + with model.editable(validate=False) as model: + + model.contact_model = jaxsim.rbda.contacts.ViscoElasticContacts.build( + terrain=model.terrain, + ) + + # Initialize the maximum penetration of each collidable point at steady state. + max_penetration = 0.001 + + # Check jaxsim_model_box@conftest.py. + box_height = 0.1 + + # Build the data of the model. + data_t0 = js.data.JaxSimModelData.build( + model=model, + base_position=jnp.array([0.0, 0.0, box_height * 2]), + velocity_representation=VelRepr.Inertial, + contacts_params=js.contact.estimate_good_contact_parameters( + model=model, + number_of_active_collidable_points_steady_state=4, + static_friction_coefficient=1.0, + damping_ratio=1.0, + max_penetration=0.001, + ), + ) + + # =========================================== + # Run the simulation and test the final state + # =========================================== + + data_tf = run_simulation(model=model, data_t0=data_t0, dt=0.001, tf=1.0) + + assert data_tf.base_position()[0:2] == pytest.approx(data_t0.base_position()[0:2]) + assert data_tf.base_position()[2] + max_penetration == pytest.approx(box_height / 2) + + +def test_simulation_with_rigid_contacts( + jaxsim_model_box: js.model.JaxSimModel, +): + + model = jaxsim_model_box + + with model.editable(validate=False) as model: + + model.contact_model = jaxsim.rbda.contacts.RigidContacts.build( + terrain=model.terrain, + ) + + # Initialize the maximum penetration of each collidable point at steady state. + # This model is rigid, so we expect (almost) no penetration. + max_penetration = 0.000 + + # Check jaxsim_model_box@conftest.py. + box_height = 0.1 + + # Build the data of the model. + data_t0 = js.data.JaxSimModelData.build( + model=model, + base_position=jnp.array([0.0, 0.0, box_height * 2]), + velocity_representation=VelRepr.Inertial, + # In order to achieve almost no penetration, we need to use a fairly large + # Baumgarte stabilization term. + contacts_params=js.contact.estimate_good_contact_parameters( + model=model, + K=100_000, + ), + ) + + # =========================================== + # Run the simulation and test the final state + # =========================================== + + data_tf = run_simulation(model=model, data_t0=data_t0, dt=0.001, tf=1.0) + + assert data_tf.base_position()[0:2] == pytest.approx(data_t0.base_position()[0:2]) + assert data_tf.base_position()[2] + max_penetration == pytest.approx(box_height / 2) + + +def test_simulation_with_relaxed_rigid_contacts( + jaxsim_model_box: js.model.JaxSimModel, +): + + model = jaxsim_model_box + + with model.editable(validate=False) as model: + + model.contact_model = jaxsim.rbda.contacts.RelaxedRigidContacts.build( + terrain=model.terrain, + ) + + # Initialize the maximum penetration of each collidable point at steady state. + # This model is quasi-rigid, so we expect (almost) no penetration. + max_penetration = 0.000 + + # Check jaxsim_model_box@conftest.py. + box_height = 0.1 + + # Build the data of the model. + data_t0 = js.data.JaxSimModelData.build( + model=model, + base_position=jnp.array([0.0, 0.0, box_height * 2]), + velocity_representation=VelRepr.Inertial, + # In order to achieve almost no penetration, we need to use a fairly large + # Baumgarte stabilization term. + contacts_params=js.contact.estimate_good_contact_parameters( + model=model, + time_constant=0.001, + ), + ) + # =========================================== + # Run the simulation and test the final state + # =========================================== + + data_tf = run_simulation(model=model, data_t0=data_t0, dt=0.001, tf=1.0) + + # With this contact model, we need to slightly adjust the tolerance on xy. + assert data_tf.base_position()[0:2] == pytest.approx( + data_t0.base_position()[0:2], abs=0.000_010 + ) + assert data_tf.base_position()[2] + max_penetration == pytest.approx( + box_height / 2, abs=0.000_100 + )