Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove object references cycle in LinkDescription #374

Merged
merged 2 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/jaxsim/api/kin_dyn_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,9 @@ def build(model_description: ModelDescription) -> KinDynParameters:
# Build the parent array λ(i) of the model.
# Note: the parent of the base link is not set since it's not defined.
parent_array_dict = {
link.index: link.parent.index
link.index: model_description.links_dict[link.parent_name].index
for link in ordered_links
if link.parent is not None
if link.parent_name is not None
}
parent_array = jnp.array([-1, *list(parent_array_dict.values())], dtype=int)

Expand Down Expand Up @@ -862,7 +862,7 @@ def build_from(model_description: ModelDescription) -> FrameParameters:

# For each frame, extract the index of the link to which it is attached to.
parent_link_index_of_frames = tuple(
model_description.links_dict[frame.parent.name].index
model_description.links_dict[frame.parent_name].index
for frame in model_description.frames
)

Expand Down
11 changes: 3 additions & 8 deletions src/jaxsim/parsers/descriptions/link.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class LinkDescription(JaxsimDataclass):
mass: float = dataclasses.field(repr=False)
inertia: jtp.Matrix = dataclasses.field(repr=False)
index: int | None = None
parent: LinkDescription | None = dataclasses.field(default=None, repr=False)
parent_name: Static[str | None] = dataclasses.field(default=None, repr=False)
pose: jtp.Matrix = dataclasses.field(default_factory=lambda: jnp.eye(4), repr=False)

children: Static[tuple[LinkDescription]] = dataclasses.field(
Expand All @@ -50,8 +50,7 @@ def __hash__(self) -> int:
hash(int(self.index)) if self.index is not None else 0,
HashedNumpyArray.hash_of_array(self.pose),
hash(tuple(self.children)),
# Here only using the name to prevent circular recursion:
hash(self.parent.name) if self.parent is not None else 0,
hash(self.parent_name) if self.parent_name is not None else 0,
)
)

Expand All @@ -67,11 +66,7 @@ def __eq__(self, other: LinkDescription) -> bool:
and self.index == other.index
and np.allclose(self.pose, other.pose)
and self.children == other.children
and (
(self.parent is not None and self.parent.name == other.parent.name)
if self.parent is not None
else other.parent is None
),
and self.parent_name == other.parent_name
):
return False

