Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reintroduce alternative contact models and streamline changes to the new API #360

Open
wants to merge 36 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
bbf8f83
Create configuration for GPU benchmarks on local runner
flferretti Feb 26, 2025
23aada9
Restore `api.model.forward_kinematics` for benchmarking
flferretti Feb 27, 2025
f791f91
Explicitly set `SoftContacts` in benchmarks
flferretti Feb 27, 2025
071b249
Refactor workflow to improve performance result handling and caching
flferretti Feb 27, 2025
6538907
Add benchmark for full simulation step
flferretti Feb 27, 2025
3494ff7
Update benchmark command
flferretti Feb 28, 2025
a6aac9e
Increase ulimit stack size and use host IPC namespace
flferretti Feb 28, 2025
9aedc95
Install Gazebo SDF
flferretti Feb 28, 2025
662c08e
Skip soft and rigid contacts model tests
flferretti Feb 28, 2025
9e37572
Pass API token for commenting PRs
flferretti Feb 28, 2025
66b8c70
Ensure GPU utilization and benchmark vectorized computations
flferretti Feb 28, 2025
33628bb
Always comment PRs and add permissions for GH pages
flferretti Feb 28, 2025
05111af
Reintroduce soft, viscoelastic and rigid contact models
flferretti Jan 28, 2025
29f9c61
Streamline new API changes to alternative contact models
flferretti Jan 28, 2025
8eef5a2
Remove viscoelastic contact model
flferretti Jan 30, 2025
65ac30c
Convert base post impact velocity to inertial before setting
flferretti Feb 10, 2025
1d0b47c
Use polymorphism for estimating good contact parameters
flferretti Feb 10, 2025
8822416
Use polymorphism to update the contact state during the step
flferretti Feb 10, 2025
e008f0d
Refactor initialization of `contact_state`
flferretti Feb 10, 2025
8e6c3b5
Use polymorphism to update the post impact velocity
flferretti Feb 10, 2025
d3b8404
Make `build_default_from_jaxsim_model` compatible with all models
flferretti Feb 10, 2025
6853fe3
Unify `api.contact_model` in `api.contact`
flferretti Feb 10, 2025
39f9989
Rename `aux_dict` to `old_contact_state`
flferretti Feb 10, 2025
6716a41
Make static methods functions in `RigidContacts`
flferretti Feb 25, 2025
f3ee2b6
Make `system_acceleration` return the contact state derivative
flferretti Feb 26, 2025
4683e81
Parametrize contact tests on integrators
flferretti Feb 26, 2025
71235cc
Allow replacing the contact state in `JaxSimModelData`
flferretti Feb 26, 2025
4f0b7df
Clarify that the gravity should be passed as positive constant
flferretti Feb 26, 2025
510a20d
Remove JIT compilation from inequality bounds function
flferretti Feb 26, 2025
2f6c231
Remove unused methods in soft contacts
flferretti Feb 26, 2025
49f062e
Speed up rigid contact model
flferretti Feb 26, 2025
383a15b
Prevent recompilation on enabled collidable points
flferretti Feb 27, 2025
c853c4f
Remove duplicated Hunt-Crossley model
flferretti Feb 27, 2025
1dd113d
Fix link forces representation in rigid contacts
flferretti Feb 27, 2025
ee5d8c8
Move `test_collidable_point_jacobians` to `test_api_contact`
flferretti Mar 1, 2025
59a2a7b
Lint code and remove duplications
flferretti Mar 1, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions .github/workflows/gpu_benchmark.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
name: GPU Benchmarks

on:
pull_request:
types: [opened, reopened, synchronize]
workflow_dispatch:
schedule:
- cron: "0 0 * * 1" # Run At 00:00 on Monday

permissions:
pull-requests: write
deployments: write
contents: write

jobs:
benchmark:
runs-on: self-hosted
container:
image: ghcr.io/nvidia/jax:jax
options: --rm --gpus all --ipc=host --ulimit memlock=-1 --ulimit stack=67108864

steps:
- name: Checkout repository
uses: actions/checkout@v4

- name: Set up environment
run: |
pip install .[all]

- name: Install Gazebo SDF
run: |
apt update && apt install -y lsb-release wget
sh -c 'echo "deb http://packages.osrfoundation.org/gazebo/ubuntu-stable `lsb_release -cs` main" > /etc/apt/sources.list.d/gazebo-stable.list'
wget http://packages.osrfoundation.org/gazebo.key -O - | apt-key add -
apt-get update
apt install -y libsdformat14-dev libsdformat14

- name: Performance regression check
uses: actions/checkout@v4

