Skip to content

Commit

Permalink
Remove duplicated Hunt-Crossley model
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Feb 27, 2025
1 parent 9287655 commit 59bd313
Showing 1 changed file with 2 additions and 151 deletions.
153 changes: 2 additions & 151 deletions src/jaxsim/rbda/contacts/relaxed_rigid.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import dataclasses
import functools
from collections.abc import Callable
from typing import Any

Expand All @@ -11,12 +10,10 @@
import optax

import jaxsim.api as js
import jaxsim.rbda.contacts
import jaxsim.typing as jtp
from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr
from jaxsim.terrain.terrain import Terrain

from . import common
from . import common, soft

try:
from typing import Self
Expand Down Expand Up @@ -424,7 +421,7 @@ def continuing_criterion(carry: OptimizationCarry) -> jtp.Bool:

# Initialize the optimized forces with a linear Hunt/Crossley model.
init_params = jax.vmap(
lambda p, v: self._hunt_crossley_contact_model(
lambda p, v: soft.SoftContacts.hunt_crossley_contact_model(
position=p,
velocity=v,
terrain=model.terrain,
Expand Down Expand Up @@ -601,149 +598,3 @@ def compute_row(
)

return a_ref, jnp.diag(R), K, D

@staticmethod
@functools.partial(jax.jit, static_argnames=("terrain",))
def _hunt_crossley_contact_model(
position: jtp.VectorLike,
velocity: jtp.VectorLike,
tangential_deformation: jtp.VectorLike,
terrain: Terrain,
K: jtp.FloatLike,
D: jtp.FloatLike,
mu: jtp.FloatLike,
p: jtp.FloatLike = 0.5,
q: jtp.FloatLike = 0.5,
) -> tuple[jtp.Vector, jtp.Vector]:
"""
Compute the contact force using the Hunt/Crossley model.
Args:
position: The position of the collidable point.
velocity: The velocity of the collidable point.
tangential_deformation: The material deformation of the collidable point.
terrain: The terrain model.
K: The stiffness parameter.
D: The damping parameter of the soft contacts model.
mu: The static friction coefficient.
p:
The exponent p corresponding to the damping-related non-linearity
of the Hunt/Crossley model.
q:
The exponent q corresponding to the spring-related non-linearity
of the Hunt/Crossley model
Returns:
A tuple containing the computed contact force and the derivative of the
material deformation.
"""

# Convert the input vectors to arrays.
W_p_C = jnp.array(position, dtype=float).squeeze()
W_ṗ_C = jnp.array(velocity, dtype=float).squeeze()
m = jnp.array(tangential_deformation, dtype=float).squeeze()

# Use symbol for the static friction.
μ = mu

# Compute the penetration depth, its rate, and the considered terrain normal.
δ, δ̇, = common.compute_penetration_data(p=W_p_C, v=W_ṗ_C, terrain=terrain)

# There are few operations like computing the norm of a vector with zero length
# or computing the square root of zero that are problematic in an AD context.
# To avoid these issues, we introduce a small tolerance ε to their arguments
# and make sure that we do not check them against zero directly.
ε = jnp.finfo(float).eps

# Compute the powers of the penetration depth.
# Inject ε to address AD issues in differentiating the square root when
# p and q are fractional.
δp = jnp.power(δ + ε, p)
δq = jnp.power(δ + ε, q)

# ========================
# Compute the normal force
# ========================

# Non-linear spring-damper model (Hunt/Crossley model).
# This is the force magnitude along the direction normal to the terrain.
force_normal_mag = (K * δp) * δ + (D * δq) * δ̇

# Depending on the magnitude of δ̇, the normal force could be negative.
force_normal_mag = jnp.maximum(0.0, force_normal_mag)

# Compute the 3D linear force in C[W] frame.
f_normal = force_normal_mag *

# ============================
# Compute the tangential force
# ============================

# Extract the tangential component of the velocity.
v_tangential = W_ṗ_C - jnp.dot(W_ṗ_C, ) *

# Extract the normal and tangential components of the material deformation.
m_normal = jnp.dot(m, ) *
m_tangential = m - jnp.dot(m, ) *

# Compute the tangential force in the sticking case.
# Using the tangential component of the material deformation should not be
# necessary if the sticking-slipping transition occurs in a terrain area
# with a locally constant normal. However, this assumption is not true in
# general, especially for highly uneven terrains.
f_tangential = -((K * δp) * m_tangential + (D * δq) * v_tangential)

# Detect the contact type (sticking or slipping).
# Note that if there is no contact, sticking is set to True, and this detail
# is exploited in the computation of the `contact_status` variable.
sticking = jnp.logical_or(
δ <= 0, f_tangential.dot(f_tangential) <= (μ * force_normal_mag) ** 2
)

# Compute the direction of the tangential force.
# To prevent dividing by zero, we use a switch statement.
norm = jaxsim.math.safe_norm(f_tangential)
f_tangential_direction = f_tangential / (
norm + jnp.finfo(float).eps * (norm == 0)
)

# Project the tangential force to the friction cone if slipping.
f_tangential = jnp.where(
sticking,
f_tangential,
jnp.minimum(μ * force_normal_mag, norm) * f_tangential_direction,
)

# Set the tangential force to zero if there is no contact.
f_tangential = jnp.where(δ <= 0, jnp.zeros(3), f_tangential)

# =====================================
# Compute the material deformation rate
# =====================================

# Compute the derivative of the material deformation.
# Note that we included an additional relaxation of `m_normal` in the
# sticking case, so that the normal deformation that could have accumulated
# from a previous slipping phase can relax to zero.
ṁ_no_contact = -(K / D) * m
ṁ_sticking = v_tangential - (K / D) * m_normal
ṁ_slipping = -(f_tangential + (K * δp) * m_tangential) / (D * δq)

# Compute the contact status:
# 0: slipping
# 1: sticking
# 2: no contact
contact_status = sticking.astype(int)
contact_status += (δ <= 0).astype(int)

# Select the right material deformation rate depending on the contact status.
= jax.lax.select_n(contact_status, ṁ_slipping, ṁ_sticking, ṁ_no_contact)

# ==========================================
# Compute and return the final contact force
# ==========================================

# Sum the normal and tangential forces.
CW_fl = f_normal + f_tangential

return CW_fl,

0 comments on commit 59bd313

Please sign in to comment.