Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
xela-95 committed Mar 7, 2025
1 parent 6ea7c18 commit 9b82920
Show file tree
Hide file tree
Showing 4 changed files with 248 additions and 97 deletions.
44 changes: 32 additions & 12 deletions src/jaxsim/api/contact_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,38 @@ def link_contact_forces(
# to the frames associated to the collidable points.
W_f_L = link_forces_from_contact_forces(model=model, contact_forces=W_f_C)

# Add the forces coming from the kinematic constraints to the link on which the constraint is applied.

# W_f_loop = aux_data["kin_constr_force"]
W_f_loop_F1 = aux_data["kin_constr_force_F1"]
W_f_loop_f2 = aux_data["kin_constr_force_F2"]
F1_idx = aux_data["F1_idx"]
F2_idx = aux_data["F2_idx"]
F1_parent_idx = js.frame.idx_of_parent_link(model, frame_index=F1_idx)
F2_parent_idx = js.frame.idx_of_parent_link(model, frame_index=F2_idx)

W_f_L = W_f_L.at[F1_parent_idx].add(W_f_loop_F1)
W_f_L = W_f_L.at[F2_parent_idx].add(W_f_loop_f2)
wrench_pair_constr_inertial = aux_data["constr_wrenches_inertial"]

constraints = model.kin_dyn_parameters.get_constraints(model)
# Get the couples of parent link indices of each couple of frames.
frame_idxs_1, frame_idxs_2, types = zip(*constraints, strict=False)
frame_idxs_1 = jnp.array(frame_idxs_1)
frame_idxs_2 = jnp.array(frame_idxs_2)

jax.debug.print("frame_idxs_1: \n{}", frame_idxs_1)
jax.debug.print("frame_idxs_2: \n{}", frame_idxs_2)

parent_link_indices = jax.vmap(
lambda frame_idx_1, frame_idx_2: (
js.frame.idx_of_parent_link(model, frame_index=frame_idx_1),
js.frame.idx_of_parent_link(model, frame_index=frame_idx_2),
)
)(frame_idxs_1, frame_idxs_2)
parent_link_indices = jnp.array(parent_link_indices)
jax.debug.print("parent_link_indices: \n{}", parent_link_indices.shape)

# Apply each constraint wrench to its corresponding parent link in W_f_L.
def apply_wrench(i, W_f_L):
parent_indices = parent_link_indices[:, i]
wrench_pair = wrench_pair_constr_inertial[:, i]
jax.debug.print("parent_indices: \n{}", parent_indices)
jax.debug.print("wrench_pair: \n{}", wrench_pair)
W_f_L = W_f_L.at[parent_indices[0]].add(wrench_pair[0])
W_f_L = W_f_L.at[parent_indices[1]].add(wrench_pair[1])
return W_f_L

W_f_L = jax.lax.fori_loop(0, parent_link_indices.shape[0], apply_wrench, W_f_L)

jax.debug.print("W_f_L: \n{}", W_f_L)

return W_f_L
Expand Down
110 changes: 98 additions & 12 deletions src/jaxsim/api/kin_dyn_parameters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import dataclasses
import enum

import jax.lax
import jax.numpy as jnp
Expand All @@ -9,6 +10,7 @@
import numpy.typing as npt
from jax_dataclasses import Static

import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim.math import Adjoint, Inertia, JointModel, supported_joint_motion
from jaxsim.parsers.descriptions import JointDescription, JointType, ModelDescription
Expand Down Expand Up @@ -51,6 +53,8 @@ class KinDynParameters(JaxsimDataclass):
joint_model: JointModel
joint_parameters: JointParameters | None

constraints: Static[ConstraintMap]

@property
def motion_subspaces(self) -> jtp.Matrix:
r"""
Expand All @@ -73,12 +77,15 @@ def support_body_array_bool(self) -> jtp.Matrix:
return self._support_body_array_bool.get()

@staticmethod
def build(model_description: ModelDescription) -> KinDynParameters:
def build(
model_description: ModelDescription, constraints: ConstraintMap | None
) -> KinDynParameters:
"""
Construct the kinematic and dynamic parameters of the model.
Args:
model_description: The parsed model description to consider.
constraints: An object of type ConstraintMap specifying the kinematic constraint of the model.
Returns:
The kinematic and dynamic parameters of the model.
Expand Down Expand Up @@ -248,6 +255,12 @@ def motion_subspace(joint_type: int, axis: npt.ArrayLike) -> npt.ArrayLike:

motion_subspaces = jnp.vstack([jnp.zeros((6, 1))[jnp.newaxis, ...], S_J])

# ===========
# Constraints
# ===========

constraints = ConstraintMap() if constraints is None else constraints

# =================================
# Build and return KinDynParameters
# =================================
Expand All @@ -262,6 +275,7 @@ def motion_subspace(joint_type: int, axis: npt.ArrayLike) -> npt.ArrayLike:
joint_parameters=joint_parameters,
contact_parameters=contact_parameters,
frame_parameters=frame_parameters,
constraints=constraints,
)

def __eq__(self, other: KinDynParameters) -> bool:
Expand All @@ -272,17 +286,15 @@ def __eq__(self, other: KinDynParameters) -> bool:
return hash(self) == hash(other)

def __hash__(self) -> int:

