Skip to content

Commit

Permalink
Merge pull request #248 from ami-iit/visco_elastic_contacts
Browse files Browse the repository at this point in the history
Add new `ViscoElasticContacts`
  • Loading branch information
diegoferigo authored Oct 8, 2024
2 parents d2be436 + aa41dfc commit 9c695f7
Show file tree
Hide file tree
Showing 6 changed files with 1,152 additions and 23 deletions.
83 changes: 67 additions & 16 deletions src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
import jax.numpy as jnp

import jaxsim.api as js
import jaxsim.exceptions
import jaxsim.terrain
import jaxsim.typing as jtp
from jaxsim import logging
from jaxsim.math import Adjoint, Cross, Transform
from jaxsim.rbda import contacts

from .common import VelRepr

Expand Down Expand Up @@ -156,14 +159,11 @@ def collidable_point_dynamics(
Instead, the 6D forces are returned in the active representation.
"""

# Import privately the contacts classes.
from jaxsim.rbda.contacts import RelaxedRigidContacts, RigidContacts, SoftContacts

# Build the soft contact model.
match model.contact_model:

case SoftContacts():
assert isinstance(model.contact_model, SoftContacts)
case contacts.SoftContacts():
assert isinstance(model.contact_model, contacts.SoftContacts)

# Compute the 6D force expressed in the inertial frame and applied to each
# collidable point, and the corresponding material deformation rate.
Expand All @@ -178,8 +178,8 @@ def collidable_point_dynamics(
# of the ODE system. We need to pass its dynamics to the integrator.
aux_data = dict(m_dot=CW_ṁ)

case RigidContacts():
assert isinstance(model.contact_model, RigidContacts)
case contacts.RigidContacts():
assert isinstance(model.contact_model, contacts.RigidContacts)

# Compute the 6D force expressed in the inertial frame and applied to each
# collidable point.
Expand All @@ -192,8 +192,8 @@ def collidable_point_dynamics(

aux_data = dict()

case RelaxedRigidContacts():
assert isinstance(model.contact_model, RelaxedRigidContacts)
case contacts.RelaxedRigidContacts():
assert isinstance(model.contact_model, contacts.RelaxedRigidContacts)

# Compute the 6D force expressed in the inertial frame and applied to each
# collidable point.
Expand All @@ -206,6 +206,31 @@ def collidable_point_dynamics(

aux_data = dict()

case contacts.ViscoElasticContacts():
assert isinstance(model.contact_model, contacts.ViscoElasticContacts)

# It is not yet clear how to pass the time step to this stage.
# A possibility is to restrict the integrator to only forward Euler
# and store the Δt inside the model.
module = jaxsim.rbda.contacts.visco_elastic.step.__module__
name = jaxsim.rbda.contacts.visco_elastic.step.__name__
msg = "You need to use the custom '{}.{}' function with this contact model."
jaxsim.exceptions.raise_runtime_error_if(
condition=True, msg=msg.format(module, name)
)

# Compute the 6D force expressed in the inertial frame and applied to each
# collidable point.
W_f_Ci, (W_f̿_Ci, m_tf) = model.contact_model.compute_contact_forces(
model=model,
data=data,
dt=None, # TODO
link_forces=link_forces,
joint_force_references=joint_force_references,
)

aux_data = dict(W_f_avg2_C=W_f̿_Ci, m_tf=m_tf)

case _:
raise ValueError(f"Invalid contact model {model.contact_model}")

Expand Down Expand Up @@ -278,7 +303,6 @@ def in_contact(
return links_in_contact


@jax.jit
def estimate_good_soft_contacts_parameters(
model: js.model.JaxSimModel,
*,
Expand All @@ -287,9 +311,15 @@ def estimate_good_soft_contacts_parameters(
number_of_active_collidable_points_steady_state: jtp.IntLike = 1,
damping_ratio: jtp.FloatLike = 1.0,
max_penetration: jtp.FloatLike | None = None,
) -> jaxsim.rbda.contacts.SoftContactsParams:
**kwargs,
) -> (
jaxsim.rbda.contacts.RelaxedRigidContactsParams
| jaxsim.rbda.contacts.RigidContactsParams
| jaxsim.rbda.contacts.SoftContactsParams
| jaxsim.rbda.contacts.ViscoElasticContactsParams
):
"""
Estimate good soft contacts parameters for the given model.
Estimate good parameters for soft-like contact models.
Args:
model: The model to consider.
Expand All @@ -313,7 +343,10 @@ def estimate_good_soft_contacts_parameters(
"""

def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:
""""""
"""
Displacement between the CoM and the lowest collidable point using zero
joint positions.
"""

zero_data = js.data.JaxSimModelData.build(
model=model,
Expand All @@ -338,21 +371,39 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:

match model.contact_model:

case jaxsim.rbda.contacts.SoftContacts():
assert isinstance(model.contact_model, jaxsim.rbda.contacts.SoftContacts)
case contacts.SoftContacts():
assert isinstance(model.contact_model, contacts.SoftContacts)

parameters = contacts.SoftContactsParams.build_default_from_jaxsim_model(
model=model,
standard_gravity=standard_gravity,
static_friction_coefficient=static_friction_coefficient,
max_penetration=max_δ,
number_of_active_collidable_points_steady_state=nc,
damping_ratio=damping_ratio,
p=model.contact_model.parameters.p,
q=model.contact_model.parameters.q,
)

case contacts.ViscoElasticContacts():
assert isinstance(model.contact_model, contacts.ViscoElasticContacts)

parameters = (
jaxsim.rbda.contacts.SoftContactsParams.build_default_from_jaxsim_model(
contacts.ViscoElasticContactsParams.build_default_from_jaxsim_model(
model=model,
standard_gravity=standard_gravity,
static_friction_coefficient=static_friction_coefficient,
max_penetration=max_δ,
number_of_active_collidable_points_steady_state=nc,
damping_ratio=damping_ratio,
p=model.contact_model.parameters.p,
q=model.contact_model.parameters.q,
**kwargs,
)
)

case _:
logging.warning("The active contact model is not soft-like, no-op.")
parameters = model.contact_model.parameters

return parameters
Expand Down
6 changes: 5 additions & 1 deletion src/jaxsim/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,11 @@ def build(

if contacts_params is None:

if isinstance(model.contact_model, jaxsim.rbda.contacts.SoftContacts):
if isinstance(
model.contact_model,
jaxsim.rbda.contacts.SoftContacts
| jaxsim.rbda.contacts.ViscoElasticContacts,
):
contacts_params = js.contact.estimate_good_soft_contacts_parameters(
model=model, standard_gravity=standard_gravity
)
Expand Down
15 changes: 14 additions & 1 deletion src/jaxsim/api/kin_dyn_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import jax.lax
import jax.numpy as jnp
import jax_dataclasses
import numpy as np
import numpy.typing as npt
from jax_dataclasses import Static

import jaxsim.typing as jtp
Expand Down Expand Up @@ -753,6 +755,13 @@ class ContactParameters(JaxsimDataclass):

point: jtp.Matrix = dataclasses.field(default_factory=lambda: jnp.array([]))

enabled: Static[tuple[bool, ...]] = dataclasses.field(default_factory=tuple)

@property
def indices_of_enabled_collidable_points(self) -> npt.NDArray:

return np.where(np.array(self.enabled))[0]

@staticmethod
def build_from(model_description: ModelDescription) -> ContactParameters:
"""
Expand Down Expand Up @@ -785,7 +794,11 @@ def build_from(model_description: ModelDescription) -> ContactParameters:
)

# Build the ContactParameters object.
cp = ContactParameters(point=points, body=link_index_of_points)
cp = ContactParameters(
point=points,
body=link_index_of_points,
enabled=tuple(True for _ in link_index_of_points),
)

assert cp.point.shape[1] == 3, cp.point.shape[1]
assert cp.point.shape[0] == len(cp.body), cp.point.shape[0]
Expand Down
15 changes: 11 additions & 4 deletions src/jaxsim/api/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import jaxsim.typing as jtp
from jaxsim.integrators import Time
from jaxsim.math import Quaternion
from jaxsim.rbda import contacts

from .common import VelRepr
from .ode_data import ODEState
Expand Down Expand Up @@ -371,8 +372,6 @@ def system_dynamics(
by the system dynamics evaluation.
"""

from jaxsim.rbda.contacts import RelaxedRigidContacts, RigidContacts, SoftContacts

# Compute the accelerations and the material deformation rate.
W_v̇_WB, , aux_dict = system_velocity_dynamics(
model=model,
Expand All @@ -387,10 +386,18 @@ def system_dynamics(

match model.contact_model:

case SoftContacts():
case contacts.SoftContacts():
extended_ode_state["tangential_deformation"] = aux_dict["m_dot"]

case RigidContacts() | RelaxedRigidContacts():
case contacts.ViscoElasticContacts():

extended_ode_state["contacts_state"] = {
"tangential_deformation": jnp.zeros_like(
data.state.extended["tangential_deformation"]
)
}

case contacts.RigidContacts() | contacts.RelaxedRigidContacts():
pass

case _:
Expand Down
3 changes: 2 additions & 1 deletion src/jaxsim/rbda/contacts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from . import relaxed_rigid, rigid, soft
from . import relaxed_rigid, rigid, soft, visco_elastic
from .common import ContactModel, ContactsParams
from .relaxed_rigid import RelaxedRigidContacts, RelaxedRigidContactsParams
from .rigid import RigidContacts, RigidContactsParams
from .soft import SoftContacts, SoftContactsParams
from .visco_elastic import ViscoElasticContacts, ViscoElasticContactsParams
Loading

0 comments on commit 9c695f7

Please sign in to comment.