Skip to content

Commit

Permalink
Merge pull request #257 from ami-iit/split_static_and_dynamic_contact…
Browse files Browse the repository at this point in the history
…_params

Split static and dynamic contact params
  • Loading branch information
diegoferigo authored Oct 7, 2024
2 parents 4d85117 + 49eadf0 commit b7e2fee
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 34 deletions.
1 change: 0 additions & 1 deletion src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,6 @@ def collidable_point_dynamics(
data=data,
link_forces=link_forces,
joint_force_references=joint_force_references,
solver_tol=1e-3,
)

aux_data = dict()
Expand Down
10 changes: 8 additions & 2 deletions src/jaxsim/rbda/contacts/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
class ContactsParams(JaxsimDataclass):
"""
Abstract class representing the parameters of a contact model.
Note:
This class is supposed to store only the tunable parameters of the contact
model, i.e. all those parameters that can be changed during runtime.
If the contact model has also static parameters, they should be stored
in the corresponding `ContactModel` class.
"""

@classmethod
Expand Down Expand Up @@ -47,7 +53,7 @@ class ContactModel(JaxsimDataclass):
Attributes:
parameters: The parameters of the contact model.
terrain: The terrain model.
terrain: The considered terrain.
"""

parameters: ContactsParams
Expand Down Expand Up @@ -85,7 +91,7 @@ def compute_contact_forces(
Compute the contact forces.
Args:
model: The model to consider.
model: The robot model considered by the contact model.
data: The data of the considered model.
Returns:
Expand Down
74 changes: 52 additions & 22 deletions src/jaxsim/rbda/contacts/relaxed_rigid.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,6 @@ class RelaxedRigidContactsParams(ContactsParams):
default_factory=lambda: jnp.array(0.5, dtype=float)
)

# Maximum number of iterations
max_iterations: jtp.Int = dataclasses.field(
default_factory=lambda: jnp.array(50, dtype=int)
)

# Solver tolerance
tolerance: jtp.Float = dataclasses.field(
default_factory=lambda: jnp.array(1e-6, dtype=float)
)

def __hash__(self) -> int:
from jaxsim.utils.wrappers import HashedNumpyArray

Expand All @@ -103,8 +93,6 @@ def __hash__(self) -> int:
HashedNumpyArray(self.stiffness),
HashedNumpyArray(self.damping),
HashedNumpyArray(self.mu),
HashedNumpyArray(self.max_iterations),
HashedNumpyArray(self.tolerance),
)
)

Expand All @@ -125,8 +113,6 @@ def build(
stiffness: jtp.FloatLike | None = None,
damping: jtp.FloatLike | None = None,
mu: jtp.FloatLike | None = None,
max_iterations: jtp.IntLike | None = None,
tolerance: jtp.FloatLike | None = None,
) -> Self:
"""Create a `RelaxedRigidContactsParams` instance"""

Expand All @@ -152,8 +138,6 @@ def valid(self) -> jtp.BoolLike:
and jnp.all(self.midpoint >= 0.0)
and jnp.all(self.power >= 0.0)
and jnp.all(self.mu >= 0.0)
and jnp.all(self.max_iterations > 0)
and jnp.all(self.tolerance > 0.0)
)


Expand All @@ -169,11 +153,30 @@ class RelaxedRigidContacts(ContactModel):
default_factory=FlatTerrain.build
)

_solver_options_keys: jax_dataclasses.Static[tuple[str, ...]] = dataclasses.field(
default=("tol", "maxiter", "memory_size"), kw_only=True
)
_solver_options_values: jax_dataclasses.Static[tuple[Any, ...]] = dataclasses.field(
default=(1e-6, 50, 10), kw_only=True
)

@property
def solver_options(self) -> dict[str, Any]:

return dict(
zip(
self._solver_options_keys,
self._solver_options_values,
strict=True,
)
)

@classmethod
def build(
cls: type[Self],
parameters: RelaxedRigidContactsParams | None = None,
terrain: Terrain | None = None,
solver_options: dict[str, Any] | None = None,
**kwargs,
) -> Self:
"""
Expand All @@ -182,6 +185,7 @@ def build(
Args:
parameters: The parameters of the rigid contacts model.
terrain: The considered terrain.
solver_options: The options to pass to the L-BFGS solver.
Returns:
The `RelaxedRigidContacts` instance.
Expand All @@ -190,11 +194,31 @@ def build(
if len(kwargs) != 0:
logging.debug(msg=f"Ignoring extra arguments: {kwargs}")

# Get the default solver options.
default_solver_options = dict(
zip(cls._solver_options_keys, cls._solver_options_values, strict=True)
)

# Create the solver options to set by combining the default solver options
# with the user-provided solver options.
solver_options = default_solver_options | (solver_options or {})

# Make sure that the solver options are hashable.
# We need to check this because the solver options are static.
try:
hash(tuple(solver_options.values()))
except TypeError as exc:
raise ValueError(
"The values of the solver options must be hashable."
) from exc

return cls(
parameters=(
parameters or cls.__dataclass_fields__["parameters"].default_factory()
),
terrain=terrain or cls.__dataclass_fields__["terrain"].default_factory(),
_solver_options_keys=tuple(solver_options.keys()),
_solver_options_values=tuple(solver_options.values()),
)

@jax.jit
Expand Down Expand Up @@ -357,17 +381,23 @@ def continuing_criterion(carry):
+ D[:, jnp.newaxis] * velocity
).flatten()

# Compute the 3D linear force in C[W] frame
# Get the solver options.
solver_options = self.solver_options

# Extract the options corresponding to the convergence criteria.
# All the remaining options are passed to the solver.
tol = solver_options.pop("tol")
maxiter = solver_options.pop("maxiter")

# Compute the 3D linear force in C[W] frame.
CW_f_Ci, _ = run_optimization(
init_params=init_params,
A=A,
b=b,
maxiter=self.parameters.max_iterations,
opt=optax.lbfgs(
memory_size=10,
),
maxiter=maxiter,
opt=optax.lbfgs(**solver_options),
fun=objective,
tol=self.parameters.tolerance,
tol=tol,
)

CW_f_Ci = CW_f_Ci.reshape((-1, 3))
Expand Down
64 changes: 55 additions & 9 deletions src/jaxsim/rbda/contacts/rigid.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,35 @@ class RigidContacts(ContactModel):
default_factory=FlatTerrain.build
)

