Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor contact models and add tests #260

Merged
merged 13 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 58 additions & 19 deletions src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,10 @@ def collidable_point_kinematics(
the linear component of the mixed 6D frame velocity.
"""

from jaxsim.rbda import collidable_points

# Switch to inertial-fixed since the RBDAs expect velocities in this representation.
with data.switch_velocity_representation(VelRepr.Inertial):
W_p_Ci, W_ṗ_Ci = collidable_points.collidable_points_pos_vel(

W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_points.collidable_points_pos_vel(
model=model,
base_position=data.base_position(),
base_quaternion=data.base_orientation(dcm=False),
Expand Down Expand Up @@ -304,6 +303,15 @@ def in_contact(


def estimate_good_soft_contacts_parameters(
*args, **kwargs
) -> jaxsim.rbda.contacts.ContactParamsTypes:

msg = "This method is deprecated, please use `{}`."
logging.warning(msg.format(estimate_good_contact_parameters.__name__))
return estimate_good_contact_parameters(*args, **kwargs)


def estimate_good_contact_parameters(
model: js.model.JaxSimModel,
*,
standard_gravity: jtp.FloatLike = jaxsim.math.StandardGravity,
Expand All @@ -312,14 +320,9 @@ def estimate_good_soft_contacts_parameters(
damping_ratio: jtp.FloatLike = 1.0,
max_penetration: jtp.FloatLike | None = None,
**kwargs,
) -> (
jaxsim.rbda.contacts.RelaxedRigidContactsParams
| jaxsim.rbda.contacts.RigidContactsParams
| jaxsim.rbda.contacts.SoftContactsParams
| jaxsim.rbda.contacts.ViscoElasticContactsParams
):
) -> jaxsim.rbda.contacts.ContactParamsTypes:
"""
Estimate good parameters for soft-like contact models.
Estimate good contact parameters.

Args:
model: The model to consider.
Expand All @@ -332,12 +335,19 @@ def estimate_good_soft_contacts_parameters(
max_penetration:
The maximum penetration allowed in steady state when the robot is
supported by the configured number of active collidable points.
kwargs:
Additional model-specific parameters passed to the builder method of
the parameters class.

Returns:
The estimated good soft contacts parameters.
The estimated good contacts parameters.

Note:
This is primarily a convenience function for soft-like contact models.
However, it provides with some good default parameters also for the other ones.

Note:
This method provides a good starting point for the soft contacts parameters.
This method provides a good set of contacts parameters.
The user is encouraged to fine-tune the parameters based on the
specific application.
"""
Expand All @@ -364,6 +374,7 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:
max_δ = (
max_penetration
if max_penetration is not None
# Consider as default a 0.5% of the model height.
else 0.005 * estimate_model_height(model=model)
)

Expand All @@ -381,8 +392,11 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:
max_penetration=max_δ,
number_of_active_collidable_points_steady_state=nc,
damping_ratio=damping_ratio,
p=model.contact_model.parameters.p,
q=model.contact_model.parameters.q,
**dict(
p=model.contact_model.parameters.p,
q=model.contact_model.parameters.q,
)
| kwargs,
)

case contacts.ViscoElasticContacts():
Expand All @@ -396,15 +410,40 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:
max_penetration=max_δ,
number_of_active_collidable_points_steady_state=nc,
damping_ratio=damping_ratio,
p=model.contact_model.parameters.p,
q=model.contact_model.parameters.q,
**kwargs,
**dict(
p=model.contact_model.parameters.p,
q=model.contact_model.parameters.q,
)
| kwargs,
)
)

case contacts.RigidContacts():
assert isinstance(model.contact_model, contacts.RigidContacts)

# Disable Baumgarte stabilization by default since it does not play
# well with the forward Euler integrator.
K = kwargs.get("K", 0.0)

parameters = contacts.RigidContactsParams.build(
mu=static_friction_coefficient,
**dict(
K=K,
D=2 * jnp.sqrt(K),
)
| kwargs,
)

case contacts.RelaxedRigidContacts():
assert isinstance(model.contact_model, contacts.RelaxedRigidContacts)

parameters = contacts.RelaxedRigidContactsParams.build(
mu=static_friction_coefficient,
**kwargs,
)

case _:
logging.warning("The active contact model is not soft-like, no-op.")
parameters = model.contact_model.parameters
raise ValueError(f"Invalid contact model: {model.contact_model}")

return parameters

Expand Down
5 changes: 3 additions & 2 deletions src/jaxsim/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):

state: ODEState

gravity: jtp.Array
gravity: jtp.Vector

contacts_params: jaxsim.rbda.contacts.ContactsParams = dataclasses.field(repr=False)

Expand Down Expand Up @@ -224,7 +224,8 @@ def build(
jaxsim.rbda.contacts.SoftContacts
| jaxsim.rbda.contacts.ViscoElasticContacts,
):
contacts_params = js.contact.estimate_good_soft_contacts_parameters(

contacts_params = js.contact.estimate_good_contact_parameters(
model=model, standard_gravity=standard_gravity
)

Expand Down
20 changes: 8 additions & 12 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class JaxSimModel(JaxsimDataclass):
default_factory=jaxsim.terrain.FlatTerrain.build, repr=False
)

# Note that this is the default contact model.
# Its parameters, if any, are then overridden from those stored in JaxSimModelData.
contact_model: jaxsim.rbda.contacts.ContactModel | None = dataclasses.field(
default=None, repr=False
)
Expand Down Expand Up @@ -2044,24 +2046,18 @@ def step(
M = js.model.free_floating_mass_matrix(model, data_tf)
W_p_C = js.contact.collidable_point_positions(model, data_tf)

# Compute the height of the terrain below each collidable point.
px, py, _ = W_p_C.T
terrain_height = jax.vmap(model.terrain.height)(px, py)

# Compute the contact state.
inactive_collidable_points, _ = (
jaxsim.rbda.contacts.RigidContacts.detect_contacts(
W_p_C=W_p_C,
terrain_height=terrain_height,
)
)
# Compute the penetration depth of the collidable points.
δ, *_ = jax.vmap(
jaxsim.rbda.contacts.common.compute_penetration_data,
in_axes=(0, 0, None),
)(W_p_C, jnp.zeros_like(W_p_C), model.terrain)

# Compute the impact velocity.
# It may be discontinuous in case new contacts are made.
BW_nu_post_impact = (
jaxsim.rbda.contacts.RigidContacts.compute_impact_velocity(
data=data_tf,
inactive_collidable_points=inactive_collidable_points,
inactive_collidable_points=(δ <= 0),
M=M,
J_WC=J_WC,
)
Expand Down
7 changes: 7 additions & 0 deletions src/jaxsim/rbda/contacts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,10 @@
from .rigid import RigidContacts, RigidContactsParams
from .soft import SoftContacts, SoftContactsParams
from .visco_elastic import ViscoElasticContacts, ViscoElasticContactsParams

ContactParamsTypes = (
SoftContactsParams
| RigidContactsParams
| RelaxedRigidContactsParams
| ViscoElasticContactsParams
)
52 changes: 49 additions & 3 deletions src/jaxsim/rbda/contacts/common.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from __future__ import annotations

import abc
import functools
from typing import Any

import jax
import jax.numpy as jnp

import jaxsim.api as js
import jaxsim.terrain
import jaxsim.typing as jtp
Expand All @@ -14,6 +18,47 @@
from typing_extensions import Self


@functools.partial(jax.jit, static_argnames=("terrain",))
def compute_penetration_data(
p: jtp.VectorLike,
v: jtp.VectorLike,
terrain: jaxsim.terrain.Terrain,
) -> tuple[jtp.Float, jtp.Float, jtp.Vector]:
"""
Compute the penetration data (depth, rate, and terrain normal) of a collidable point.

Args:
p: The position of the collidable point.
v:
The linear velocity of the point (linear component of the mixed 6D velocity
of the implicit frame `C = (W_p_C, [W])` associated to the point).
terrain: The considered terrain.

Returns:
A tuple containing the penetration depth, the penetration velocity,
and the considered terrain normal.
"""

# Pre-process the position and the linear velocity of the collidable point.
W_ṗ_C = jnp.array(v).squeeze()
px, py, pz = jnp.array(p).squeeze()

# Compute the terrain normal and the contact depth.
n̂ = terrain.normal(x=px, y=py).squeeze()
h = jnp.array([0, 0, terrain.height(x=px, y=py) - pz])

# Compute the penetration depth normal to the terrain.
δ = jnp.maximum(0.0, jnp.dot(h, n̂))

# Compute the penetration normal velocity.
δ_dot = -jnp.dot(W_ṗ_C, n̂)

# Enforce the penetration rate to be zero when the penetration depth is zero.
δ_dot = jnp.where(δ > 0, δ_dot, 0.0)

return δ, δ_dot, n̂


class ContactsParams(JaxsimDataclass):
"""
Abstract class representing the parameters of a contact model.
Expand Down Expand Up @@ -86,7 +131,7 @@ def compute_contact_forces(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
**kwargs,
) -> tuple[jtp.Vector, tuple[Any, ...]]:
) -> tuple[jtp.Matrix, tuple[Any, ...]]:
"""
Compute the contact forces.

Expand All @@ -95,8 +140,9 @@ def compute_contact_forces(
data: The data of the considered model.

Returns:
A tuple containing as first element the computed 6D contact force applied to the contact point and expressed in the world frame,
and as second element a tuple of optional additional information.
A tuple containing as first element the computed 6D contact force applied to
the contact points and expressed in the world frame, and as second element
a tuple of optional additional information.
"""

pass
Expand Down
Loading