Skip to content

Commit

Permalink
Update parameters handling of the relaxed-rigid contact model
Browse files Browse the repository at this point in the history
  • Loading branch information
diegoferigo committed Oct 7, 2024
1 parent 7db9adb commit 49eadf0
Showing 1 changed file with 52 additions and 22 deletions.
74 changes: 52 additions & 22 deletions src/jaxsim/rbda/contacts/relaxed_rigid.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,6 @@ class RelaxedRigidContactsParams(ContactsParams):
default_factory=lambda: jnp.array(0.5, dtype=float)
)

# Maximum number of iterations
max_iterations: jtp.Int = dataclasses.field(
default_factory=lambda: jnp.array(50, dtype=int)
)

# Solver tolerance
tolerance: jtp.Float = dataclasses.field(
default_factory=lambda: jnp.array(1e-6, dtype=float)
)

def __hash__(self) -> int:
from jaxsim.utils.wrappers import HashedNumpyArray

Expand All @@ -103,8 +93,6 @@ def __hash__(self) -> int:
HashedNumpyArray(self.stiffness),
HashedNumpyArray(self.damping),
HashedNumpyArray(self.mu),
HashedNumpyArray(self.max_iterations),
HashedNumpyArray(self.tolerance),
)
)

Expand All @@ -125,8 +113,6 @@ def build(
stiffness: jtp.FloatLike | None = None,
damping: jtp.FloatLike | None = None,
mu: jtp.FloatLike | None = None,
max_iterations: jtp.IntLike | None = None,
tolerance: jtp.FloatLike | None = None,
) -> Self:
"""Create a `RelaxedRigidContactsParams` instance"""

Expand All @@ -152,8 +138,6 @@ def valid(self) -> jtp.BoolLike:
and jnp.all(self.midpoint >= 0.0)
and jnp.all(self.power >= 0.0)
and jnp.all(self.mu >= 0.0)
and jnp.all(self.max_iterations > 0)
and jnp.all(self.tolerance > 0.0)
)


Expand All @@ -169,11 +153,30 @@ class RelaxedRigidContacts(ContactModel):
default_factory=FlatTerrain.build
)

_solver_options_keys: jax_dataclasses.Static[tuple[str, ...]] = dataclasses.field(
default=("tol", "maxiter", "memory_size"), kw_only=True
)
_solver_options_values: jax_dataclasses.Static[tuple[Any, ...]] = dataclasses.field(
default=(1e-6, 50, 10), kw_only=True
)

@property
def solver_options(self) -> dict[str, Any]:

return dict(
zip(
self._solver_options_keys,
self._solver_options_values,
strict=True,
)
)

@classmethod
def build(
cls: type[Self],
parameters: RelaxedRigidContactsParams | None = None,
terrain: Terrain | None = None,
solver_options: dict[str, Any] | None = None,
**kwargs,
) -> Self:
"""
Expand All @@ -182,6 +185,7 @@ def build(
Args:
parameters: The parameters of the rigid contacts model.
terrain: The considered terrain.
solver_options: The options to pass to the L-BFGS solver.
Returns:
The `RelaxedRigidContacts` instance.
Expand All @@ -190,11 +194,31 @@ def build(
if len(kwargs) != 0:
logging.debug(msg=f"Ignoring extra arguments: {kwargs}")

# Get the default solver options.
default_solver_options = dict(
zip(cls._solver_options_keys, cls._solver_options_values, strict=True)
)

# Create the solver options to set by combining the default solver options
# with the user-provided solver options.
solver_options = default_solver_options | (solver_options or {})

# Make sure that the solver options are hashable.
# We need to check this because the solver options are static.
try:
hash(tuple(solver_options.values()))
except TypeError as exc:
raise ValueError(
"The values of the solver options must be hashable."
) from exc

return cls(
parameters=(
parameters or cls.__dataclass_fields__["parameters"].default_factory()
),
terrain=terrain or cls.__dataclass_fields__["terrain"].default_factory(),
_solver_options_keys=tuple(solver_options.keys()),
_solver_options_values=tuple(solver_options.values()),
)

@jax.jit
Expand Down Expand Up @@ -357,17 +381,23 @@ def continuing_criterion(carry):
+ D[:, jnp.newaxis] * velocity
).flatten()

# Compute the 3D linear force in C[W] frame
# Get the solver options.
solver_options = self.solver_options

# Extract the options corresponding to the convergence criteria.
# All the remaining options are passed to the solver.
tol = solver_options.pop("tol")
maxiter = solver_options.pop("maxiter")

# Compute the 3D linear force in C[W] frame.
CW_f_Ci, _ = run_optimization(
init_params=init_params,
A=A,
b=b,
maxiter=self.parameters.max_iterations,
opt=optax.lbfgs(
memory_size=10,
),
maxiter=maxiter,
opt=optax.lbfgs(**solver_options),
fun=objective,
tol=self.parameters.tolerance,
tol=tol,
)

CW_f_Ci = CW_f_Ci.reshape((-1, 3))
Expand Down

0 comments on commit 49eadf0

Please sign in to comment.