regularization_delassus: jax_dataclasses.Static[float] = dataclasses.field(
default=1e-6, kw_only=True
)

_solver_options_keys: jax_dataclasses.Static[tuple[str, ...]] = dataclasses.field(
default=("solver_tol",), kw_only=True
)
_solver_options_values: jax_dataclasses.Static[tuple[Any, ...]] = dataclasses.field(
default=(1e-3,), kw_only=True
)

@property
def solver_options(self) -> dict[str, Any]:

return dict(
zip(
self._solver_options_keys,
self._solver_options_values,
strict=True,
)
)

@classmethod
def build(
cls: type[Self],
parameters: RigidContactsParams | None = None,
terrain: Terrain | None = None,
regularization_delassus: jtp.FloatLike | None = None,
solver_options: dict[str, Any] | None = None,
**kwargs,
) -> Self:
"""
Expand All @@ -104,6 +128,9 @@ def build(
Args:
parameters: The parameters of the rigid contacts model.
terrain: The considered terrain.
regularization_delassus:
The regularization term to add to the diagonal of the Delassus matrix.
solver_options: The options to pass to the QP solver.
Returns:
The `RigidContacts` instance.
Expand All @@ -112,11 +139,35 @@ def build(
if len(kwargs) != 0:
logging.debug(msg=f"Ignoring extra arguments: {kwargs}")

# Get the default solver options.
default_solver_options = dict(
zip(cls._solver_options_keys, cls._solver_options_values, strict=True)
)

# Create the solver options to set by combining the default solver options
# with the user-provided solver options.
solver_options = default_solver_options | (solver_options or {})

# Make sure that the solver options are hashable.
# We need to check this because the solver options are static.
try:
hash(tuple(solver_options.values()))
except TypeError as exc:
raise ValueError(
"The values of the solver options must be hashable."
) from exc

return cls(
parameters=(
parameters or cls.__dataclass_fields__["parameters"].default_factory()
),
terrain=terrain or cls.__dataclass_fields__["terrain"].default_factory(),
regularization_delassus=float(
regularization_delassus
or cls.__dataclass_fields__["regularization_delassus"].default
),
_solver_options_keys=tuple(solver_options.keys()),
_solver_options_values=tuple(solver_options.values()),
)

@staticmethod
Expand Down Expand Up @@ -230,8 +281,6 @@ def compute_contact_forces(
*,
link_forces: jtp.MatrixLike | None = None,
joint_force_references: jtp.VectorLike | None = None,
regularization_term: jtp.FloatLike = 1e-6,
solver_tol: jtp.FloatLike = 1e-3,
) -> tuple[jtp.Vector, tuple[Any, ...]]:
"""
Compute the contact forces.
Expand All @@ -244,10 +293,6 @@ def compute_contact_forces(
expressed in the same representation of data.
joint_force_references:
Optional `(n_joints,)` vector of joint forces.
regularization_term:
The regularization term to add to the diagonal of the Delassus
matrix for better numerical conditioning.
solver_tol: The convergence tolerance to consider in the QP solver.
Returns:
A tuple containing the contact forces.
Expand Down Expand Up @@ -296,10 +341,11 @@ def compute_contact_forces(
terrain_height=terrain_height,
)

# Compute the Delassus matrix.
delassus_matrix = RigidContacts._delassus_matrix(M=M, J_WC=J_WC)

# Add regularization for better numerical conditioning
delassus_matrix = delassus_matrix + regularization_term * jnp.eye(
# Add regularization for better numerical conditioning.
delassus_matrix = delassus_matrix + self.regularization_delassus * jnp.eye(
delassus_matrix.shape[0]
)

Expand Down Expand Up @@ -359,7 +405,7 @@ def compute_contact_forces(

# Solve the optimization problem
solution, *_ = qpax.solve_qp(
Q=Q, q=q, A=A, b=b, G=G, h=h_bounds, solver_tol=solver_tol
Q=Q, q=q, A=A, b=b, G=G, h=h_bounds, **self.solver_options
)

f_C_lin = solution.reshape(-1, 3)
Expand Down

0 comments on commit b7e2fee

Please sign in to comment.