-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add new tests for all contact models
- Loading branch information
1 parent
e63fda4
commit 2e0b3af
Showing
1 changed file
with
238 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,14 @@ | ||
import functools | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
import pytest | ||
|
||
import jaxsim.api as js | ||
import jaxsim.integrators | ||
import jaxsim.rbda | ||
import jaxsim.typing as jtp | ||
from jaxsim import VelRepr | ||
from jaxsim.utils import Mutability | ||
|
||
|
||
def test_box_with_external_forces( | ||
|
@@ -102,7 +104,7 @@ def test_box_with_zero_gravity( | |
model = jaxsim_model_box | ||
|
||
# Move the terrain (almost) infinitely far away from the box. | ||
with model.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): | ||
with model.editable(validate=False) as model: | ||
model.terrain = jaxsim.terrain.FlatTerrain.build(height=-1e9) | ||
|
||
# Split the PRNG key. | ||
|
@@ -186,3 +188,237 @@ def test_box_with_zero_gravity( | |
+ 0.5 * LW_f[:, :3].squeeze() / js.model.total_mass(model=model) * tf**2, | ||
abs=1e-3, | ||
) | ||
|
||
|
||
def run_simulation( | ||
model: js.model.JaxSimModel, | ||
data_t0: js.data.JaxSimModelData, | ||
dt: jtp.FloatLike, | ||
tf: jtp.FloatLike, | ||
) -> js.data.JaxSimModelData: | ||
|
||
@functools.cache | ||
def get_integrator() -> tuple[jaxsim.integrators.Integrator, dict[str, jtp.PyTree]]: | ||
|
||
# Create the integrator. | ||
integrator = jaxsim.integrators.fixed_step.Heun2.build( | ||
fsal_enabled_if_supported=False, | ||
dynamics=js.ode.wrap_system_dynamics_for_integration( | ||
model=model, | ||
data=data_t0, | ||
system_dynamics=js.ode.system_dynamics, | ||
), | ||
) | ||
|
||
# Initialize the integrator state. | ||
integrator_state_t0 = integrator.init(x0=data_t0.state, t0=0.0, dt=dt) | ||
|
||
return integrator, integrator_state_t0 | ||
|
||
# Initialize the integration horizon. | ||
T_ns = jnp.arange(start=0.0, stop=int(tf * 1e9), step=int(dt * 1e9)).astype(int) | ||
|
||
# Initialize the simulation data. | ||
integrator = None | ||
integrator_state = None | ||
data = data_t0.copy() | ||
|
||
for t_ns in T_ns: | ||
|
||
match model.contact_model: | ||
|
||
case jaxsim.rbda.contacts.ViscoElasticContacts(): | ||
|
||
data, _ = jaxsim.rbda.contacts.visco_elastic.step( | ||
model=model, | ||
data=data, | ||
dt=dt, | ||
) | ||
|
||
case _: | ||
|
||
integrator, integrator_state = ( | ||
get_integrator() if t_ns == 0 else (integrator, integrator_state) | ||
) | ||
|
||
data, integrator_state = js.model.step( | ||
model=model, | ||
data=data, | ||
dt=dt, | ||
integrator=integrator, | ||
integrator_state=integrator_state, | ||
) | ||
|
||
return data | ||
|
||
|
||
def test_simulation_with_soft_contacts( | ||
jaxsim_model_box: js.model.JaxSimModel, | ||
): | ||
|
||
model = jaxsim_model_box | ||
|
||
with model.editable(validate=False) as model: | ||
|
||
model.contact_model = jaxsim.rbda.contacts.SoftContacts.build( | ||
terrain=model.terrain, | ||
) | ||
|
||
# Initialize the maximum penetration of each collidable point at steady state. | ||
max_penetration = 0.001 | ||
|
||
# Check [email protected]. | ||
box_height = 0.1 | ||
|
||
# Build the data of the model. | ||
data_t0 = js.data.JaxSimModelData.build( | ||
model=model, | ||
base_position=jnp.array([0.0, 0.0, box_height * 2]), | ||
velocity_representation=VelRepr.Inertial, | ||
contacts_params=js.contact.estimate_good_contact_parameters( | ||
model=model, | ||
number_of_active_collidable_points_steady_state=4, | ||
static_friction_coefficient=1.0, | ||
damping_ratio=1.0, | ||
max_penetration=0.001, | ||
), | ||
) | ||
|
||
# =========================================== | ||
# Run the simulation and test the final state | ||
# =========================================== | ||
|
||
data_tf = run_simulation(model=model, data_t0=data_t0, dt=0.001, tf=1.0) | ||
|
||
assert data_tf.base_position()[0:2] == pytest.approx(data_t0.base_position()[0:2]) | ||
assert data_tf.base_position()[2] + max_penetration == pytest.approx(box_height / 2) | ||
|
||
|
||
def test_simulation_with_visco_elastic_contacts( | ||
jaxsim_model_box: js.model.JaxSimModel, | ||
): | ||
|
||
model = jaxsim_model_box | ||
|
||
with model.editable(validate=False) as model: | ||
|
||
model.contact_model = jaxsim.rbda.contacts.ViscoElasticContacts.build( | ||
terrain=model.terrain, | ||
) | ||
|
||
# Initialize the maximum penetration of each collidable point at steady state. | ||
max_penetration = 0.001 | ||
|
||
# Check [email protected]. | ||
box_height = 0.1 | ||
|
||
# Build the data of the model. | ||
data_t0 = js.data.JaxSimModelData.build( | ||
model=model, | ||
base_position=jnp.array([0.0, 0.0, box_height * 2]), | ||
velocity_representation=VelRepr.Inertial, | ||
contacts_params=js.contact.estimate_good_contact_parameters( | ||
model=model, | ||
number_of_active_collidable_points_steady_state=4, | ||
static_friction_coefficient=1.0, | ||
damping_ratio=1.0, | ||
max_penetration=0.001, | ||
), | ||
) | ||
|
||
# =========================================== | ||
# Run the simulation and test the final state | ||
# =========================================== | ||
|
||
data_tf = run_simulation(model=model, data_t0=data_t0, dt=0.001, tf=1.0) | ||
|
||
assert data_tf.base_position()[0:2] == pytest.approx(data_t0.base_position()[0:2]) | ||
assert data_tf.base_position()[2] + max_penetration == pytest.approx(box_height / 2) | ||
|
||
|
||
def test_simulation_with_rigid_contacts( | ||
jaxsim_model_box: js.model.JaxSimModel, | ||
): | ||
|
||
model = jaxsim_model_box | ||
|
||
with model.editable(validate=False) as model: | ||
|
||
model.contact_model = jaxsim.rbda.contacts.RigidContacts.build( | ||
terrain=model.terrain, | ||
) | ||
|
||
# Initialize the maximum penetration of each collidable point at steady state. | ||
# This model is rigid, so we expect (almost) no penetration. | ||
max_penetration = 0.000 | ||
|
||
# Check [email protected]. | ||
box_height = 0.1 | ||
|
||
# Build the data of the model. | ||
data_t0 = js.data.JaxSimModelData.build( | ||
model=model, | ||
base_position=jnp.array([0.0, 0.0, box_height * 2]), | ||
velocity_representation=VelRepr.Inertial, | ||
# In order to achieve almost no penetration, we need to use a fairly large | ||
# Baumgarte stabilization term. | ||
contacts_params=js.contact.estimate_good_contact_parameters( | ||
model=model, | ||
K=100_000, | ||
), | ||
) | ||
|
||
# =========================================== | ||
# Run the simulation and test the final state | ||
# =========================================== | ||
|
||
data_tf = run_simulation(model=model, data_t0=data_t0, dt=0.001, tf=1.0) | ||
|
||
assert data_tf.base_position()[0:2] == pytest.approx(data_t0.base_position()[0:2]) | ||
assert data_tf.base_position()[2] + max_penetration == pytest.approx(box_height / 2) | ||
|
||
|
||
def test_simulation_with_relaxed_rigid_contacts( | ||
jaxsim_model_box: js.model.JaxSimModel, | ||
): | ||
|
||
model = jaxsim_model_box | ||
|
||
with model.editable(validate=False) as model: | ||
|
||
model.contact_model = jaxsim.rbda.contacts.RelaxedRigidContacts.build( | ||
terrain=model.terrain, | ||
) | ||
|
||
# Initialize the maximum penetration of each collidable point at steady state. | ||
# This model is quasi-rigid, so we expect (almost) no penetration. | ||
max_penetration = 0.000 | ||
|
||
# Check [email protected]. | ||
box_height = 0.1 | ||
|
||
# Build the data of the model. | ||
data_t0 = js.data.JaxSimModelData.build( | ||
model=model, | ||
base_position=jnp.array([0.0, 0.0, box_height * 2]), | ||
velocity_representation=VelRepr.Inertial, | ||
# In order to achieve almost no penetration, we need to use a fairly large | ||
# Baumgarte stabilization term. | ||
contacts_params=js.contact.estimate_good_contact_parameters( | ||
model=model, | ||
time_constant=0.001, | ||
), | ||
) | ||
# =========================================== | ||
# Run the simulation and test the final state | ||
# =========================================== | ||
|
||
data_tf = run_simulation(model=model, data_t0=data_t0, dt=0.001, tf=1.0) | ||
|
||
# With this contact model, we need to slightly adjust the tolerance on xy. | ||
assert data_tf.base_position()[0:2] == pytest.approx( | ||
data_t0.base_position()[0:2], abs=0.000_010 | ||
) | ||
assert data_tf.base_position()[2] + max_penetration == pytest.approx( | ||
box_height / 2, abs=0.000_100 | ||
) |