From 59a2a7b682d7b149607d2f25fbe4591df8dffb92 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Sat, 1 Mar 2025 14:53:25 +0100 Subject: [PATCH] Lint code and remove duplications --- src/jaxsim/api/com.py | 4 +--- src/jaxsim/api/contact.py | 8 ++------ src/jaxsim/api/data.py | 22 +++++++++++--------- src/jaxsim/api/ode.py | 6 +++--- src/jaxsim/parsers/kinematic_graph.py | 2 +- src/jaxsim/rbda/__init__.py | 2 +- src/jaxsim/rbda/contacts/soft.py | 2 +- src/jaxsim/rbda/forward_kinematics.py | 29 --------------------------- 8 files changed, 21 insertions(+), 54 deletions(-) diff --git a/src/jaxsim/api/com.py b/src/jaxsim/api/com.py index c952fa0ac..ee85d078b 100644 --- a/src/jaxsim/api/com.py +++ b/src/jaxsim/api/com.py @@ -301,9 +301,7 @@ def other_representation_to_body( C_v̇_WL = W_v̇_bias_WL = v̇_bias_WL # noqa: F841 C_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841 - L_H_C = L_H_W = jax.vmap( # noqa: F841 - lambda W_H_L: jaxsim.math.Transform.inverse(W_H_L) - )(W_H_L) + L_H_C = L_H_W = jax.vmap(jaxsim.math.Transform.inverse)(W_H_L) # noqa: F841 L_v_LC = L_v_LW = jax.vmap( # noqa: F841 lambda i: -js.link.velocity( diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 9ee210412..5873d7e4c 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -267,7 +267,7 @@ def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jt # Build the link-to-point transform from the displacement between the link frame L # and the implicit contact frame C. - L_H_C = jax.vmap(lambda L_p_C: jnp.eye(4).at[0:3, 3].set(L_p_C))(L_p_Ci) + L_H_C = jax.vmap(jnp.eye(4).at[0:3, 3].set)(L_p_Ci) # Compose the work-to-link and link-to-point transforms. return jax.vmap(lambda W_H_Li, L_H_Ci: W_H_Li @ L_H_Ci)(W_H_L, L_H_C) @@ -567,9 +567,7 @@ def link_contact_forces( # Compute the 6D forces applied to the links equivalent to the forces applied # to the frames associated to the collidable points. - W_f_L = link_forces_from_contact_forces( - model=model, data=data, contact_forces=W_f_C - ) + W_f_L = link_forces_from_contact_forces(model=model, contact_forces=W_f_C) return W_f_L, aux_dict @@ -577,7 +575,6 @@ def link_contact_forces( @staticmethod def link_forces_from_contact_forces( model: js.model.JaxSimModel, - data: js.data.JaxSimModelData, *, contact_forces: jtp.MatrixLike, ) -> jtp.Matrix: @@ -586,7 +583,6 @@ def link_forces_from_contact_forces( Args: model: The robot model considered by the contact model. - data: The data of the considered model. contact_forces: The contact forces computed by the contact model. Returns: diff --git a/src/jaxsim/api/data.py b/src/jaxsim/api/data.py index d477a8705..d7acecef6 100644 --- a/src/jaxsim/api/data.py +++ b/src/jaxsim/api/data.py @@ -5,9 +5,9 @@ from collections.abc import Sequence try: - from typing import override + from typing import Self, override except ImportError: - from typing_extensions import override + from typing_extensions import override, Self import jax import jax.numpy as jnp @@ -22,11 +22,6 @@ from . import common from .common import VelRepr -try: - from typing import Self -except ImportError: - from typing_extensions import Self - @jax_dataclasses.pytree_dataclass class JaxSimModelData(common.ModelDataWithVelocityRepresentation): @@ -364,11 +359,14 @@ def base_transform(self) -> jtp.Matrix: @js.common.named_scope @jax.jit - def reset_base_quaternion(self, base_quaternion: jtp.VectorLike) -> Self: + def reset_base_quaternion( + self, model: js.model.JaxSimModel, base_quaternion: jtp.VectorLike + ) -> Self: """ Reset the base quaternion. Args: + model: The JaxSim model to use. base_quaternion: The base orientation as a quaternion. Returns: @@ -380,15 +378,18 @@ def reset_base_quaternion(self, base_quaternion: jtp.VectorLike) -> Self: norm = jaxsim.math.safe_norm(W_Q_B) W_Q_B = W_Q_B / (norm + jnp.finfo(float).eps * (norm == 0)) - return self.replace(validate=True, base_quaternion=W_Q_B) + return self.replace(model=model, base_quaternion=W_Q_B) @js.common.named_scope @jax.jit - def reset_base_pose(self, base_pose: jtp.MatrixLike) -> Self: + def reset_base_pose( + self, model: js.model.JaxSimModel, base_pose: jtp.MatrixLike + ) -> Self: """ Reset the base pose. Args: + model: The JaxSim model to use. base_pose: The base pose as an SE(3) matrix. Returns: @@ -399,6 +400,7 @@ def reset_base_pose(self, base_pose: jtp.MatrixLike) -> Self: W_p_B = base_pose[0:3, 3] W_Q_B = jaxsim.math.Quaternion.from_dcm(dcm=base_pose[0:3, 0:3]) return self.replace( + model=model, base_position=W_p_B, base_quaternion=W_Q_B, ) diff --git a/src/jaxsim/api/ode.py b/src/jaxsim/api/ode.py index dc903b8c1..8c8a949b6 100644 --- a/src/jaxsim/api/ode.py +++ b/src/jaxsim/api/ode.py @@ -50,6 +50,9 @@ def system_acceleration( # Compute contact forces # ====================== + W_f_L_terrain = jnp.zeros_like(f_L) + contact_state_derivative = {} + if len(model.kin_dyn_parameters.contact_parameters.body) > 0: # Compute the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact @@ -95,7 +98,6 @@ def system_acceleration( @jax.jit @js.common.named_scope def system_position_dynamics( - model: js.model.JaxSimModel, data: js.data.JaxSimModelData, baumgarte_quaternion_regularization: jtp.FloatLike = 1.0, ) -> tuple[jtp.Vector, jtp.Vector, jtp.Vector]: @@ -103,7 +105,6 @@ def system_position_dynamics( Compute the dynamics of the system position. Args: - model: The model to consider. data: The data of the considered model. baumgarte_quaternion_regularization: The Baumgarte regularization coefficient for adjusting the quaternion norm. @@ -173,7 +174,6 @@ def system_dynamics( ) W_ṗ_B, W_Q̇_B, ṡ = system_position_dynamics( - model=model, data=data, baumgarte_quaternion_regularization=baumgarte_quaternion_regularization, ) diff --git a/src/jaxsim/parsers/kinematic_graph.py b/src/jaxsim/parsers/kinematic_graph.py index 9c16136d3..b8b935a36 100644 --- a/src/jaxsim/parsers/kinematic_graph.py +++ b/src/jaxsim/parsers/kinematic_graph.py @@ -973,7 +973,7 @@ def find_parent_link_of_frame(self, name: str) -> str: if frame.parent_name in self.graph.links_dict: return frame.parent_name - elif frame.parent_name in self.graph.frames_dict: + if frame.parent_name in self.graph.frames_dict: return self.find_parent_link_of_frame(name=frame.parent_name) msg = f"Failed to find parent element of frame '{name}' with name '{frame.parent_name}'" diff --git a/src/jaxsim/rbda/__init__.py b/src/jaxsim/rbda/__init__.py index 25bafc1ae..eb89bd208 100644 --- a/src/jaxsim/rbda/__init__.py +++ b/src/jaxsim/rbda/__init__.py @@ -2,7 +2,7 @@ from .aba import aba from .collidable_points import collidable_points_pos_vel from .crba import crba -from .forward_kinematics import forward_kinematics, forward_kinematics_model +from .forward_kinematics import forward_kinematics_model from .jacobian import ( jacobian, jacobian_derivative_full_doubly_left, diff --git a/src/jaxsim/rbda/contacts/soft.py b/src/jaxsim/rbda/contacts/soft.py index 24f26be1c..53654cad4 100644 --- a/src/jaxsim/rbda/contacts/soft.py +++ b/src/jaxsim/rbda/contacts/soft.py @@ -414,4 +414,4 @@ def compute_contact_forces( ṁ = ṁ.at[indices_of_enabled_collidable_points].set(ṁ_enabled) - return W_f, dict(m_dot=ṁ) + return W_f, {"m_dot": ṁ} diff --git a/src/jaxsim/rbda/forward_kinematics.py b/src/jaxsim/rbda/forward_kinematics.py index 355fe347c..58c230c7d 100644 --- a/src/jaxsim/rbda/forward_kinematics.py +++ b/src/jaxsim/rbda/forward_kinematics.py @@ -111,32 +111,3 @@ def propagate_kinematics( ) return jax.vmap(Adjoint.to_transform)(W_X_i), W_v_Wi - - -def forward_kinematics( - model: js.model.JaxSimModel, - link_index: jtp.Int, - base_position: jtp.VectorLike, - base_quaternion: jtp.VectorLike, - joint_positions: jtp.VectorLike, -) -> jtp.Matrix: - """ - Compute the forward kinematics of a specific link. - - Args: - model: The model to consider. - link_index: The index of the link to consider. - base_position: The position of the base link. - base_quaternion: The quaternion of the base link. - joint_positions: The positions of the joints. - - Returns: - The SE(3) transform of the link. - """ - - return forward_kinematics_model( - model=model, - base_position=base_position, - base_quaternion=base_quaternion, - joint_positions=joint_positions, - )[link_index]