From 7f0893f849130e0df235eb04d516e7ce4c81afc4 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 28 Jan 2025 15:00:53 +0100 Subject: [PATCH] Reintroduce soft, viscoelastic and rigid contact models This reverts commit 9408f744ec6302c107c886f43b92529809acbd21. --- src/jaxsim/rbda/contacts/__init__.py | 12 +- src/jaxsim/rbda/contacts/rigid.py | 462 +++++++++ src/jaxsim/rbda/contacts/soft.py | 480 ++++++++++ src/jaxsim/rbda/contacts/visco_elastic.py | 1066 +++++++++++++++++++++ tests/test_automatic_differentiation.py | 63 +- 5 files changed, 2077 insertions(+), 6 deletions(-) create mode 100644 src/jaxsim/rbda/contacts/rigid.py create mode 100644 src/jaxsim/rbda/contacts/soft.py create mode 100644 src/jaxsim/rbda/contacts/visco_elastic.py diff --git a/src/jaxsim/rbda/contacts/__init__.py b/src/jaxsim/rbda/contacts/__init__.py index 3688468cf..06646f14d 100644 --- a/src/jaxsim/rbda/contacts/__init__.py +++ b/src/jaxsim/rbda/contacts/__init__.py @@ -1,5 +1,13 @@ -from . import relaxed_rigid +from . import relaxed_rigid, rigid, soft, visco_elastic from .common import ContactModel, ContactsParams from .relaxed_rigid import RelaxedRigidContacts, RelaxedRigidContactsParams +from .rigid import RigidContacts, RigidContactsParams +from .soft import SoftContacts, SoftContactsParams +from .visco_elastic import ViscoElasticContacts, ViscoElasticContactsParams -ContactParamsTypes = RelaxedRigidContactsParams +ContactParamsTypes = ( + SoftContactsParams + | RigidContactsParams + | RelaxedRigidContactsParams + | ViscoElasticContactsParams +) diff --git a/src/jaxsim/rbda/contacts/rigid.py b/src/jaxsim/rbda/contacts/rigid.py new file mode 100644 index 000000000..d04a7b895 --- /dev/null +++ b/src/jaxsim/rbda/contacts/rigid.py @@ -0,0 +1,462 @@ +from __future__ import annotations + +import dataclasses +from typing import Any + +import jax +import jax.numpy as jnp +import jax_dataclasses + +import jaxsim.api as js +import jaxsim.typing as jtp +from jaxsim import logging +from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr + +from . import common +from .common import ContactModel, ContactsParams + +try: + from typing import Self +except ImportError: + from typing_extensions import Self + + +@jax_dataclasses.pytree_dataclass +class RigidContactsParams(ContactsParams): + """Parameters of the rigid contacts model.""" + + # Static friction coefficient + mu: jtp.Float = dataclasses.field( + default_factory=lambda: jnp.array(0.5, dtype=float) + ) + + # Baumgarte proportional term + K: jtp.Float = dataclasses.field( + default_factory=lambda: jnp.array(0.0, dtype=float) + ) + + # Baumgarte derivative term + D: jtp.Float = dataclasses.field( + default_factory=lambda: jnp.array(0.0, dtype=float) + ) + + def __hash__(self) -> int: + from jaxsim.utils.wrappers import HashedNumpyArray + + return hash( + ( + HashedNumpyArray.hash_of_array(self.mu), + HashedNumpyArray.hash_of_array(self.K), + HashedNumpyArray.hash_of_array(self.D), + ) + ) + + def __eq__(self, other: RigidContactsParams) -> bool: + return hash(self) == hash(other) + + @classmethod + def build( + cls: type[Self], + *, + mu: jtp.FloatLike | None = None, + K: jtp.FloatLike | None = None, + D: jtp.FloatLike | None = None, + ) -> Self: + """Create a `RigidContactParams` instance.""" + + return cls( + mu=jnp.array( + mu + if mu is not None + else cls.__dataclass_fields__["mu"].default_factory() + ).astype(float), + K=jnp.array( + K if K is not None else cls.__dataclass_fields__["K"].default_factory() + ).astype(float), + D=jnp.array( + D if D is not None else cls.__dataclass_fields__["D"].default_factory() + ).astype(float), + ) + + def valid(self) -> jtp.BoolLike: + """Check if the parameters are valid.""" + return bool( + jnp.all(self.mu >= 0.0) + and jnp.all(self.K >= 0.0) + and jnp.all(self.D >= 0.0) + ) + + +@jax_dataclasses.pytree_dataclass +class RigidContacts(ContactModel): + """Rigid contacts model.""" + + regularization_delassus: jax_dataclasses.Static[float] = dataclasses.field( + default=1e-6, kw_only=True + ) + + _solver_options_keys: jax_dataclasses.Static[tuple[str, ...]] = dataclasses.field( + default=("solver_tol",), kw_only=True + ) + _solver_options_values: jax_dataclasses.Static[tuple[Any, ...]] = dataclasses.field( + default=(1e-3,), kw_only=True + ) + + @property + def solver_options(self) -> dict[str, Any]: + """Get the solver options as a dictionary.""" + + return dict( + zip( + self._solver_options_keys, + self._solver_options_values, + strict=True, + ) + ) + + @classmethod + def build( + cls: type[Self], + regularization_delassus: jtp.FloatLike | None = None, + solver_options: dict[str, Any] | None = None, + **kwargs, + ) -> Self: + """ + Create a `RigidContacts` instance with specified parameters. + + Args: + regularization_delassus: + The regularization term to add to the diagonal of the Delassus matrix. + solver_options: The options to pass to the QP solver. + **kwargs: Extra arguments which are ignored. + + Returns: + The `RigidContacts` instance. + """ + + if len(kwargs) != 0: + logging.debug(msg=f"Ignoring extra arguments: {kwargs}") + + # Get the default solver options. + default_solver_options = dict( + zip(cls._solver_options_keys, cls._solver_options_values, strict=True) + ) + + # Create the solver options to set by combining the default solver options + # with the user-provided solver options. + solver_options = default_solver_options | ( + solver_options if solver_options is not None else {} + ) + + # Make sure that the solver options are hashable. + # We need to check this because the solver options are static. + try: + hash(tuple(solver_options.values())) + except TypeError as exc: + raise ValueError( + "The values of the solver options must be hashable." + ) from exc + + return cls( + regularization_delassus=float( + regularization_delassus + if regularization_delassus is not None + else cls.__dataclass_fields__["regularization_delassus"].default + ), + _solver_options_keys=tuple(solver_options.keys()), + _solver_options_values=tuple(solver_options.values()), + **kwargs, + ) + + @staticmethod + def compute_impact_velocity( + inactive_collidable_points: jtp.ArrayLike, + M: jtp.MatrixLike, + J_WC: jtp.MatrixLike, + generalized_velocity: jtp.VectorLike, + ) -> jtp.Vector: + """ + Return the new velocity of the system after a potential impact. + + Args: + inactive_collidable_points: The activation state of the collidable points. + M: The mass matrix of the system (in mixed representation). + J_WC: The Jacobian matrix of the collidable points (in mixed representation). + generalized_velocity: The generalized velocity of the system. + + Note: + The mass matrix `M`, the Jacobian `J_WC`, and the generalized velocity `generalized_velocity` + must be expressed in the same velocity representation. + """ + + # Compute system velocity after impact maintaining zero linear velocity of active points. + sl = jnp.s_[:, 0:3, :] + Jl_WC = J_WC[sl] + + # Zero out the jacobian rows of inactive points. + Jl_WC = jnp.vstack( + jnp.where( + inactive_collidable_points[:, jnp.newaxis, jnp.newaxis], + jnp.zeros_like(Jl_WC), + Jl_WC, + ) + ) + + A = jnp.vstack( + [ + jnp.hstack([M, -Jl_WC.T]), + jnp.hstack([Jl_WC, jnp.zeros((Jl_WC.shape[0], Jl_WC.shape[0]))]), + ] + ) + b = jnp.hstack([M @ generalized_velocity, jnp.zeros(Jl_WC.shape[0])]) + + BW_ν_post_impact = jnp.linalg.lstsq(A, b)[0] + + return BW_ν_post_impact[0 : M.shape[0]] + + @jax.jit + def compute_contact_forces( + self, + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, + *, + link_forces: jtp.MatrixLike | None = None, + joint_force_references: jtp.VectorLike | None = None, + ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]: + """ + Compute the contact forces. + + Args: + model: The model to consider. + data: The data of the considered model. + link_forces: + Optional `(n_links, 6)` matrix of external forces acting on the links, + expressed in the same representation of data. + joint_force_references: + Optional `(n_joints,)` vector of joint forces. + + Returns: + A tuple containing as first element the computed contact forces. + """ + + # Import qpax privately just in this method. + import qpax + + # Get the indices of the enabled collidable points. + indices_of_enabled_collidable_points = ( + model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points + ) + + n_collidable_points = len(indices_of_enabled_collidable_points) + + link_forces = jnp.atleast_2d( + jnp.array(link_forces, dtype=float).squeeze() + if link_forces is not None + else jnp.zeros((model.number_of_links(), 6)) + ) + + joint_force_references = jnp.atleast_1d( + jnp.array(joint_force_references, dtype=float).squeeze() + if joint_force_references is not None + else jnp.zeros((model.number_of_joints(),)) + ) + + # Compute kin-dyn quantities used in the contact model. + with data.switch_velocity_representation(VelRepr.Mixed): + BW_ν = data.generalized_velocity() + + M = js.model.free_floating_mass_matrix(model=model, data=data) + + J_WC = js.contact.jacobian(model=model, data=data) + J̇_WC = js.contact.jacobian_derivative(model=model, data=data) + + W_H_C = js.contact.transforms(model=model, data=data) + + # Compute the position and linear velocities (mixed representation) of + # all enabled collidable points belonging to the robot. + position, velocity = js.contact.collidable_point_kinematics( + model=model, data=data + ) + + # Compute the penetration depth and velocity of the collidable points. + # Note that this function considers the penetration in the normal direction. + δ, δ_dot, n̂ = jax.vmap(common.compute_penetration_data, in_axes=(0, 0, None))( + position, velocity, model.terrain + ) + + # Build a references object to simplify converting link forces. + references = js.references.JaxSimModelReferences.build( + model=model, + data=data, + velocity_representation=data.velocity_representation, + link_forces=link_forces, + joint_force_references=joint_force_references, + ) + + # Compute the generalized free acceleration. + with ( + references.switch_velocity_representation(VelRepr.Mixed), + data.switch_velocity_representation(VelRepr.Mixed), + ): + + BW_ν̇_free = jnp.hstack( + js.ode.system_acceleration( + model=model, + data=data, + link_forces=references.link_forces(model=model, data=data), + joint_force_references=references.joint_force_references( + model=model + ), + ) + ) + + # Compute the free linear acceleration of the collidable points. + # Since we use doubly-mixed jacobian, this corresponds to W_p̈_C. + free_contact_acc = RigidContacts._linear_acceleration_of_collidable_points( + BW_nu=BW_ν, + BW_nu_dot=BW_ν̇_free, + CW_J_WC_BW=J_WC, + CW_J_dot_WC_BW=J̇_WC, + ).flatten() + + # Compute stabilization term. + baumgarte_term = RigidContacts._compute_baumgarte_stabilization_term( + inactive_collidable_points=(δ <= 0), + δ=δ, + δ_dot=δ_dot, + n=n̂, + K=data.contacts_params.K, + D=data.contacts_params.D, + ).flatten() + + # Compute the Delassus matrix. + delassus_matrix = RigidContacts._delassus_matrix(M=M, J_WC=J_WC) + + # Initialize regularization term of the Delassus matrix for + # better numerical conditioning. + Iε = self.regularization_delassus * jnp.eye(delassus_matrix.shape[0]) + + # Construct the quadratic cost function. + Q = delassus_matrix + Iε + q = free_contact_acc - baumgarte_term + + # Construct the inequality constraints. + G = RigidContacts._compute_ineq_constraint_matrix( + inactive_collidable_points=(δ <= 0), mu=data.contacts_params.mu + ) + h_bounds = RigidContacts._compute_ineq_bounds( + n_collidable_points=n_collidable_points + ) + + # Construct the equality constraints. + A = jnp.zeros((0, 3 * n_collidable_points)) + b = jnp.zeros((0,)) + + # Solve the following optimization problem with qpax: + # + # min_{x} 0.5 x⊤ Q x + q⊤ x + # + # s.t. A x = b + # G x ≤ h + # + # TODO: add possibility to notify if the QP problem did not converge. + solution, _, _, _, converged, _ = qpax.solve_qp( # noqa: F841 + Q=Q, q=q, A=A, b=b, G=G, h=h_bounds, **self.solver_options + ) + + # Reshape the optimized solution to be a matrix of 3D contact forces. + CW_fl_C = solution.reshape(-1, 3) + + # Convert the contact forces from mixed to inertial-fixed representation. + W_f_C = jax.vmap( + lambda CW_fl_C, W_H_C: ( + ModelDataWithVelocityRepresentation.other_representation_to_inertial( + array=jnp.zeros(6).at[0:3].set(CW_fl_C), + transform=W_H_C, + other_representation=VelRepr.Mixed, + is_force=True, + ) + ), + )(CW_fl_C, W_H_C) + + return W_f_C, {} + + @staticmethod + def _delassus_matrix( + M: jtp.MatrixLike, + J_WC: jtp.MatrixLike, + ) -> jtp.Matrix: + + sl = jnp.s_[:, 0:3, :] + J_WC_lin = jnp.vstack(J_WC[sl]) + + delassus_matrix = J_WC_lin @ jnp.linalg.pinv(M) @ J_WC_lin.T + return delassus_matrix + + @staticmethod + def _compute_ineq_constraint_matrix( + inactive_collidable_points: jtp.Vector, mu: jtp.FloatLike + ) -> jtp.Matrix: + """ + Compute the inequality constraint matrix for a single collidable point. + + Rows 0-3: enforce the friction pyramid constraint, + Row 4: last one is for the non negativity of the vertical force + Row 5: contact complementarity condition + """ + G_single_point = jnp.array( + [ + [1, 0, -mu], + [0, 1, -mu], + [-1, 0, -mu], + [0, -1, -mu], + [0, 0, -1], + [0, 0, 0], + ] + ) + G = jnp.tile(G_single_point, (len(inactive_collidable_points), 1, 1)) + G = G.at[:, 5, 2].set(inactive_collidable_points) + + G = jax.scipy.linalg.block_diag(*G) + return G + + @staticmethod + def _compute_ineq_bounds(n_collidable_points: int) -> jtp.Vector: + + n_constraints = 6 * n_collidable_points + return jnp.zeros(shape=(n_constraints,)) + + @staticmethod + def _linear_acceleration_of_collidable_points( + BW_nu: jtp.ArrayLike, + BW_nu_dot: jtp.ArrayLike, + CW_J_WC_BW: jtp.MatrixLike, + CW_J_dot_WC_BW: jtp.MatrixLike, + ) -> jtp.Matrix: + + BW_ν = BW_nu + BW_ν̇ = BW_nu_dot + CW_J̇_WC_BW = CW_J_dot_WC_BW + + # Compute the linear acceleration of the collidable points. + # Since we use doubly-mixed jacobians, this corresponds to W_p̈_C. + CW_a_WC = jnp.vstack(CW_J̇_WC_BW) @ BW_ν + jnp.vstack(CW_J_WC_BW) @ BW_ν̇ + + CW_a_WC = CW_a_WC.reshape(-1, 6) + return CW_a_WC[:, 0:3].squeeze() + + @staticmethod + def _compute_baumgarte_stabilization_term( + inactive_collidable_points: jtp.ArrayLike, + δ: jtp.ArrayLike, + δ_dot: jtp.ArrayLike, + n: jtp.ArrayLike, + K: jtp.FloatLike, + D: jtp.FloatLike, + ) -> jtp.Array: + + return jnp.where( + inactive_collidable_points[:, jnp.newaxis], + jnp.zeros_like(n), + (K * δ + D * δ_dot)[:, jnp.newaxis] * n, + ) diff --git a/src/jaxsim/rbda/contacts/soft.py b/src/jaxsim/rbda/contacts/soft.py new file mode 100644 index 000000000..dde16cfb2 --- /dev/null +++ b/src/jaxsim/rbda/contacts/soft.py @@ -0,0 +1,480 @@ +from __future__ import annotations + +import dataclasses +import functools + +import jax +import jax.numpy as jnp +import jax_dataclasses + +import jaxsim.api as js +import jaxsim.math +import jaxsim.typing as jtp +from jaxsim import logging +from jaxsim.math import StandardGravity +from jaxsim.terrain import Terrain + +from . import common + +try: + from typing import Self +except ImportError: + from typing_extensions import Self + + +@jax_dataclasses.pytree_dataclass +class SoftContactsParams(common.ContactsParams): + """Parameters of the soft contacts model.""" + + K: jtp.Float = dataclasses.field( + default_factory=lambda: jnp.array(1e6, dtype=float) + ) + + D: jtp.Float = dataclasses.field( + default_factory=lambda: jnp.array(2000, dtype=float) + ) + + mu: jtp.Float = dataclasses.field( + default_factory=lambda: jnp.array(0.5, dtype=float) + ) + + p: jtp.Float = dataclasses.field( + default_factory=lambda: jnp.array(0.5, dtype=float) + ) + + q: jtp.Float = dataclasses.field( + default_factory=lambda: jnp.array(0.5, dtype=float) + ) + + def __hash__(self) -> int: + + from jaxsim.utils.wrappers import HashedNumpyArray + + return hash( + ( + HashedNumpyArray.hash_of_array(self.K), + HashedNumpyArray.hash_of_array(self.D), + HashedNumpyArray.hash_of_array(self.mu), + HashedNumpyArray.hash_of_array(self.p), + HashedNumpyArray.hash_of_array(self.q), + ) + ) + + def __eq__(self, other: SoftContactsParams) -> bool: + + if not isinstance(other, SoftContactsParams): + return NotImplemented + + return hash(self) == hash(other) + + @classmethod + def build( + cls: type[Self], + *, + K: jtp.FloatLike = 1e6, + D: jtp.FloatLike = 2_000, + mu: jtp.FloatLike = 0.5, + p: jtp.FloatLike = 0.5, + q: jtp.FloatLike = 0.5, + ) -> Self: + """ + Create a SoftContactsParams instance with specified parameters. + + Args: + K: The stiffness parameter. + D: The damping parameter of the soft contacts model. + mu: The static friction coefficient. + p: + The exponent p corresponding to the damping-related non-linearity + of the Hunt/Crossley model. + q: + The exponent q corresponding to the spring-related non-linearity + of the Hunt/Crossley model + + Returns: + A SoftContactsParams instance with the specified parameters. + """ + + return SoftContactsParams( + K=jnp.array(K, dtype=float), + D=jnp.array(D, dtype=float), + mu=jnp.array(mu, dtype=float), + p=jnp.array(p, dtype=float), + q=jnp.array(q, dtype=float), + ) + + @classmethod + def build_default_from_jaxsim_model( + cls: type[Self], + model: js.model.JaxSimModel, + *, + standard_gravity: jtp.FloatLike = StandardGravity, + static_friction_coefficient: jtp.FloatLike = 0.5, + max_penetration: jtp.FloatLike = 0.001, + number_of_active_collidable_points_steady_state: jtp.IntLike = 1, + damping_ratio: jtp.FloatLike = 1.0, + p: jtp.FloatLike = 0.5, + q: jtp.FloatLike = 0.5, + ) -> SoftContactsParams: + """ + Create a SoftContactsParams instance with good default parameters. + + Args: + model: The target model. + standard_gravity: The standard gravity constant. + static_friction_coefficient: + The static friction coefficient between the model and the terrain. + max_penetration: The maximum penetration depth. + number_of_active_collidable_points_steady_state: + The number of contacts supporting the weight of the model + in steady state. + damping_ratio: The ratio controlling the damping behavior. + p: + The exponent p corresponding to the damping-related non-linearity + of the Hunt/Crossley model. + q: + The exponent q corresponding to the spring-related non-linearity + of the Hunt/Crossley model + + Returns: + A `SoftContactsParams` instance with the specified parameters. + + Note: + The `damping_ratio` parameter allows to operate on the following conditions: + - ξ > 1.0: over-damped + - ξ = 1.0: critically damped + - ξ < 1.0: under-damped + """ + + # Use symbols for input parameters. + ξ = damping_ratio + δ_max = max_penetration + μc = static_friction_coefficient + + # Compute the total mass of the model. + m = jnp.array(model.kin_dyn_parameters.link_parameters.mass).sum() + + # Rename the standard gravity. + g = standard_gravity + + # Compute the average support force on each collidable point. + f_average = m * g / number_of_active_collidable_points_steady_state + + # Compute the stiffness to get the desired steady-state penetration. + # Note that this is dependent on the non-linear exponent used in + # the damping term of the Hunt/Crossley model. + K = f_average / jnp.power(δ_max, 1 + p) + + # Compute the damping using the damping ratio. + critical_damping = 2 * jnp.sqrt(K * m) + D = ξ * critical_damping + + return SoftContactsParams.build(K=K, D=D, mu=μc, p=p, q=q) + + def valid(self) -> jtp.BoolLike: + """ + Check if the parameters are valid. + + Returns: + `True` if the parameters are valid, `False` otherwise. + """ + + return jnp.hstack( + [ + self.K >= 0.0, + self.D >= 0.0, + self.mu >= 0.0, + self.p >= 0.0, + self.q >= 0.0, + ] + ).all() + + +@jax_dataclasses.pytree_dataclass +class SoftContacts(common.ContactModel): + """Soft contacts model.""" + + @classmethod + def build( + cls: type[Self], + model: js.model.JaxSimModel | None = None, + **kwargs, + ) -> Self: + """ + Create a `SoftContacts` instance with specified parameters. + + Args: + model: + The robot model considered by the contact model. + If passed, it is used to estimate good default parameters. + **kwargs: Additional parameters to pass to the contact model. + + Returns: + The `SoftContacts` instance. + """ + + if len(kwargs) != 0: + logging.debug(msg=f"Ignoring extra arguments: {kwargs}") + + return cls(**kwargs) + + @classmethod + def zero_state_variables(cls, model: js.model.JaxSimModel) -> dict[str, jtp.Array]: + """ + Build zero state variables of the contact model. + """ + + # Initialize the material deformation to zero. + tangential_deformation = jnp.zeros( + shape=(len(model.kin_dyn_parameters.contact_parameters.body), 3), + dtype=float, + ) + + return {"tangential_deformation": tangential_deformation} + + @staticmethod + @functools.partial(jax.jit, static_argnames=("terrain",)) + def hunt_crossley_contact_model( + position: jtp.VectorLike, + velocity: jtp.VectorLike, + tangential_deformation: jtp.VectorLike, + terrain: Terrain, + K: jtp.FloatLike, + D: jtp.FloatLike, + mu: jtp.FloatLike, + p: jtp.FloatLike = 0.5, + q: jtp.FloatLike = 0.5, + ) -> tuple[jtp.Vector, jtp.Vector]: + """ + Compute the contact force using the Hunt/Crossley model. + + Args: + position: The position of the collidable point. + velocity: The velocity of the collidable point. + tangential_deformation: The material deformation of the collidable point. + terrain: The terrain model. + K: The stiffness parameter. + D: The damping parameter of the soft contacts model. + mu: The static friction coefficient. + p: + The exponent p corresponding to the damping-related non-linearity + of the Hunt/Crossley model. + q: + The exponent q corresponding to the spring-related non-linearity + of the Hunt/Crossley model + + Returns: + A tuple containing the computed contact force and the derivative of the + material deformation. + """ + + # Convert the input vectors to arrays. + W_p_C = jnp.array(position, dtype=float).squeeze() + W_ṗ_C = jnp.array(velocity, dtype=float).squeeze() + m = jnp.array(tangential_deformation, dtype=float).squeeze() + + # Use symbol for the static friction. + μ = mu + + # Compute the penetration depth, its rate, and the considered terrain normal. + δ, δ̇, n̂ = common.compute_penetration_data(p=W_p_C, v=W_ṗ_C, terrain=terrain) + + # There are few operations like computing the norm of a vector with zero length + # or computing the square root of zero that are problematic in an AD context. + # To avoid these issues, we introduce a small tolerance ε to their arguments + # and make sure that we do not check them against zero directly. + ε = jnp.finfo(float).eps + + # Compute the powers of the penetration depth. + # Inject ε to address AD issues in differentiating the square root when + # p and q are fractional. + δp = jnp.power(δ + ε, p) + δq = jnp.power(δ + ε, q) + + # ======================== + # Compute the normal force + # ======================== + + # Non-linear spring-damper model (Hunt/Crossley model). + # This is the force magnitude along the direction normal to the terrain. + force_normal_mag = (K * δp) * δ + (D * δq) * δ̇ + + # Depending on the magnitude of δ̇, the normal force could be negative. + force_normal_mag = jnp.maximum(0.0, force_normal_mag) + + # Compute the 3D linear force in C[W] frame. + f_normal = force_normal_mag * n̂ + + # ============================ + # Compute the tangential force + # ============================ + + # Extract the tangential component of the velocity. + v_tangential = W_ṗ_C - jnp.dot(W_ṗ_C, n̂) * n̂ + + # Extract the normal and tangential components of the material deformation. + m_normal = jnp.dot(m, n̂) * n̂ + m_tangential = m - jnp.dot(m, n̂) * n̂ + + # Compute the tangential force in the sticking case. + # Using the tangential component of the material deformation should not be + # necessary if the sticking-slipping transition occurs in a terrain area + # with a locally constant normal. However, this assumption is not true in + # general, especially for highly uneven terrains. + f_tangential = -((K * δp) * m_tangential + (D * δq) * v_tangential) + + # Detect the contact type (sticking or slipping). + # Note that if there is no contact, sticking is set to True, and this detail + # is exploited in the computation of the `contact_status` variable. + sticking = jnp.logical_or( + δ <= 0, f_tangential.dot(f_tangential) <= (μ * force_normal_mag) ** 2 + ) + + # Compute the direction of the tangential force. + # To prevent dividing by zero, we use a switch statement. + norm = jaxsim.math.safe_norm(f_tangential) + f_tangential_direction = f_tangential / ( + norm + jnp.finfo(float).eps * (norm == 0) + ) + + # Project the tangential force to the friction cone if slipping. + f_tangential = jnp.where( + sticking, + f_tangential, + jnp.minimum(μ * force_normal_mag, norm) * f_tangential_direction, + ) + + # Set the tangential force to zero if there is no contact. + f_tangential = jnp.where(δ <= 0, jnp.zeros(3), f_tangential) + + # ===================================== + # Compute the material deformation rate + # ===================================== + + # Compute the derivative of the material deformation. + # Note that we included an additional relaxation of `m_normal` in the + # sticking case, so that the normal deformation that could have accumulated + # from a previous slipping phase can relax to zero. + ṁ_no_contact = -(K / D) * m + ṁ_sticking = v_tangential - (K / D) * m_normal + ṁ_slipping = -(f_tangential + (K * δp) * m_tangential) / (D * δq) + + # Compute the contact status: + # 0: slipping + # 1: sticking + # 2: no contact + contact_status = sticking.astype(int) + contact_status += (δ <= 0).astype(int) + + # Select the right material deformation rate depending on the contact status. + ṁ = jax.lax.select_n(contact_status, ṁ_slipping, ṁ_sticking, ṁ_no_contact) + + # ========================================== + # Compute and return the final contact force + # ========================================== + + # Sum the normal and tangential forces. + CW_fl = f_normal + f_tangential + + return CW_fl, ṁ + + @staticmethod + @functools.partial(jax.jit, static_argnames=("terrain",)) + def compute_contact_force( + position: jtp.VectorLike, + velocity: jtp.VectorLike, + tangential_deformation: jtp.VectorLike, + parameters: SoftContactsParams, + terrain: Terrain, + ) -> tuple[jtp.Vector, jtp.Vector]: + """ + Compute the contact force. + + Args: + position: The position of the collidable point. + velocity: The velocity of the collidable point. + tangential_deformation: The material deformation of the collidable point. + parameters: The parameters of the soft contacts model. + terrain: The terrain model. + + Returns: + A tuple containing the computed contact force and the derivative of the + material deformation. + """ + + CW_fl, ṁ = SoftContacts.hunt_crossley_contact_model( + position=position, + velocity=velocity, + tangential_deformation=tangential_deformation, + terrain=terrain, + K=parameters.K, + D=parameters.D, + mu=parameters.mu, + p=parameters.p, + q=parameters.q, + ) + + # Pack a mixed 6D force. + CW_f = jnp.hstack([CW_fl, jnp.zeros(3)]) + + # Compute the 6D force transform from the mixed to the inertial-fixed frame. + W_Xf_CW = jaxsim.math.Adjoint.from_quaternion_and_translation( + translation=jnp.array(position), inverse=True + ).T + + # Compute the 6D force in the inertial-fixed frame. + W_f = W_Xf_CW @ CW_f + + return W_f, ṁ + + @staticmethod + @jax.jit + def compute_contact_forces( + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, + ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]: + """ + Compute the contact forces. + + Args: + model: The model to consider. + data: The data of the considered model. + + Returns: + A tuple containing as first element the computed contact forces, and as + second element a dictionary with derivative of the material deformation. + """ + + # Get the indices of the enabled collidable points. + indices_of_enabled_collidable_points = ( + model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points + ) + + # Compute the position and linear velocities (mixed representation) of + # all the collidable points belonging to the robot and extract the ones + # for the enabled collidable points. + W_p_C, W_ṗ_C = js.contact.collidable_point_kinematics(model=model, data=data) + + # Extract the material deformation corresponding to the collidable points. + m = data.state.extended["tangential_deformation"] + + m_enabled = m[indices_of_enabled_collidable_points] + + # Initialize the tangential deformation rate array for every collidable point. + ṁ = jnp.zeros_like(m) + + # Compute the contact forces only for the enabled collidable points. + # Since we treat them as independent, we can vmap the computation. + W_f, ṁ_enabled = jax.vmap( + lambda p, v, m: SoftContacts.compute_contact_force( + position=p, + velocity=v, + tangential_deformation=m, + parameters=data.contacts_params, + terrain=model.terrain, + ) + )(W_p_C, W_ṗ_C, m_enabled) + + ṁ = ṁ.at[indices_of_enabled_collidable_points].set(ṁ_enabled) + + return W_f, dict(m_dot=ṁ) diff --git a/src/jaxsim/rbda/contacts/visco_elastic.py b/src/jaxsim/rbda/contacts/visco_elastic.py new file mode 100644 index 000000000..40ad4ab61 --- /dev/null +++ b/src/jaxsim/rbda/contacts/visco_elastic.py @@ -0,0 +1,1066 @@ +from __future__ import annotations + +import dataclasses +import functools +from typing import Any + +import jax +import jax.numpy as jnp +import jax_dataclasses + +import jaxsim +import jaxsim.api as js +import jaxsim.exceptions +import jaxsim.typing as jtp +from jaxsim import logging +from jaxsim.api.common import ModelDataWithVelocityRepresentation +from jaxsim.math import StandardGravity +from jaxsim.terrain import Terrain + +from . import common +from .soft import SoftContacts, SoftContactsParams + +try: + from typing import Self +except ImportError: + from typing_extensions import Self + + +@jax_dataclasses.pytree_dataclass +class ViscoElasticContactsParams(common.ContactsParams): + """Parameters of the visco-elastic contacts model.""" + + K: jtp.Float = dataclasses.field( + default_factory=lambda: jnp.array(1e6, dtype=float) + ) + + D: jtp.Float = dataclasses.field( + default_factory=lambda: jnp.array(2000, dtype=float) + ) + + static_friction: jtp.Float = dataclasses.field( + default_factory=lambda: jnp.array(0.5, dtype=float) + ) + + p: jtp.Float = dataclasses.field( + default_factory=lambda: jnp.array(0.5, dtype=float) + ) + + q: jtp.Float = dataclasses.field( + default_factory=lambda: jnp.array(0.5, dtype=float) + ) + + @classmethod + def build( + cls: type[Self], + K: jtp.FloatLike = 1e6, + D: jtp.FloatLike = 2_000, + static_friction: jtp.FloatLike = 0.5, + p: jtp.FloatLike = 0.5, + q: jtp.FloatLike = 0.5, + ) -> Self: + """ + Create a SoftContactsParams instance with specified parameters. + + Args: + K: The stiffness parameter. + D: The damping parameter of the soft contacts model. + static_friction: The static friction coefficient. + p: + The exponent p corresponding to the damping-related non-linearity + of the Hunt/Crossley model. + q: + The exponent q corresponding to the spring-related non-linearity + of the Hunt/Crossley model. + + Returns: + A ViscoElasticParams instance with the specified parameters. + """ + + return ViscoElasticContactsParams( + K=jnp.array(K, dtype=float), + D=jnp.array(D, dtype=float), + static_friction=jnp.array(static_friction, dtype=float), + p=jnp.array(p, dtype=float), + q=jnp.array(q, dtype=float), + ) + + @classmethod + def build_default_from_jaxsim_model( + cls: type[Self], + model: js.model.JaxSimModel, + *, + standard_gravity: jtp.FloatLike = StandardGravity, + static_friction_coefficient: jtp.FloatLike = 0.5, + max_penetration: jtp.FloatLike = 0.001, + number_of_active_collidable_points_steady_state: jtp.IntLike = 1, + damping_ratio: jtp.FloatLike = 1.0, + p: jtp.FloatLike = 0.5, + q: jtp.FloatLike = 0.5, + ) -> Self: + """ + Create a ViscoElasticContactsParams instance with good default parameters. + + Args: + model: The target model. + standard_gravity: The standard gravity constant. + static_friction_coefficient: + The static friction coefficient between the model and the terrain. + max_penetration: The maximum penetration depth. + number_of_active_collidable_points_steady_state: + The number of contacts supporting the weight of the model + in steady state. + damping_ratio: The ratio controlling the damping behavior. + p: + The exponent p corresponding to the damping-related non-linearity + of the Hunt/Crossley model. + q: + The exponent q corresponding to the spring-related non-linearity + of the Hunt/Crossley model. + + Returns: + A `ViscoElasticContactsParams` instance with the specified parameters. + + Note: + The `damping_ratio` parameter allows to operate on the following conditions: + - ξ > 1.0: over-damped + - ξ = 1.0: critically damped + - ξ < 1.0: under-damped + """ + + # Call the SoftContact builder instead of duplicating the logic. + soft_contacts_params = SoftContactsParams.build_default_from_jaxsim_model( + model=model, + standard_gravity=standard_gravity, + static_friction_coefficient=static_friction_coefficient, + max_penetration=max_penetration, + number_of_active_collidable_points_steady_state=number_of_active_collidable_points_steady_state, + damping_ratio=damping_ratio, + ) + + return ViscoElasticContactsParams.build( + K=soft_contacts_params.K, + D=soft_contacts_params.D, + static_friction=soft_contacts_params.mu, + p=p, + q=q, + ) + + def valid(self) -> jtp.BoolLike: + """ + Check if the parameters are valid. + + Returns: + `True` if the parameters are valid, `False` otherwise. + """ + + return ( + jnp.all(self.K >= 0.0) + and jnp.all(self.D >= 0.0) + and jnp.all(self.static_friction >= 0.0) + and jnp.all(self.p >= 0.0) + and jnp.all(self.q >= 0.0) + ) + + def __hash__(self) -> int: + + from jaxsim.utils.wrappers import HashedNumpyArray + + return hash( + ( + HashedNumpyArray.hash_of_array(self.K), + HashedNumpyArray.hash_of_array(self.D), + HashedNumpyArray.hash_of_array(self.static_friction), + HashedNumpyArray.hash_of_array(self.p), + HashedNumpyArray.hash_of_array(self.q), + ) + ) + + def __eq__(self, other: ViscoElasticContactsParams) -> bool: + + if not isinstance(other, ViscoElasticContactsParams): + return False + + return hash(self) == hash(other) + + +@jax_dataclasses.pytree_dataclass +class ViscoElasticContacts(common.ContactModel): + """Visco-elastic contacts model.""" + + max_squarings: jax_dataclasses.Static[int] = dataclasses.field(default=25) + + @classmethod + def build( + cls: type[Self], + model: js.model.JaxSimModel | None = None, + max_squarings: jtp.IntLike | None = None, + **kwargs, + ) -> Self: + """ + Create a `ViscoElasticContacts` instance with specified parameters. + + Args: + model: + The robot model considered by the contact model. + If passed, it is used to estimate good default parameters. + max_squarings: + The maximum number of squarings performed in the matrix exponential. + **kwargs: Extra arguments to ignore. + + Returns: + The `ViscoElasticContacts` instance. + """ + + if len(kwargs) != 0: + logging.debug(msg=f"Ignoring extra arguments: {kwargs}") + + return cls( + max_squarings=int( + max_squarings + if max_squarings is not None + else cls.__dataclass_fields__["max_squarings"].default + ), + ) + + @classmethod + def zero_state_variables(cls, model: js.model.JaxSimModel) -> dict[str, jtp.Array]: + """ + Build zero state variables of the contact model. + """ + + # Initialize the material deformation to zero. + tangential_deformation = jnp.zeros( + shape=(len(model.kin_dyn_parameters.contact_parameters.body), 3), + dtype=float, + ) + + return {"tangential_deformation": tangential_deformation} + + @jax.jit + def compute_contact_forces( + self, + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, + *, + dt: jtp.FloatLike | None = None, + link_forces: jtp.MatrixLike | None = None, + joint_force_references: jtp.VectorLike | None = None, + ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]: + """ + Compute the contact forces. + + Args: + model: The robot model considered by the contact model. + data: The data of the considered model. + dt: The time step to consider. If not specified, it is read from the model. + link_forces: + The 6D forces to apply to the links expressed in the frame corresponding + to the velocity representation of `data`. + joint_force_references: The joint force references to apply. + + Note: + This contact model, contrarily to most other contact models, requires the + knowledge of the integration step. It is not straightforward to assess how + this contact model behaves when used with high-order Runge-Kutta schemes. + For the time being, it is recommended to use a simple forward Euler scheme. + The main benefit of this model is that the stiff contact dynamics is computed + separately from the rest of the system dynamics, which allows to use simple + integration schemes without altering significantly the simulation stability. + + Returns: + A tuple containing as first element the computed 6D contact force applied to + the contact point and expressed in the world frame, and as second element + a dictionary of optional additional information. + """ + + # Extract the indices corresponding to the enabled collidable points. + indices_of_enabled_collidable_points = ( + model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points + ) + + # Initialize the time step. + dt = dt if dt is not None else model.time_step + + # Compute the average contact linear forces in mixed representation by + # integrating the contact dynamics in the continuous time domain. + CW_f̅l, CW_fl̿, m_tf = ( + ViscoElasticContacts._compute_contact_forces_with_exponential_integration( + model=model, + data=data, + dt=jnp.array(dt).astype(float), + link_forces=link_forces, + joint_force_references=joint_force_references, + indices_of_enabled_collidable_points=indices_of_enabled_collidable_points, + max_squarings=self.max_squarings, + ) + ) + + # ============================================ + # Compute the inertial-fixed 6D contact forces + # ============================================ + + # Compute the transforms of the mixed frames `C[W] = (W_p_C, [W])` + # associated to each collidable point. + W_H_C = js.contact.transforms(model=model, data=data)[ + indices_of_enabled_collidable_points, :, : + ] + + # Vmapped transformation from mixed to inertial-fixed representation. + compute_forces_inertial_fixed_vmap = jax.vmap( + lambda CW_fl_C, W_H_C: ( + ModelDataWithVelocityRepresentation.other_representation_to_inertial( + array=jnp.zeros(6).at[0:3].set(CW_fl_C), + other_representation=jaxsim.VelRepr.Mixed, + transform=W_H_C, + is_force=True, + ) + ) + ) + + # Express the linear contact forces in the inertial-fixed frame. + W_f̅_C, W_f̿_C = jax.vmap( + lambda CW_fl: compute_forces_inertial_fixed_vmap(CW_fl, W_H_C) + )(jnp.stack([CW_f̅l, CW_fl̿])) + + return W_f̅_C, dict(W_f_avg2_C=W_f̿_C, m_tf=m_tf) + + @staticmethod + @functools.partial(jax.jit, static_argnames=("max_squarings",)) + def _compute_contact_forces_with_exponential_integration( + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, + *, + dt: jtp.FloatLike, + link_forces: jtp.MatrixLike | None = None, + joint_force_references: jtp.VectorLike | None = None, + indices_of_enabled_collidable_points: jtp.VectorLike | None = None, + max_squarings: int = 25, + ) -> tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix]: + """ + Compute the average contact forces by integrating the contact dynamics. + + Args: + model: The robot model considered by the contact model. + data: The data of the considered model. + dt: The integration time step. + link_forces: The 6D forces to apply to the links. + joint_force_references: The joint force references to apply. + indices_of_enabled_collidable_points: + The indices of the enabled collidable points. + max_squarings: + The maximum number of squarings performed in the matrix exponential. + + Returns: + A tuple containing: + - The average contact forces. + - The average of the average contact forces. + - The tangential deformation at the final state. + """ + + # ========================== + # Populate missing arguments + # ========================== + + indices = ( + indices_of_enabled_collidable_points + if indices_of_enabled_collidable_points is not None + else jnp.arange( + len(model.kin_dyn_parameters.contact_parameters.body) + ).astype(int) + ) + + # ================================== + # Compute the contact point dynamics + # ================================== + + p_t0, v_t0 = js.contact.collidable_point_kinematics(model, data) + m_t0 = data.state.extended["tangential_deformation"][indices, :] + + p_t0 = p_t0[indices, :] + v_t0 = v_t0[indices, :] + + # Compute the linearized contact dynamics. + # Note that it linearizes the (non-linear) contact model at (p, v, m)[t0]. + A, b, A_sc, b_sc = ViscoElasticContacts._contact_points_dynamics( + model=model, + data=data, + link_forces=link_forces, + joint_force_references=joint_force_references, + indices_of_enabled_collidable_points=indices, + p_t0=p_t0, + v_t0=v_t0, + m_t0=m_t0, + ) + + # ============================================= + # Compute the integrals of the contact dynamics + # ============================================= + + # Pack the initial state of the contact points. + x_t0 = jnp.hstack([p_t0.flatten(), v_t0.flatten(), m_t0.flatten()]) + + # Pack the augmented matrix used to compute the single and double integral + # of the exponential integration. + A̅ = jnp.vstack( + [ + jnp.hstack( + [ + A, + jnp.vstack(b), + jnp.vstack(x_t0), + jnp.vstack(jnp.zeros_like(x_t0)), + ] + ), + jnp.hstack([jnp.zeros(A.shape[1]), 0, 1, 0]), + jnp.hstack([jnp.zeros(A.shape[1]), 0, 0, 1]), + jnp.hstack([jnp.zeros(A.shape[1]), 0, 0, 0]), + ] + ) + + # Compute the matrix exponential. + exp_tA = jax.scipy.linalg.expm( + (dt * A̅).astype(float), max_squarings=max_squarings + ) + + # Integrate the contact dynamics in the continuous time domain. + x_int, x_int2 = ( + jnp.hstack([jnp.eye(A.shape[0]), jnp.zeros(shape=(A.shape[0], 3))]) + @ exp_tA + @ jnp.vstack([jnp.zeros(shape=(A.shape[0] + 1, 2)), jnp.eye(2)]) + ).T + + jaxsim.exceptions.raise_runtime_error_if( + condition=jnp.isnan(x_int).any(), + msg="NaN integration, try to increase `max_squarings` or decreasing `dt`", + ) + + # ========================== + # Compute the contact forces + # ========================== + + # Compute the average contact forces. + CW_f̅, _ = jnp.split( + (A_sc @ x_int / dt + b_sc).reshape(-1, 3), + indices_or_sections=2, + ) + + # Compute the average of the average contact forces. + CW_f̿, _ = jnp.split( + (A_sc @ x_int2 * 2 / (dt**2) + b_sc).reshape(-1, 3), + indices_or_sections=2, + ) + + # Extract the tangential deformation at the final state. + x_tf = x_int / dt + m_tf = jnp.split(x_tf, 3)[2].reshape(-1, 3) + + return CW_f̅, CW_f̿, m_tf + + @staticmethod + @jax.jit + def _contact_points_dynamics( + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, + *, + link_forces: jtp.MatrixLike | None = None, + joint_force_references: jtp.VectorLike | None = None, + indices_of_enabled_collidable_points: jtp.VectorLike | None = None, + p_t0: jtp.MatrixLike | None = None, + v_t0: jtp.MatrixLike | None = None, + m_t0: jtp.MatrixLike | None = None, + ) -> tuple[jtp.Matrix, jtp.Vector, jtp.Matrix, jtp.Vector]: + """ + Compute the dynamics of the contact points. + + Note: + This function projects the system dynamics to the contact space and + returns the matrices of a linear system to simulate its evolution. + Since the active contact model can be non-linear, this function also + linearizes the contact model at the initial state. + + Args: + model: The robot model considered by the contact model. + data: The data of the considered model. + link_forces: The 6D forces to apply to the links. + joint_force_references: The joint force references to apply. + indices_of_enabled_collidable_points: + The indices of the enabled collidable points. + p_t0: The initial position of the collidable points. + v_t0: The initial velocity of the collidable points. + m_t0: The initial tangential deformation of the collidable points. + + Returns: + A tuple containing: + - The `A` matrix of the linear system that models the contact dynamics. + - The `b` vector of the linear system that models the contact dynamics. + - The `A_sc` matrix of the linear system that approximates the contact model. + - The `b_sc` vector of the linear system that approximates the contact model. + """ + + indices_of_enabled_collidable_points = ( + indices_of_enabled_collidable_points + if indices_of_enabled_collidable_points is not None + else jnp.arange( + len(model.kin_dyn_parameters.contact_parameters.body) + ).astype(int) + ) + + p_t0 = jnp.atleast_2d( + p_t0 + if p_t0 is not None + else js.contact.collidable_point_positions(model=model, data=data)[ + indices_of_enabled_collidable_points, : + ] + ) + + v_t0 = jnp.atleast_2d( + v_t0 + if v_t0 is not None + else js.contact.collidable_point_velocities(model=model, data=data)[ + indices_of_enabled_collidable_points, : + ] + ) + + m_t0 = jnp.atleast_2d( + m_t0 + if m_t0 is not None + else data.state.extended["tangential_deformation"][ + indices_of_enabled_collidable_points, : + ] + ) + + # We expect that the 6D forces of the `link_forces` argument are expressed + # in the frame corresponding to the velocity representation of `data`. + references = js.references.JaxSimModelReferences.build( + model=model, + link_forces=link_forces, + joint_force_references=joint_force_references, + data=data, + velocity_representation=data.velocity_representation, + ) + + # =========================== + # Linearize the contact model + # =========================== + + # Linearize the contact model at the initial state of all considered + # contact points. + A_sc_points, b_sc_points = jax.vmap( + lambda p, v, m: ViscoElasticContacts._linearize_contact_model( + position=p, + velocity=v, + tangential_deformation=m, + parameters=data.contacts_params, + terrain=model.terrain, + ) + )(p_t0, v_t0, m_t0) + + # Since x = [p1, p2, ..., v1, v2, ..., m1, m2, ...], we need to split the A_sc of + # individual points since otherwise we'd get x = [ p1, v1, m1, p2, v2, m2, ...]. + A_sc_p, A_sc_v, A_sc_m = jnp.split(A_sc_points, indices_or_sections=3, axis=-1) + + # We want to have in output first the forces and then the material deformation rates. + # Therefore, we need to extract the components is A_sc_* separately. + A_sc = jnp.vstack( + [ + jnp.hstack( + [ + jax.scipy.linalg.block_diag(*A_sc_p[:, 0:3, :]), + jax.scipy.linalg.block_diag(*A_sc_v[:, 0:3, :]), + jax.scipy.linalg.block_diag(*A_sc_m[:, 0:3, :]), + ], + ), + jnp.hstack( + [ + jax.scipy.linalg.block_diag(*A_sc_p[:, 3:6, :]), + jax.scipy.linalg.block_diag(*A_sc_v[:, 3:6, :]), + jax.scipy.linalg.block_diag(*A_sc_m[:, 3:6, :]), + ] + ), + ] + ) + + # We need to do the same for the b_sc. + b_sc = jnp.hstack( + [b_sc_points[:, 0:3].flatten(), b_sc_points[:, 3:6].flatten()] + ) + + # =========================================================== + # Compute the A and b matrices of the contact points dynamics + # =========================================================== + + with data.switch_velocity_representation(jaxsim.VelRepr.Mixed): + + BW_ν = data.generalized_velocity() + + M = js.model.free_floating_mass_matrix(model=model, data=data) + + CW_Jl_WC = js.contact.jacobian( + model=model, + data=data, + output_vel_repr=jaxsim.VelRepr.Mixed, + )[indices_of_enabled_collidable_points, 0:3, :] + + CW_J̇l_WC = js.contact.jacobian_derivative( + model=model, data=data, output_vel_repr=jaxsim.VelRepr.Mixed + )[indices_of_enabled_collidable_points, 0:3, :] + + # Compute the Delassus matrix. + ψ = jnp.vstack(CW_Jl_WC) @ jnp.linalg.lstsq(M, jnp.vstack(CW_Jl_WC).T)[0] + + I_nc = jnp.eye(v_t0.flatten().size) + O_nc = jnp.zeros(shape=(p_t0.flatten().size, p_t0.flatten().size)) + + # Pack the A matrix. + A = jnp.vstack( + [ + jnp.hstack([O_nc, I_nc, O_nc]), + ψ @ jnp.split(A_sc, 2, axis=0)[0], + jnp.split(A_sc, 2, axis=0)[1], + ] + ) + + # Short names for few variables. + ν = BW_ν + J = jnp.vstack(CW_Jl_WC) + J̇ = jnp.vstack(CW_J̇l_WC) + + # Compute the free system acceleration components. + with ( + data.switch_velocity_representation(jaxsim.VelRepr.Mixed), + references.switch_velocity_representation(jaxsim.VelRepr.Mixed), + ): + + BW_v̇_free_WB, s̈_free = js.ode.system_acceleration( + model=model, + data=data, + link_forces=references.link_forces(model=model, data=data), + joint_force_references=references.joint_force_references(model=model), + ) + + # Pack the free system acceleration in mixed representation. + ν̇_free = jnp.hstack([BW_v̇_free_WB, s̈_free]) + + # Compute the acceleration of collidable points. + # This is the true derivative of ṗ only in mixed representation. + p̈ = J @ ν̇_free + J̇ @ ν + + # Pack the b array. + b = jnp.hstack( + [ + jnp.zeros_like(p_t0.flatten()), + p̈ + ψ @ jnp.split(b_sc, indices_or_sections=2)[0], + jnp.split(b_sc, indices_or_sections=2)[1], + ] + ) + + return A, b, A_sc, b_sc + + @staticmethod + @functools.partial(jax.jit, static_argnames=("terrain",)) + def _linearize_contact_model( + position: jtp.VectorLike, + velocity: jtp.VectorLike, + tangential_deformation: jtp.VectorLike, + parameters: ViscoElasticContactsParams, + terrain: Terrain, + ) -> tuple[jtp.Matrix, jtp.Vector]: + """ + Linearize the Hunt/Crossley contact model at the initial state. + + Args: + position: The position of the contact point. + velocity: The velocity of the contact point. + tangential_deformation: The tangential deformation of the contact point. + parameters: The parameters of the contact model. + terrain: The considered terrain. + + Returns: + A tuple containing the `A` matrix and the `b` vector of the linear system + corresponding to the contact dynamics linearized at the initial state. + """ + + # Initialize the state at which the model is linearized. + p0 = jnp.array(position, dtype=float).squeeze() + v0 = jnp.array(velocity, dtype=float).squeeze() + m0 = jnp.array(tangential_deformation, dtype=float).squeeze() + + # ============ + # Compute A_sc + # ============ + + compute_contact_force_non_linear_model = functools.partial( + ViscoElasticContacts._compute_contact_force_non_linear_model, + parameters=parameters, + terrain=terrain, + ) + + # Compute with AD the functions to get the Jacobians of CW_fl. + df_dp_fun, df_dv_fun, df_dm_fun = ( + jax.jacrev( + lambda p0, v0, m0: compute_contact_force_non_linear_model( + position=p0, velocity=v0, tangential_deformation=m0 + )[0], + argnums=num, + ) + for num in (0, 1, 2) + ) + + # Compute with AD the functions to get the Jacobians of ṁ. + dṁ_dp_fun, dṁ_dv_fun, dṁ_dm_fun = ( + jax.jacrev( + lambda p0, v0, m0: compute_contact_force_non_linear_model( + position=p0, velocity=v0, tangential_deformation=m0 + )[1], + argnums=num, + ) + for num in (0, 1, 2) + ) + + # Compute the Jacobians of the contact forces w.r.t. the state. + df_dp = jnp.vstack(df_dp_fun(p0, v0, m0)) + df_dv = jnp.vstack(df_dv_fun(p0, v0, m0)) + df_dm = jnp.vstack(df_dm_fun(p0, v0, m0)) + + # Compute the Jacobians of the material deformation rate w.r.t. the state. + dṁ_dp = jnp.vstack(dṁ_dp_fun(p0, v0, m0)) + dṁ_dv = jnp.vstack(dṁ_dv_fun(p0, v0, m0)) + dṁ_dm = jnp.vstack(dṁ_dm_fun(p0, v0, m0)) + + # Pack the A matrix. + A_sc = jnp.vstack( + [ + jnp.hstack([df_dp, df_dv, df_dm]), + jnp.hstack([dṁ_dp, dṁ_dv, dṁ_dm]), + ] + ) + + # ============ + # Compute b_sc + # ============ + + # Compute the output of the non-linear model at the initial state. + x0 = jnp.hstack([p0, v0, m0]) + f0, ṁ0 = compute_contact_force_non_linear_model( + position=p0, velocity=v0, tangential_deformation=m0 + ) + + # Pack the b vector. + b_sc = jnp.hstack([f0, ṁ0]) - A_sc @ x0 + + return A_sc, b_sc + + @staticmethod + @functools.partial(jax.jit, static_argnames=("terrain",)) + def _compute_contact_force_non_linear_model( + position: jtp.VectorLike, + velocity: jtp.VectorLike, + tangential_deformation: jtp.VectorLike, + parameters: ViscoElasticContactsParams, + terrain: Terrain, + ) -> tuple[jtp.Vector, jtp.Vector]: + """ + Compute the contact forces using the non-linear Hunt/Crossley model. + + Args: + position: The position of the contact point. + velocity: The velocity of the contact point. + tangential_deformation: The tangential deformation of the contact point. + parameters: The parameters of the contact model. + terrain: The considered terrain. + + Returns: + A tuple containing: + - The linear contact force in the mixed contact frame. + - The rate of material deformation. + """ + + # Compute the linear contact force in mixed representation using + # the non-linear Hunt/Crossley model. + # The following function also returns the rate of material deformation. + CW_fl, ṁ = SoftContacts.hunt_crossley_contact_model( + position=position, + velocity=velocity, + tangential_deformation=tangential_deformation, + terrain=terrain, + K=parameters.K, + D=parameters.D, + mu=parameters.static_friction, + p=parameters.p, + q=parameters.q, + ) + + return CW_fl, ṁ + + @staticmethod + @jax.jit + def integrate_data_with_average_contact_forces( + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, + *, + dt: jtp.FloatLike, + link_forces: jtp.MatrixLike | None = None, + joint_force_references: jtp.VectorLike | None = None, + average_link_contact_forces_inertial: jtp.MatrixLike | None = None, + average_of_average_link_contact_forces_mixed: jtp.MatrixLike | None = None, + ) -> js.data.JaxSimModelData: + """ + Advance the system state by integrating the dynamics. + + Args: + model: The model to consider. + data: The data of the considered model. + dt: The integration time step. + link_forces: + The 6D forces to apply to the links expressed in the frame corresponding + to the velocity representation of `data`. + joint_force_references: The joint force references to apply. + average_link_contact_forces_inertial: + The average contact forces computed with the exponential integrator and + expressed in the inertial-fixed frame. + average_of_average_link_contact_forces_mixed: + The average of the average contact forces computed with the exponential + integrator and expressed in the mixed frame. + + Returns: + The data object storing the system state at the final time. + """ + + s_t0 = data.joint_positions() + W_p_B_t0 = data.base_position() + W_Q_B_t0 = data.base_orientation(dcm=False) + + ṡ_t0 = data.joint_velocities() + with data.switch_velocity_representation(jaxsim.VelRepr.Mixed): + W_ṗ_B_t0 = data.base_velocity()[0:3] + W_ω_WB_t0 = data.base_velocity()[3:6] + + with data.switch_velocity_representation(jaxsim.VelRepr.Inertial): + W_ν_t0 = data.generalized_velocity() + + # We expect that the 6D forces of the `link_forces` argument are expressed + # in the frame corresponding to the velocity representation of `data`. + references = js.references.JaxSimModelReferences.build( + model=model, + link_forces=link_forces, + joint_force_references=joint_force_references, + data=data, + velocity_representation=data.velocity_representation, + ) + + W_f̅_L = ( + jnp.array(average_link_contact_forces_inertial) + if average_link_contact_forces_inertial is not None + else jnp.zeros_like(references._link_forces) + ).astype(float) + + LW_f̿_L = ( + jnp.array(average_of_average_link_contact_forces_mixed) + if average_of_average_link_contact_forces_mixed is not None + else W_f̅_L + ).astype(float) + + # Compute the system inertial acceleration, used to integrate the system velocity. + # It considers the average contact forces computed with the exponential integrator. + with ( + data.switch_velocity_representation(jaxsim.VelRepr.Inertial), + references.switch_velocity_representation(jaxsim.VelRepr.Inertial), + ): + + W_ν̇_pr = jnp.hstack( + js.ode.system_acceleration( + model=model, + data=data, + joint_force_references=references.joint_force_references( + model=model + ), + link_forces=W_f̅_L + references.link_forces(model=model, data=data), + ) + ) + + # Compute the system mixed acceleration, used to integrate the system position. + # It considers the average of the average contact forces computed with the + # exponential integrator. + with ( + data.switch_velocity_representation(jaxsim.VelRepr.Mixed), + references.switch_velocity_representation(jaxsim.VelRepr.Mixed), + ): + + BW_ν̇_pr2 = jnp.hstack( + js.ode.system_acceleration( + model=model, + data=data, + joint_force_references=references.joint_force_references( + model=model + ), + link_forces=LW_f̿_L + references.link_forces(model=model, data=data), + ) + ) + + # Integrate the system velocity using the inertial-fixed acceleration. + W_ν_plus = W_ν_t0 + dt * W_ν̇_pr + + # Integrate the system position using the mixed velocity. + q_plus = jnp.hstack( + [ + # Note: here both ṗ and p̈ -> need mixed representation. + W_p_B_t0 + dt * W_ṗ_B_t0 + 0.5 * dt**2 * BW_ν̇_pr2[0:3], + jaxsim.math.Quaternion.integration( + dt=dt, + quaternion=W_Q_B_t0, + omega=(W_ω_WB_t0 + 0.5 * dt * BW_ν̇_pr2[3:6]), + omega_in_body_fixed=False, + ).squeeze(), + s_t0 + dt * ṡ_t0 + 0.5 * dt**2 * BW_ν̇_pr2[6:], + ] + ) + + # Create the data at the final time. + data_tf = data.copy() + data_tf = data_tf.reset_joint_positions(q_plus[7:]) + data_tf = data_tf.reset_base_position(q_plus[0:3]) + data_tf = data_tf.reset_base_quaternion(q_plus[3:7]) + data_tf = data_tf.reset_joint_velocities(W_ν_plus[6:]) + data_tf = data_tf.reset_base_velocity( + W_ν_plus[0:6], velocity_representation=jaxsim.VelRepr.Inertial + ) + + return data_tf.replace( + velocity_representation=data.velocity_representation, validate=False + ) + + +@jax.jit +def step( + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, + *, + dt: jtp.FloatLike | None = None, + link_forces: jtp.MatrixLike | None = None, + joint_force_references: jtp.VectorLike | None = None, +) -> tuple[js.data.JaxSimModelData, dict[str, Any]]: + """ + Step the system dynamics with the visco-elastic contact model. + + Args: + model: The model to consider. + data: The data of the considered model. + dt: The time step to consider. If not specified, it is read from the model. + link_forces: + The 6D forces to apply to the links expressed in the frame corresponding to + the velocity representation of `data`. + joint_force_references: The joint force references to consider. + + Returns: + A tuple containing the new data of the model + and an empty dictionary of auxiliary data. + """ + + assert isinstance(model.contact_model, ViscoElasticContacts) + assert isinstance(data.contacts_params, ViscoElasticContactsParams) + + # Compute the contact forces in inertial-fixed representation. + # TODO: understand what's wrong in other representations. + data_inertial_fixed = data.replace( + velocity_representation=jaxsim.VelRepr.Inertial, validate=False + ) + + # Create the references object. + references = js.references.JaxSimModelReferences.build( + model=model, + data=data, + link_forces=link_forces, + joint_force_references=joint_force_references, + velocity_representation=data.velocity_representation, + ) + + # Initialize the time step. + dt = dt if dt is not None else model.time_step + + # Compute the contact forces with the exponential integrator. + W_f̅_C, aux_data = model.contact_model.compute_contact_forces( + model=model, + data=data_inertial_fixed, + dt=jnp.array(dt).astype(float), + link_forces=references.link_forces(model=model, data=data), + joint_force_references=references.joint_force_references(model=model), + ) + + # Extract the final material deformation and the average of average forces + # from the dictionary containing auxiliary data. + m_tf = aux_data["m_tf"] + W_f̿_C = aux_data["W_f_avg2_C"] + + # =============================== + # Compute the link contact forces + # =============================== + + # Get the link contact forces by summing the forces of contact points belonging + # to the same link. + W_f̅_L, W_f̿_L = jax.vmap( + lambda W_f_C: model.contact_model.link_forces_from_contact_forces( + model=model, data=data_inertial_fixed, contact_forces=W_f_C + ) + )(jnp.stack([W_f̅_C, W_f̿_C])) + + # Compute the link transforms. + W_H_L = ( + js.model.forward_kinematics(model=model, data=data) + if data.velocity_representation is not jaxsim.VelRepr.Inertial + else jnp.zeros(shape=(model.number_of_links(), 4, 4)) + ) + + # For integration purpose, we need the average of average forces expressed in + # mixed representation. + LW_f̿_L = jax.vmap( + lambda W_f_L, W_H_L: ( + ModelDataWithVelocityRepresentation.inertial_to_other_representation( + array=W_f_L, + other_representation=jaxsim.VelRepr.Mixed, + transform=W_H_L, + is_force=True, + ) + ) + )(W_f̿_L, W_H_L) + + # ========================== + # Integrate the system state + # ========================== + + # Integrate the system dynamics using the average contact forces. + data_tf: js.data.JaxSimModelData = ( + model.contact_model.integrate_data_with_average_contact_forces( + model=model, + data=data_inertial_fixed, + dt=dt, + link_forces=references.link_forces(model=model, data=data), + joint_force_references=references.joint_force_references(model=model), + average_link_contact_forces_inertial=W_f̅_L, + average_of_average_link_contact_forces_mixed=LW_f̿_L, + ) + ) + + # Store the tangential deformation at the final state. + # Note that this was integrated in the continuous time domain, therefore it should + # be much more accurate than the one computed with the discrete soft contacts. + with data_tf.mutable_context(): + + # Extract the indices corresponding to the enabled collidable points. + # The visco-elastic contact model computed only their contact forces. + indices_of_enabled_collidable_points = ( + model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points + ) + + data_tf.state.extended |= { + "tangential_deformation": data_tf.state.extended["tangential_deformation"] + .at[indices_of_enabled_collidable_points] + .set(m_tf) + } + + # Restore the original velocity representation. + data_tf = data_tf.replace( + velocity_representation=data.velocity_representation, validate=False + ) + + return data_tf, {} diff --git a/tests/test_automatic_differentiation.py b/tests/test_automatic_differentiation.py index d3153f2ed..83418e20a 100644 --- a/tests/test_automatic_differentiation.py +++ b/tests/test_automatic_differentiation.py @@ -9,6 +9,7 @@ import jaxsim.rbda import jaxsim.typing as jtp from jaxsim import VelRepr +from jaxsim.rbda.contacts import SoftContacts, SoftContactsParams # All JaxSim algorithms, excluding the variable-step integrators, should support # being automatically differentiated until second order, both in FWD and REV modes. @@ -288,6 +289,55 @@ def test_ad_jacobian( ) +def test_ad_soft_contacts( + jaxsim_models_types: js.model.JaxSimModel, + prng_key: jax.Array, +): + + model = jaxsim_models_types + + _, subkey1, subkey2, subkey3 = jax.random.split(prng_key, num=4) + p = jax.random.uniform(subkey1, shape=(3,), minval=-1) + v = jax.random.uniform(subkey2, shape=(3,), minval=-1) + m = jax.random.uniform(subkey3, shape=(3,), minval=-1) + + # Get the soft contacts parameters. + parameters = js.contact.estimate_good_contact_parameters(model=model) + + # ==== + # Test + # ==== + + # Get a closure exposing only the parameters to be differentiated. + def close_over_inputs_and_parameters( + p: jtp.VectorLike, + v: jtp.VectorLike, + m: jtp.VectorLike, + params: SoftContactsParams, + ) -> tuple[jtp.Vector, jtp.Vector]: + + W_f_Ci, CW_ṁ = SoftContacts.compute_contact_force( + position=p, + velocity=v, + tangential_deformation=m, + parameters=params, + terrain=model.terrain, + ) + + return W_f_Ci, CW_ṁ + + # Check derivatives against finite differences. + check_grads( + f=close_over_inputs_and_parameters, + args=(p, v, m, parameters), + order=AD_ORDER, + modes=["rev", "fwd"], + eps=ε, + # On GPU, the tolerance needs to be increased. + rtol=0.02 if "gpu" in {d.platform for d in p.devices()} else None, + ) + + def test_ad_integration( jaxsim_models_types: js.model.JaxSimModel, prng_key: jax.Array, @@ -305,7 +355,8 @@ def test_ad_integration( W_Q_B = data.base_orientation(dcm=False) s = data.joint_positions W_v_WB = data.base_velocity() - ṡ = data.joint_velocities + ṡ = data.joint_velocities(model=model) + m = data.extended_state["tangential_deformation"] # Inputs. W_f_L = references.link_forces(model=model) @@ -322,9 +373,10 @@ def step( s: jtp.Vector, W_v_WB: jtp.Vector, ṡ: jtp.Vector, + m: jtp.Vector, τ: jtp.Vector, W_f_L: jtp.Matrix, - ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: + ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: # When JAX tests against finite differences, the injected ε will make the # quaternion non-unitary, which will cause the AD check to fail. @@ -337,7 +389,9 @@ def step( base_linear_velocity=W_v_WB[0:3], base_angular_velocity=W_v_WB[3:6], joint_velocities=ṡ, + extended_state={"tangential_deformation": m}, ) + data.update_cached(model) data_xf = js.model.step( @@ -352,15 +406,16 @@ def step( xf_s = data_xf.joint_positions xf_W_v_WB = data_xf.base_velocity() xf_ṡ = data_xf.joint_velocities + xf_m = data_xf.extended_state["tangential_deformation"] - return xf_W_p_B, xf_W_Q_B, xf_s, xf_W_v_WB, xf_ṡ + return xf_W_p_B, xf_W_Q_B, xf_s, xf_W_v_WB, xf_ṡ, xf_m # Check derivatives against finite differences. # We set forward mode only because the backward mode is not supported by the # current implementation of `optax` optimizers in the relaxed rigid contact model. check_grads( f=step, - args=(W_p_B, W_Q_B, s, W_v_WB, ṡ, τ, W_f_L), + args=(W_p_B, W_Q_B, s, W_v_WB, ṡ, m, τ, W_f_L), order=AD_ORDER, modes=["fwd"], eps=ε,