- name: Run benchmark and store result
run: |
pytest tests/test_benchmark.py -k 'not test_rigid_contact_model and not test_soft_contact_model' --gpu-only --batch-size 128 --benchmark-only --benchmark-json output.json

- name: Download previous benchmark data
uses: actions/cache@v4
with:
path: ./cache
key: ${{ runner.os }}-benchmark

- name: Store benchmark result
uses: benchmark-action/github-action-benchmark@v1
with:
tool: 'pytest'
output-file-path: output.json
external-data-json-path: ./cache/benchmark-data.json
fail-on-alert: true
summary-always: true
comment-always: true
alert-threshold: 150%
auto-push: false
gh-pages-branch: gh-pages
github-token: ${{ secrets.GITHUB_TOKEN }}

- name: Push benchmark result
if: github.event_name != 'pull_request'
run: git push 'https://ami-iit:${{ secrets.GITHUB_TOKEN }}@github.com/ami-iit/jaxsim.git' gh-pages:gh-pages
1 change: 0 additions & 1 deletion src/jaxsim/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
actuation_model,
com,
contact,
contact_model,
frame,
integrators,
joint,
Expand Down
4 changes: 1 addition & 3 deletions src/jaxsim/api/com.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,9 +301,7 @@ def other_representation_to_body(
C_v̇_WL = W_v̇_bias_WL = v̇_bias_WL # noqa: F841
C_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841

L_H_C = L_H_W = jax.vmap( # noqa: F841
lambda W_H_L: jaxsim.math.Transform.inverse(W_H_L)
)(W_H_L)
L_H_C = L_H_W = jax.vmap(jaxsim.math.Transform.inverse)(W_H_L) # noqa: F841

L_v_LC = L_v_LW = jax.vmap( # noqa: F841
lambda i: -js.link.velocity(
Expand Down
164 changes: 140 additions & 24 deletions src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import jaxsim.typing as jtp
from jaxsim import logging
from jaxsim.math import Adjoint, Cross, Transform
from jaxsim.rbda import contacts
from jaxsim.rbda.contacts import SoftContacts

from .common import VelRepr

Expand All @@ -37,14 +37,11 @@ def collidable_point_kinematics(
the linear component of the mixed 6D frame velocity.
"""

# Switch to inertial-fixed since the RBDAs expect velocities in this representation.
with data.switch_velocity_representation(VelRepr.Inertial):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should pay attention to this change of representation for PR #359 in which the property returns data in the active representation (not always inertial)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm accessing the attributes, not the property, so I believe it should always return the quantity in inertial representation, right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree, it should be equivalent

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep true. Coming from other languages I don't like to access directly the "private" attributes of a class, is like coupling to the inner implementation of it. If in a next PR we change the representation in which they are stored it will be difficult to backtrace errors here, while if we use the class API we are relying on the class contract.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, we should see them as protected rather than private

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving unresolved for now, thanks for pointing this out


W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_points.collidable_points_pos_vel(
model=model,
link_transforms=data._link_transforms,
link_velocities=data._link_velocities,
)
W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_points.collidable_points_pos_vel(
model=model,
link_transforms=data._link_transforms,
link_velocities=data._link_velocities,
)

return W_p_Ci, W_ṗ_Ci

Expand Down Expand Up @@ -164,18 +161,23 @@ def estimate_good_soft_contacts_parameters(
def estimate_good_contact_parameters(
model: js.model.JaxSimModel,
*,
standard_gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY,
static_friction_coefficient: jtp.FloatLike = 0.5,
**kwargs,
number_of_active_collidable_points_steady_state: jtp.IntLike = 1,
damping_ratio: jtp.FloatLike = 1.0,
max_penetration: jtp.FloatLike | None = None,
) -> jaxsim.rbda.contacts.ContactParamsTypes:
"""
Estimate good contact parameters.

Args:
model: The model to consider.
standard_gravity: The standard gravity acceleration.
static_friction_coefficient: The static friction coefficient.
kwargs:
Additional model-specific parameters passed to the builder method of
the parameters class.
number_of_active_collidable_points_steady_state:
The number of active collidable points in steady state.
damping_ratio: The damping ratio.
max_penetration: The maximum penetration allowed.

Returns:
The estimated good contacts parameters.
Expand All @@ -190,20 +192,41 @@ def estimate_good_contact_parameters(
specific application.
"""

match model.contact_model:
def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:
"""
Displacement between the CoM and the lowest collidable point using zero
joint positions.
"""

zero_data = js.data.JaxSimModelData.build(
model=model,
)

W_pz_CoM = js.com.com_position(model=model, data=zero_data)[2]

case contacts.RelaxedRigidContacts():
assert isinstance(model.contact_model, contacts.RelaxedRigidContacts)
if model.floating_base():
W_pz_C = collidable_point_positions(model=model, data=zero_data)[:, -1]
return 2 * (W_pz_CoM - W_pz_C.min())

parameters = contacts.RelaxedRigidContactsParams.build(
mu=static_friction_coefficient,
**kwargs,
)
return 2 * W_pz_CoM

case _:
raise ValueError(f"Invalid contact model: {model.contact_model}")
max_δ = (
max_penetration
if max_penetration is not None
# Consider as default a 0.5% of the model height.
else 0.005 * estimate_model_height(model=model)
)

return parameters
nc = number_of_active_collidable_points_steady_state

return model.contact_model._parameters_class().build_default_from_jaxsim_model(
model=model,
standard_gravity=standard_gravity,
static_friction_coefficient=static_friction_coefficient,
max_penetration=max_δ,
number_of_active_collidable_points_steady_state=nc,
damping_ratio=damping_ratio,
)


@jax.jit
Expand Down Expand Up @@ -244,7 +267,7 @@ def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jt

# Build the link-to-point transform from the displacement between the link frame L
# and the implicit contact frame C.
L_H_C = jax.vmap(lambda L_p_C: jnp.eye(4).at[0:3, 3].set(L_p_C))(L_p_Ci)
L_H_C = jax.vmap(jnp.eye(4).at[0:3, 3].set)(L_p_Ci)

# Compose the work-to-link and link-to-point transforms.
return jax.vmap(lambda W_H_Li, L_H_Ci: W_H_Li @ L_H_Ci)(W_H_L, L_H_C)
Expand Down Expand Up @@ -504,3 +527,96 @@ def compute_O_J̇_WC_I(
)

return O_J̇_WC


@jax.jit
@js.common.named_scope
def link_contact_forces(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
link_forces: jtp.MatrixLike | None = None,
joint_torques: jtp.VectorLike | None = None,
) -> tuple[jtp.Matrix, dict[str, jtp.Matrix]]:
"""
Compute the 6D contact forces of all links of the model in inertial representation.

Args:
model: The model to consider.
data: The data of the considered model.
link_forces:
The 6D external forces to apply to the links expressed in inertial representation
joint_torques:
The joint torques acting on the joints.

Returns:
A `(nL, 6)` array containing the stacked 6D contact forces of the links,
expressed in inertial representation.
"""

# Compute the contact forces for each collidable point with the active contact model.
W_f_C, aux_dict = model.contact_model.compute_contact_forces(
model=model,
data=data,
**(
dict(link_forces=link_forces, joint_force_references=joint_torques)
if not isinstance(model.contact_model, SoftContacts)
else {}
),
)

# Compute the 6D forces applied to the links equivalent to the forces applied
# to the frames associated to the collidable points.
W_f_L = link_forces_from_contact_forces(model=model, contact_forces=W_f_C)

return W_f_L, aux_dict


@staticmethod
def link_forces_from_contact_forces(
model: js.model.JaxSimModel,
*,
contact_forces: jtp.MatrixLike,
) -> jtp.Matrix:
"""
Compute the link forces from the contact forces.

Args:
model: The robot model considered by the contact model.
contact_forces: The contact forces computed by the contact model.

Returns:
The 6D contact forces applied to the links and expressed in the frame of
the velocity representation of data.
"""

# Get the object storing the contact parameters of the model.
contact_parameters = model.kin_dyn_parameters.contact_parameters

# Extract the indices corresponding to the enabled collidable points.
indices_of_enabled_collidable_points = (
contact_parameters.indices_of_enabled_collidable_points
)

# Convert the contact forces to a JAX array.
W_f_C = jnp.atleast_2d(jnp.array(contact_forces, dtype=float).squeeze())

# Construct the vector defining the parent link index of each collidable point.
# We use this vector to sum the 6D forces of all collidable points rigidly
# attached to the same link.
parent_link_index_of_collidable_points = jnp.array(
contact_parameters.body, dtype=int
)[indices_of_enabled_collidable_points]

# Create the mask that associate each collidable point to their parent link.
# We use this mask to sum the collidable points to the right link.
mask = parent_link_index_of_collidable_points[:, jnp.newaxis] == jnp.arange(
model.number_of_links()
)

# Sum the forces of all collidable points rigidly attached to a body.
# Since the contact forces W_f_C are expressed in the world frame,
# we don't need any coordinate transformation.
W_f_L = mask.T @ W_f_C

return W_f_L
Loading
Loading