diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 2a1f6d569..2b0aed385 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -188,7 +188,6 @@ def collidable_point_dynamics( data=data, link_forces=link_forces, joint_force_references=joint_force_references, - solver_tol=1e-3, ) aux_data = dict() diff --git a/src/jaxsim/rbda/contacts/common.py b/src/jaxsim/rbda/contacts/common.py index cad3fbf32..7639a9738 100644 --- a/src/jaxsim/rbda/contacts/common.py +++ b/src/jaxsim/rbda/contacts/common.py @@ -17,6 +17,12 @@ class ContactsParams(JaxsimDataclass): """ Abstract class representing the parameters of a contact model. + + Note: + This class is supposed to store only the tunable parameters of the contact + model, i.e. all those parameters that can be changed during runtime. + If the contact model has also static parameters, they should be stored + in the corresponding `ContactModel` class. """ @classmethod @@ -47,7 +53,7 @@ class ContactModel(JaxsimDataclass): Attributes: parameters: The parameters of the contact model. - terrain: The terrain model. + terrain: The considered terrain. """ parameters: ContactsParams @@ -85,7 +91,7 @@ def compute_contact_forces( Compute the contact forces. Args: - model: The model to consider. + model: The robot model considered by the contact model. data: The data of the considered model. Returns: diff --git a/src/jaxsim/rbda/contacts/relaxed_rigid.py b/src/jaxsim/rbda/contacts/relaxed_rigid.py index 9699b97a6..6737a14b3 100644 --- a/src/jaxsim/rbda/contacts/relaxed_rigid.py +++ b/src/jaxsim/rbda/contacts/relaxed_rigid.py @@ -78,16 +78,6 @@ class RelaxedRigidContactsParams(ContactsParams): default_factory=lambda: jnp.array(0.5, dtype=float) ) - # Maximum number of iterations - max_iterations: jtp.Int = dataclasses.field( - default_factory=lambda: jnp.array(50, dtype=int) - ) - - # Solver tolerance - tolerance: jtp.Float = dataclasses.field( - default_factory=lambda: jnp.array(1e-6, dtype=float) - ) - def __hash__(self) -> int: from jaxsim.utils.wrappers import HashedNumpyArray @@ -103,8 +93,6 @@ def __hash__(self) -> int: HashedNumpyArray(self.stiffness), HashedNumpyArray(self.damping), HashedNumpyArray(self.mu), - HashedNumpyArray(self.max_iterations), - HashedNumpyArray(self.tolerance), ) ) @@ -125,8 +113,6 @@ def build( stiffness: jtp.FloatLike | None = None, damping: jtp.FloatLike | None = None, mu: jtp.FloatLike | None = None, - max_iterations: jtp.IntLike | None = None, - tolerance: jtp.FloatLike | None = None, ) -> Self: """Create a `RelaxedRigidContactsParams` instance""" @@ -152,8 +138,6 @@ def valid(self) -> jtp.BoolLike: and jnp.all(self.midpoint >= 0.0) and jnp.all(self.power >= 0.0) and jnp.all(self.mu >= 0.0) - and jnp.all(self.max_iterations > 0) - and jnp.all(self.tolerance > 0.0) ) @@ -169,11 +153,30 @@ class RelaxedRigidContacts(ContactModel): default_factory=FlatTerrain.build ) + _solver_options_keys: jax_dataclasses.Static[tuple[str, ...]] = dataclasses.field( + default=("tol", "maxiter", "memory_size"), kw_only=True + ) + _solver_options_values: jax_dataclasses.Static[tuple[Any, ...]] = dataclasses.field( + default=(1e-6, 50, 10), kw_only=True + ) + + @property + def solver_options(self) -> dict[str, Any]: + + return dict( + zip( + self._solver_options_keys, + self._solver_options_values, + strict=True, + ) + ) + @classmethod def build( cls: type[Self], parameters: RelaxedRigidContactsParams | None = None, terrain: Terrain | None = None, + solver_options: dict[str, Any] | None = None, **kwargs, ) -> Self: """ @@ -182,6 +185,7 @@ def build( Args: parameters: The parameters of the rigid contacts model. terrain: The considered terrain. + solver_options: The options to pass to the L-BFGS solver. Returns: The `RelaxedRigidContacts` instance. @@ -190,11 +194,31 @@ def build( if len(kwargs) != 0: logging.debug(msg=f"Ignoring extra arguments: {kwargs}") + # Get the default solver options. + default_solver_options = dict( + zip(cls._solver_options_keys, cls._solver_options_values, strict=True) + ) + + # Create the solver options to set by combining the default solver options + # with the user-provided solver options. + solver_options = default_solver_options | (solver_options or {}) + + # Make sure that the solver options are hashable. + # We need to check this because the solver options are static. + try: + hash(tuple(solver_options.values())) + except TypeError as exc: + raise ValueError( + "The values of the solver options must be hashable." + ) from exc + return cls( parameters=( parameters or cls.__dataclass_fields__["parameters"].default_factory() ), terrain=terrain or cls.__dataclass_fields__["terrain"].default_factory(), + _solver_options_keys=tuple(solver_options.keys()), + _solver_options_values=tuple(solver_options.values()), ) @jax.jit @@ -357,17 +381,23 @@ def continuing_criterion(carry): + D[:, jnp.newaxis] * velocity ).flatten() - # Compute the 3D linear force in C[W] frame + # Get the solver options. + solver_options = self.solver_options + + # Extract the options corresponding to the convergence criteria. + # All the remaining options are passed to the solver. + tol = solver_options.pop("tol") + maxiter = solver_options.pop("maxiter") + + # Compute the 3D linear force in C[W] frame. CW_f_Ci, _ = run_optimization( init_params=init_params, A=A, b=b, - maxiter=self.parameters.max_iterations, - opt=optax.lbfgs( - memory_size=10, - ), + maxiter=maxiter, + opt=optax.lbfgs(**solver_options), fun=objective, - tol=self.parameters.tolerance, + tol=tol, ) CW_f_Ci = CW_f_Ci.reshape((-1, 3)) diff --git a/src/jaxsim/rbda/contacts/rigid.py b/src/jaxsim/rbda/contacts/rigid.py index 7c6d1f6ac..bfacba19a 100644 --- a/src/jaxsim/rbda/contacts/rigid.py +++ b/src/jaxsim/rbda/contacts/rigid.py @@ -91,11 +91,35 @@ class RigidContacts(ContactModel): default_factory=FlatTerrain.build ) + regularization_delassus: jax_dataclasses.Static[float] = dataclasses.field( + default=1e-6, kw_only=True + ) + + _solver_options_keys: jax_dataclasses.Static[tuple[str, ...]] = dataclasses.field( + default=("solver_tol",), kw_only=True + ) + _solver_options_values: jax_dataclasses.Static[tuple[Any, ...]] = dataclasses.field( + default=(1e-3,), kw_only=True + ) + + @property + def solver_options(self) -> dict[str, Any]: + + return dict( + zip( + self._solver_options_keys, + self._solver_options_values, + strict=True, + ) + ) + @classmethod def build( cls: type[Self], parameters: RigidContactsParams | None = None, terrain: Terrain | None = None, + regularization_delassus: jtp.FloatLike | None = None, + solver_options: dict[str, Any] | None = None, **kwargs, ) -> Self: """ @@ -104,6 +128,9 @@ def build( Args: parameters: The parameters of the rigid contacts model. terrain: The considered terrain. + regularization_delassus: + The regularization term to add to the diagonal of the Delassus matrix. + solver_options: The options to pass to the QP solver. Returns: The `RigidContacts` instance. @@ -112,11 +139,35 @@ def build( if len(kwargs) != 0: logging.debug(msg=f"Ignoring extra arguments: {kwargs}") + # Get the default solver options. + default_solver_options = dict( + zip(cls._solver_options_keys, cls._solver_options_values, strict=True) + ) + + # Create the solver options to set by combining the default solver options + # with the user-provided solver options. + solver_options = default_solver_options | (solver_options or {}) + + # Make sure that the solver options are hashable. + # We need to check this because the solver options are static. + try: + hash(tuple(solver_options.values())) + except TypeError as exc: + raise ValueError( + "The values of the solver options must be hashable." + ) from exc + return cls( parameters=( parameters or cls.__dataclass_fields__["parameters"].default_factory() ), terrain=terrain or cls.__dataclass_fields__["terrain"].default_factory(), + regularization_delassus=float( + regularization_delassus + or cls.__dataclass_fields__["regularization_delassus"].default + ), + _solver_options_keys=tuple(solver_options.keys()), + _solver_options_values=tuple(solver_options.values()), ) @staticmethod @@ -230,8 +281,6 @@ def compute_contact_forces( *, link_forces: jtp.MatrixLike | None = None, joint_force_references: jtp.VectorLike | None = None, - regularization_term: jtp.FloatLike = 1e-6, - solver_tol: jtp.FloatLike = 1e-3, ) -> tuple[jtp.Vector, tuple[Any, ...]]: """ Compute the contact forces. @@ -244,10 +293,6 @@ def compute_contact_forces( expressed in the same representation of data. joint_force_references: Optional `(n_joints,)` vector of joint forces. - regularization_term: - The regularization term to add to the diagonal of the Delassus - matrix for better numerical conditioning. - solver_tol: The convergence tolerance to consider in the QP solver. Returns: A tuple containing the contact forces. @@ -296,10 +341,11 @@ def compute_contact_forces( terrain_height=terrain_height, ) + # Compute the Delassus matrix. delassus_matrix = RigidContacts._delassus_matrix(M=M, J_WC=J_WC) - # Add regularization for better numerical conditioning - delassus_matrix = delassus_matrix + regularization_term * jnp.eye( + # Add regularization for better numerical conditioning. + delassus_matrix = delassus_matrix + self.regularization_delassus * jnp.eye( delassus_matrix.shape[0] ) @@ -359,7 +405,7 @@ def compute_contact_forces( # Solve the optimization problem solution, *_ = qpax.solve_qp( - Q=Q, q=q, A=A, b=b, G=G, h=h_bounds, solver_tol=solver_tol + Q=Q, q=q, A=A, b=b, G=G, h=h_bounds, **self.solver_options ) f_C_lin = solution.reshape(-1, 3)