Skip to content

Commit

Permalink
Add new tests for all contact models
Browse files Browse the repository at this point in the history
  • Loading branch information
diegoferigo committed Oct 14, 2024
1 parent e63fda4 commit 2e0b3af
Showing 1 changed file with 238 additions and 2 deletions.
240 changes: 238 additions & 2 deletions tests/test_simulations.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import functools

import jax
import jax.numpy as jnp
import pytest

import jaxsim.api as js
import jaxsim.integrators
import jaxsim.rbda
import jaxsim.typing as jtp
from jaxsim import VelRepr
from jaxsim.utils import Mutability


def test_box_with_external_forces(
Expand Down Expand Up @@ -102,7 +104,7 @@ def test_box_with_zero_gravity(
model = jaxsim_model_box

# Move the terrain (almost) infinitely far away from the box.
with model.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
with model.editable(validate=False) as model:
model.terrain = jaxsim.terrain.FlatTerrain.build(height=-1e9)

# Split the PRNG key.
Expand Down Expand Up @@ -186,3 +188,237 @@ def test_box_with_zero_gravity(
+ 0.5 * LW_f[:, :3].squeeze() / js.model.total_mass(model=model) * tf**2,
abs=1e-3,
)


def run_simulation(
model: js.model.JaxSimModel,
data_t0: js.data.JaxSimModelData,
dt: jtp.FloatLike,
tf: jtp.FloatLike,
) -> js.data.JaxSimModelData:

@functools.cache
def get_integrator() -> tuple[jaxsim.integrators.Integrator, dict[str, jtp.PyTree]]:

# Create the integrator.
integrator = jaxsim.integrators.fixed_step.Heun2.build(
fsal_enabled_if_supported=False,
dynamics=js.ode.wrap_system_dynamics_for_integration(
model=model,
data=data_t0,
system_dynamics=js.ode.system_dynamics,
),
)

# Initialize the integrator state.
integrator_state_t0 = integrator.init(x0=data_t0.state, t0=0.0, dt=dt)

return integrator, integrator_state_t0

# Initialize the integration horizon.
T_ns = jnp.arange(start=0.0, stop=int(tf * 1e9), step=int(dt * 1e9)).astype(int)

# Initialize the simulation data.
integrator = None
integrator_state = None
data = data_t0.copy()

for t_ns in T_ns:

match model.contact_model:

case jaxsim.rbda.contacts.ViscoElasticContacts():

data, _ = jaxsim.rbda.contacts.visco_elastic.step(
model=model,
data=data,
dt=dt,
)

case _:

integrator, integrator_state = (
get_integrator() if t_ns == 0 else (integrator, integrator_state)
)

data, integrator_state = js.model.step(
model=model,
data=data,
dt=dt,
integrator=integrator,
integrator_state=integrator_state,
)

return data


def test_simulation_with_soft_contacts(
jaxsim_model_box: js.model.JaxSimModel,
):

model = jaxsim_model_box

with model.editable(validate=False) as model:

model.contact_model = jaxsim.rbda.contacts.SoftContacts.build(
terrain=model.terrain,
)

# Initialize the maximum penetration of each collidable point at steady state.
max_penetration = 0.001

# Check [email protected].
box_height = 0.1

# Build the data of the model.
data_t0 = js.data.JaxSimModelData.build(
model=model,
base_position=jnp.array([0.0, 0.0, box_height * 2]),
velocity_representation=VelRepr.Inertial,
contacts_params=js.contact.estimate_good_contact_parameters(
model=model,
number_of_active_collidable_points_steady_state=4,
static_friction_coefficient=1.0,
damping_ratio=1.0,
max_penetration=0.001,
),
)

# ===========================================
# Run the simulation and test the final state
# ===========================================

data_tf = run_simulation(model=model, data_t0=data_t0, dt=0.001, tf=1.0)

assert data_tf.base_position()[0:2] == pytest.approx(data_t0.base_position()[0:2])
assert data_tf.base_position()[2] + max_penetration == pytest.approx(box_height / 2)


def test_simulation_with_visco_elastic_contacts(
jaxsim_model_box: js.model.JaxSimModel,
):

model = jaxsim_model_box

with model.editable(validate=False) as model:

model.contact_model = jaxsim.rbda.contacts.ViscoElasticContacts.build(
terrain=model.terrain,
)

# Initialize the maximum penetration of each collidable point at steady state.
max_penetration = 0.001

# Check [email protected].
box_height = 0.1

# Build the data of the model.
data_t0 = js.data.JaxSimModelData.build(
model=model,
base_position=jnp.array([0.0, 0.0, box_height * 2]),
velocity_representation=VelRepr.Inertial,
contacts_params=js.contact.estimate_good_contact_parameters(
model=model,
number_of_active_collidable_points_steady_state=4,
static_friction_coefficient=1.0,
damping_ratio=1.0,
max_penetration=0.001,
),
)

# ===========================================
# Run the simulation and test the final state
# ===========================================

data_tf = run_simulation(model=model, data_t0=data_t0, dt=0.001, tf=1.0)

assert data_tf.base_position()[0:2] == pytest.approx(data_t0.base_position()[0:2])
assert data_tf.base_position()[2] + max_penetration == pytest.approx(box_height / 2)


def test_simulation_with_rigid_contacts(
jaxsim_model_box: js.model.JaxSimModel,
):

model = jaxsim_model_box

with model.editable(validate=False) as model:

model.contact_model = jaxsim.rbda.contacts.RigidContacts.build(
terrain=model.terrain,
)

# Initialize the maximum penetration of each collidable point at steady state.
# This model is rigid, so we expect (almost) no penetration.
max_penetration = 0.000

# Check [email protected].
box_height = 0.1

# Build the data of the model.
data_t0 = js.data.JaxSimModelData.build(
model=model,
base_position=jnp.array([0.0, 0.0, box_height * 2]),
velocity_representation=VelRepr.Inertial,
# In order to achieve almost no penetration, we need to use a fairly large
# Baumgarte stabilization term.
contacts_params=js.contact.estimate_good_contact_parameters(
model=model,
K=100_000,
),
)

# ===========================================
# Run the simulation and test the final state
# ===========================================

data_tf = run_simulation(model=model, data_t0=data_t0, dt=0.001, tf=1.0)

assert data_tf.base_position()[0:2] == pytest.approx(data_t0.base_position()[0:2])
assert data_tf.base_position()[2] + max_penetration == pytest.approx(box_height / 2)


def test_simulation_with_relaxed_rigid_contacts(
jaxsim_model_box: js.model.JaxSimModel,
):

model = jaxsim_model_box

with model.editable(validate=False) as model:

model.contact_model = jaxsim.rbda.contacts.RelaxedRigidContacts.build(
terrain=model.terrain,
)

# Initialize the maximum penetration of each collidable point at steady state.
# This model is quasi-rigid, so we expect (almost) no penetration.
max_penetration = 0.000

# Check [email protected].
box_height = 0.1

# Build the data of the model.
data_t0 = js.data.JaxSimModelData.build(
model=model,
base_position=jnp.array([0.0, 0.0, box_height * 2]),
velocity_representation=VelRepr.Inertial,
# In order to achieve almost no penetration, we need to use a fairly large
# Baumgarte stabilization term.
contacts_params=js.contact.estimate_good_contact_parameters(
model=model,
time_constant=0.001,
),
)
# ===========================================
# Run the simulation and test the final state
# ===========================================

data_tf = run_simulation(model=model, data_t0=data_t0, dt=0.001, tf=1.0)

# With this contact model, we need to slightly adjust the tolerance on xy.
assert data_tf.base_position()[0:2] == pytest.approx(
data_t0.base_position()[0:2], abs=0.000_010
)
assert data_tf.base_position()[2] + max_penetration == pytest.approx(
box_height / 2, abs=0.000_100
)

0 comments on commit 2e0b3af

Please sign in to comment.