From 804ee9f7bca4148226278c066de356567fc7e390 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Sun, 5 Jan 2025 23:27:58 +0100 Subject: [PATCH] Remove `wrappers` module and custom hash methods --- src/jaxsim/api/kin_dyn_parameters.py | 46 ++--- src/jaxsim/api/model.py | 46 ++--- src/jaxsim/math/joint_model.py | 6 +- src/jaxsim/parsers/descriptions/__init__.py | 2 +- src/jaxsim/parsers/descriptions/collision.py | 99 +++-------- src/jaxsim/parsers/descriptions/joint.py | 177 ++++++++++--------- src/jaxsim/parsers/descriptions/link.py | 92 +++++----- src/jaxsim/parsers/descriptions/model.py | 30 ---- src/jaxsim/parsers/kinematic_graph.py | 134 ++++++++------ src/jaxsim/parsers/rod/parser.py | 48 +++-- src/jaxsim/parsers/rod/utils.py | 6 +- src/jaxsim/rbda/contacts/relaxed_rigid.py | 114 +++--------- src/jaxsim/terrain/terrain.py | 41 ----- src/jaxsim/utils/__init__.py | 1 - src/jaxsim/utils/wrappers.py | 159 ----------------- 15 files changed, 337 insertions(+), 664 deletions(-) delete mode 100644 src/jaxsim/utils/wrappers.py diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index 93fe7d6ff..fd6bc053f 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -12,7 +12,7 @@ import jaxsim.typing as jtp from jaxsim.math import Adjoint, Inertia, JointModel, supported_joint_motion from jaxsim.parsers.descriptions import JointDescription, JointType, ModelDescription -from jaxsim.utils import HashedNumpyArray, JaxsimDataclass +from jaxsim.utils import JaxsimDataclass @jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False) @@ -34,9 +34,9 @@ class KinDynParameters(JaxsimDataclass): # Static link_names: Static[tuple[str]] - _parent_array: Static[HashedNumpyArray] - _support_body_array_bool: Static[HashedNumpyArray] - _motion_subspaces: Static[HashedNumpyArray] + _parent_array: Static[tuple[int]] + _support_body_array_bool: Static[tuple[int]] + _motion_subspaces: Static[tuple[float]] # Links link_parameters: LinkParameters @@ -56,21 +56,21 @@ def motion_subspaces(self) -> jtp.Matrix: r""" Return the motion subspaces :math:`\mathbf{S}(s)` of the joints. """ - return self._motion_subspaces.get() + return jnp.array(self._motion_subspaces, dtype=float) @property def parent_array(self) -> jtp.Vector: r""" Return the parent array :math:`\lambda(i)` of the model. """ - return self._parent_array.get() + return jnp.array(self._parent_array, dtype=int) @property def support_body_array_bool(self) -> jtp.Matrix: r""" Return the boolean support parent array :math:`\kappa_{b}(i)` of the model. """ - return self._support_body_array_bool.get() + return jnp.array(self._support_body_array_bool, dtype=int) @staticmethod def build(model_description: ModelDescription) -> KinDynParameters: @@ -227,8 +227,8 @@ def motion_subspace(joint_type: int, axis: npt.ArrayLike) -> npt.ArrayLike: S = { JointType.Fixed: np.zeros(shape=(6, 1)), - JointType.Revolute: np.vstack(np.hstack([np.zeros(3), axis.axis])), - JointType.Prismatic: np.vstack(np.hstack([axis.axis, np.zeros(3)])), + JointType.Revolute: np.vstack(np.hstack([np.zeros(3), axis])), + JointType.Prismatic: np.vstack(np.hstack([axis, np.zeros(3)])), } return S[joint_type] @@ -254,9 +254,9 @@ def motion_subspace(joint_type: int, axis: npt.ArrayLike) -> npt.ArrayLike: return KinDynParameters( link_names=tuple(l.name for l in ordered_links), - _parent_array=HashedNumpyArray(array=parent_array), - _support_body_array_bool=HashedNumpyArray(array=support_body_array_bool), - _motion_subspaces=HashedNumpyArray(array=motion_subspaces), + _parent_array=tuple(parent_array.tolist()), + _support_body_array_bool=tuple(support_body_array_bool.tolist()), + _motion_subspaces=tuple(motion_subspaces.tolist()), link_parameters=link_parameters, joint_model=joint_model, joint_parameters=joint_parameters, @@ -264,26 +264,6 @@ def motion_subspace(joint_type: int, axis: npt.ArrayLike) -> npt.ArrayLike: frame_parameters=frame_parameters, ) - def __eq__(self, other: KinDynParameters) -> bool: - - if not isinstance(other, KinDynParameters): - return False - - return hash(self) == hash(other) - - def __hash__(self) -> int: - - return hash( - ( - hash(self.number_of_links()), - hash(self.number_of_joints()), - hash(self.frame_parameters.name), - hash(self.frame_parameters.body), - hash(self._parent_array), - hash(self._support_body_array_bool), - ) - ) - # ============================= # Helpers to extract parameters # ============================= @@ -409,7 +389,7 @@ def joint_transforms( pre_H_suc_J = jax.vmap(supported_joint_motion)( joint_types=jnp.array(self.joint_model.joint_types[1:]).astype(int), joint_positions=jnp.array(joint_positions), - joint_axes=jnp.array([j.axis for j in self.joint_model.joint_axis]), + joint_axes=jnp.array(self.joint_model.joint_axis), ) # Extract the transforms and motion subspaces of the joints. diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 47f601512..554e8c807 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -19,7 +19,7 @@ import jaxsim.typing as jtp from jaxsim.math import Adjoint, Cross from jaxsim.parsers.descriptions import ModelDescription -from jaxsim.utils import JaxsimDataclass, Mutability, wrappers +from jaxsim.utils import JaxsimDataclass, Mutability from .common import VelRepr @@ -69,8 +69,8 @@ class JaxSimModel(JaxsimDataclass): default=None, repr=False ) - _description: Static[wrappers.HashlessObject[ModelDescription | None]] = ( - dataclasses.field(default=None, repr=False) + _description: Static[ModelDescription | None] = dataclasses.field( + default=None, repr=False ) @property @@ -78,34 +78,7 @@ def description(self) -> ModelDescription: """ Return the model description. """ - return self._description.get() - - def __eq__(self, other: JaxSimModel) -> bool: - - if not isinstance(other, JaxSimModel): - return False - - if self.model_name != other.model_name: - return False - - if self.time_step != other.time_step: - return False - - if self.kin_dyn_parameters != other.kin_dyn_parameters: - return False - - return True - - def __hash__(self) -> int: - - return hash( - ( - hash(self.model_name), - hash(self.time_step), - hash(self.kin_dyn_parameters), - hash(self.contact_model), - ) - ) + return self._description # ======================== # Initialization and state @@ -275,7 +248,7 @@ def build( # don't want to trigger recompilation if it changes. All relevant parameters # needed to compute kinematics and dynamics quantities are stored in the # kin_dyn_parameters attribute. - _description=wrappers.HashlessObject(obj=model_description), + _description=model_description, ) return model @@ -446,15 +419,16 @@ def reduce( # Operate on a deep copy of the model description in order to prevent problems # when mutable attributes are updated. - intermediate_description = copy.deepcopy(model.description) + intermediate_description = copy.deepcopy(model._description) # Update the initial position of the joints. # This is necessary to compute the correct pose of the link pairs connected # to removed joints. for joint_name in set(model.joint_names()) - set(considered_joints): - j = intermediate_description.joints_dict[joint_name] - with j.mutable_context(): - j.initial_position = locked_joint_positions.get(joint_name, 0.0) + intermediate_description.joints_dict[joint_name] = dataclasses.replace( + intermediate_description.joints_dict[joint_name], + _initial_position=float(locked_joint_positions.get(joint_name, 0.0)), + ) # Reduce the model description. # If `considered_joints` contains joints not existing in the model, diff --git a/src/jaxsim/math/joint_model.py b/src/jaxsim/math/joint_model.py index 6d94800ff..65aefa6ae 100644 --- a/src/jaxsim/math/joint_model.py +++ b/src/jaxsim/math/joint_model.py @@ -8,7 +8,7 @@ import jaxsim.typing as jtp from jaxsim.math import Rotation -from jaxsim.parsers.descriptions import JointGenericAxis, JointType, ModelDescription +from jaxsim.parsers.descriptions import JointType, ModelDescription from jaxsim.parsers.kinematic_graph import KinematicGraphTransforms @@ -39,7 +39,7 @@ class JointModel: joint_dofs: Static[tuple[int, ...]] joint_names: Static[tuple[str, ...]] joint_types: Static[tuple[int, ...]] - joint_axis: Static[tuple[JointGenericAxis, ...]] + joint_axis: Static[tuple[tuple[int]]] @staticmethod def build(description: ModelDescription) -> JointModel: @@ -108,7 +108,7 @@ def build(description: ModelDescription) -> JointModel: joint_dofs=tuple([base_dofs] + [1 for _ in ordered_joints]), joint_names=tuple(["world_to_base"] + [j.name for j in ordered_joints]), joint_types=tuple([JointType.Fixed] + [j.jtype for j in ordered_joints]), - joint_axis=tuple(JointGenericAxis(axis=j.axis) for j in ordered_joints), + joint_axis=tuple(tuple(j.axis.tolist()) for j in ordered_joints), ) def parent_H_predecessor(self, joint_index: jtp.IntLike) -> jtp.Matrix: diff --git a/src/jaxsim/parsers/descriptions/__init__.py b/src/jaxsim/parsers/descriptions/__init__.py index ff3bf631d..a391cb3d7 100644 --- a/src/jaxsim/parsers/descriptions/__init__.py +++ b/src/jaxsim/parsers/descriptions/__init__.py @@ -5,6 +5,6 @@ MeshCollision, SphereCollision, ) -from .joint import JointDescription, JointGenericAxis, JointType +from .joint import JointDescription, JointType from .link import LinkDescription from .model import ModelDescription diff --git a/src/jaxsim/parsers/descriptions/collision.py b/src/jaxsim/parsers/descriptions/collision.py index 719c92d2b..b9b5365a8 100644 --- a/src/jaxsim/parsers/descriptions/collision.py +++ b/src/jaxsim/parsers/descriptions/collision.py @@ -3,11 +3,9 @@ import abc import dataclasses -import jax.numpy as jnp import numpy as np import numpy.typing as npt -import jaxsim.typing as jtp from jaxsim import logging from .link import LinkDescription @@ -25,8 +23,28 @@ class CollidablePoint: """ parent_link: LinkDescription - position: npt.NDArray = dataclasses.field(default_factory=lambda: np.zeros(3)) enabled: bool = True + _position: tuple[float] = dataclasses.field(default=(0.0, 0.0, 0.0)) + + @property + def position(self) -> npt.NDArray: + """ + Get the position of the collidable point. + + Returns: + The position of the collidable point. + """ + return np.array(self._position) + + @position.setter + def position(self, value: npt.NDArray) -> None: + """ + Set the position of the collidable point. + + Args: + value: The new position of the collidable point. + """ + self._position = tuple(value.tolist()) def change_link( self, new_link: LinkDescription, new_H_old: npt.NDArray @@ -35,8 +53,8 @@ def change_link( Move the collidable point to a new parent link. Args: - new_link (LinkDescription): The new parent link to which the collidable point is moved. - new_H_old (npt.NDArray): The transformation matrix from the new link's frame to the old link's frame. + new_link: The new parent link to which the collidable point is moved. + new_H_old: The transformation matrix from the new link's frame to the old link's frame. Returns: CollidablePoint: A new collidable point associated with the new parent link. @@ -47,27 +65,12 @@ def change_link( return CollidablePoint( parent_link=new_link, - position=(new_H_old @ jnp.hstack([self.position, 1.0])).squeeze()[0:3], + _position=tuple( + (new_H_old @ np.hstack([self.position, 1.0])).squeeze()[0:3].tolist() + ), enabled=self.enabled, ) - def __hash__(self) -> int: - - return hash( - ( - hash(self.parent_link), - hash(tuple(self.position.tolist())), - hash(self.enabled), - ) - ) - - def __eq__(self, other: CollidablePoint) -> bool: - - if not isinstance(other, CollidablePoint): - return False - - return hash(self) == hash(other) - def __str__(self) -> str: return ( f"{self.__class__.__name__}(" @@ -107,22 +110,7 @@ class BoxCollision(CollisionShape): center: The center of the box in the local frame of the collision shape. """ - center: jtp.VectorLike - - def __hash__(self) -> int: - return hash( - ( - hash(super()), - hash(tuple(self.center.tolist())), - ) - ) - - def __eq__(self, other: BoxCollision) -> bool: - - if not isinstance(other, BoxCollision): - return False - - return hash(self) == hash(other) + center: tuple[float, float, float] @dataclasses.dataclass @@ -134,22 +122,7 @@ class SphereCollision(CollisionShape): center: The center of the sphere in the local frame of the collision shape. """ - center: jtp.VectorLike - - def __hash__(self) -> int: - return hash( - ( - hash(super()), - hash(tuple(self.center.tolist())), - ) - ) - - def __eq__(self, other: BoxCollision) -> bool: - - if not isinstance(other, BoxCollision): - return False - - return hash(self) == hash(other) + center: tuple[float, float, float] @dataclasses.dataclass @@ -161,18 +134,4 @@ class MeshCollision(CollisionShape): center: The center of the mesh in the local frame of the collision shape. """ - center: jtp.VectorLike - - def __hash__(self) -> int: - return hash( - ( - hash(tuple(self.center.tolist())), - hash(self.collidable_points), - ) - ) - - def __eq__(self, other: MeshCollision) -> bool: - if not isinstance(other, MeshCollision): - return False - - return hash(self) == hash(other) + center: tuple[float, float, float] diff --git a/src/jaxsim/parsers/descriptions/joint.py b/src/jaxsim/parsers/descriptions/joint.py index 04ccfaa4a..181b42cbf 100644 --- a/src/jaxsim/parsers/descriptions/joint.py +++ b/src/jaxsim/parsers/descriptions/joint.py @@ -3,11 +3,8 @@ import dataclasses from typing import ClassVar -import jax_dataclasses import numpy as np - -import jaxsim.typing as jtp -from jaxsim.utils import JaxsimDataclass, Mutability +import numpy.typing as npt from .link import LinkDescription @@ -23,31 +20,10 @@ class JointType: Prismatic: ClassVar[int] = 2 -@jax_dataclasses.pytree_dataclass -class JointGenericAxis: - """ - A joint requiring the specification of a 3D axis. +@dataclasses.dataclass(eq=False) +class JointDescription: """ - - # The axis of rotation or translation of the joint (must have norm 1). - axis: jtp.Vector - - def __hash__(self) -> int: - - return hash(tuple(self.axis.tolist())) - - def __eq__(self, other: JointGenericAxis) -> bool: - - if not isinstance(other, JointGenericAxis): - return False - - return hash(self) == hash(other) - - -@jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False) -class JointDescription(JaxsimDataclass): - """ - In-memory description of a robot link. + In-memory description of a robot joint. Attributes: name: The name of the joint. @@ -65,66 +41,101 @@ class JointDescription(JaxsimDataclass): initial_position: The initial position of the joint. """ - name: jax_dataclasses.Static[str] - axis: jtp.Vector - pose: jtp.Matrix - jtype: jax_dataclasses.Static[jtp.IntLike] - child: LinkDescription = dataclasses.dataclass(repr=False) - parent: LinkDescription = dataclasses.dataclass(repr=False) + name: str + _axis: tuple[float] + _pose: tuple[float] + jtype: int + child: LinkDescription = dataclasses.field( + default_factory=LinkDescription, repr=False + ) + parent: LinkDescription = dataclasses.field( + default_factory=LinkDescription, repr=False + ) - index: jtp.IntLike | None = None + index: int | None = None - friction_static: jtp.FloatLike = 0.0 - friction_viscous: jtp.FloatLike = 0.0 + friction_static: float = 0.0 + friction_viscous: float = 0.0 - position_limit_damper: jtp.FloatLike = 0.0 - position_limit_spring: jtp.FloatLike = 0.0 + position_limit_damper: float = 0.0 + position_limit_spring: float = 0.0 - position_limit: tuple[jtp.FloatLike, jtp.FloatLike] = (0.0, 0.0) - initial_position: jtp.FloatLike | jtp.VectorLike = 0.0 + position_limit: tuple[float, float] = (0.0, 0.0) + _initial_position: float | tuple[float] = 0.0 - motor_inertia: jtp.FloatLike = 0.0 - motor_viscous_friction: jtp.FloatLike = 0.0 - motor_gear_ratio: jtp.FloatLike = 1.0 + motor_inertia: float = 0.0 + motor_viscous_friction: float = 0.0 + motor_gear_ratio: float = 1.0 def __post_init__(self) -> None: - if self.axis is not None: - - with self.mutable_context( - mutability=Mutability.MUTABLE, restore_after_exception=False - ): - norm_of_axis = np.linalg.norm(self.axis) - self.axis = self.axis / norm_of_axis - - def __eq__(self, other: JointDescription) -> bool: - - if not isinstance(other, JointDescription): - return False - - return hash(self) == hash(other) - - def __hash__(self) -> int: - - from jaxsim.utils.wrappers import HashedNumpyArray - - return hash( - ( - hash(self.name), - HashedNumpyArray.hash_of_array(self.axis), - HashedNumpyArray.hash_of_array(self.pose), - hash(int(self.jtype)), - hash(self.child), - hash(self.parent), - hash(int(self.index)) if self.index is not None else 0, - HashedNumpyArray.hash_of_array(self.friction_static), - HashedNumpyArray.hash_of_array(self.friction_viscous), - HashedNumpyArray.hash_of_array(self.position_limit_damper), - HashedNumpyArray.hash_of_array(self.position_limit_spring), - HashedNumpyArray.hash_of_array(self.position_limit), - HashedNumpyArray.hash_of_array(self.initial_position), - HashedNumpyArray.hash_of_array(self.motor_inertia), - HashedNumpyArray.hash_of_array(self.motor_viscous_friction), - HashedNumpyArray.hash_of_array(self.motor_gear_ratio), - ), - ) + if self._axis is not None: + + self._axis = self.axis / np.linalg.norm(self.axis) + + @property + def axis(self) -> npt.NDArray: + """ + Get the axis of the joint. + + Returns: + npt.NDArray: The axis of the joint. + """ + + return np.array(self._axis) + + @axis.setter + def axis(self, value: npt.NDArray) -> None: + """ + Set the axis of the joint. + + Args: + value (npt.NDArray): The new axis of the joint. + """ + + norm_of_axis = np.linalg.norm(value) + self._axis = tuple((value / norm_of_axis).tolist()) + + @property + def pose(self) -> npt.NDArray: + """ + Get the pose of the joint. + + Returns: + The pose of the joint. + """ + + return np.array(self._pose, dtype=float) + + @pose.setter + def pose(self, value: npt.NDArray) -> None: + """ + Set the pose of the joint. + + Args: + value: The new pose of the joint. + """ + + self._pose = tuple(np.array(value).tolist()) + + @property + def initial_position(self) -> float | npt.NDArray: + """ + Get the initial position of the joint. + + Returns: + The initial position of the joint. + """ + + return np.array(self._initial_position, dtype=float) + + @initial_position.setter + def initial_position(self, value: float | npt.NDArray) -> None: + """ + Set the initial position of the joint. + + Args: + value: The new initial position of the joint. + """ + + self._initial_position = tuple(np.array(value).tolist()) diff --git a/src/jaxsim/parsers/descriptions/link.py b/src/jaxsim/parsers/descriptions/link.py index 9bad65be9..f94024268 100644 --- a/src/jaxsim/parsers/descriptions/link.py +++ b/src/jaxsim/parsers/descriptions/link.py @@ -2,18 +2,15 @@ import dataclasses -import jax.numpy as jnp -import jax_dataclasses import numpy as np -from jax_dataclasses import Static +import numpy.typing as npt import jaxsim.typing as jtp from jaxsim.math import Adjoint -from jaxsim.utils import JaxsimDataclass -@jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False) -class LinkDescription(JaxsimDataclass): +@dataclasses.dataclass(eq=False, unsafe_hash=False) +class LinkDescription: """ In-memory description of a robot link. @@ -27,50 +24,58 @@ class LinkDescription(JaxsimDataclass): children: The children links. """ - name: Static[str] + name: str mass: float = dataclasses.field(repr=False) - inertia: jtp.Matrix = dataclasses.field(repr=False) + _inertia: tuple[float] = dataclasses.field(repr=False) index: int | None = None - parent_name: Static[str | None] = dataclasses.field(default=None, repr=False) - pose: jtp.Matrix = dataclasses.field(default_factory=lambda: jnp.eye(4), repr=False) + parent_name: str | None = dataclasses.field(default=None, repr=False) + _pose: tuple[float] = dataclasses.field( + default=tuple(np.eye(4).tolist()), repr=False + ) - children: Static[tuple[LinkDescription]] = dataclasses.field( + children: tuple[LinkDescription] = dataclasses.field( default_factory=list, repr=False ) - def __hash__(self) -> int: + @property + def inertia(self) -> npt.NDArray: + """ + Get the inertia tensor of the link. - from jaxsim.utils.wrappers import HashedNumpyArray + Returns: + npt.NDArray: The inertia tensor of the link. + """ + return np.array(self._inertia) - return hash( - ( - hash(self.name), - hash(float(self.mass)), - HashedNumpyArray.hash_of_array(self.inertia), - hash(int(self.index)) if self.index is not None else 0, - HashedNumpyArray.hash_of_array(self.pose), - hash(tuple(self.children)), - hash(self.parent_name) if self.parent_name is not None else 0, - ) - ) + @inertia.setter + def inertia(self, inertia: npt.NDArray) -> None: + """ + Set the inertia tensor of the link. + + Args: + inertia: The inertia tensor of the link. + """ + self._inertia = tuple(inertia.tolist()) - def __eq__(self, other: LinkDescription) -> bool: + @property + def pose(self) -> npt.NDArray: + """ + Get the pose transformation matrix of the link. - if not isinstance(other, LinkDescription): - return False + Returns: + npt.NDArray: The pose transformation matrix of the link. + """ + return np.array(self._pose) - if not ( - self.name == other.name - and np.allclose(self.mass, other.mass) - and np.allclose(self.inertia, other.inertia) - and self.index == other.index - and np.allclose(self.pose, other.pose) - and self.children == other.children - and self.parent_name == other.parent_name - ): - return False + @pose.setter + def pose(self, pose: npt.NDArray) -> None: + """ + Set the pose transformation matrix of the link. - return True + Args: + pose: The pose transformation matrix of the link. + """ + self._pose = tuple(pose.tolist()) @property def name_and_index(self) -> str: @@ -101,15 +106,18 @@ def lump_with( I_removed = link.inertia # Create the SE3 object. Note the inverse. - r_X_l = Adjoint.from_transform(transform=lumped_H_removed, inverse=True) + r_X_l = np.array( + Adjoint.from_transform(transform=lumped_H_removed, inverse=True) + ) # Move the inertia I_removed_in_lumped_frame = r_X_l.transpose() @ I_removed @ r_X_l # Create the new combined link - lumped_link = self.replace( - mass=self.mass + link.mass, - inertia=self.inertia + I_removed_in_lumped_frame, + lumped_link = dataclasses.replace( + self, + mass=float(self.mass + link.mass), + _inertia=tuple((self.inertia + I_removed_in_lumped_frame).tolist()), ) return lumped_link diff --git a/src/jaxsim/parsers/descriptions/model.py b/src/jaxsim/parsers/descriptions/model.py index cd54ba27f..ea0a69bb9 100644 --- a/src/jaxsim/parsers/descriptions/model.py +++ b/src/jaxsim/parsers/descriptions/model.py @@ -243,33 +243,3 @@ def all_enabled_collidable_points(self) -> list[CollidablePoint]: # Return enabled collidable points return [cp for cp in all_collidable_points if cp.enabled] - - def __eq__(self, other: ModelDescription) -> bool: - - if not isinstance(other, ModelDescription): - return False - - if not ( - self.name == other.name - and self.fixed_base == other.fixed_base - and self.root == other.root - and self.joints == other.joints - and self.frames == other.frames - and self.root_pose == other.root_pose - ): - return False - - return True - - def __hash__(self) -> int: - - return hash( - ( - hash(self.name), - hash(self.fixed_base), - hash(self.root), - hash(tuple(self.joints)), - hash(tuple(self.frames)), - hash(self.root_pose), - ) - ) diff --git a/src/jaxsim/parsers/kinematic_graph.py b/src/jaxsim/parsers/kinematic_graph.py index 9c16136d3..6a60c558e 100644 --- a/src/jaxsim/parsers/kinematic_graph.py +++ b/src/jaxsim/parsers/kinematic_graph.py @@ -9,9 +9,7 @@ import numpy as np import numpy.typing as npt -import jaxsim.utils from jaxsim import logging -from jaxsim.utils import Mutability from .descriptions.joint import JointDescription, JointType from .descriptions.link import LinkDescription @@ -31,35 +29,53 @@ class RootPose: The root link of the kinematic graph is the base link. """ - root_position: npt.NDArray = dataclasses.field(default_factory=lambda: np.zeros(3)) + _root_position: tuple[float] = dataclasses.field( + default=tuple(np.zeros(3).tolist()) + ) - root_quaternion: npt.NDArray = dataclasses.field( - default_factory=lambda: np.array([1.0, 0, 0, 0]) + _root_quaternion: tuple[float] = dataclasses.field( + default=tuple(np.array([1.0, 0, 0, 0]).tolist()) ) - def __hash__(self) -> int: + @property + def root_position(self) -> npt.NDArray: + """ + Get the 3D position of the root link of the graph. - from jaxsim.utils.wrappers import HashedNumpyArray + Returns: + The 3D position of the root link. + """ + return np.array(self._root_position) - return hash( - ( - HashedNumpyArray.hash_of_array(self.root_position), - HashedNumpyArray.hash_of_array(self.root_quaternion), - ) - ) + @root_position.setter + def root_position(self, value: npt.NDArray) -> None: + """ + Set the 3D position of the root link of the graph. - def __eq__(self, other: RootPose) -> bool: + Args: + value: The new 3D position of the root link. + """ + self._root_position = tuple(value.tolist()) - if not isinstance(other, RootPose): - return False + @property + def root_quaternion(self) -> npt.NDArray: + """ + Get the quaternion representing the rotation of the root link of the graph. - if not np.allclose(self.root_position, other.root_position): - return False + Returns: + The quaternion representing the rotation of the root link. + """ + return np.array(self._root_quaternion) - if not np.allclose(self.root_quaternion, other.root_quaternion): - return False + @root_quaternion.setter + def root_quaternion(self, value: npt.NDArray) -> None: + """ + Set the quaternion representing the rotation of the root link of the graph. - return True + Args: + value: The new quaternion representing the rotation of the root link. + """ + self._root_quaternion = tuple(value.tolist()) @dataclasses.dataclass(frozen=True) @@ -131,7 +147,7 @@ def __post_init__(self) -> None: # Here we assume the model being fixed-base, therefore the base link will # have index 0. We will deal with the floating base in a later stage. for index, link in enumerate(self): - link.mutable(validate=False).index = index + link.index = index # Get the names of the links, frames, and joints. link_names = [l.name for l in self] @@ -152,15 +168,17 @@ def __post_init__(self) -> None: # We assume the model being fixed-base, therefore the first frame will # have last_link_idx + 1. for index, frame in enumerate(self.frames): - with frame.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): - frame.index = int(index + len(self.link_names())) + frame.index = int(index + len(self.link_names())) # Number joints so that their index matches their child link index. # Therefore, the first joint has index 1. links_dict = {l.name: l for l in iter(self)} - for joint in self.joints: - with joint.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): - joint.index = links_dict[joint.child.name].index + + joints = [ + dataclasses.replace(joint, index=int(links_dict[joint.child.name].index)) + for joint in self.joints + ] + super().__setattr__("joints", joints) # Check that joint indices are unique. assert len([j.index for j in self.joints]) == len( @@ -262,9 +280,7 @@ def _create_graph( """ # Create a dictionary that maps the link name to the link, for easy retrieval. - links_dict: dict[str, LinkDescription] = { - l.name: l.mutable(validate=False) for l in links - } + links_dict: dict[str, LinkDescription] = {l.name: l for l in links} # Create an empty list of frames if not provided. frames = frames if frames is not None else [] @@ -306,8 +322,11 @@ def _create_graph( # Assign link's children and make sure they are unique. if child_link.name not in {l.name for l in parent_link.children}: - with parent_link.mutable_context(Mutability.MUTABLE_NO_VALIDATION): - parent_link.children = (*parent_link.children, child_link) + parent_link.children = (*parent_link.children, child_link) + + # Update links_dict with the modified links. + links_dict[child_link.name] = child_link + links_dict[parent_link.name] = parent_link # Collect all the links of the kinematic graph. all_links_in_graph = list( @@ -352,11 +371,14 @@ def _create_graph( unconnected_links = [l for l in links if l.name not in all_link_names_in_graph] # Update the unconnected links by removing their children. The other properties - # are left untouched, it's caller responsibility to post-process them if needed. - for link in unconnected_links: - link.children = tuple() - msg = "Link '{}' won't be part of the kinematic graph because unconnected" - logging.debug(msg=msg.format(link.name)) + # are left untouched; it's the caller's responsibility to post-process them if needed. + updated_unconnected_links = [ + dataclasses.replace(link, children=tuple()) for link in unconnected_links + ] + + for link in updated_unconnected_links: + msg = "Link '{}' won't be part of the kinematic graph because it's unconnected." + logging.debug(msg.format(link.name)) # Collect all the frames that are not part of the kinematic graph. unconnected_frames = [ @@ -368,7 +390,7 @@ def _create_graph( logging.debug(msg=msg.format(frame.name)) return ( - links_dict[root_link_name].mutable(mutable=False), + links_dict[root_link_name], list(set(joints) - set(removed_joints)), all_frames_in_graph, unconnected_links, @@ -508,13 +530,12 @@ def reduce(self, considered_joints: Sequence[str]) -> KinematicGraph: for joint in joints_with_removed_parent_link: # Update the pose. Note that after the lumping process, the dict entry # links_dict[joint.parent.name] contains the final lumped link - with joint.mutable_context(mutability=Mutability.MUTABLE): - joint.pose = fk.relative_transform( - relative_to=links_dict[joint.parent.name].name, name=joint.name - ) - with joint.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): - # Update the parent link - joint.parent = links_dict[joint.parent.name] + + joint.pose = fk.relative_transform( + relative_to=links_dict[joint.parent.name].name, name=joint.name + ) + # Update the parent link + joint.parent = links_dict[joint.parent.name] # =================================================================== # 3. Create the reduced graph considering the removed links as frames @@ -591,21 +612,20 @@ def reduce(self, considered_joints: Sequence[str]) -> KinematicGraph: logging.debug(msg=msg.format(frame.name, name_of_new_parent_link)) # Always recompute the pose of the frame, and set zero inertial params. - with frame.mutable_context(jaxsim.utils.Mutability.MUTABLE_NO_VALIDATION): - # Update kinematic parameters of the frame. - # Note that here we compute the transform using the FK object of the - # full model, so that we are sure that the kinematic is not altered. - frame.pose = fk.relative_transform( - relative_to=name_of_new_parent_link, name=frame.name - ) + # Update kinematic parameters of the frame. + # Note that here we compute the transform using the FK object of the + # full model, so that we are sure that the kinematic is not altered. + frame.pose = fk.relative_transform( + relative_to=name_of_new_parent_link, name=frame.name + ) - # Update the parent link such that the pose is expressed in its frame. - frame.parent_name = name_of_new_parent_link + # Update the parent link such that the pose is expressed in its frame. + frame.parent_name = name_of_new_parent_link - # Update dynamic parameters of the frame. - frame.mass = 0.0 - frame.inertia = np.zeros_like(frame.inertia) + # Update dynamic parameters of the frame. + frame.mass = 0.0 + frame.inertia = np.zeros_like(frame.inertia) # Return the reduced graph. return reduced_graph diff --git a/src/jaxsim/parsers/rod/parser.py b/src/jaxsim/parsers/rod/parser.py index 531bcd73f..7dd6b43d7 100644 --- a/src/jaxsim/parsers/rod/parser.py +++ b/src/jaxsim/parsers/rod/parser.py @@ -98,8 +98,8 @@ def extract_model_data( else: W_H_M = sdf_model.pose.transform() model_pose = kinematic_graph.RootPose( - root_position=W_H_M[0:3, 3], - root_quaternion=Quaternion.from_dcm(dcm=W_H_M[0:3, 0:3]), + _root_position=tuple(W_H_M[0:3, 3].tolist()), + _root_quaternion=tuple(Quaternion.from_dcm(dcm=W_H_M[0:3, 0:3]).tolist()), ) # =========== @@ -111,8 +111,12 @@ def extract_model_data( descriptions.LinkDescription( name=l.name, mass=float(l.inertial.mass), - inertia=utils.from_sdf_inertial(inertial=l.inertial), - pose=l.pose.transform() if l.pose is not None else np.eye(4), + _inertia=tuple(utils.from_sdf_inertial(inertial=l.inertial).tolist()), + _pose=tuple( + l.pose.transform().tolist() + if l.pose is not None + else np.eye(4).tolist() + ), ) for l in sdf_model.links() if l.inertial.mass > 0 @@ -129,10 +133,14 @@ def extract_model_data( frames = [ descriptions.LinkDescription( name=f.name, - mass=jnp.array(0.0, dtype=float), - inertia=jnp.zeros(shape=(3, 3)), parent_name=f.attached_to, - pose=f.pose.transform() if f.pose is not None else jnp.eye(4), + mass=0.0, + _inertia=tuple(np.zeros(shape=(3, 3)).tolist()), + _pose=tuple( + f.pose.transform().tolist() + if f.pose is not None + else np.eye(4).tolist() + ), ) for f in sdf_model.frames() if f.attached_to in links_dict @@ -147,7 +155,7 @@ def extract_model_data( if sdf_model.is_fixed_base(): # Create a massless word link world_link = descriptions.LinkDescription( - name="world", mass=0, inertia=np.zeros(shape=(6, 6)) + name="world", mass=0, _inertia=tuple(np.zeros(shape=(6, 6)).tolist()) ) # Gather joints connecting fixed-base models to the world. @@ -160,14 +168,18 @@ def extract_model_data( parent=world_link, child=links_dict[j.child], jtype=utils.joint_to_joint_type(joint=j), - axis=( - np.array(j.axis.xyz.xyz) + _axis=( + np.array(j.axis.xyz.xyz).tolist() if j.axis is not None and j.axis.xyz is not None and j.axis.xyz.xyz is not None else None ), - pose=j.pose.transform() if j.pose is not None else np.eye(4), + _pose=tuple( + j.pose.transform().tolist() + if j.pose is not None + else np.eye(4).tolist() + ), ) for j in sdf_model.joints() if j.type == "fixed" @@ -192,7 +204,7 @@ def extract_model_data( # Combine the pose of the base link (child of the found fixed joint) # with the pose of the fixed joint connecting with the world. # Note: we assume it's a fixed joint and ignore any joint angle. - links_dict[base_link_name].mutable(validate=False).pose = ( + links_dict[base_link_name].pose = ( joints_with_world_parent[0].pose @ links_dict[base_link_name].pose ) @@ -222,15 +234,19 @@ def extract_model_data( parent=links_dict[j.parent], child=links_dict[j.child], jtype=utils.joint_to_joint_type(joint=j), - axis=( - np.array(j.axis.xyz.xyz, dtype=float) + _axis=tuple( + np.array(j.axis.xyz.xyz, dtype=float).tolist() if j.axis is not None and j.axis.xyz is not None and j.axis.xyz.xyz is not None else None ), - pose=j.pose.transform() if j.pose is not None else np.eye(4), - initial_position=0.0, + _pose=tuple( + j.pose.transform().tolist() + if j.pose is not None + else np.eye(4).tolist() + ), + _initial_position=0.0, position_limit=( float( j.axis.limit.lower diff --git a/src/jaxsim/parsers/rod/utils.py b/src/jaxsim/parsers/rod/utils.py index a295b7fab..19f79433c 100644 --- a/src/jaxsim/parsers/rod/utils.py +++ b/src/jaxsim/parsers/rod/utils.py @@ -144,7 +144,7 @@ def create_box_collision( collidable_points = [ descriptions.CollidablePoint( parent_link=link_description, - position=np.array(corner), + _position=tuple(np.array(corner).tolist()), enabled=True, ) for corner in box_corners_wrt_link.T @@ -214,7 +214,7 @@ def fibonacci_sphere(samples: int) -> npt.NDArray: collidable_points = [ descriptions.CollidablePoint( parent_link=link_description, - position=np.array(point), + _position=tuple(np.array(point).tolist()), enabled=True, ) for point in sphere_points_wrt_link.T @@ -271,7 +271,7 @@ def create_mesh_collision( collidable_points = [ descriptions.CollidablePoint( parent_link=link_description, - position=point, + _position=tuple(point).tolist(), enabled=True, ) for point in mesh_points_wrt_link diff --git a/src/jaxsim/rbda/contacts/relaxed_rigid.py b/src/jaxsim/rbda/contacts/relaxed_rigid.py index 5e20ada54..76a70d9bd 100644 --- a/src/jaxsim/rbda/contacts/relaxed_rigid.py +++ b/src/jaxsim/rbda/contacts/relaxed_rigid.py @@ -29,75 +29,34 @@ class RelaxedRigidContactsParams(common.ContactsParams): """Parameters of the relaxed rigid contacts model.""" # Time constant - time_constant: jtp.Float = dataclasses.field( - default_factory=lambda: jnp.array(0.01, dtype=float) - ) + time_constant: jtp.Float = dataclasses.field(default=0.01) # Adimensional damping coefficient - damping_coefficient: jtp.Float = dataclasses.field( - default_factory=lambda: jnp.array(1.0, dtype=float) - ) + damping_coefficient: jtp.Float = dataclasses.field(default=1.0) # Minimum impedance - d_min: jtp.Float = dataclasses.field( - default_factory=lambda: jnp.array(0.9, dtype=float) - ) + d_min: jtp.Float = dataclasses.field(default=0.9) # Maximum impedance - d_max: jtp.Float = dataclasses.field( - default_factory=lambda: jnp.array(0.95, dtype=float) - ) + d_max: jtp.Float = dataclasses.field(default=0.95) # Width - width: jtp.Float = dataclasses.field( - default_factory=lambda: jnp.array(0.0001, dtype=float) - ) + width: jtp.Float = dataclasses.field(default=0.0001) # Midpoint - midpoint: jtp.Float = dataclasses.field( - default_factory=lambda: jnp.array(0.1, dtype=float) - ) + midpoint: jtp.Float = dataclasses.field(default=0.1) # Power exponent - power: jtp.Float = dataclasses.field( - default_factory=lambda: jnp.array(1.0, dtype=float) - ) + power: jtp.Float = dataclasses.field(default=1.0) # Stiffness - stiffness: jtp.Float = dataclasses.field( - default_factory=lambda: jnp.array(0.0, dtype=float) - ) + stiffness: jtp.Float = dataclasses.field(default=0.0) # Damping - damping: jtp.Float = dataclasses.field( - default_factory=lambda: jnp.array(0.0, dtype=float) - ) + damping: jtp.Float = dataclasses.field(default=0.0) # Friction coefficient - mu: jtp.Float = dataclasses.field( - default_factory=lambda: jnp.array(0.5, dtype=float) - ) - - def __hash__(self) -> int: - from jaxsim.utils.wrappers import HashedNumpyArray - - return hash( - ( - HashedNumpyArray(self.time_constant), - HashedNumpyArray(self.damping_coefficient), - HashedNumpyArray(self.d_min), - HashedNumpyArray(self.d_max), - HashedNumpyArray(self.width), - HashedNumpyArray(self.midpoint), - HashedNumpyArray(self.power), - HashedNumpyArray(self.stiffness), - HashedNumpyArray(self.damping), - HashedNumpyArray(self.mu), - ) - ) - - def __eq__(self, other: RelaxedRigidContactsParams) -> bool: - return hash(self) == hash(other) + mu: jtp.Float = dataclasses.field(default=0.5) @classmethod def build( @@ -117,48 +76,25 @@ def build( """Create a `RelaxedRigidContactsParams` instance.""" def default(name: str): - return cls.__dataclass_fields__[name].default_factory() + return cls.__dataclass_fields__[name].default return cls( - time_constant=jnp.array( - ( - time_constant - if time_constant is not None - else default("time_constant") - ), - dtype=float, - ), - damping_coefficient=jnp.array( - ( - damping_coefficient - if damping_coefficient is not None - else default("damping_coefficient") - ), - dtype=float, - ), - d_min=jnp.array( - d_min if d_min is not None else default("d_min"), dtype=float - ), - d_max=jnp.array( - d_max if d_max is not None else default("d_max"), dtype=float - ), - width=jnp.array( - width if width is not None else default("width"), dtype=float - ), - midpoint=jnp.array( - midpoint if midpoint is not None else default("midpoint"), dtype=float - ), - power=jnp.array( - power if power is not None else default("power"), dtype=float - ), - stiffness=jnp.array( - stiffness if stiffness is not None else default("stiffness"), - dtype=float, + time_constant=( + time_constant if time_constant is not None else default("time_constant") ), - damping=jnp.array( - damping if damping is not None else default("damping"), dtype=float + damping_coefficient=( + damping_coefficient + if damping_coefficient is not None + else default("damping_coefficient") ), - mu=jnp.array(mu if mu is not None else default("mu"), dtype=float), + d_min=d_min if d_min is not None else default("d_min"), + d_max=d_max if d_max is not None else default("d_max"), + width=width if width is not None else default("width"), + midpoint=midpoint if midpoint is not None else default("midpoint"), + power=power if power is not None else default("power"), + stiffness=stiffness if stiffness is not None else default("stiffness"), + damping=damping if damping is not None else default("damping"), + mu=mu if mu is not None else default("mu"), ) def valid(self) -> jtp.BoolLike: diff --git a/src/jaxsim/terrain/terrain.py b/src/jaxsim/terrain/terrain.py index f5b364dec..025c7be93 100644 --- a/src/jaxsim/terrain/terrain.py +++ b/src/jaxsim/terrain/terrain.py @@ -5,7 +5,6 @@ import jax.numpy as jnp import jax_dataclasses -import numpy as np import jaxsim.math import jaxsim.typing as jtp @@ -112,17 +111,6 @@ def normal(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Vector: return jnp.array([0.0, 0.0, 1.0], dtype=float) - def __hash__(self) -> int: - - return hash(self._height) - - def __eq__(self, other: FlatTerrain) -> bool: - - if not isinstance(other, FlatTerrain): - return False - - return self._height == other._height - @jax_dataclasses.pytree_dataclass class PlaneTerrain(FlatTerrain): @@ -207,32 +195,3 @@ def height(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Float: # Invert the plane equation to get the height at the given (x, y) coordinates. return jnp.array(-(A * x + B * y + D) / C).astype(float) - - def __hash__(self) -> int: - - from jaxsim.utils.wrappers import HashedNumpyArray - - return hash( - ( - hash(self._height), - HashedNumpyArray.hash_of_array( - array=np.array(self._normal, dtype=float) - ), - ) - ) - - def __eq__(self, other: PlaneTerrain) -> bool: - - if not isinstance(other, PlaneTerrain): - return False - - if not ( - np.allclose(self._height, other._height) - and np.allclose( - np.array(self._normal, dtype=float), - np.array(other._normal, dtype=float), - ) - ): - return False - - return True diff --git a/src/jaxsim/utils/__init__.py b/src/jaxsim/utils/__init__.py index d0b881ceb..9e689bc29 100644 --- a/src/jaxsim/utils/__init__.py +++ b/src/jaxsim/utils/__init__.py @@ -2,4 +2,3 @@ from .jaxsim_dataclass import JaxsimDataclass from .tracing import not_tracing, tracing -from .wrappers import HashedNumpyArray, HashlessObject diff --git a/src/jaxsim/utils/wrappers.py b/src/jaxsim/utils/wrappers.py deleted file mode 100644 index bfb29701f..000000000 --- a/src/jaxsim/utils/wrappers.py +++ /dev/null @@ -1,159 +0,0 @@ -from __future__ import annotations - -import dataclasses -from collections.abc import Callable -from typing import Generic, TypeVar - -import jax -import jax_dataclasses -import numpy as np -import numpy.typing as npt - -T = TypeVar("T") - - -@dataclasses.dataclass -class HashlessObject(Generic[T]): - """ - A class that wraps an object and makes it hashless. - - This is useful for creating particular JAX pytrees. - For example, to create a pytree with a static leaf that is ignored - by JAX when it compares two instances to trigger a JIT recompilation. - """ - - obj: T - - def get(self: HashlessObject[T]) -> T: - """ - Get the wrapped object. - """ - return self.obj - - def __hash__(self) -> int: - - return 0 - - def __eq__(self, other: HashlessObject[T]) -> bool: - - if not isinstance(other, HashlessObject) and isinstance( - other.get(), type(self.get()) - ): - return False - - return hash(self) == hash(other) - - -@dataclasses.dataclass -class CustomHashedObject(Generic[T]): - """ - A class that wraps an object and computes its hash with a custom hash function. - """ - - obj: T - - hash_function: Callable[[T], int] = hash - - def get(self: CustomHashedObject[T]) -> T: - """ - Get the wrapped object. - """ - return self.obj - - def __hash__(self) -> int: - - return self.hash_function(self.obj) - - def __eq__(self, other: CustomHashedObject[T]) -> bool: - - if not isinstance(other, CustomHashedObject) and isinstance( - other.get(), type(self.get()) - ): - return False - - return hash(self) == hash(other) - - -@jax_dataclasses.pytree_dataclass -class HashedNumpyArray: - """ - A class that wraps a numpy array and makes it hashable. - - This is useful for creating particular JAX pytrees. - For example, to create a pytree with a plain NumPy or JAX NumPy array as static leaf. - - Note: - Calculating with the wrapper class the hash of a very large array can be - very expensive. If the array is large and only the equality operator is needed, - set `large_array=True` to use a faster comparison method. - """ - - array: jax.Array | npt.NDArray - - precision: float | None = dataclasses.field( - default=1e-9, repr=False, compare=False, hash=False - ) - - large_array: jax_dataclasses.Static[bool] = dataclasses.field( - default=False, repr=False, compare=False, hash=False - ) - - def get(self) -> jax.Array | npt.NDArray: - """ - Get the wrapped array. - """ - return self.array - - def __hash__(self) -> int: - - return HashedNumpyArray.hash_of_array( - array=self.array, precision=self.precision - ) - - def __eq__(self, other: HashedNumpyArray) -> bool: - - if not isinstance(other, HashedNumpyArray): - return False - - if self.large_array: - return np.allclose( - self.array, - other.array, - **(dict(atol=self.precision) if self.precision is not None else {}), - ) - - return hash(self) == hash(other) - - @staticmethod - def hash_of_array( - array: jax.Array | npt.NDArray, precision: float | None = 1e-9 - ) -> int: - """ - Calculate the hash of a NumPy array. - - Args: - array: The array to hash. - precision: Optionally limit the precision over which the hash is computed. - - Returns: - The hash of the array. - """ - - array = np.array(array).flatten() - - array = np.where(array == np.nan, hash(np.nan), array) - array = np.where(array == np.inf, hash(np.inf), array) - array = np.where(array == -np.inf, hash(-np.inf), array) - - if precision is not None: - - integer1 = (array * precision).astype(int) - integer2 = (array - integer1 / precision).astype(int) - - decimal_array = ((array - integer1 * 1e9 - integer2) / precision).astype( - int - ) - - array = np.hstack([integer1, integer2, decimal_array]).astype(int) - - return hash(tuple(array.tolist()))