Skip to content

Commit

Permalink
WIP Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Mar 7, 2025
1 parent 9b82920 commit fe7a382
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 87 deletions.
46 changes: 27 additions & 19 deletions src/jaxsim/api/contact_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,39 +42,47 @@ def link_contact_forces(

# Compute the 6D forces applied to the links equivalent to the forces applied
# to the frames associated to the collidable points.
W_f_L = link_forces_from_contact_forces(model=model, contact_forces=W_f_C)
W_f_L_contact = link_forces_from_contact_forces(model=model, contact_forces=W_f_C)

wrench_pair_constr_inertial = aux_data["constr_wrenches_inertial"]

constraints = model.kin_dyn_parameters.get_constraints(model)
# Get the couples of parent link indices of each couple of frames.
frame_idxs_1, frame_idxs_2, types = zip(*constraints, strict=False)
frame_idxs_1 = jnp.array(frame_idxs_1)
frame_idxs_2 = jnp.array(frame_idxs_2)
frame_idxs_1, frame_idxs_2 = model.kin_dyn_parameters.get_constraints(model).T

jax.debug.print("frame_idxs_1: \n{}", frame_idxs_1)
jax.debug.print("frame_idxs_2: \n{}", frame_idxs_2)

parent_link_indices = jax.vmap(
lambda frame_idx_1, frame_idx_2: (
js.frame.idx_of_parent_link(model, frame_index=frame_idx_1),
js.frame.idx_of_parent_link(model, frame_index=frame_idx_2),
lambda frame_idx_1, frame_idx_2: jnp.array(
(
js.frame.idx_of_parent_link(model, frame_index=frame_idx_1),
js.frame.idx_of_parent_link(model, frame_index=frame_idx_2),
)
)
)(frame_idxs_1, frame_idxs_2)
parent_link_indices = jnp.array(parent_link_indices)
jax.debug.print("parent_link_indices: \n{}", parent_link_indices.shape)

# Apply each constraint wrench to its corresponding parent link in W_f_L.
def apply_wrench(i, W_f_L):
parent_indices = parent_link_indices[:, i]
wrench_pair = wrench_pair_constr_inertial[:, i]
jax.debug.print("parent_indices: \n{}", parent_indices)
jax.debug.print("wrench_pair: \n{}", wrench_pair)
W_f_L = W_f_L.at[parent_indices[0]].add(wrench_pair[0])
W_f_L = W_f_L.at[parent_indices[1]].add(wrench_pair[1])
return W_f_L

W_f_L = jax.lax.fori_loop(0, parent_link_indices.shape[0], apply_wrench, W_f_L)
# def apply_wrench(i, W_f_L):
# parent_indices = parent_link_indices[:, i]
# wrench_pair = wrench_pair_constr_inertial[:, i]
# jax.debug.print("parent_indices: \n{}", parent_indices)
# jax.debug.print("wrench_pair: \n{}", wrench_pair)
# W_f_L = W_f_L.at[parent_indices[0]].add(wrench_pair[0])
# W_f_L = W_f_L.at[parent_indices[1]].add(wrench_pair[1])
# return W_f_L

mask = jax.vmap(
lambda parent_link_idxs_couple: parent_link_idxs_couple[:, None]
== jnp.arange(model.number_of_links())
)(parent_link_indices)

# b = Number of constraint, k = 2 (Constraint couple), j = Number of links, i = 6
W_f_L_constr = jnp.einsum("bkj,bki->bi", mask, wrench_pair_constr_inertial)

# W_f_L = jax.lax.fori_loop(0, parent_link_indices.shape[0], apply_wrench, W_f_L)

W_f_L = W_f_L_contact + W_f_L_constr

jax.debug.print("W_f_L: \n{}", W_f_L)

Expand Down
45 changes: 24 additions & 21 deletions src/jaxsim/api/kin_dyn_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,15 +286,17 @@ def __eq__(self, other: KinDynParameters) -> bool:
return hash(self) == hash(other)

def __hash__(self) -> int:
return hash((
hash(self.number_of_links()),
hash(self.number_of_joints()),
hash(self.frame_parameters.name),
hash(self.frame_parameters.body),
hash(self._parent_array),
hash(self._support_body_array_bool),
hash(self.constraints),
))
return hash(
(
hash(self.number_of_links()),
hash(self.number_of_joints()),
hash(self.frame_parameters.name),
hash(self.frame_parameters.body),
hash(self._parent_array),
hash(self._support_body_array_bool),
hash(self.constraints),
)
)

# =============================
# Helpers to extract parameters
Expand Down Expand Up @@ -349,7 +351,9 @@ def support_body_array(self, link_index: jtp.IntLike) -> jtp.Vector:
jnp.where(self.support_body_array_bool[link_index])[0], dtype=int
)

def get_constraints(self, model:js.model.JaxSimModel) -> tuple[tuple[str, str, ConstraintType], ...]:
def get_constraints(
self, model: js.model.JaxSimModel
) -> tuple[tuple[str, str, ConstraintType], ...]:
r"""
Return the constraints of the model.
"""
Expand Down Expand Up @@ -955,16 +959,15 @@ def get_constraints(
Returns:
A tuple, in which each element defines a kinematic constraint.
"""
return tuple(
return jnp.array(
(
js.frame.name_to_idx(model, frame_name=frame_name_1),
js.frame.name_to_idx(model, frame_name=frame_name_2),
constraint_type,
jax.tree.map(
lambda f1: js.frame.name_to_idx(model, frame_name=f1),
self.frame_names_1,
),
jax.tree.map(
lambda f1: js.frame.name_to_idx(model, frame_name=f1),
self.frame_names_2,
),
)
for frame_name_1, frame_name_2, constraint_type in zip(
self.frame_names_1,
self.frame_names_2,
self.constraint_types,
strict=True,
)
)
).T
101 changes: 54 additions & 47 deletions src/jaxsim/rbda/contacts/relaxed_rigid.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,22 +311,21 @@ def compute_contact_forces(
W_H_C = js.contact.transforms(model=model, data=data)

# Retrieve the kinematic constraints
constraints = model.kin_dyn_parameters.get_constraints(model)
constraints = jnp.array(constraints)
n_kin_constraints = 6 * len(constraints)
idxs = model.kin_dyn_parameters.get_constraints(model)
n_kin_constraints = 6 * len(idxs)
jax.debug.print("n_kin_constraints: \n{}", n_kin_constraints)

# TODO (xela-95): manage the case of contact constraint
def compute_constraint_jacobians(constraint):
frame_1_idx, frame_2_idx, constraint_type = constraint # noqa: F841
def compute_constraint_jacobians(data, constraint):
frame_1_idx, frame_2_idx = constraint

J_WF1 = js.frame.jacobian(model=model, data=data, frame_index=frame_1_idx)
J_WF2 = js.frame.jacobian(model=model, data=data, frame_index=frame_2_idx)

return J_WF1 - J_WF2

def compute_constraint_jacobians_derivative(constraint):
frame_1_idx, frame_2_idx, constraint_type = constraint # noqa: F841
def compute_constraint_jacobians_derivative(data, constraint):
frame_1_idx, frame_2_idx = constraint

J̇_WF1 = js.frame.jacobian_derivative(
model=model, data=data, frame_index=frame_1_idx
Expand All @@ -337,7 +336,9 @@ def compute_constraint_jacobians_derivative(constraint):

return J̇_WF1 - J̇_WF2

def compute_constraint_baumgarte_term(J_constr, BW_ν, W_H_F1, W_H_F2):
def compute_constraint_baumgarte_term(data, J_constr, BW_ν, W_H_F):
W_H_F1, W_H_F2 = W_H_F

W_p_F1 = W_H_F1[0:3, 3]
W_p_F2 = W_H_F2[0:3, 3]

Expand All @@ -354,13 +355,13 @@ def compute_constraint_baumgarte_term(J_constr, BW_ν, W_H_F1, W_H_F2):

return baumgarte_term

def compute_constraint_transforms(constraint):
frame_1_idx, frame_2_idx, constraint_type = constraint # noqa: F841
def compute_constraint_transforms(data, constraint):
frame_1_idx, frame_2_idx = constraint

W_H_F1 = js.frame.transform(model=model, data=data, frame_index=frame_1_idx)
W_H_F2 = js.frame.transform(model=model, data=data, frame_index=frame_2_idx)

return W_H_F1, W_H_F2
return jnp.array((W_H_F1, W_H_F2))

with (
data.switch_velocity_representation(VelRepr.Mixed),
Expand Down Expand Up @@ -393,33 +394,37 @@ def compute_constraint_transforms(constraint):
),
)

J_constr = jnp.vstack(jax.vmap(compute_constraint_jacobians)(constraints))
J_constr = jnp.vstack(
jax.vmap(compute_constraint_jacobians, in_axes=(None, 0))(data, idxs)
)
J̇_constr = jnp.vstack(
jax.vmap(compute_constraint_jacobians_derivative)(constraints)
jax.vmap(compute_constraint_jacobians_derivative, in_axes=(None, 0))(
data, idxs
)
)

J = jnp.vstack([Jl_WC, J_constr])
= jnp.vstack([J̇l_WC, J̇_constr])

# Compute the regularization terms.
a_ref, R, *_ = self._regularizers(
a_ref, r, *_ = self._regularizers(
model=model,
position_constraint=position_constraint,
velocity_constraint=velocity,
parameters=model.contacts_params,
)

W_H_constr = jnp.array(jax.vmap(compute_constraint_transforms)(constraints))
W_H_constr = jax.vmap(compute_constraint_transforms, in_axes=(None, 0))(
data, idxs
)

constr_baumgarte_term = jnp.hstack(
jax.vmap(compute_constraint_baumgarte_term, in_axes=(None, None, 0, 0))(
J_constr, BW_ν, W_H_constr[0], W_H_constr[1]
)
jax.vmap(compute_constraint_baumgarte_term, in_axes=(None, None, None))(
data, J_constr, BW_ν, W_H_F=W_H_constr
),
).squeeze()

R_constr = jnp.pad(
R, ((0, n_kin_constraints), (0, n_kin_constraints)), mode="constant"
)
R_constr = jnp.diag(jnp.hstack([r, jnp.zeros(n_kin_constraints)]))
a_ref_constr = jnp.hstack([a_ref, -constr_baumgarte_term])

# Compute the Delassus matrix and the free mixed linear acceleration of
Expand Down Expand Up @@ -517,12 +522,14 @@ def continuing_criterion(carry: OptimizationCarry) -> jtp.Bool:
)[0]
)(position, velocity).flatten()

init_params = jnp.hstack([
init_params,
jnp.zeros(
n_kin_constraints,
),
])
init_params = jnp.hstack(
[
init_params,
jnp.zeros(
n_kin_constraints,
),
]
)

# Get the solver options.
solver_options = self.solver_options
Expand All @@ -545,36 +552,36 @@ def continuing_criterion(carry: OptimizationCarry) -> jtp.Bool:
kin_constr_wrench_mixed = solution[-n_kin_constraints:].reshape(-1, 6)

# Form an array of tuples with each wrench and its opposite using jax constructs
kin_constr_wrench_pairs_mixed = jax.vmap(lambda wrench: (wrench, -wrench))(
kin_constr_wrench_mixed
)
kin_constr_wrench_pairs_mixed = jnp.array(kin_constr_wrench_pairs_mixed)
kin_constr_wrench_pairs_mixed = jax.vmap(
lambda wrench: jnp.array((wrench, -wrench))
)(kin_constr_wrench_mixed)

jax.debug.print(
"kin_constr_wrench_pairs_mixed: \n{}", kin_constr_wrench_pairs_mixed.shape
)

# Transform each wrench in the pair to inertial representation using the appropriate transform
def transform_wrench_pair_to_inertial(wrench_pair, transform):
return (
ModelDataWithVelocityRepresentation.other_representation_to_inertial(
array=wrench_pair[0],
transform=transform[0],
other_representation=VelRepr.Mixed,
is_force=True,
),
ModelDataWithVelocityRepresentation.other_representation_to_inertial(
array=wrench_pair[1],
transform=transform[1],
other_representation=VelRepr.Mixed,
is_force=True,
),
return jnp.array(
(
ModelDataWithVelocityRepresentation.other_representation_to_inertial(
array=wrench_pair[0],
transform=transform[0],
other_representation=VelRepr.Mixed,
is_force=True,
),
ModelDataWithVelocityRepresentation.other_representation_to_inertial(
array=wrench_pair[1],
transform=transform[1],
other_representation=VelRepr.Mixed,
is_force=True,
),
)
)

kin_constr_force_pairs_inertial = jax.vmap(
transform_wrench_pair_to_inertial,
in_axes=(1, 1),
)(kin_constr_wrench_pairs_mixed, W_H_constr)
kin_constr_force_pairs_inertial = jnp.array(kin_constr_force_pairs_inertial)

jax.debug.print(
"kin_constr_force_pairs_inertial: \n{}",
Expand Down Expand Up @@ -728,7 +735,7 @@ def compute_row(
),
)

return a_ref, jnp.diag(R), K, D
return a_ref, R, K, D

@staticmethod
@functools.partial(jax.jit, static_argnames=("terrain",))
Expand Down

0 comments on commit fe7a382

Please sign in to comment.