Skip to content

Commit

Permalink
Lint code and remove duplications
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Mar 1, 2025
1 parent 694428a commit f3af1d1
Show file tree
Hide file tree
Showing 8 changed files with 21 additions and 52 deletions.
2 changes: 1 addition & 1 deletion src/jaxsim/api/com.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def other_representation_to_body(
C_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841

L_H_C = L_H_W = jax.vmap( # noqa: F841
lambda W_H_L: jaxsim.math.Transform.inverse(W_H_L)
jaxsim.math.Transform.inverse(W_H_L)
)(W_H_L)

L_v_LC = L_v_LW = jax.vmap( # noqa: F841
Expand Down
8 changes: 2 additions & 6 deletions src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jt

# Build the link-to-point transform from the displacement between the link frame L
# and the implicit contact frame C.
L_H_C = jax.vmap(lambda L_p_C: jnp.eye(4).at[0:3, 3].set(L_p_C))(L_p_Ci)
L_H_C = jax.vmap(jnp.eye(4).at[0:3, 3].set(L_p_Ci))(L_p_Ci)

# Compose the work-to-link and link-to-point transforms.
return jax.vmap(lambda W_H_Li, L_H_Ci: W_H_Li @ L_H_Ci)(W_H_L, L_H_C)
Expand Down Expand Up @@ -567,17 +567,14 @@ def link_contact_forces(

# Compute the 6D forces applied to the links equivalent to the forces applied
# to the frames associated to the collidable points.
W_f_L = link_forces_from_contact_forces(
model=model, data=data, contact_forces=W_f_C
)
W_f_L = link_forces_from_contact_forces(model=model, contact_forces=W_f_C)

return W_f_L, aux_dict


@staticmethod
def link_forces_from_contact_forces(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
contact_forces: jtp.MatrixLike,
) -> jtp.Matrix:
Expand All @@ -586,7 +583,6 @@ def link_forces_from_contact_forces(
Args:
model: The robot model considered by the contact model.
data: The data of the considered model.
contact_forces: The contact forces computed by the contact model.
Returns:
Expand Down
22 changes: 12 additions & 10 deletions src/jaxsim/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from collections.abc import Sequence

try:
from typing import override
from typing import Self, override
except ImportError:
from typing_extensions import override
from typing_extensions import override, Self

import jax
import jax.numpy as jnp
Expand All @@ -22,11 +22,6 @@
from . import common
from .common import VelRepr

try:
from typing import Self
except ImportError:
from typing_extensions import Self


@jax_dataclasses.pytree_dataclass
class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
Expand Down Expand Up @@ -364,11 +359,14 @@ def base_transform(self) -> jtp.Matrix:

@js.common.named_scope
@jax.jit
def reset_base_quaternion(self, base_quaternion: jtp.VectorLike) -> Self:
def reset_base_quaternion(
self, model: js.model.JaxSimModel, base_quaternion: jtp.VectorLike
) -> Self:
"""
Reset the base quaternion.
Args:
model: The JaxSim model to use.
base_quaternion: The base orientation as a quaternion.
Returns:
Expand All @@ -380,15 +378,18 @@ def reset_base_quaternion(self, base_quaternion: jtp.VectorLike) -> Self:
norm = jaxsim.math.safe_norm(W_Q_B)
W_Q_B = W_Q_B / (norm + jnp.finfo(float).eps * (norm == 0))

return self.replace(validate=True, base_quaternion=W_Q_B)
return self.replace(model=model, base_quaternion=W_Q_B)

@js.common.named_scope
@jax.jit
def reset_base_pose(self, base_pose: jtp.MatrixLike) -> Self:
def reset_base_pose(
self, model: js.model.JaxSimModel, base_pose: jtp.MatrixLike
) -> Self:
"""
Reset the base pose.
Args:
model: The JaxSim model to use.
base_pose: The base pose as an SE(3) matrix.
Returns:
Expand All @@ -399,6 +400,7 @@ def reset_base_pose(self, base_pose: jtp.MatrixLike) -> Self:
W_p_B = base_pose[0:3, 3]
W_Q_B = jaxsim.math.Quaternion.from_dcm(dcm=base_pose[0:3, 0:3])
return self.replace(
model=model,
base_position=W_p_B,
base_quaternion=W_Q_B,
)
Expand Down
6 changes: 3 additions & 3 deletions src/jaxsim/api/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ def system_acceleration(
# Compute contact forces
# ======================

W_f_L_terrain = jnp.zeros_like(f_L)
contact_state_derivative = {}

if len(model.kin_dyn_parameters.contact_parameters.body) > 0:

# Compute the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact
Expand Down Expand Up @@ -95,15 +98,13 @@ def system_acceleration(
@jax.jit
@js.common.named_scope
def system_position_dynamics(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
baumgarte_quaternion_regularization: jtp.FloatLike = 1.0,
) -> tuple[jtp.Vector, jtp.Vector, jtp.Vector]:
r"""
Compute the dynamics of the system position.
Args:
model: The model to consider.
data: The data of the considered model.
baumgarte_quaternion_regularization:
The Baumgarte regularization coefficient for adjusting the quaternion norm.
Expand Down Expand Up @@ -173,7 +174,6 @@ def system_dynamics(
)

W_ṗ_B, W_Q̇_B, = system_position_dynamics(
model=model,
data=data,
baumgarte_quaternion_regularization=baumgarte_quaternion_regularization,
)
Expand Down
2 changes: 1 addition & 1 deletion src/jaxsim/parsers/kinematic_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,7 +973,7 @@ def find_parent_link_of_frame(self, name: str) -> str:

if frame.parent_name in self.graph.links_dict:
return frame.parent_name
elif frame.parent_name in self.graph.frames_dict:
if frame.parent_name in self.graph.frames_dict:
return self.find_parent_link_of_frame(name=frame.parent_name)

msg = f"Failed to find parent element of frame '{name}' with name '{frame.parent_name}'"
Expand Down
2 changes: 1 addition & 1 deletion src/jaxsim/parsers/rod/meshes.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def extract_points_select_points_over_axis(
arr = mesh.vertices

# Sort rows lexicographically first, then columnar.
arr.sort(axis=0)
arr.sort(axis=VALID_AXIS[axis])
sorted_arr = arr[dirs[direction]]
return sorted_arr

Expand Down
2 changes: 1 addition & 1 deletion src/jaxsim/rbda/contacts/soft.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,4 +414,4 @@ def compute_contact_forces(

= .at[indices_of_enabled_collidable_points].set(ṁ_enabled)

return W_f, dict(m_dot=)
return W_f, {"m_dot": }
29 changes: 0 additions & 29 deletions src/jaxsim/rbda/forward_kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,32 +111,3 @@ def propagate_kinematics(
)

return jax.vmap(Adjoint.to_transform)(W_X_i), W_v_Wi


def forward_kinematics(
model: js.model.JaxSimModel,
link_index: jtp.Int,
base_position: jtp.VectorLike,
base_quaternion: jtp.VectorLike,
joint_positions: jtp.VectorLike,
) -> jtp.Matrix:
"""
Compute the forward kinematics of a specific link.
Args:
model: The model to consider.
link_index: The index of the link to consider.
base_position: The position of the base link.
base_quaternion: The quaternion of the base link.
joint_positions: The positions of the joints.
Returns:
The SE(3) transform of the link.
"""

return forward_kinematics_model(
model=model,
base_position=base_position,
base_quaternion=base_quaternion,
joint_positions=joint_positions,
)[link_index]

0 comments on commit f3af1d1

Please sign in to comment.