From fe7a382d80c1fe2a202b9335aff64afe9249f74d Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Fri, 7 Mar 2025 16:05:56 +0000 Subject: [PATCH] WIP Refactor --- src/jaxsim/api/contact_model.py | 46 ++++++---- src/jaxsim/api/kin_dyn_parameters.py | 45 +++++----- src/jaxsim/rbda/contacts/relaxed_rigid.py | 101 ++++++++++++---------- 3 files changed, 105 insertions(+), 87 deletions(-) diff --git a/src/jaxsim/api/contact_model.py b/src/jaxsim/api/contact_model.py index e4a3dc5c3..64aa53fdd 100644 --- a/src/jaxsim/api/contact_model.py +++ b/src/jaxsim/api/contact_model.py @@ -42,39 +42,47 @@ def link_contact_forces( # Compute the 6D forces applied to the links equivalent to the forces applied # to the frames associated to the collidable points. - W_f_L = link_forces_from_contact_forces(model=model, contact_forces=W_f_C) + W_f_L_contact = link_forces_from_contact_forces(model=model, contact_forces=W_f_C) wrench_pair_constr_inertial = aux_data["constr_wrenches_inertial"] - constraints = model.kin_dyn_parameters.get_constraints(model) # Get the couples of parent link indices of each couple of frames. - frame_idxs_1, frame_idxs_2, types = zip(*constraints, strict=False) - frame_idxs_1 = jnp.array(frame_idxs_1) - frame_idxs_2 = jnp.array(frame_idxs_2) + frame_idxs_1, frame_idxs_2 = model.kin_dyn_parameters.get_constraints(model).T jax.debug.print("frame_idxs_1: \n{}", frame_idxs_1) jax.debug.print("frame_idxs_2: \n{}", frame_idxs_2) parent_link_indices = jax.vmap( - lambda frame_idx_1, frame_idx_2: ( - js.frame.idx_of_parent_link(model, frame_index=frame_idx_1), - js.frame.idx_of_parent_link(model, frame_index=frame_idx_2), + lambda frame_idx_1, frame_idx_2: jnp.array( + ( + js.frame.idx_of_parent_link(model, frame_index=frame_idx_1), + js.frame.idx_of_parent_link(model, frame_index=frame_idx_2), + ) ) )(frame_idxs_1, frame_idxs_2) - parent_link_indices = jnp.array(parent_link_indices) jax.debug.print("parent_link_indices: \n{}", parent_link_indices.shape) # Apply each constraint wrench to its corresponding parent link in W_f_L. - def apply_wrench(i, W_f_L): - parent_indices = parent_link_indices[:, i] - wrench_pair = wrench_pair_constr_inertial[:, i] - jax.debug.print("parent_indices: \n{}", parent_indices) - jax.debug.print("wrench_pair: \n{}", wrench_pair) - W_f_L = W_f_L.at[parent_indices[0]].add(wrench_pair[0]) - W_f_L = W_f_L.at[parent_indices[1]].add(wrench_pair[1]) - return W_f_L - - W_f_L = jax.lax.fori_loop(0, parent_link_indices.shape[0], apply_wrench, W_f_L) + # def apply_wrench(i, W_f_L): + # parent_indices = parent_link_indices[:, i] + # wrench_pair = wrench_pair_constr_inertial[:, i] + # jax.debug.print("parent_indices: \n{}", parent_indices) + # jax.debug.print("wrench_pair: \n{}", wrench_pair) + # W_f_L = W_f_L.at[parent_indices[0]].add(wrench_pair[0]) + # W_f_L = W_f_L.at[parent_indices[1]].add(wrench_pair[1]) + # return W_f_L + + mask = jax.vmap( + lambda parent_link_idxs_couple: parent_link_idxs_couple[:, None] + == jnp.arange(model.number_of_links()) + )(parent_link_indices) + + # b = Number of constraint, k = 2 (Constraint couple), j = Number of links, i = 6 + W_f_L_constr = jnp.einsum("bkj,bki->bi", mask, wrench_pair_constr_inertial) + + # W_f_L = jax.lax.fori_loop(0, parent_link_indices.shape[0], apply_wrench, W_f_L) + + W_f_L = W_f_L_contact + W_f_L_constr jax.debug.print("W_f_L: \n{}", W_f_L) diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index c1de4bc75..64fb3f153 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -286,15 +286,17 @@ def __eq__(self, other: KinDynParameters) -> bool: return hash(self) == hash(other) def __hash__(self) -> int: - return hash(( - hash(self.number_of_links()), - hash(self.number_of_joints()), - hash(self.frame_parameters.name), - hash(self.frame_parameters.body), - hash(self._parent_array), - hash(self._support_body_array_bool), - hash(self.constraints), - )) + return hash( + ( + hash(self.number_of_links()), + hash(self.number_of_joints()), + hash(self.frame_parameters.name), + hash(self.frame_parameters.body), + hash(self._parent_array), + hash(self._support_body_array_bool), + hash(self.constraints), + ) + ) # ============================= # Helpers to extract parameters @@ -349,7 +351,9 @@ def support_body_array(self, link_index: jtp.IntLike) -> jtp.Vector: jnp.where(self.support_body_array_bool[link_index])[0], dtype=int ) - def get_constraints(self, model:js.model.JaxSimModel) -> tuple[tuple[str, str, ConstraintType], ...]: + def get_constraints( + self, model: js.model.JaxSimModel + ) -> tuple[tuple[str, str, ConstraintType], ...]: r""" Return the constraints of the model. """ @@ -955,16 +959,15 @@ def get_constraints( Returns: A tuple, in which each element defines a kinematic constraint. """ - return tuple( + return jnp.array( ( - js.frame.name_to_idx(model, frame_name=frame_name_1), - js.frame.name_to_idx(model, frame_name=frame_name_2), - constraint_type, + jax.tree.map( + lambda f1: js.frame.name_to_idx(model, frame_name=f1), + self.frame_names_1, + ), + jax.tree.map( + lambda f1: js.frame.name_to_idx(model, frame_name=f1), + self.frame_names_2, + ), ) - for frame_name_1, frame_name_2, constraint_type in zip( - self.frame_names_1, - self.frame_names_2, - self.constraint_types, - strict=True, - ) - ) + ).T diff --git a/src/jaxsim/rbda/contacts/relaxed_rigid.py b/src/jaxsim/rbda/contacts/relaxed_rigid.py index 45a16b720..a63522cab 100644 --- a/src/jaxsim/rbda/contacts/relaxed_rigid.py +++ b/src/jaxsim/rbda/contacts/relaxed_rigid.py @@ -311,22 +311,21 @@ def compute_contact_forces( W_H_C = js.contact.transforms(model=model, data=data) # Retrieve the kinematic constraints - constraints = model.kin_dyn_parameters.get_constraints(model) - constraints = jnp.array(constraints) - n_kin_constraints = 6 * len(constraints) + idxs = model.kin_dyn_parameters.get_constraints(model) + n_kin_constraints = 6 * len(idxs) jax.debug.print("n_kin_constraints: \n{}", n_kin_constraints) # TODO (xela-95): manage the case of contact constraint - def compute_constraint_jacobians(constraint): - frame_1_idx, frame_2_idx, constraint_type = constraint # noqa: F841 + def compute_constraint_jacobians(data, constraint): + frame_1_idx, frame_2_idx = constraint J_WF1 = js.frame.jacobian(model=model, data=data, frame_index=frame_1_idx) J_WF2 = js.frame.jacobian(model=model, data=data, frame_index=frame_2_idx) return J_WF1 - J_WF2 - def compute_constraint_jacobians_derivative(constraint): - frame_1_idx, frame_2_idx, constraint_type = constraint # noqa: F841 + def compute_constraint_jacobians_derivative(data, constraint): + frame_1_idx, frame_2_idx = constraint J̇_WF1 = js.frame.jacobian_derivative( model=model, data=data, frame_index=frame_1_idx @@ -337,7 +336,9 @@ def compute_constraint_jacobians_derivative(constraint): return J̇_WF1 - J̇_WF2 - def compute_constraint_baumgarte_term(J_constr, BW_ν, W_H_F1, W_H_F2): + def compute_constraint_baumgarte_term(data, J_constr, BW_ν, W_H_F): + W_H_F1, W_H_F2 = W_H_F + W_p_F1 = W_H_F1[0:3, 3] W_p_F2 = W_H_F2[0:3, 3] @@ -354,13 +355,13 @@ def compute_constraint_baumgarte_term(J_constr, BW_ν, W_H_F1, W_H_F2): return baumgarte_term - def compute_constraint_transforms(constraint): - frame_1_idx, frame_2_idx, constraint_type = constraint # noqa: F841 + def compute_constraint_transforms(data, constraint): + frame_1_idx, frame_2_idx = constraint W_H_F1 = js.frame.transform(model=model, data=data, frame_index=frame_1_idx) W_H_F2 = js.frame.transform(model=model, data=data, frame_index=frame_2_idx) - return W_H_F1, W_H_F2 + return jnp.array((W_H_F1, W_H_F2)) with ( data.switch_velocity_representation(VelRepr.Mixed), @@ -393,33 +394,37 @@ def compute_constraint_transforms(constraint): ), ) - J_constr = jnp.vstack(jax.vmap(compute_constraint_jacobians)(constraints)) + J_constr = jnp.vstack( + jax.vmap(compute_constraint_jacobians, in_axes=(None, 0))(data, idxs) + ) J̇_constr = jnp.vstack( - jax.vmap(compute_constraint_jacobians_derivative)(constraints) + jax.vmap(compute_constraint_jacobians_derivative, in_axes=(None, 0))( + data, idxs + ) ) J = jnp.vstack([Jl_WC, J_constr]) J̇ = jnp.vstack([J̇l_WC, J̇_constr]) # Compute the regularization terms. - a_ref, R, *_ = self._regularizers( + a_ref, r, *_ = self._regularizers( model=model, position_constraint=position_constraint, velocity_constraint=velocity, parameters=model.contacts_params, ) - W_H_constr = jnp.array(jax.vmap(compute_constraint_transforms)(constraints)) + W_H_constr = jax.vmap(compute_constraint_transforms, in_axes=(None, 0))( + data, idxs + ) constr_baumgarte_term = jnp.hstack( - jax.vmap(compute_constraint_baumgarte_term, in_axes=(None, None, 0, 0))( - J_constr, BW_ν, W_H_constr[0], W_H_constr[1] - ) + jax.vmap(compute_constraint_baumgarte_term, in_axes=(None, None, None))( + data, J_constr, BW_ν, W_H_F=W_H_constr + ), ).squeeze() - R_constr = jnp.pad( - R, ((0, n_kin_constraints), (0, n_kin_constraints)), mode="constant" - ) + R_constr = jnp.diag(jnp.hstack([r, jnp.zeros(n_kin_constraints)])) a_ref_constr = jnp.hstack([a_ref, -constr_baumgarte_term]) # Compute the Delassus matrix and the free mixed linear acceleration of @@ -517,12 +522,14 @@ def continuing_criterion(carry: OptimizationCarry) -> jtp.Bool: )[0] )(position, velocity).flatten() - init_params = jnp.hstack([ - init_params, - jnp.zeros( - n_kin_constraints, - ), - ]) + init_params = jnp.hstack( + [ + init_params, + jnp.zeros( + n_kin_constraints, + ), + ] + ) # Get the solver options. solver_options = self.solver_options @@ -545,36 +552,36 @@ def continuing_criterion(carry: OptimizationCarry) -> jtp.Bool: kin_constr_wrench_mixed = solution[-n_kin_constraints:].reshape(-1, 6) # Form an array of tuples with each wrench and its opposite using jax constructs - kin_constr_wrench_pairs_mixed = jax.vmap(lambda wrench: (wrench, -wrench))( - kin_constr_wrench_mixed - ) - kin_constr_wrench_pairs_mixed = jnp.array(kin_constr_wrench_pairs_mixed) + kin_constr_wrench_pairs_mixed = jax.vmap( + lambda wrench: jnp.array((wrench, -wrench)) + )(kin_constr_wrench_mixed) jax.debug.print( "kin_constr_wrench_pairs_mixed: \n{}", kin_constr_wrench_pairs_mixed.shape ) + # Transform each wrench in the pair to inertial representation using the appropriate transform def transform_wrench_pair_to_inertial(wrench_pair, transform): - return ( - ModelDataWithVelocityRepresentation.other_representation_to_inertial( - array=wrench_pair[0], - transform=transform[0], - other_representation=VelRepr.Mixed, - is_force=True, - ), - ModelDataWithVelocityRepresentation.other_representation_to_inertial( - array=wrench_pair[1], - transform=transform[1], - other_representation=VelRepr.Mixed, - is_force=True, - ), + return jnp.array( + ( + ModelDataWithVelocityRepresentation.other_representation_to_inertial( + array=wrench_pair[0], + transform=transform[0], + other_representation=VelRepr.Mixed, + is_force=True, + ), + ModelDataWithVelocityRepresentation.other_representation_to_inertial( + array=wrench_pair[1], + transform=transform[1], + other_representation=VelRepr.Mixed, + is_force=True, + ), + ) ) kin_constr_force_pairs_inertial = jax.vmap( transform_wrench_pair_to_inertial, - in_axes=(1, 1), )(kin_constr_wrench_pairs_mixed, W_H_constr) - kin_constr_force_pairs_inertial = jnp.array(kin_constr_force_pairs_inertial) jax.debug.print( "kin_constr_force_pairs_inertial: \n{}", @@ -728,7 +735,7 @@ def compute_row( ), ) - return a_ref, jnp.diag(R), K, D + return a_ref, R, K, D @staticmethod @functools.partial(jax.jit, static_argnames=("terrain",))