Skip to content

Commit

Permalink
Merge pull request #269 from ami-iit/update_contact_models
Browse files Browse the repository at this point in the history
Update contact models
  • Loading branch information
diegoferigo authored Oct 23, 2024
2 parents edc4a35 + 103750e commit ff611d9
Show file tree
Hide file tree
Showing 8 changed files with 474 additions and 254 deletions.
125 changes: 48 additions & 77 deletions src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def collidable_point_forces(
data: js.data.JaxSimModelData,
link_forces: jtp.MatrixLike | None = None,
joint_force_references: jtp.VectorLike | None = None,
**kwargs,
) -> jtp.Matrix:
"""
Compute the 6D forces applied to each collidable point.
Expand All @@ -110,6 +111,7 @@ def collidable_point_forces(
representation of data.
joint_force_references:
The joint force references to apply to the joints.
kwargs: Additional keyword arguments to pass to the active contact model.
Returns:
The 6D forces applied to each collidable point expressed in the frame
Expand All @@ -121,6 +123,7 @@ def collidable_point_forces(
data=data,
link_forces=link_forces,
joint_force_references=joint_force_references,
**kwargs,
)

return f_Ci
Expand All @@ -132,7 +135,8 @@ def collidable_point_dynamics(
data: js.data.JaxSimModelData,
link_forces: jtp.MatrixLike | None = None,
joint_force_references: jtp.VectorLike | None = None,
) -> tuple[jtp.Matrix, dict[str, jtp.Array]]:
**kwargs,
) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]:
r"""
Compute the 6D force applied to each collidable point.
Expand All @@ -144,6 +148,7 @@ def collidable_point_dynamics(
representation of data.
joint_force_references:
The joint force references to apply to the joints.
kwargs: Additional keyword arguments to pass to the active contact model.
Returns:
The 6D force applied to each collidable point and additional data based
Expand All @@ -158,86 +163,46 @@ def collidable_point_dynamics(
Instead, the 6D forces are returned in the active representation.
"""

# Build the soft contact model.
# Build the common kw arguments to pass to the computation of the contact forces.
common_kwargs = dict(
link_forces=link_forces,
joint_force_references=joint_force_references,
)

# Build the additional kwargs to pass to the computation of the contact forces.
match model.contact_model:

case contacts.SoftContacts():
assert isinstance(model.contact_model, contacts.SoftContacts)

# Compute the 6D force expressed in the inertial frame and applied to each
# collidable point, and the corresponding material deformation rate.
# Note that the material deformation rate is always returned in the mixed frame
# C[W] = (W_p_C, [W]). This is convenient for integration purpose.
W_f_Ci, (CW_ṁ,) = model.contact_model.compute_contact_forces(
model=model, data=data
)

# Create the dictionary of auxiliary data.
# This contact model considers the material deformation as additional state
# of the ODE system. We need to pass its dynamics to the integrator.
aux_data = dict(m_dot=CW_ṁ)
kwargs_contact_model = {}

case contacts.RigidContacts():
assert isinstance(model.contact_model, contacts.RigidContacts)

# Compute the 6D force expressed in the inertial frame and applied to each
# collidable point.
W_f_Ci, _ = model.contact_model.compute_contact_forces(
model=model,
data=data,
link_forces=link_forces,
joint_force_references=joint_force_references,
)

aux_data = dict()
kwargs_contact_model = common_kwargs | kwargs

case contacts.RelaxedRigidContacts():
assert isinstance(model.contact_model, contacts.RelaxedRigidContacts)

# Compute the 6D force expressed in the inertial frame and applied to each
# collidable point.
W_f_Ci, _ = model.contact_model.compute_contact_forces(
model=model,
data=data,
link_forces=link_forces,
joint_force_references=joint_force_references,
)

aux_data = dict()
kwargs_contact_model = common_kwargs | kwargs

case contacts.ViscoElasticContacts():
assert isinstance(model.contact_model, contacts.ViscoElasticContacts)

# It is not yet clear how to pass the time step to this stage.
# A possibility is to restrict the integrator to only forward Euler
# and store the Δt inside the model.
module = jaxsim.rbda.contacts.visco_elastic.step.__module__
name = jaxsim.rbda.contacts.visco_elastic.step.__name__
msg = "You need to use the custom '{}.{}' function with this contact model."
jaxsim.exceptions.raise_runtime_error_if(
condition=True, msg=msg.format(module, name)
)

# Compute the 6D force expressed in the inertial frame and applied to each
# collidable point.
W_f_Ci, (W_f̿_Ci, m_tf) = model.contact_model.compute_contact_forces(
model=model,
data=data,
dt=None, # TODO
link_forces=link_forces,
joint_force_references=joint_force_references,
)

aux_data = dict(W_f_avg2_C=W_f̿_Ci, m_tf=m_tf)
kwargs_contact_model = common_kwargs | dict(dt=model.time_step) | kwargs

case _:
raise ValueError(f"Invalid contact model {model.contact_model}")
raise ValueError(f"Invalid contact model: {model.contact_model}")

# Compute the contact forces with the active contact model.
W_f_C, aux_data = model.contact_model.compute_contact_forces(
model=model,
data=data,
**kwargs_contact_model,
)

# Compute the transforms of the implicit frames `C[L] = (W_p_C, [L])`
# associated to each collidable point.
# In inertial-fixed representation, the computation of these transforms
# is not necessary and the conversion below becomes a no-op.
W_H_Ci = (
W_H_C = (
js.contact.transforms(model=model, data=data)
if data.velocity_representation is not VelRepr.Inertial
else jnp.zeros(
Expand All @@ -253,7 +218,7 @@ def collidable_point_dynamics(
transform=W_H_C,
is_force=True,
)
)(W_f_Ci, W_H_Ci)
)(W_f_C, W_H_C)

return f_Ci, aux_data

Expand Down Expand Up @@ -392,11 +357,13 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:
max_penetration=max_δ,
number_of_active_collidable_points_steady_state=nc,
damping_ratio=damping_ratio,
**dict(
p=model.contact_model.parameters.p,
q=model.contact_model.parameters.q,
)
| kwargs,
**(
dict(
p=model.contact_model.parameters.p,
q=model.contact_model.parameters.q,
)
| kwargs
),
)

case contacts.ViscoElasticContacts():
Expand All @@ -410,11 +377,13 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:
max_penetration=max_δ,
number_of_active_collidable_points_steady_state=nc,
damping_ratio=damping_ratio,
**dict(
p=model.contact_model.parameters.p,
q=model.contact_model.parameters.q,
)
| kwargs,
**(
dict(
p=model.contact_model.parameters.p,
q=model.contact_model.parameters.q,
)
| kwargs
),
)
)

Expand All @@ -427,11 +396,13 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:

parameters = contacts.RigidContactsParams.build(
mu=static_friction_coefficient,
**dict(
K=K,
D=2 * jnp.sqrt(K),
)
| kwargs,
**(
dict(
K=K,
D=2 * jnp.sqrt(K),
)
| kwargs
),
)

case contacts.RelaxedRigidContacts():
Expand Down
Loading

0 comments on commit ff611d9

Please sign in to comment.