Expand Down
8 changes: 4 additions & 4 deletions src/jaxsim/parsers/descriptions/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,16 +119,16 @@ def build_model_from(
for cp in collision_shape.collidable_points:
# Find the link that is part of the (reduced) model in which the
# collision shape's parent was lumped into
real_parent_link_of_shape = kinematic_graph.frames_dict[
real_parent_link_name = kinematic_graph.frames_dict[
parent_link_of_shape.name
].parent
].parent_name

# Change the link associated to the collidable point, updating their
# relative pose
moved_cp = cp.change_link(
new_link=real_parent_link_of_shape,
new_link=kinematic_graph.links_dict[real_parent_link_name],
new_H_old=fk.relative_transform(
relative_to=real_parent_link_of_shape.name,
relative_to=real_parent_link_name,
name=cp.parent_link.name,
),
)
Expand Down
37 changes: 17 additions & 20 deletions src/jaxsim/parsers/kinematic_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,10 +274,10 @@ def _create_graph(

# Check that our parser correctly resolved the frame's parent to be a link.
for frame in frames:
assert frame.parent.name != "", frame
assert frame.parent.name is not None, frame
assert frame.parent.name != "__model__", frame
assert frame.parent.name not in frames_dict, frame
assert frame.parent_name != "", frame
assert frame.parent_name is not None, frame
assert frame.parent_name != "__model__", frame
assert frame.parent_name not in frames_dict, frame

# ===========================================================
# Populate the kinematic graph with links, joints, and frames
Expand All @@ -302,7 +302,7 @@ def _create_graph(
assert parent_link.name == joint.parent.name

# Assign link's parent.
child_link.parent = parent_link
child_link.parent_name = parent_link.name

# Assign link's children and make sure they are unique.
if child_link.name not in {l.name for l in parent_link.children}:
Expand Down Expand Up @@ -331,7 +331,7 @@ def _create_graph(
# Collect all the frames of the kinematic graph.
# Note: our parser ensures that the parent of a frame is not another frame.
all_frames_in_graph = [
frame for frame in frames if frame.parent.name in all_link_names_in_graph
frame for frame in frames if frame.parent_name in all_link_names_in_graph
]

# Get the names of all frames in the kinematic graph.
Expand Down Expand Up @@ -450,7 +450,7 @@ def reduce(self, considered_joints: Sequence[str]) -> KinematicGraph:

# Get the link to remove and its parent, i.e. the lumped link
link_to_remove = links_dict[link.name]
parent_of_link_to_remove = links_dict[link.parent.name]
parent_of_link_to_remove = links_dict[link.parent_name]

msg = "Lumping chain: {}->({})->{}"
logging.info(
Expand Down Expand Up @@ -586,7 +586,7 @@ def reduce(self, considered_joints: Sequence[str]) -> KinematicGraph:
assert name_of_new_parent_link in reduced_graph, name_of_new_parent_link

# Notify the user if the parent link has changed.
if name_of_new_parent_link != frame.parent.name:
if name_of_new_parent_link != frame.parent_name:
msg = "New parent of frame '{}' is '{}'"
logging.debug(msg=msg.format(frame.name, name_of_new_parent_link))

Expand All @@ -601,7 +601,7 @@ def reduce(self, considered_joints: Sequence[str]) -> KinematicGraph:
)

# Update the parent link such that the pose is expressed in its frame.
frame.parent = reduced_graph.links_dict[name_of_new_parent_link]
frame.parent_name = name_of_new_parent_link

# Update dynamic parameters of the frame.
frame.mass = 0.0
Expand Down Expand Up @@ -880,7 +880,7 @@ def transform(self, name: str) -> npt.NDArray:

# Get the joint between the link and its parent.
parent_joint = self.graph.joints_connection_dict[
link.parent.name, link.name
link.parent_name, link.name
]

# Get the transform of the parent joint.
Expand All @@ -901,7 +901,7 @@ def transform(self, name: str) -> npt.NDArray:
frame = self.graph.frames_dict[name]

# Get the transform of the parent link.
M_H_L = self.transform(name=frame.parent.name)
M_H_L = self.transform(name=frame.parent_name)

# Rename the pose of the frame w.r.t. its parent link.
L_H_F = frame.pose
Expand Down Expand Up @@ -971,13 +971,10 @@ def find_parent_link_of_frame(self, name: str) -> str:
except KeyError as e:
raise ValueError(f"Frame '{name}' not found in the kinematic graph") from e

match frame.parent.name:
case parent_name if parent_name in self.graph.links_dict:
return parent_name
if frame.parent_name in self.graph.links_dict:
return frame.parent_name
elif frame.parent_name in self.graph.frames_dict:
return self.find_parent_link_of_frame(name=frame.parent_name)

case parent_name if parent_name in self.graph.frames_dict:
return self.find_parent_link_of_frame(name=parent_name)

case _:
msg = f"Failed to find parent element of frame '{name}' with name '{frame.parent.name}'"
raise RuntimeError(msg)
msg = f"Failed to find parent element of frame '{name}' with name '{frame.parent_name}'"
raise RuntimeError(msg)
2 changes: 1 addition & 1 deletion src/jaxsim/parsers/rod/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def extract_model_data(
name=f.name,
mass=jnp.array(0.0, dtype=float),
inertia=jnp.zeros(shape=(3, 3)),
parent=links_dict[f.attached_to],
parent_name=f.attached_to,
pose=f.pose.transform() if f.pose is not None else jnp.eye(4),
)
for f in sdf_model.frames()
Expand Down
Loading