return hash(
(
hash(self.number_of_links()),
hash(self.number_of_joints()),
hash(self.frame_parameters.name),
hash(self.frame_parameters.body),
hash(self._parent_array),
hash(self._support_body_array_bool),
)
)
return hash((
hash(self.number_of_links()),
hash(self.number_of_joints()),
hash(self.frame_parameters.name),
hash(self.frame_parameters.body),
hash(self._parent_array),
hash(self._support_body_array_bool),
hash(self.constraints),
))

# =============================
# Helpers to extract parameters
Expand Down Expand Up @@ -337,6 +349,13 @@ def support_body_array(self, link_index: jtp.IntLike) -> jtp.Vector:
jnp.where(self.support_body_array_bool[link_index])[0], dtype=int
)

def get_constraints(self, model:js.model.JaxSimModel) -> tuple[tuple[str, str, ConstraintType], ...]:
r"""
Return the constraints of the model.
"""

return self.constraints.get_constraints(model)

# ========================
# Quantities used by RBDAs
# ========================
Expand Down Expand Up @@ -882,3 +901,70 @@ def build_from(model_description: ModelDescription) -> FrameParameters:
assert fp.transform.shape[0] == len(fp.body), fp.transform.shape[0]

return fp


@enum.unique
class ConstraintType(enum.IntEnum):
"""
Enumeration of all supported constraint types.
"""

Weld = enum.auto()
Connect = enum.auto()


@jax_dataclasses.pytree_dataclass
class ConstraintMap(JaxsimDataclass):
"""
Class storing the kinematic constraints of a model.
"""

frame_names_1: Static[tuple[str, ...]] = dataclasses.field(default_factory=tuple)
frame_names_2: Static[tuple[str, ...]] = dataclasses.field(default_factory=tuple)
constraint_types: Static[tuple[ConstraintType, ...]] = dataclasses.field(
default_factory=tuple
)

def add_constraint(
self, frame_name_1: str, frame_name_2: str, constraint_type: ConstraintType
) -> ConstraintMap:
"""
Add a constraint to the constraint map.
Args:
frame_name_1: The name of the first frame.
frame_name_2: The name of the second frame.
constraint_type: The type of constraint.
Returns:
A new ConstraintMap instance with the added constraint.
"""
return self.replace(
frame_names_1=(*self.frame_names_1, frame_name_1),
frame_names_2=(*self.frame_names_2, frame_name_2),
constraint_types=(*self.constraint_types, constraint_type),
validate=False,
)

def get_constraints(
self, model: js.model.JaxSimModel
) -> tuple[tuple[int, int, ConstraintType], ...]:
"""
Get the list of constraints.
Returns:
A tuple, in which each element defines a kinematic constraint.
"""
return tuple(
(
js.frame.name_to_idx(model, frame_name=frame_name_1),
js.frame.name_to_idx(model, frame_name=frame_name_2),
constraint_type,
)
for frame_name_1, frame_name_2, constraint_type in zip(
self.frame_names_1,
self.frame_names_2,
self.constraint_types,
strict=True,
)
)
10 changes: 9 additions & 1 deletion src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import jaxsim.exceptions
import jaxsim.terrain
import jaxsim.typing as jtp
from jaxsim.api.kin_dyn_parameters import ConstraintMap
from jaxsim.math import Adjoint, Cross
from jaxsim.parsers.descriptions import ModelDescription
from jaxsim.utils import JaxsimDataclass, Mutability, wrappers
Expand Down Expand Up @@ -126,6 +127,7 @@ def build_from_model_description(
integrator: IntegratorType | None = None,
is_urdf: bool | None = None,
considered_joints: Sequence[str] | None = None,
constraints: ConstraintMap | None = None,
) -> JaxSimModel:
"""
Build a Model object from a model description.
Expand All @@ -150,6 +152,8 @@ def build_from_model_description(
This is usually automatically inferred.
considered_joints:
The list of joints to consider. If None, all joints are considered.
constraints:
An object of type ConstraintMap containing the kinematic constraints to consider. If None, no constraints are considered.
Returns:
The built Model object.
Expand Down Expand Up @@ -179,6 +183,7 @@ def build_from_model_description(
contact_model=contact_model,
contacts_params=contact_params,
integrator=integrator,
constraints=constraints,
)

# Store the origin of the model, in case downstream logic needs it.
Expand All @@ -199,6 +204,7 @@ def build(
contacts_params: jaxsim.rbda.contacts.ContactsParams | None = None,
integrator: IntegratorType | None = None,
gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY,
constraints: ConstraintMap | None = None,
) -> JaxSimModel:
"""
Build a Model object from an intermediate model description.
Expand All @@ -220,6 +226,8 @@ def build(
contacts_params: The parameters of the soft contacts.
integrator: The integrator to use for the simulation.
gravity: The gravity constant.
constraints:
An object of type ConstraintMap containing the kinematic constraints to consider. If None, no constraints are considered.
Returns:
The built Model object.
Expand Down Expand Up @@ -265,7 +273,7 @@ def build(
model = cls(
model_name=model_name,
kin_dyn_parameters=js.kin_dyn_parameters.KinDynParameters.build(
model_description=model_description
model_description=model_description, constraints=constraints
),
time_step=time_step,
terrain=terrain,
Expand Down
Loading

0 comments on commit 9b82920

Please sign in to comment.