From 9b82920f91803f2e2a18c29ca83597151fc66b18 Mon Sep 17 00:00:00 2001 From: Alessandro Croci Date: Fri, 7 Mar 2025 15:41:01 +0100 Subject: [PATCH] WIP --- src/jaxsim/api/contact_model.py | 44 ++++-- src/jaxsim/api/kin_dyn_parameters.py | 110 +++++++++++-- src/jaxsim/api/model.py | 10 +- src/jaxsim/rbda/contacts/relaxed_rigid.py | 181 +++++++++++++--------- 4 files changed, 248 insertions(+), 97 deletions(-) diff --git a/src/jaxsim/api/contact_model.py b/src/jaxsim/api/contact_model.py index 98ff34187..e4a3dc5c3 100644 --- a/src/jaxsim/api/contact_model.py +++ b/src/jaxsim/api/contact_model.py @@ -44,18 +44,38 @@ def link_contact_forces( # to the frames associated to the collidable points. W_f_L = link_forces_from_contact_forces(model=model, contact_forces=W_f_C) - # Add the forces coming from the kinematic constraints to the link on which the constraint is applied. - - # W_f_loop = aux_data["kin_constr_force"] - W_f_loop_F1 = aux_data["kin_constr_force_F1"] - W_f_loop_f2 = aux_data["kin_constr_force_F2"] - F1_idx = aux_data["F1_idx"] - F2_idx = aux_data["F2_idx"] - F1_parent_idx = js.frame.idx_of_parent_link(model, frame_index=F1_idx) - F2_parent_idx = js.frame.idx_of_parent_link(model, frame_index=F2_idx) - - W_f_L = W_f_L.at[F1_parent_idx].add(W_f_loop_F1) - W_f_L = W_f_L.at[F2_parent_idx].add(W_f_loop_f2) + wrench_pair_constr_inertial = aux_data["constr_wrenches_inertial"] + + constraints = model.kin_dyn_parameters.get_constraints(model) + # Get the couples of parent link indices of each couple of frames. + frame_idxs_1, frame_idxs_2, types = zip(*constraints, strict=False) + frame_idxs_1 = jnp.array(frame_idxs_1) + frame_idxs_2 = jnp.array(frame_idxs_2) + + jax.debug.print("frame_idxs_1: \n{}", frame_idxs_1) + jax.debug.print("frame_idxs_2: \n{}", frame_idxs_2) + + parent_link_indices = jax.vmap( + lambda frame_idx_1, frame_idx_2: ( + js.frame.idx_of_parent_link(model, frame_index=frame_idx_1), + js.frame.idx_of_parent_link(model, frame_index=frame_idx_2), + ) + )(frame_idxs_1, frame_idxs_2) + parent_link_indices = jnp.array(parent_link_indices) + jax.debug.print("parent_link_indices: \n{}", parent_link_indices.shape) + + # Apply each constraint wrench to its corresponding parent link in W_f_L. + def apply_wrench(i, W_f_L): + parent_indices = parent_link_indices[:, i] + wrench_pair = wrench_pair_constr_inertial[:, i] + jax.debug.print("parent_indices: \n{}", parent_indices) + jax.debug.print("wrench_pair: \n{}", wrench_pair) + W_f_L = W_f_L.at[parent_indices[0]].add(wrench_pair[0]) + W_f_L = W_f_L.at[parent_indices[1]].add(wrench_pair[1]) + return W_f_L + + W_f_L = jax.lax.fori_loop(0, parent_link_indices.shape[0], apply_wrench, W_f_L) + jax.debug.print("W_f_L: \n{}", W_f_L) return W_f_L diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index 93fe7d6ff..c1de4bc75 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -1,6 +1,7 @@ from __future__ import annotations import dataclasses +import enum import jax.lax import jax.numpy as jnp @@ -9,6 +10,7 @@ import numpy.typing as npt from jax_dataclasses import Static +import jaxsim.api as js import jaxsim.typing as jtp from jaxsim.math import Adjoint, Inertia, JointModel, supported_joint_motion from jaxsim.parsers.descriptions import JointDescription, JointType, ModelDescription @@ -51,6 +53,8 @@ class KinDynParameters(JaxsimDataclass): joint_model: JointModel joint_parameters: JointParameters | None + constraints: Static[ConstraintMap] + @property def motion_subspaces(self) -> jtp.Matrix: r""" @@ -73,12 +77,15 @@ def support_body_array_bool(self) -> jtp.Matrix: return self._support_body_array_bool.get() @staticmethod - def build(model_description: ModelDescription) -> KinDynParameters: + def build( + model_description: ModelDescription, constraints: ConstraintMap | None + ) -> KinDynParameters: """ Construct the kinematic and dynamic parameters of the model. Args: model_description: The parsed model description to consider. + constraints: An object of type ConstraintMap specifying the kinematic constraint of the model. Returns: The kinematic and dynamic parameters of the model. @@ -248,6 +255,12 @@ def motion_subspace(joint_type: int, axis: npt.ArrayLike) -> npt.ArrayLike: motion_subspaces = jnp.vstack([jnp.zeros((6, 1))[jnp.newaxis, ...], S_J]) + # =========== + # Constraints + # =========== + + constraints = ConstraintMap() if constraints is None else constraints + # ================================= # Build and return KinDynParameters # ================================= @@ -262,6 +275,7 @@ def motion_subspace(joint_type: int, axis: npt.ArrayLike) -> npt.ArrayLike: joint_parameters=joint_parameters, contact_parameters=contact_parameters, frame_parameters=frame_parameters, + constraints=constraints, ) def __eq__(self, other: KinDynParameters) -> bool: @@ -272,17 +286,15 @@ def __eq__(self, other: KinDynParameters) -> bool: return hash(self) == hash(other) def __hash__(self) -> int: - - return hash( - ( - hash(self.number_of_links()), - hash(self.number_of_joints()), - hash(self.frame_parameters.name), - hash(self.frame_parameters.body), - hash(self._parent_array), - hash(self._support_body_array_bool), - ) - ) + return hash(( + hash(self.number_of_links()), + hash(self.number_of_joints()), + hash(self.frame_parameters.name), + hash(self.frame_parameters.body), + hash(self._parent_array), + hash(self._support_body_array_bool), + hash(self.constraints), + )) # ============================= # Helpers to extract parameters @@ -337,6 +349,13 @@ def support_body_array(self, link_index: jtp.IntLike) -> jtp.Vector: jnp.where(self.support_body_array_bool[link_index])[0], dtype=int ) + def get_constraints(self, model:js.model.JaxSimModel) -> tuple[tuple[str, str, ConstraintType], ...]: + r""" + Return the constraints of the model. + """ + + return self.constraints.get_constraints(model) + # ======================== # Quantities used by RBDAs # ======================== @@ -882,3 +901,70 @@ def build_from(model_description: ModelDescription) -> FrameParameters: assert fp.transform.shape[0] == len(fp.body), fp.transform.shape[0] return fp + + +@enum.unique +class ConstraintType(enum.IntEnum): + """ + Enumeration of all supported constraint types. + """ + + Weld = enum.auto() + Connect = enum.auto() + + +@jax_dataclasses.pytree_dataclass +class ConstraintMap(JaxsimDataclass): + """ + Class storing the kinematic constraints of a model. + """ + + frame_names_1: Static[tuple[str, ...]] = dataclasses.field(default_factory=tuple) + frame_names_2: Static[tuple[str, ...]] = dataclasses.field(default_factory=tuple) + constraint_types: Static[tuple[ConstraintType, ...]] = dataclasses.field( + default_factory=tuple + ) + + def add_constraint( + self, frame_name_1: str, frame_name_2: str, constraint_type: ConstraintType + ) -> ConstraintMap: + """ + Add a constraint to the constraint map. + + Args: + frame_name_1: The name of the first frame. + frame_name_2: The name of the second frame. + constraint_type: The type of constraint. + + Returns: + A new ConstraintMap instance with the added constraint. + """ + return self.replace( + frame_names_1=(*self.frame_names_1, frame_name_1), + frame_names_2=(*self.frame_names_2, frame_name_2), + constraint_types=(*self.constraint_types, constraint_type), + validate=False, + ) + + def get_constraints( + self, model: js.model.JaxSimModel + ) -> tuple[tuple[int, int, ConstraintType], ...]: + """ + Get the list of constraints. + + Returns: + A tuple, in which each element defines a kinematic constraint. + """ + return tuple( + ( + js.frame.name_to_idx(model, frame_name=frame_name_1), + js.frame.name_to_idx(model, frame_name=frame_name_2), + constraint_type, + ) + for frame_name_1, frame_name_2, constraint_type in zip( + self.frame_names_1, + self.frame_names_2, + self.constraint_types, + strict=True, + ) + ) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index f9bb48fc7..ed62cd192 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -17,6 +17,7 @@ import jaxsim.exceptions import jaxsim.terrain import jaxsim.typing as jtp +from jaxsim.api.kin_dyn_parameters import ConstraintMap from jaxsim.math import Adjoint, Cross from jaxsim.parsers.descriptions import ModelDescription from jaxsim.utils import JaxsimDataclass, Mutability, wrappers @@ -126,6 +127,7 @@ def build_from_model_description( integrator: IntegratorType | None = None, is_urdf: bool | None = None, considered_joints: Sequence[str] | None = None, + constraints: ConstraintMap | None = None, ) -> JaxSimModel: """ Build a Model object from a model description. @@ -150,6 +152,8 @@ def build_from_model_description( This is usually automatically inferred. considered_joints: The list of joints to consider. If None, all joints are considered. + constraints: + An object of type ConstraintMap containing the kinematic constraints to consider. If None, no constraints are considered. Returns: The built Model object. @@ -179,6 +183,7 @@ def build_from_model_description( contact_model=contact_model, contacts_params=contact_params, integrator=integrator, + constraints=constraints, ) # Store the origin of the model, in case downstream logic needs it. @@ -199,6 +204,7 @@ def build( contacts_params: jaxsim.rbda.contacts.ContactsParams | None = None, integrator: IntegratorType | None = None, gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY, + constraints: ConstraintMap | None = None, ) -> JaxSimModel: """ Build a Model object from an intermediate model description. @@ -220,6 +226,8 @@ def build( contacts_params: The parameters of the soft contacts. integrator: The integrator to use for the simulation. gravity: The gravity constant. + constraints: + An object of type ConstraintMap containing the kinematic constraints to consider. If None, no constraints are considered. Returns: The built Model object. @@ -265,7 +273,7 @@ def build( model = cls( model_name=model_name, kin_dyn_parameters=js.kin_dyn_parameters.KinDynParameters.build( - model_description=model_description + model_description=model_description, constraints=constraints ), time_step=time_step, terrain=terrain, diff --git a/src/jaxsim/rbda/contacts/relaxed_rigid.py b/src/jaxsim/rbda/contacts/relaxed_rigid.py index b9fa96631..45a16b720 100644 --- a/src/jaxsim/rbda/contacts/relaxed_rigid.py +++ b/src/jaxsim/rbda/contacts/relaxed_rigid.py @@ -259,7 +259,7 @@ def compute_contact_forces( model: The model to consider. data: The data of the considered model. link_forces: - Optional `(n_links, 6)` matrix of external forces acting on the links, + Optional `(n_links, 6)` matrix of external forces acting on the links,` expressed in the same representation of data. joint_force_references: Optional `(n_joints,)` vector of joint forces. @@ -268,6 +268,9 @@ def compute_contact_forces( A tuple containing as first element the computed contact forces in inertial representation. """ + K_P = 0 # 1e1 + K_D = 0 # 2 * jnp.sqrt(K_P) + link_forces = jnp.atleast_2d( jnp.array(link_forces, dtype=float).squeeze() if link_forces is not None @@ -307,6 +310,58 @@ def compute_contact_forces( # collidable points. W_H_C = js.contact.transforms(model=model, data=data) + # Retrieve the kinematic constraints + constraints = model.kin_dyn_parameters.get_constraints(model) + constraints = jnp.array(constraints) + n_kin_constraints = 6 * len(constraints) + jax.debug.print("n_kin_constraints: \n{}", n_kin_constraints) + + # TODO (xela-95): manage the case of contact constraint + def compute_constraint_jacobians(constraint): + frame_1_idx, frame_2_idx, constraint_type = constraint # noqa: F841 + + J_WF1 = js.frame.jacobian(model=model, data=data, frame_index=frame_1_idx) + J_WF2 = js.frame.jacobian(model=model, data=data, frame_index=frame_2_idx) + + return J_WF1 - J_WF2 + + def compute_constraint_jacobians_derivative(constraint): + frame_1_idx, frame_2_idx, constraint_type = constraint # noqa: F841 + + J̇_WF1 = js.frame.jacobian_derivative( + model=model, data=data, frame_index=frame_1_idx + ) + J̇_WF2 = js.frame.jacobian_derivative( + model=model, data=data, frame_index=frame_2_idx + ) + + return J̇_WF1 - J̇_WF2 + + def compute_constraint_baumgarte_term(J_constr, BW_ν, W_H_F1, W_H_F2): + W_p_F1 = W_H_F1[0:3, 3] + W_p_F2 = W_H_F2[0:3, 3] + + W_R_F1 = W_H_F1[0:3, 0:3] + W_R_F2 = W_H_F2[0:3, 0:3] + + vel_error = J_constr @ BW_ν + position_error = W_p_F1 - W_p_F2 + R_error = W_R_F2.T @ W_R_F1 + orientation_error = jaxsim.math.rotation.Rotation.log_vee(R_error) + baumgarte_term = ( + K_P * jnp.hstack([position_error, orientation_error]) + K_D * vel_error + ) + + return baumgarte_term + + def compute_constraint_transforms(constraint): + frame_1_idx, frame_2_idx, constraint_type = constraint # noqa: F841 + + W_H_F1 = js.frame.transform(model=model, data=data, frame_index=frame_1_idx) + W_H_F2 = js.frame.transform(model=model, data=data, frame_index=frame_2_idx) + + return W_H_F1, W_H_F2 + with ( data.switch_velocity_representation(VelRepr.Mixed), references.switch_velocity_representation(VelRepr.Mixed), @@ -324,43 +379,27 @@ def compute_contact_forces( M = js.model.free_floating_mass_matrix(model=model, data=data) + # Compute the linear part of the Jacobian of the collidable points Jl_WC = jnp.vstack( jax.vmap(lambda J, δ: J * (δ > 0))( js.contact.jacobian(model=model, data=data)[:, :3, :], δ ) ) - J̇_WC = jnp.vstack( + # Compute the linear part of the Jacobian derivative of the collidable points + J̇l_WC = jnp.vstack( jax.vmap(lambda J̇, δ: J̇ * (δ > 0))( js.contact.jacobian_derivative(model=model, data=data)[:, :3], δ ), ) - # Compute the Jacobians for the closed-chain kinematic constraint\ - F1_name = "BC1_frame" - F2_name = "BC2_frame" - F1_idx = js.frame.name_to_idx(model=model, frame_name=F1_name) - F2_idx = js.frame.name_to_idx(model=model, frame_name=F2_name) - - W_H_F1 = js.frame.transform(model=model, data=data, frame_index=F1_idx) - W_H_F2 = js.frame.transform(model=model, data=data, frame_index=F2_idx) - W_p_F1 = W_H_F1[:3, 3] - W_p_F2 = W_H_F2[:3, 3] - W_R_F1 = W_H_F1[:3, :3] - W_R_F2 = W_H_F2[:3, :3] - - J_WF1 = js.frame.jacobian(model=model, data=data, frame_index=F1_idx) - J_WF2 = js.frame.jacobian(model=model, data=data, frame_index=F2_idx) - J̇_WF1 = js.frame.jacobian_derivative( - model=model, data=data, frame_index=F1_idx - ) - J̇_WF2 = js.frame.jacobian_derivative( - model=model, data=data, frame_index=F2_idx + J_constr = jnp.vstack(jax.vmap(compute_constraint_jacobians)(constraints)) + J̇_constr = jnp.vstack( + jax.vmap(compute_constraint_jacobians_derivative)(constraints) ) - # Add the kinematic constraint terms to the Jacobians - Jl_WC = jnp.vstack([Jl_WC, J_WF1 - J_WF2]) - J̇_WC = jnp.vstack([J̇_WC, J̇_WF1 - J̇_WF2]) + J = jnp.vstack([Jl_WC, J_constr]) + J̇ = jnp.vstack([J̇l_WC, J̇_constr]) # Compute the regularization terms. a_ref, R, *_ = self._regularizers( @@ -370,38 +409,27 @@ def compute_contact_forces( parameters=model.contacts_params, ) - # jax.debug.print("a_ref.shape: \n{}", a_ref.shape) - # jax.debug.print("R.shape: \n{}", R.shape) + W_H_constr = jnp.array(jax.vmap(compute_constraint_transforms)(constraints)) - dim_kin_constr = J_WF1.shape[0] - - # Compute Baumgarte stabilization terms for the kinematic constraints + constr_baumgarte_term = jnp.hstack( + jax.vmap(compute_constraint_baumgarte_term, in_axes=(None, None, 0, 0))( + J_constr, BW_ν, W_H_constr[0], W_H_constr[1] + ) + ).squeeze() - K_P = 100 # 1e1 - K_D = 2 * jnp.sqrt(K_P) - K_D = jnp.diag(jnp.array([K_D, K_D, K_D, K_D, K_D, K_D])) - vel_error = (J_WF1 - J_WF2) @ BW_ν - position_error = W_p_F1 - W_p_F2 - R_error = W_R_F2.T @ W_R_F1 - orientation_error = jaxsim.math.rotation.Rotation.log_vee(R_error) - baumgarte_term = ( - K_P * jnp.hstack([position_error, orientation_error]) + K_D @ vel_error + R_constr = jnp.pad( + R, ((0, n_kin_constraints), (0, n_kin_constraints)), mode="constant" ) - - # baumgarte_term = K_P * position_error + K_D * vel_error - - R_ext = jnp.pad(R, ((0, dim_kin_constr), (0, dim_kin_constr)), mode="constant") - a_ref_ext = jnp.hstack([a_ref, -baumgarte_term]) - # a_ref_ext = jnp.hstack([a_ref, jnp.zeros((3,))]) + a_ref_constr = jnp.hstack([a_ref, -constr_baumgarte_term]) # Compute the Delassus matrix and the free mixed linear acceleration of # the collidable points. - G = Jl_WC @ jnp.linalg.pinv(M) @ Jl_WC.T - CW_al_free_WC = Jl_WC @ BW_ν̇_free + J̇_WC @ BW_ν + G = J @ jnp.linalg.pinv(M) @ J.T + CW_al_free_WC = J @ BW_ν̇_free + J̇ @ BW_ν # Calculate quantities for the linear optimization problem. - A = G + R_ext - b = CW_al_free_WC - a_ref_ext + A = G + R_constr + b = CW_al_free_WC - a_ref_constr # Create the objective function to minimize as a lambda computing the cost # from the optimized variables x. @@ -492,7 +520,7 @@ def continuing_criterion(carry: OptimizationCarry) -> jtp.Bool: init_params = jnp.hstack([ init_params, jnp.zeros( - dim_kin_constr, + n_kin_constraints, ), ]) @@ -513,35 +541,48 @@ def continuing_criterion(carry: OptimizationCarry) -> jtp.Bool: maxiter=maxiter, ) - # Extract the last 6 values from the solution - kin_constr_force_mixed_F1 = solution[-dim_kin_constr:] - # jax.debug.print("f_loop_mixed: \n{}", kin_constr_force_mixed_linear) + # Extract the last n_kin_constr values from the solution and split them into 6D wrenches + kin_constr_wrench_mixed = solution[-n_kin_constraints:].reshape(-1, 6) - kin_constr_force_mixed_F2 = -kin_constr_force_mixed_F1 + # Form an array of tuples with each wrench and its opposite using jax constructs + kin_constr_wrench_pairs_mixed = jax.vmap(lambda wrench: (wrench, -wrench))( + kin_constr_wrench_mixed + ) + kin_constr_wrench_pairs_mixed = jnp.array(kin_constr_wrench_pairs_mixed) - # Transform the wrench in inertial representation - kin_constr_force_inertial_F1 = ( + jax.debug.print( + "kin_constr_wrench_pairs_mixed: \n{}", kin_constr_wrench_pairs_mixed.shape + ) + # Transform each wrench in the pair to inertial representation using the appropriate transform + def transform_wrench_pair_to_inertial(wrench_pair, transform): + return ( ModelDataWithVelocityRepresentation.other_representation_to_inertial( - array=kin_constr_force_mixed_F1, - transform=W_H_F1, + array=wrench_pair[0], + transform=transform[0], other_representation=VelRepr.Mixed, is_force=True, - ) - ) - - kin_constr_force_inertial_F2 = ( + ), ModelDataWithVelocityRepresentation.other_representation_to_inertial( - array=kin_constr_force_mixed_F2, - transform=W_H_F2, + array=wrench_pair[1], + transform=transform[1], other_representation=VelRepr.Mixed, is_force=True, + ), ) + + kin_constr_force_pairs_inertial = jax.vmap( + transform_wrench_pair_to_inertial, + in_axes=(1, 1), + )(kin_constr_wrench_pairs_mixed, W_H_constr) + kin_constr_force_pairs_inertial = jnp.array(kin_constr_force_pairs_inertial) + + jax.debug.print( + "kin_constr_force_pairs_inertial: \n{}", + kin_constr_force_pairs_inertial.shape, ) - jax.debug.print("f_loop_inertial_F1: \n{}", kin_constr_force_inertial_F1) - jax.debug.print("f_loop_inertial_F2: \n{}", kin_constr_force_inertial_F2) # Reshape the optimized solution to be a matrix of 3D contact forces. - CW_fl_C = solution[0:-dim_kin_constr].reshape(-1, 3) + CW_fl_C = solution[0:-n_kin_constraints].reshape(-1, 3) # Convert the contact forces from mixed to inertial-fixed representation. W_f_C = jax.vmap( @@ -556,11 +597,7 @@ def continuing_criterion(carry: OptimizationCarry) -> jtp.Bool: )(CW_fl_C, W_H_C) return W_f_C, { - "kin_constr_force_F1": kin_constr_force_inertial_F1, - "kin_constr_force_F2": kin_constr_force_inertial_F2, - # "kin_constr_force": kin_constr_force_mixed, - "F1_idx": F1_idx, - "F2_idx": F2_idx, + "constr_wrenches_inertial": kin_constr_force_pairs_inertial, } @staticmethod