Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Sprint] Simplify contact model usage in Jaxsim #347

Merged
merged 40 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
7911bbb
erase integrators
younik Jan 20, 2025
3642e9a
fix import
younik Jan 20, 2025
88395a4
remotveintegrators
younik Jan 20, 2025
55f5c4a
restore heun2 integrator
younik Jan 20, 2025
002ad30
tentative change
younik Jan 20, 2025
f019cba
adding quaternion integration from axis angle
CarlottaSartore Jan 20, 2025
054107b
improve semi implicit euler and add velocity representation
CarlottaSartore Jan 20, 2025
342f788
tentative fix
younik Jan 21, 2025
d7cde88
Pass link and joint forces to `js.ode.system_dynamics`
flferretti Jan 21, 2025
deb9378
fix tests but one
younik Jan 21, 2025
42e2c77
Fix position derivative extraction in semi-implicit Euler
flferretti Jan 21, 2025
2091bd0
Add `pytest-benchmark` in environment YAML
flferretti Jan 20, 2025
68560c7
Remove `rigid.py`
xela-95 Jan 20, 2025
9e22325
Remove `soft.py`
xela-95 Jan 20, 2025
f7a37d1
Remove `visco_elastic.py`
xela-95 Jan 20, 2025
67615ac
Remove commented note about default contact model in JaxSimModel
xela-95 Jan 20, 2025
ef8c4b6
Update `jaxsim.rbda.contacts.__init__.py`
xela-95 Jan 20, 2025
ebe90de
Add `_hunt_crossley_contact_model` to initialize RelaxedRigidContacts
xela-95 Jan 20, 2025
5f337dc
Refactor `contact.py` to use only RelaxedRigidContacts
xela-95 Jan 20, 2025
0a7bd9f
Update `JaxSimModel` to use only RelaxedRigidContacts
xela-95 Jan 20, 2025
20078a7
Update `JaxSimModelData` to use only RelaxedRigidContacts
xela-95 Jan 20, 2025
696bd55
Refactor `ode.py` to simplify `system_dynamics` and `system_velocity_…
xela-95 Jan 20, 2025
e40d796
Remove unused soft contacts test and related parameters from automati…
xela-95 Jan 20, 2025
f9627ad
Refactor simulation tests to remove soft contacts and streamline cont…
xela-95 Jan 20, 2025
42e86b8
Create `contact_model.py` to contain all functions related to contact…
xela-95 Jan 21, 2025
d38a1a8
Add `contact_model` __init__.py
xela-95 Jan 21, 2025
b48e7f8
Remove `collidable_point_dynamics` function from contact module
xela-95 Jan 21, 2025
f7aa6a7
Remove `link_contact_forces` function from model API
xela-95 Jan 21, 2025
2a0903b
Remove compute_link_contact_forces and link_forces_from_contact_force…
xela-95 Jan 21, 2025
7be7794
Update link_contact_forces reference to use contact_model module
xela-95 Jan 21, 2025
721db59
Remove collidable_point_forces function from contact.py
xela-95 Jan 21, 2025
15a7ff1
Refactor `link_contact_forces` to compute contact forces in inertial …
xela-95 Jan 21, 2025
3a61ac6
Refactor `link_forces_from_contact_forces` to work only in inertial
xela-95 Jan 21, 2025
bb2afa3
Update docstring to clarify that computed contact forces are in inert…
xela-95 Jan 21, 2025
d7494de
Refactor heun2_integration to remove aux_dict
xela-95 Jan 21, 2025
0370e26
Refactor system_velocity_dynamics and system_dynamics to use inertial…
xela-95 Jan 21, 2025
9f4111b
Fix semi-implicit integrators
flferretti Jan 21, 2025
e3af2db
Fix pre-commit
flferretti Jan 21, 2025
3441f67
Improve error message for incompatible PyTree structures
flferretti Jan 21, 2025
ff88446
Update AD test to include aux dict and fix gradient check modes
flferretti Jan 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ dependencies:
# [testing]
- idyntree >= 12.2.1
- pytest
- pytest-benchmark
- pytest-icdiff
- robot_descriptions
- icub-models
Expand Down
6 changes: 3 additions & 3 deletions examples/jaxsim_as_physics_engine.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@
"\n",
"# Simulate\n",
"for _t in T:\n",
" data, _ = js.model.step(\n",
" data = js.model.step(\n",
" model=model,\n",
" data=data,\n",
" link_forces=None,\n",
Expand Down Expand Up @@ -239,7 +239,7 @@
"\n",
"data = data_batch_t0\n",
"for _t in T:\n",
" data, _ = step_parallel(model, data)"
" data = step_parallel(model, data)"
]
}
],
Expand All @@ -266,7 +266,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.0"
"version": "3.13.1"
}
},
"nbformat": 4,
Expand Down
2 changes: 0 additions & 2 deletions examples/jaxsim_as_physics_engine_advanced.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
"import jax\n",
"import jax.numpy as jnp\n",
"import jaxsim.api as js\n",
"import jaxsim\n",
"import rod\n",
"from jaxsim import logging\n",
"from rod.builder.primitives import SphereBuilder\n",
Expand Down Expand Up @@ -163,7 +162,6 @@
"model = js.model.JaxSimModel.build_from_model_description(\n",
" model_description=model_sdf_string,\n",
" time_step=0.001,\n",
" integrator=jaxsim.integrators.fixed_step.Heun2,\n",
")\n",
"\n",
"# Create the data of a single model.\n",
Expand Down
2 changes: 1 addition & 1 deletion src/jaxsim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,5 +114,5 @@ def _get_default_logging_level(env_var: str) -> logging.LoggingLevel:
del _is_editable

from . import terrain # isort:skip
from . import api, integrators, logging, math, rbda
from . import api, logging, math, rbda
from .api.common import VelRepr
2 changes: 2 additions & 0 deletions src/jaxsim/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from . import (
com,
contact,
contact_model,
frame,
integrators,
joint,
kin_dyn_parameters,
link,
Expand Down
189 changes: 3 additions & 186 deletions src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,143 +95,6 @@ def collidable_point_velocities(
return W_ṗ_Ci


@jax.jit
@js.common.named_scope
def collidable_point_forces(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
link_forces: jtp.MatrixLike | None = None,
joint_force_references: jtp.VectorLike | None = None,
**kwargs,
) -> jtp.Matrix:
"""
Compute the 6D forces applied to each collidable point.

Args:
model: The model to consider.
data: The data of the considered model.
link_forces:
The 6D external forces to apply to the links expressed in the same
representation of data.
joint_force_references:
The joint force references to apply to the joints.
kwargs: Additional keyword arguments to pass to the active contact model.

Returns:
The 6D forces applied to each collidable point expressed in the frame
corresponding to the active representation.
"""

f_Ci, _ = collidable_point_dynamics(
model=model,
data=data,
link_forces=link_forces,
joint_force_references=joint_force_references,
**kwargs,
)

return f_Ci


@jax.jit
@js.common.named_scope
def collidable_point_dynamics(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
link_forces: jtp.MatrixLike | None = None,
joint_force_references: jtp.VectorLike | None = None,
**kwargs,
) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]:
r"""
Compute the 6D force applied to each enabled collidable point.

Args:
model: The model to consider.
data: The data of the considered model.
link_forces:
The 6D external forces to apply to the links expressed in the same
representation of data.
joint_force_references:
The joint force references to apply to the joints.
kwargs: Additional keyword arguments to pass to the active contact model.

Returns:
The 6D force applied to each enabled collidable point and additional data based
on the contact model configured:
- Soft: the material deformation rate.
- Rigid: no additional data.
- QuasiRigid: no additional data.

Note:
The material deformation rate is always returned in the mixed frame
`C[W] = ({}^W \mathbf{p}_C, [W])`. This is convenient for integration purpose.
Instead, the 6D forces are returned in the active representation.
"""

# Build the common kw arguments to pass to the computation of the contact forces.
common_kwargs = dict(
link_forces=link_forces,
joint_force_references=joint_force_references,
)

# Build the additional kwargs to pass to the computation of the contact forces.
match model.contact_model:

case contacts.SoftContacts():

kwargs_contact_model = {}

case contacts.RigidContacts():

kwargs_contact_model = common_kwargs | kwargs

case contacts.RelaxedRigidContacts():

kwargs_contact_model = common_kwargs | kwargs

case contacts.ViscoElasticContacts():

kwargs_contact_model = common_kwargs | dict(dt=model.time_step) | kwargs

case _:
raise ValueError(f"Invalid contact model: {model.contact_model}")

# Compute the contact forces with the active contact model.
W_f_C, aux_data = model.contact_model.compute_contact_forces(
model=model,
data=data,
**kwargs_contact_model,
)

# Compute the transforms of the implicit frames `C[L] = (W_p_C, [L])`
# associated to the enabled collidable point.
# In inertial-fixed representation, the computation of these transforms
# is not necessary and the conversion below becomes a no-op.

# Get the indices of the enabled collidable points.
indices_of_enabled_collidable_points = (
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
)

W_H_C = (
js.contact.transforms(model=model, data=data)
if data.velocity_representation is not VelRepr.Inertial
else jnp.stack([jnp.eye(4)] * len(indices_of_enabled_collidable_points))
)

# Convert the 6D forces to the active representation.
f_Ci = jax.vmap(
lambda W_f_C, W_H_C: data.inertial_to_other_representation(
array=W_f_C,
other_representation=data.velocity_representation,
transform=W_H_C,
is_force=True,
)
)(W_f_C, W_H_C)

return f_Ci, aux_data


@functools.partial(jax.jit, static_argnames=["link_names"])
@js.common.named_scope
def in_contact(
Expand Down Expand Up @@ -351,7 +214,7 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:

zero_data = js.data.JaxSimModelData.build(
model=model,
contacts_params=jaxsim.rbda.contacts.SoftContactsParams(),
contacts_params=jaxsim.rbda.contacts.RelaxedRigidContactsParams(),
)

W_pz_CoM = js.com.com_position(model=model, data=zero_data)[2]
Expand All @@ -362,63 +225,17 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:

return 2 * W_pz_CoM

max_δ = (
max_δ = ( # noqa: F841
max_penetration
if max_penetration is not None
# Consider as default a 0.5% of the model height.
else 0.005 * estimate_model_height(model=model)
)

nc = number_of_active_collidable_points_steady_state
nc = number_of_active_collidable_points_steady_state # noqa: F841

match model.contact_model:

case contacts.SoftContacts():
assert isinstance(model.contact_model, contacts.SoftContacts)

parameters = contacts.SoftContactsParams.build_default_from_jaxsim_model(
model=model,
standard_gravity=standard_gravity,
static_friction_coefficient=static_friction_coefficient,
max_penetration=max_δ,
number_of_active_collidable_points_steady_state=nc,
damping_ratio=damping_ratio,
**kwargs,
)

case contacts.ViscoElasticContacts():
assert isinstance(model.contact_model, contacts.ViscoElasticContacts)

parameters = (
contacts.ViscoElasticContactsParams.build_default_from_jaxsim_model(
model=model,
standard_gravity=standard_gravity,
static_friction_coefficient=static_friction_coefficient,
max_penetration=max_δ,
number_of_active_collidable_points_steady_state=nc,
damping_ratio=damping_ratio,
**kwargs,
)
)

case contacts.RigidContacts():
assert isinstance(model.contact_model, contacts.RigidContacts)

# Disable Baumgarte stabilization by default since it does not play
# well with the forward Euler integrator.
K = kwargs.get("K", 0.0)

parameters = contacts.RigidContactsParams.build(
mu=static_friction_coefficient,
**(
dict(
K=K,
D=2 * jnp.sqrt(K),
)
| kwargs
),
)

case contacts.RelaxedRigidContacts():
assert isinstance(model.contact_model, contacts.RelaxedRigidContacts)

Expand Down
103 changes: 103 additions & 0 deletions src/jaxsim/api/contact_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from __future__ import annotations

import jax
import jax.numpy as jnp

import jaxsim.api as js
import jaxsim.typing as jtp


@jax.jit
@js.common.named_scope
def link_contact_forces(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
link_forces: jtp.MatrixLike | None = None,
joint_force_references: jtp.VectorLike | None = None,
**kwargs,
) -> jtp.Matrix:
"""
Compute the 6D contact forces of all links of the model in inertial representation.

Args:
model: The model to consider.
data: The data of the considered model.
link_forces:
The 6D external forces to apply to the links expressed in inertial representation
joint_force_references:
The joint force references to apply to the joints.
kwargs: Additional keyword arguments to pass to the active contact model..

Returns:
A `(nL, 6)` array containing the stacked 6D contact forces of the links,
expressed in inertial representation.
"""

# Compute the contact forces for each collidable point with the active contact model.
W_f_C, _ = model.contact_model.compute_contact_forces(
model=model,
data=data,
link_forces=link_forces,
joint_force_references=joint_force_references,
)

# Compute the 6D forces applied to the links equivalent to the forces applied
# to the frames associated to the collidable points.
W_f_L = link_forces_from_contact_forces(
model=model, data=data, contact_forces=W_f_C
)

return W_f_L


@staticmethod
def link_forces_from_contact_forces(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
contact_forces: jtp.MatrixLike,
) -> jtp.Matrix:
"""
Compute the link forces from the contact forces.

Args:
model: The robot model considered by the contact model.
data: The data of the considered model.
contact_forces: The contact forces computed by the contact model.

Returns:
The 6D contact forces applied to the links and expressed in the frame of
the velocity representation of data.
"""

# Get the object storing the contact parameters of the model.
contact_parameters = model.kin_dyn_parameters.contact_parameters

# Extract the indices corresponding to the enabled collidable points.
indices_of_enabled_collidable_points = (
contact_parameters.indices_of_enabled_collidable_points
)

# Convert the contact forces to a JAX array.
W_f_C = jnp.atleast_2d(jnp.array(contact_forces, dtype=float).squeeze())

# Construct the vector defining the parent link index of each collidable point.
# We use this vector to sum the 6D forces of all collidable points rigidly
# attached to the same link.
parent_link_index_of_collidable_points = jnp.array(
contact_parameters.body, dtype=int
)[indices_of_enabled_collidable_points]

# Create the mask that associate each collidable point to their parent link.
# We use this mask to sum the collidable points to the right link.
mask = parent_link_index_of_collidable_points[:, jnp.newaxis] == jnp.arange(
model.number_of_links()
)

# Sum the forces of all collidable points rigidly attached to a body.
# Since the contact forces W_f_C are expressed in the world frame,
# we don't need any coordinate transformation.
W_f_L = mask.T @ W_f_C

return W_f_L
Loading
Loading