From a8d0c5b3537cf842258b15f2295d0431bfc05104 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Thu, 23 May 2024 12:11:07 +0200 Subject: [PATCH 01/27] Avoid to use `enum` in `VelRepr` --- src/jaxsim/api/common.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/jaxsim/api/common.py b/src/jaxsim/api/common.py index d8ba8dd7b..d8fb1d2ac 100644 --- a/src/jaxsim/api/common.py +++ b/src/jaxsim/api/common.py @@ -1,9 +1,8 @@ import abc import contextlib import dataclasses -import enum import functools -from typing import ContextManager +from typing import ClassVar, ContextManager import jax import jax.numpy as jnp @@ -20,15 +19,15 @@ from typing_extensions import Self -@enum.unique -class VelRepr(enum.IntEnum): +@dataclasses.dataclass(frozen=True) +class VelRepr: """ Enumeration of all supported 6D velocity representations. """ - Body = enum.auto() - Mixed = enum.auto() - Inertial = enum.auto() + Body: ClassVar[int] = 0 + Mixed: ClassVar[int] = 1 + Inertial: ClassVar[int] = 2 @jax_dataclasses.pytree_dataclass From acc70abfdd60f9b2b0d1e5a906f89bfe42488681 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Thu, 23 May 2024 12:11:28 +0200 Subject: [PATCH 02/27] Make `velocity_representation` non-static in `api.data` --- src/jaxsim/api/data.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/jaxsim/api/data.py b/src/jaxsim/api/data.py index 73883dd5c..6f28f1ced 100644 --- a/src/jaxsim/api/data.py +++ b/src/jaxsim/api/data.py @@ -621,7 +621,7 @@ def reset_base_pose(self, base_pose: jtp.MatrixLike) -> Self: base_quaternion=W_Q_B ) - @functools.partial(jax.jit, static_argnames=["velocity_representation"]) + @jax.jit def reset_base_linear_velocity( self, linear_velocity: jtp.VectorLike, @@ -652,7 +652,7 @@ def reset_base_linear_velocity( velocity_representation=velocity_representation, ) - @functools.partial(jax.jit, static_argnames=["velocity_representation"]) + @jax.jit def reset_base_angular_velocity( self, angular_velocity: jtp.VectorLike, @@ -683,7 +683,7 @@ def reset_base_angular_velocity( velocity_representation=velocity_representation, ) - @functools.partial(jax.jit, static_argnames=["velocity_representation"]) + @jax.jit def reset_base_velocity( self, base_velocity: jtp.VectorLike, From 120b4744d8efaa340ce73123562c5e656b67eb1d Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Thu, 23 May 2024 12:34:48 +0200 Subject: [PATCH 03/27] Update documentation for `VelRepr` --- docs/modules/api.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/modules/api.rst b/docs/modules/api.rst index 9b3833c39..cff2b8b7f 100644 --- a/docs/modules/api.rst +++ b/docs/modules/api.rst @@ -94,7 +94,7 @@ References Common ~~~~~~ -.. autoflag:: jaxsim.api.common.VelRepr +.. autoclass:: jaxsim.api.common.VelRepr :members: .. autoclass:: jaxsim.api.common.ModelDataWithVelocityRepresentation From d346f2e383d466eb7b204bf3c84b0474b9ff43b3 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Thu, 23 May 2024 15:09:54 +0200 Subject: [PATCH 04/27] Make `output_vel_repr` non-static in `api.frame` --- src/jaxsim/api/frame.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jaxsim/api/frame.py b/src/jaxsim/api/frame.py index 2767e38dd..82e14363f 100644 --- a/src/jaxsim/api/frame.py +++ b/src/jaxsim/api/frame.py @@ -180,7 +180,7 @@ def transform( return W_H_L @ L_H_F -@functools.partial(jax.jit, static_argnames=["output_vel_repr"]) +@jax.jit def jacobian( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, From e299b59bd922262aa8198797452b5d6d772cb0cc Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Thu, 23 May 2024 15:11:27 +0200 Subject: [PATCH 05/27] Make `output_vel_repr` non-static in `api.link` --- src/jaxsim/api/link.py | 170 ++++++++++++++++++++++------------------- 1 file changed, 91 insertions(+), 79 deletions(-) diff --git a/src/jaxsim/api/link.py b/src/jaxsim/api/link.py index 75f2e89a0..d7fe87e50 100644 --- a/src/jaxsim/api/link.py +++ b/src/jaxsim/api/link.py @@ -235,7 +235,7 @@ def com_in_inertial_frame(): ) -@functools.partial(jax.jit, static_argnames=["output_vel_repr"]) +@jax.jit def jacobian( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, @@ -335,7 +335,7 @@ def jacobian( return O_J_WL_I -@functools.partial(jax.jit, static_argnames=["output_vel_repr"]) +@jax.jit def velocity( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, @@ -385,13 +385,13 @@ def velocity( return O_J_WL_I @ I_ν -@functools.partial(jax.jit, static_argnames=["output_vel_repr"]) +@jax.jit def jacobian_derivative( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, *, link_index: jtp.IntLike, - output_vel_repr: VelRepr | None = None, + output_vel_repr: int | None = None, ) -> jtp.Matrix: r""" Compute the derivative of the free-floating jacobian of the link. @@ -442,116 +442,128 @@ def jacobian_derivative( In = jnp.eye(model.dofs()) On = jnp.zeros(shape=(model.dofs(), model.dofs())) - match data.velocity_representation: - - case VelRepr.Inertial: + def from_inertial() -> jtp.Matrix: - W_H_B = data.base_transform() - B_X_W = jaxsim.math.Adjoint.from_transform(transform=W_H_B, inverse=True) + W_H_B = data.base_transform() + B_X_W = jaxsim.math.Adjoint.from_transform(transform=W_H_B, inverse=True) - with data.switch_velocity_representation(VelRepr.Inertial): - W_v_WB = data.base_velocity() - B_Ẋ_W = -B_X_W @ jaxsim.math.Cross.vx(W_v_WB) + with data.switch_velocity_representation(VelRepr.Inertial): + W_v_WB = data.base_velocity() + B_Ẋ_W = -B_X_W @ jaxsim.math.Cross.vx(W_v_WB) - # Compute the operator to change the representation of ν, and its - # time derivative. - T = jax.scipy.linalg.block_diag(B_X_W, In) - Ṫ = jax.scipy.linalg.block_diag(B_Ẋ_W, On) + # Compute the operator to change the representation of ν, and its + # time derivative. + T = jax.scipy.linalg.block_diag(B_X_W, In) + Ṫ = jax.scipy.linalg.block_diag(B_Ẋ_W, On) + return T, Ṫ - case VelRepr.Body: + def from_body() -> jtp.Matrix: - B_X_B = jaxsim.math.Adjoint.from_rotation_and_translation( - translation=jnp.zeros(3), rotation=jnp.eye(3) - ) + B_X_B = jaxsim.math.Adjoint.from_rotation_and_translation( + translation=jnp.zeros(3), rotation=jnp.eye(3) + ) - B_Ẋ_B = jnp.zeros(shape=(6, 6)) + B_Ẋ_B = jnp.zeros(shape=(6, 6)) - # Compute the operator to change the representation of ν, and its - # time derivative. - T = jax.scipy.linalg.block_diag(B_X_B, In) - Ṫ = jax.scipy.linalg.block_diag(B_Ẋ_B, On) + # Compute the operator to change the representation of ν, and its + # time derivative. + T = jax.scipy.linalg.block_diag(B_X_B, In) + Ṫ = jax.scipy.linalg.block_diag(B_Ẋ_B, On) + return T, Ṫ - case VelRepr.Mixed: + def from_mixed() -> jtp.Matrix: - BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3)) - B_X_BW = jaxsim.math.Adjoint.from_transform(transform=BW_H_B, inverse=True) + BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3)) + B_X_BW = jaxsim.math.Adjoint.from_transform(transform=BW_H_B, inverse=True) - with data.switch_velocity_representation(VelRepr.Mixed): - BW_v_WB = data.base_velocity() - BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3)) + with data.switch_velocity_representation(VelRepr.Mixed): + BW_v_WB = data.base_velocity() + BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3)) - BW_v_BW_B = BW_v_WB - BW_v_W_BW - B_Ẋ_BW = -B_X_BW @ jaxsim.math.Cross.vx(BW_v_BW_B) + BW_v_BW_B = BW_v_WB - BW_v_W_BW + B_Ẋ_BW = -B_X_BW @ jaxsim.math.Cross.vx(BW_v_BW_B) - # Compute the operator to change the representation of ν, and its - # time derivative. - T = jax.scipy.linalg.block_diag(B_X_BW, In) - Ṫ = jax.scipy.linalg.block_diag(B_Ẋ_BW, On) + # Compute the operator to change the representation of ν, and its + # time derivative. + T = jax.scipy.linalg.block_diag(B_X_BW, In) + Ṫ = jax.scipy.linalg.block_diag(B_Ẋ_BW, On) + return T, Ṫ - case _: - raise ValueError(data.velocity_representation) + T, Ṫ = jax.lax.switch( + index=data.velocity_representation, + branches=( + from_body, # VelRepr.Body + from_mixed, # VelRepr.Mixed + from_inertial, # VelRepr.Inertial + ), + ) # ====================================================== # Compute quantities to adjust the output representation # ====================================================== - match output_vel_repr: + def to_inertial() -> jtp.Matrix: - case VelRepr.Inertial: + W_H_B = data.base_transform() + O_X_B = W_X_B = jaxsim.math.Adjoint.from_transform(transform=W_H_B) - W_H_B = data.base_transform() - O_X_B = W_X_B = jaxsim.math.Adjoint.from_transform(transform=W_H_B) + with data.switch_velocity_representation(VelRepr.Body): + B_v_WB = data.base_velocity() - with data.switch_velocity_representation(VelRepr.Body): - B_v_WB = data.base_velocity() + O_Ẋ_B = W_Ẋ_B = W_X_B @ jaxsim.math.Cross.vx(B_v_WB) # noqa: F841 + return O_X_B, O_Ẋ_B - O_Ẋ_B = W_Ẋ_B = W_X_B @ jaxsim.math.Cross.vx(B_v_WB) # noqa: F841 + def to_body() -> jtp.Matrix: - case VelRepr.Body: + O_X_B = L_X_B = jaxsim.math.Adjoint.from_transform( + transform=B_H_L[link_index, :, :], inverse=True + ) - O_X_B = L_X_B = jaxsim.math.Adjoint.from_transform( - transform=B_H_L[link_index, :, :], inverse=True - ) - - B_X_L = jaxsim.math.Adjoint.inverse(adjoint=L_X_B) + B_X_L = jaxsim.math.Adjoint.inverse(adjoint=L_X_B) - with data.switch_velocity_representation(VelRepr.Body): - B_v_WB = data.base_velocity() - L_v_WL = js.link.velocity(model=model, data=data, link_index=link_index) + with data.switch_velocity_representation(VelRepr.Body): + B_v_WB = data.base_velocity() + L_v_WL = js.link.velocity(model=model, data=data, link_index=link_index) - O_Ẋ_B = L_Ẋ_B = -L_X_B @ jaxsim.math.Cross.vx( # noqa: F841 - B_X_L @ L_v_WL - B_v_WB - ) + O_Ẋ_B = L_Ẋ_B = -L_X_B @ jaxsim.math.Cross.vx( # noqa: F841 + B_X_L @ L_v_WL - B_v_WB + ) + return O_X_B, O_Ẋ_B - case VelRepr.Mixed: + def to_mixed() -> jtp.Matrix: - W_H_B = data.base_transform() - W_H_L = W_H_B @ B_H_L[link_index, :, :] - LW_H_L = W_H_L.at[0:3, 3].set(jnp.zeros(3)) - LW_H_B = LW_H_L @ jaxsim.math.Transform.inverse(B_H_L[link_index, :, :]) + W_H_B = data.base_transform() + W_H_L = W_H_B @ B_H_L[link_index, :, :] + LW_H_L = W_H_L.at[0:3, 3].set(jnp.zeros(3)) + LW_H_B = LW_H_L @ jaxsim.math.Transform.inverse(B_H_L[link_index, :, :]) - O_X_B = LW_X_B = jaxsim.math.Adjoint.from_transform(transform=LW_H_B) + O_X_B = LW_X_B = jaxsim.math.Adjoint.from_transform(transform=LW_H_B) - B_X_LW = jaxsim.math.Adjoint.inverse(adjoint=LW_X_B) + B_X_LW = jaxsim.math.Adjoint.inverse(adjoint=LW_X_B) - with data.switch_velocity_representation(VelRepr.Body): - B_v_WB = data.base_velocity() + with data.switch_velocity_representation(VelRepr.Body): + B_v_WB = data.base_velocity() - with data.switch_velocity_representation(VelRepr.Mixed): - LW_v_WL = js.link.velocity( - model=model, data=data, link_index=link_index - ) - LW_v_W_LW = LW_v_WL.at[3:6].set(jnp.zeros(3)) + with data.switch_velocity_representation(VelRepr.Mixed): + LW_v_WL = js.link.velocity(model=model, data=data, link_index=link_index) + LW_v_W_LW = LW_v_WL.at[3:6].set(jnp.zeros(3)) - LW_v_LW_L = LW_v_WL - LW_v_W_LW - LW_v_B_LW = LW_v_WL - LW_X_B @ B_v_WB - LW_v_LW_L + LW_v_LW_L = LW_v_WL - LW_v_W_LW + LW_v_B_LW = LW_v_WL - LW_X_B @ B_v_WB - LW_v_LW_L - O_Ẋ_B = LW_Ẋ_B = -LW_X_B @ jaxsim.math.Cross.vx( # noqa: F841 - B_X_LW @ LW_v_B_LW - ) - case _: - raise ValueError(output_vel_repr) + O_Ẋ_B = LW_Ẋ_B = -LW_X_B @ jaxsim.math.Cross.vx( # noqa: F841 + B_X_LW @ LW_v_B_LW + ) + return O_X_B, O_Ẋ_B + O_X_B, O_Ẋ_B = jax.lax.switch( + index=output_vel_repr, + branches=( + to_body, # VelRepr.Body + to_mixed, # VelRepr.Mixed + to_inertial, # VelRepr.Inertial + ), + ) # ============================================================= # Express the Jacobian derivative in the target representations # ============================================================= From b4b359d5a420199033f988b17ebd5b7ed6644dff Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 4 Jun 2024 12:44:00 +0200 Subject: [PATCH 06/27] Use `jax.lax.switch` inside `api.frame` for `VelRepr` --- src/jaxsim/api/frame.py | 61 ++++++++++++++++++++++------------------- 1 file changed, 33 insertions(+), 28 deletions(-) diff --git a/src/jaxsim/api/frame.py b/src/jaxsim/api/frame.py index 82e14363f..e243df6dd 100644 --- a/src/jaxsim/api/frame.py +++ b/src/jaxsim/api/frame.py @@ -227,34 +227,39 @@ def jacobian( model=model, data=data, link_index=L, output_vel_repr=VelRepr.Body ) - # Adjust the output representation. - match output_vel_repr: - case VelRepr.Inertial: - W_H_L = js.link.transform(model=model, data=data, link_index=L) - W_X_L = Adjoint.from_transform(transform=W_H_L) - W_J_WL = W_X_L @ L_J_WL - O_J_WL_I = W_J_WL - - case VelRepr.Body: - W_H_L = js.link.transform(model=model, data=data, link_index=L) - W_H_F = transform(model=model, data=data, frame_index=frame_index) - F_H_L = Transform.inverse(W_H_F) @ W_H_L - F_X_L = Adjoint.from_transform(transform=F_H_L) - F_J_WL = F_X_L @ L_J_WL - O_J_WL_I = F_J_WL - - case VelRepr.Mixed: - W_H_L = js.link.transform(model=model, data=data, link_index=L) - W_H_F = transform(model=model, data=data, frame_index=frame_index) - F_H_L = Transform.inverse(W_H_F) @ W_H_L - FW_H_F = W_H_F.at[0:3, 3].set(jnp.zeros(3)) - FW_H_L = FW_H_F @ F_H_L - FW_X_L = Adjoint.from_transform(transform=FW_H_L) - FW_J_WL = FW_X_L @ L_J_WL - O_J_WL_I = FW_J_WL - - case _: - raise ValueError(output_vel_repr) + def to_inertial() -> jtp.Matrix: + W_H_L = js.link.transform(model=model, data=data, link_index=L) + W_X_L = Adjoint.from_transform(transform=W_H_L) + W_J_WL = W_X_L @ L_J_WL + return W_J_WL + + def to_body() -> jtp.Matrix: + W_H_L = js.link.transform(model=model, data=data, link_index=L) + W_H_F = transform(model=model, data=data, frame_index=frame_index) + F_H_L = Transform.inverse(W_H_F) @ W_H_L + F_X_L = Adjoint.from_transform(transform=F_H_L) + F_J_WL = F_X_L @ L_J_WL + return F_J_WL + + def to_mixed() -> jtp.Matrix: + W_H_L = js.link.transform(model=model, data=data, link_index=L) + W_H_F = transform(model=model, data=data, frame_index=frame_index) + F_H_L = Transform.inverse(W_H_F) @ W_H_L + FW_H_F = W_H_F.at[0:3, 3].set(jnp.zeros(3)) + FW_H_L = FW_H_F @ F_H_L + FW_X_L = Adjoint.from_transform(transform=FW_H_L) + FW_J_WL = FW_X_L @ L_J_WL + return FW_J_WL + + # Adjust the output representation + O_J_WL_I = jax.lax.switch( + index=output_vel_repr, + branches=( + to_body, # VelRepr.Body + to_mixed, # VelRepr.Mixed + to_inertial, # VelRepr.Inertial + ), + ) return O_J_WL_I From 9cc3d3c5c5889b328e14b2c1683bb0532f541b20 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 4 Jun 2024 12:44:47 +0200 Subject: [PATCH 07/27] Use `jax.lax.switch` inside `api.link` for `VelRepr` --- src/jaxsim/api/link.py | 119 +++++++++++++++++++++++------------------ 1 file changed, 66 insertions(+), 53 deletions(-) diff --git a/src/jaxsim/api/link.py b/src/jaxsim/api/link.py index d7fe87e50..affc8b361 100644 --- a/src/jaxsim/api/link.py +++ b/src/jaxsim/api/link.py @@ -9,7 +9,7 @@ import jaxsim.rbda import jaxsim.typing as jtp from jaxsim import exceptions -from jaxsim.math import Adjoint +from jaxsim.math import Adjoint, Transform from .common import VelRepr @@ -284,53 +284,66 @@ def jacobian( B_J_WL_B = jnp.hstack([jnp.ones(5), κb]) * B_J_full_WX_B # Adjust the input representation such that `J_WL_I @ I_ν`. - match data.velocity_representation: - case VelRepr.Inertial: - W_H_B = data.base_transform() - B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True) - B_J_WL_I = B_J_WL_W = B_J_WL_B @ jax.scipy.linalg.block_diag( # noqa: F841 - B_X_W, jnp.eye(model.dofs()) - ) - - case VelRepr.Body: - B_J_WL_I = B_J_WL_B - - case VelRepr.Mixed: - W_R_B = data.base_orientation(dcm=True) - BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B) - B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) - B_J_WL_I = B_J_WL_BW = B_J_WL_B @ jax.scipy.linalg.block_diag( # noqa: F841 - B_X_BW, jnp.eye(model.dofs()) - ) - - case _: - raise ValueError(data.velocity_representation) + def to_inertial() -> jtp.Matrix: + W_H_B = data.base_transform() + B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True) + B_J_WL_I = B_J_WL_W = B_J_WL_B @ jax.scipy.linalg.block_diag( # noqa: F841 + B_X_W, jnp.eye(model.dofs()) + ) + return B_J_WL_W + + def to_body() -> jtp.Matrix: + return B_J_WL_B + + def to_mixed() -> jtp.Matrix: + W_R_B = data.base_orientation(dcm=True) + BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B) + B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) + B_J_WL_I = B_J_WL_BW = B_J_WL_B @ jax.scipy.linalg.block_diag( # noqa: F841 + B_X_BW, jnp.eye(model.dofs()) + ) + return B_J_WL_BW + + B_J_WL_I = jax.lax.switch( + index=data.velocity_representation, + branches=( + to_body, # VelRepr.Body + to_mixed, # VelRepr.Mixed + to_inertial, # VelRepr.Inertial + ), + ) B_H_L = B_H_Li[link_index] + def to_inertial() -> jtp.Matrix: + W_H_B = data.base_transform() + W_X_B = Adjoint.from_transform(transform=W_H_B) + W_J_WL_I = W_X_B @ B_J_WL_I + return W_J_WL_I + + def to_body() -> jtp.Matrix: + L_X_B = Adjoint.from_transform(transform=B_H_L, inverse=True) + L_J_WL_I = L_X_B @ B_J_WL_I + return L_J_WL_I + + def to_mixed() -> jtp.Matrix: + W_H_B = data.base_transform() + W_H_L = W_H_B @ B_H_L + LW_H_L = W_H_L.at[0:3, 3].set(jnp.zeros(3)) + LW_H_B = LW_H_L @ Transform.inverse(B_H_L) + LW_X_B = Adjoint.from_transform(transform=LW_H_B) + LW_J_WL_I = LW_X_B @ B_J_WL_I + return LW_J_WL_I + # Adjust the output representation such that `O_v_WL_I = O_J_WL_I @ I_ν`. - match output_vel_repr: - case VelRepr.Inertial: - W_H_B = data.base_transform() - W_X_B = Adjoint.from_transform(transform=W_H_B) - O_J_WL_I = W_J_WL_I = W_X_B @ B_J_WL_I # noqa: F841 - - case VelRepr.Body: - L_X_B = Adjoint.from_transform(transform=B_H_L, inverse=True) - L_J_WL_I = L_X_B @ B_J_WL_I - O_J_WL_I = L_J_WL_I - - case VelRepr.Mixed: - W_H_B = data.base_transform() - W_H_L = W_H_B @ B_H_L - LW_H_L = W_H_L.at[0:3, 3].set(jnp.zeros(3)) - LW_H_B = LW_H_L @ jaxsim.math.Transform.inverse(B_H_L) - LW_X_B = Adjoint.from_transform(transform=LW_H_B) - LW_J_WL_I = LW_X_B @ B_J_WL_I - O_J_WL_I = LW_J_WL_I - - case _: - raise ValueError(output_vel_repr) + O_J_WL_I = jax.lax.switch( + index=output_vel_repr, + branches=( + to_body, # VelRepr.Body + to_mixed, # VelRepr.Mixed + to_inertial, # VelRepr.Inertial + ), + ) return O_J_WL_I @@ -445,7 +458,7 @@ def jacobian_derivative( def from_inertial() -> jtp.Matrix: W_H_B = data.base_transform() - B_X_W = jaxsim.math.Adjoint.from_transform(transform=W_H_B, inverse=True) + B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True) with data.switch_velocity_representation(VelRepr.Inertial): W_v_WB = data.base_velocity() @@ -459,7 +472,7 @@ def from_inertial() -> jtp.Matrix: def from_body() -> jtp.Matrix: - B_X_B = jaxsim.math.Adjoint.from_rotation_and_translation( + B_X_B = Adjoint.from_rotation_and_translation( translation=jnp.zeros(3), rotation=jnp.eye(3) ) @@ -474,7 +487,7 @@ def from_body() -> jtp.Matrix: def from_mixed() -> jtp.Matrix: BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3)) - B_X_BW = jaxsim.math.Adjoint.from_transform(transform=BW_H_B, inverse=True) + B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) with data.switch_velocity_representation(VelRepr.Mixed): BW_v_WB = data.base_velocity() @@ -505,7 +518,7 @@ def from_mixed() -> jtp.Matrix: def to_inertial() -> jtp.Matrix: W_H_B = data.base_transform() - O_X_B = W_X_B = jaxsim.math.Adjoint.from_transform(transform=W_H_B) + O_X_B = W_X_B = Adjoint.from_transform(transform=W_H_B) with data.switch_velocity_representation(VelRepr.Body): B_v_WB = data.base_velocity() @@ -515,11 +528,11 @@ def to_inertial() -> jtp.Matrix: def to_body() -> jtp.Matrix: - O_X_B = L_X_B = jaxsim.math.Adjoint.from_transform( + O_X_B = L_X_B = Adjoint.from_transform( transform=B_H_L[link_index, :, :], inverse=True ) - B_X_L = jaxsim.math.Adjoint.inverse(adjoint=L_X_B) + B_X_L = Adjoint.inverse(adjoint=L_X_B) with data.switch_velocity_representation(VelRepr.Body): B_v_WB = data.base_velocity() @@ -535,11 +548,11 @@ def to_mixed() -> jtp.Matrix: W_H_B = data.base_transform() W_H_L = W_H_B @ B_H_L[link_index, :, :] LW_H_L = W_H_L.at[0:3, 3].set(jnp.zeros(3)) - LW_H_B = LW_H_L @ jaxsim.math.Transform.inverse(B_H_L[link_index, :, :]) + LW_H_B = LW_H_L @ Transform.inverse(B_H_L[link_index, :, :]) - O_X_B = LW_X_B = jaxsim.math.Adjoint.from_transform(transform=LW_H_B) + O_X_B = LW_X_B = Adjoint.from_transform(transform=LW_H_B) - B_X_LW = jaxsim.math.Adjoint.inverse(adjoint=LW_X_B) + B_X_LW = Adjoint.inverse(adjoint=LW_X_B) with data.switch_velocity_representation(VelRepr.Body): B_v_WB = data.base_velocity() From c545bd78b05c75669885d9a0e3d69ee078300b7c Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 4 Jun 2024 12:45:55 +0200 Subject: [PATCH 08/27] Use `jax.lax.switch` inside `api.com` for `VelRepr` --- src/jaxsim/api/com.py | 247 +++++++++++++++++++++++------------------- 1 file changed, 137 insertions(+), 110 deletions(-) diff --git a/src/jaxsim/api/com.py b/src/jaxsim/api/com.py index 6d25d276d..c9be3c15f 100644 --- a/src/jaxsim/api/com.py +++ b/src/jaxsim/api/com.py @@ -2,8 +2,8 @@ import jax.numpy as jnp import jaxsim.api as js -import jaxsim.math import jaxsim.typing as jtp +from jaxsim.math import Adjoint, Cross, Transform from .common import VelRepr @@ -27,7 +27,7 @@ def com_position( W_H_L = js.model.forward_kinematics(model=model, data=data) W_H_B = data.base_transform() - B_H_W = jaxsim.math.Transform.inverse(transform=W_H_B) + B_H_W = Transform.inverse(transform=W_H_B) def B_p̃_LCoM(i) -> jtp.Vector: m = js.link.mass(model=model, link_index=i) @@ -131,20 +131,29 @@ def centroidal_momentum_jacobian( ) W_H_B = data.base_transform() - B_H_W = jaxsim.math.Transform.inverse(W_H_B) + B_H_W = Transform.inverse(W_H_B) W_p_CoM = com_position(model=model, data=data) - match data.velocity_representation: - case VelRepr.Inertial | VelRepr.Mixed: - W_H_G = W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM) # noqa: F841 - case VelRepr.Body: - W_H_G = W_H_GB = W_H_B.at[0:3, 3].set(W_p_CoM) # noqa: F841 - case _: - raise ValueError(data.velocity_representation) + def to_inertial_and_mixed(): + W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM) + return W_H_GW + + def to_body(): + W_H_GB = W_H_B.at[0:3, 3].set(W_p_CoM) + return W_H_GB + + W_H_G = jax.lax.switch( + index=data.velocity_representation, + branches=( + to_body, # VelRepr.Body + to_inertial_and_mixed, # VelRepr.Mixed + to_inertial_and_mixed, # VelRepr.Inertial + ), + ) # Compute the transform for 6D forces. - G_Xf_B = jaxsim.math.Adjoint.from_transform(transform=B_H_W @ W_H_G).T + G_Xf_B = Adjoint.from_transform(transform=B_H_W @ W_H_G).T return G_Xf_B @ B_Jh @@ -170,17 +179,26 @@ def locked_centroidal_spatial_inertia( W_H_B = data.base_transform() W_p_CoM = com_position(model=model, data=data) - match data.velocity_representation: - case VelRepr.Inertial | VelRepr.Mixed: - W_H_G = W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM) # noqa: F841 - case VelRepr.Body: - W_H_G = W_H_GB = W_H_B.at[0:3, 3].set(W_p_CoM) # noqa: F841 - case _: - raise ValueError(data.velocity_representation) + def to_inertial_or_mixed() -> jtp.Matrix: + W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM) + return W_H_GW + + def to_body() -> jtp.Matrix: + W_H_GB = W_H_B.at[0:3, 3].set(W_p_CoM) + return W_H_GB + + W_H_G = jax.lax.switch( + index=data.velocity_representation, + branches=( + to_body, # VelRepr.Body + to_inertial_or_mixed, # VelRepr.Mixed + to_inertial_or_mixed, # VelRepr.Inertial + ), + ) - B_H_G = jaxsim.math.Transform.inverse(W_H_B) @ W_H_G + B_H_G = Transform.inverse(W_H_B) @ W_H_G - B_Xv_G = jaxsim.math.Adjoint.from_transform(transform=B_H_G) + B_Xv_G = Adjoint.from_transform(transform=B_H_G) G_Xf_B = B_Xv_G.transpose() return G_Xf_B @ B_Mbb_B @ B_Xv_G @@ -275,80 +293,85 @@ def other_representation_to_body( C_v̇_WL expressed in a generic frame C to the body-fixed representation L_v̇_WL. """ - L_X_C = jaxsim.math.Adjoint.from_transform(transform=L_H_C) - C_X_L = jaxsim.math.Adjoint.inverse(L_X_C) + L_X_C = Adjoint.from_transform(transform=L_H_C) + C_X_L = Adjoint.inverse(L_X_C) - L_v̇_WL = L_X_C @ (C_v̇_WL + jaxsim.math.Cross.vx(C_X_L @ L_v_LC) @ C_v_WC) + L_v̇_WL = L_X_C @ (C_v̇_WL + Cross.vx(C_X_L @ L_v_LC) @ C_v_WC) return L_v̇_WL + def to_body() -> jtp.Vector: + L_a_bias_WL = v̇_bias_WL + return L_a_bias_WL + + def to_inertial() -> jtp.Vector: + + C_v̇_WL = W_v̇_bias_WL = v̇_bias_WL # noqa: F841 + C_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841 + + L_H_C = L_H_W = jax.vmap(lambda W_H_L: Transform.inverse(W_H_L))( # noqa: F841 + W_H_L + ) + + L_v_LC = L_v_LW = jax.vmap( # noqa: F841 + lambda i: -js.link.velocity( + model=model, data=data, link_index=i, output_vel_repr=VelRepr.Body + ) + )(jnp.arange(model.number_of_links())) + + L_a_bias_WL = jax.vmap( + lambda i: other_representation_to_body( + C_v̇_WL=C_v̇_WL[i], + C_v_WC=C_v_WC, + L_H_C=L_H_C[i], + L_v_LC=L_v_LC[i], + ) + )(jnp.arange(model.number_of_links())) + return L_a_bias_WL + + def to_mixed() -> jtp.Vector: + + C_v̇_WL = LW_v̇_bias_WL = v̇_bias_WL # noqa: F841 + + C_v_WC = LW_v_W_LW = jax.vmap( # noqa: F841 + lambda i: js.link.velocity( + model=model, data=data, link_index=i, output_vel_repr=VelRepr.Mixed + ) + .at[3:6] + .set(jnp.zeros(3)) + )(jnp.arange(model.number_of_links())) + + L_H_C = L_H_LW = jax.vmap( # noqa: F841 + lambda W_H_L: Transform.inverse(W_H_L.at[0:3, 3].set(jnp.zeros(3))) + )(W_H_L) + + L_v_LC = L_v_L_LW = jax.vmap( # noqa: F841 + lambda i: -js.link.velocity( + model=model, data=data, link_index=i, output_vel_repr=VelRepr.Body + ) + .at[0:3] + .set(jnp.zeros(3)) + )(jnp.arange(model.number_of_links())) + + L_a_bias_WL = jax.vmap( + lambda i: other_representation_to_body( + C_v̇_WL=C_v̇_WL[i], + C_v_WC=C_v_WC[i], + L_H_C=L_H_C[i], + L_v_LC=L_v_LC[i], + ) + )(jnp.arange(model.number_of_links())) + return L_a_bias_WL + # We need here to get the body-fixed bias acceleration of the links. # Since it's computed in the active representation, we need to convert it to body. - match data.velocity_representation: - - case VelRepr.Body: - L_a_bias_WL = v̇_bias_WL - - case VelRepr.Inertial: - - C_v̇_WL = W_v̇_bias_WL = v̇_bias_WL # noqa: F841 - C_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841 - - L_H_C = L_H_W = jax.vmap( # noqa: F841 - lambda W_H_L: jaxsim.math.Transform.inverse(W_H_L) - )(W_H_L) - - L_v_LC = L_v_LW = jax.vmap( # noqa: F841 - lambda i: -js.link.velocity( - model=model, data=data, link_index=i, output_vel_repr=VelRepr.Body - ) - )(jnp.arange(model.number_of_links())) - - L_a_bias_WL = jax.vmap( - lambda i: other_representation_to_body( - C_v̇_WL=C_v̇_WL[i], - C_v_WC=C_v_WC, - L_H_C=L_H_C[i], - L_v_LC=L_v_LC[i], - ) - )(jnp.arange(model.number_of_links())) - - case VelRepr.Mixed: - - C_v̇_WL = LW_v̇_bias_WL = v̇_bias_WL # noqa: F841 - - C_v_WC = LW_v_W_LW = jax.vmap( # noqa: F841 - lambda i: js.link.velocity( - model=model, data=data, link_index=i, output_vel_repr=VelRepr.Mixed - ) - .at[3:6] - .set(jnp.zeros(3)) - )(jnp.arange(model.number_of_links())) - - L_H_C = L_H_LW = jax.vmap( # noqa: F841 - lambda W_H_L: jaxsim.math.Transform.inverse( - W_H_L.at[0:3, 3].set(jnp.zeros(3)) - ) - )(W_H_L) - - L_v_LC = L_v_L_LW = jax.vmap( # noqa: F841 - lambda i: -js.link.velocity( - model=model, data=data, link_index=i, output_vel_repr=VelRepr.Body - ) - .at[0:3] - .set(jnp.zeros(3)) - )(jnp.arange(model.number_of_links())) - - L_a_bias_WL = jax.vmap( - lambda i: other_representation_to_body( - C_v̇_WL=C_v̇_WL[i], - C_v_WC=C_v_WC[i], - L_H_C=L_H_C[i], - L_v_LC=L_v_LC[i], - ) - )(jnp.arange(model.number_of_links())) - - case _: - raise ValueError(data.velocity_representation) + L_a_bias_WL = jax.lax.switch( + index=data.velocity_representation, + branches=( + to_body, # VelRepr.Body + to_mixed, # VelRepr.Mixed + to_inertial, # VelRepr.Inertial + ), + ) # Compute the bias of the 6D momentum derivative. def bias_momentum_derivative_term( @@ -364,13 +387,11 @@ def bias_momentum_derivative_term( ) # Compute the world-to-link transformations for 6D forces. - W_Xf_L = jaxsim.math.Adjoint.from_transform( - transform=W_H_L[link_index], inverse=True - ).T + W_Xf_L = Adjoint.from_transform(transform=W_H_L[link_index], inverse=True).T # Compute the contribution of the link to the bias acceleration of the CoM. W_ḣ_bias_link_contribution = W_Xf_L @ ( - L_M_L @ L_a_bias_WL + jaxsim.math.Cross.vx_star(L_v_WL) @ L_M_L @ L_v_WL + L_M_L @ L_a_bias_WL + Cross.vx_star(L_v_WL) @ L_M_L @ L_v_WL ) return W_ḣ_bias_link_contribution @@ -386,30 +407,36 @@ def bias_momentum_derivative_term( # Compute the position of the CoM. W_p_CoM = com_position(model=model, data=data) - match data.velocity_representation: - + def to_inertial_or_mixed() -> jtp.Vector: # G := G[W] = (W_p_CoM, [W]) - case VelRepr.Inertial | VelRepr.Mixed: - W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM) - GW_Xf_W = jaxsim.math.Adjoint.from_transform(W_H_GW).T + W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM) + GW_Xf_W = Adjoint.from_transform(W_H_GW).T - GW_ḣ_bias = GW_Xf_W @ W_ḣ_bias - GW_v̇l_com_bias = GW_ḣ_bias[0:3] / m + GW_ḣ_bias = GW_Xf_W @ W_ḣ_bias + GW_v̇l_com_bias = GW_ḣ_bias[0:3] / m - return GW_v̇l_com_bias + return GW_v̇l_com_bias + def to_body() -> jtp.Vector: # G := G[B] = (W_p_CoM, [B]) - case VelRepr.Body: - GB_Xf_W = jaxsim.math.Adjoint.from_transform( - transform=data.base_transform().at[0:3].set(W_p_CoM) - ).T + GB_Xf_W = Adjoint.from_transform( + transform=data.base_transform().at[0:3].set(W_p_CoM) + ).T + + GB_ḣ_bias = GB_Xf_W @ W_ḣ_bias + GB_v̇l_com_bias = GB_ḣ_bias[0:3] / m - GB_ḣ_bias = GB_Xf_W @ W_ḣ_bias - GB_v̇l_com_bias = GB_ḣ_bias[0:3] / m + return GB_v̇l_com_bias - return GB_v̇l_com_bias + GB_v̇l_com_bias = jax.lax.switch( + index=data.velocity_representation, + branches=( + to_body, # VelRepr.Body + to_inertial_or_mixed, # VelRepr.Mixed + to_inertial_or_mixed, # VelRepr.Inertial + ), + ) - case _: - raise ValueError(data.velocity_representation) + return GB_v̇l_com_bias From d96d98de0589eeba51722ee1ad90d49e248ca235 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 4 Jun 2024 12:46:34 +0200 Subject: [PATCH 09/27] Make `VelRepr` non-static in `ModelDatawithVelocityRepresentation` --- src/jaxsim/api/common.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/jaxsim/api/common.py b/src/jaxsim/api/common.py index d8fb1d2ac..389211e72 100644 --- a/src/jaxsim/api/common.py +++ b/src/jaxsim/api/common.py @@ -7,7 +7,6 @@ import jax import jax.numpy as jnp import jax_dataclasses -from jax_dataclasses import Static import jaxsim.typing as jtp from jaxsim.math import Adjoint @@ -36,7 +35,7 @@ class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC): Base class for model data structures with velocity representation. """ - velocity_representation: Static[VelRepr] = dataclasses.field( + velocity_representation: VelRepr = dataclasses.field( default=VelRepr.Inertial, kw_only=True ) From 41a3f89547455958d528b0f201ed9b557d2cba99 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 4 Jun 2024 12:48:42 +0200 Subject: [PATCH 10/27] Make `output_vel_repr` non-static in `api.common` --- src/jaxsim/api/common.py | 125 ++++++++++++++++++++------------------- 1 file changed, 64 insertions(+), 61 deletions(-) diff --git a/src/jaxsim/api/common.py b/src/jaxsim/api/common.py index 389211e72..74eadab48 100644 --- a/src/jaxsim/api/common.py +++ b/src/jaxsim/api/common.py @@ -80,7 +80,7 @@ def switch_velocity_representation( self.velocity_representation = original_representation @staticmethod - @functools.partial(jax.jit, static_argnames=["other_representation", "is_force"]) + @functools.partial(jax.jit, static_argnames=["is_force"]) def inertial_to_other_representation( array: jtp.Array, other_representation: VelRepr, @@ -112,42 +112,40 @@ def inertial_to_other_representation( if W_H_O.shape != (4, 4): raise ValueError(W_H_O.shape, (4, 4)) - match other_representation: - - case VelRepr.Inertial: - return W_array - - case VelRepr.Body: - - if not is_force: - O_Xv_W = Adjoint.from_transform(transform=W_H_O, inverse=True) - O_array = O_Xv_W @ W_array - - else: - O_Xf_W = Adjoint.from_transform(transform=W_H_O).T - O_array = O_Xf_W @ W_array - - return O_array - - case VelRepr.Mixed: - W_p_O = W_H_O[0:3, 3] - W_H_OW = jnp.eye(4).at[0:3, 3].set(W_p_O) - - if not is_force: - OW_Xv_W = Adjoint.from_transform(transform=W_H_OW, inverse=True) - OW_array = OW_Xv_W @ W_array - - else: - OW_Xf_W = Adjoint.from_transform(transform=W_H_OW).T - OW_array = OW_Xf_W @ W_array - - return OW_array - - case _: - raise ValueError(other_representation) + def to_inertial(): + return W_array + + def to_body(): + if not is_force: + O_Xv_W = Adjoint.from_transform(transform=W_H_O, inverse=True) + O_array = O_Xv_W @ W_array + else: + O_Xf_W = Adjoint.from_transform(transform=W_H_O).T + O_array = O_Xf_W @ W_array + return O_array + + def to_mixed(): + W_p_O = W_H_O[0:3, 3] + W_H_OW = jnp.eye(4).at[0:3, 3].set(W_p_O) + if not is_force: + OW_Xv_W = Adjoint.from_transform(transform=W_H_OW, inverse=True) + OW_array = OW_Xv_W @ W_array + else: + OW_Xf_W = Adjoint.from_transform(transform=W_H_OW).T + OW_array = OW_Xf_W @ W_array + return OW_array + + return jax.lax.switch( + index=other_representation, + branches=( + to_body, # VelRepr.Body + to_mixed, # VelRepr.Mixed + to_inertial, # VelRepr.Inertial + ), + ) @staticmethod - @functools.partial(jax.jit, static_argnames=["other_representation", "is_force"]) + @functools.partial(jax.jit, static_argnames=["is_force"]) def other_representation_to_inertial( array: jtp.Array, other_representation: VelRepr, @@ -179,38 +177,43 @@ def other_representation_to_inertial( if W_H_O.shape != (4, 4): raise ValueError(W_H_O.shape, (4, 4)) - match other_representation: - case VelRepr.Inertial: - W_array = array - return W_array + def from_inertial(): + W_array = array + return W_array - case VelRepr.Body: - O_array = array + def from_body(): + O_array = array - if not is_force: - W_Xv_O: jtp.Array = Adjoint.from_transform(W_H_O) - W_array = W_Xv_O @ O_array + if not is_force: + W_Xv_O = Adjoint.from_transform(W_H_O) + W_array = W_Xv_O @ O_array - else: - W_Xf_O = Adjoint.from_transform(transform=W_H_O, inverse=True).T - W_array = W_Xf_O @ O_array + else: + W_Xf_O = Adjoint.from_transform(transform=W_H_O, inverse=True).T + W_array = W_Xf_O @ O_array - return W_array + return W_array - case VelRepr.Mixed: - BW_array = array - W_p_O = W_H_O[0:3, 3] - W_H_OW = jnp.eye(4).at[0:3, 3].set(W_p_O) + def from_mixed(): + BW_array = array + W_p_O = W_H_O[0:3, 3] + W_H_OW = jnp.eye(4).at[0:3, 3].set(W_p_O) - if not is_force: - W_Xv_BW: jtp.Array = Adjoint.from_transform(W_H_OW) - W_array = W_Xv_BW @ BW_array + if not is_force: + W_Xv_BW = Adjoint.from_transform(W_H_OW) + W_array = W_Xv_BW @ BW_array - else: - W_Xf_BW = Adjoint.from_transform(transform=W_H_OW, inverse=True).T - W_array = W_Xf_BW @ BW_array + else: + W_Xf_BW = Adjoint.from_transform(transform=W_H_OW, inverse=True).T + W_array = W_Xf_BW @ BW_array - return W_array + return W_array - case _: - raise ValueError(other_representation) + return jax.lax.switch( + index=other_representation, + branches=( + from_body, # VelRepr.Body + from_mixed, # VelRepr.Mixed + from_inertial, # VelRepr.Inertial + ), + ) From baf8f6f387c358e0acc335cbaa4aa65262d181c2 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 4 Jun 2024 12:52:36 +0200 Subject: [PATCH 11/27] Make `output_vel_repr` non-static in `api.model` --- src/jaxsim/api/model.py | 207 ++++++++++++++++++++++++---------------- 1 file changed, 125 insertions(+), 82 deletions(-) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 84e2b8229..d5960b04c 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -448,7 +448,7 @@ def forward_kinematics(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp return jnp.atleast_3d(W_H_LL).astype(float) -@functools.partial(jax.jit, static_argnames=["output_vel_repr"]) +@jax.jit def generalized_free_floating_jacobian( model: JaxSimModel, data: js.data.JaxSimModelData, @@ -488,47 +488,75 @@ def generalized_free_floating_jacobian( # Update the input velocity representation such that v_WL = J_WL_I @ I_ν # ====================================================================== - match data.velocity_representation: - - case VelRepr.Inertial: - - W_H_B = data.base_transform() - B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True) + def to_inertial(): + W_H_B = data.base_transform() + B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True) - B_J_full_WX_I = B_J_full_WX_W = ( # noqa: F841 - B_J_full_WX_B - @ jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs())) - ) + B_J_full_WX_W = B_J_full_WX_B @ jax.scipy.linalg.block_diag( + B_X_W, jnp.eye(model.dofs()) + ) + return B_J_full_WX_W - case VelRepr.Body: + def to_body(): + return B_J_full_WX_B - B_J_full_WX_I = B_J_full_WX_B + def to_mixed(): + W_R_B = data.base_orientation(dcm=True) + BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B) + B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) - case VelRepr.Mixed: + B_J_full_WX_BW = B_J_full_WX_B @ jax.scipy.linalg.block_diag( + B_X_BW, jnp.eye(model.dofs()) + ) + return B_J_full_WX_BW + + # Update the input velocity representation such that `J_WL_I @ I_ν`. + B_J_full_WX_I = jax.lax.switch( + index=data.velocity_representation, + branches=( + to_body, # VelRepr.Body + to_mixed, # VelRepr.Mixed + to_inertial, # VelRepr.Inertial + ), + ) - W_R_B = data.base_orientation(dcm=True) - BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B) - B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) + def to_inertial(): + W_H_B = data.base_transform() + W_X_B = Adjoint.from_transform(transform=W_H_B) + W_J_full_WX_I = W_X_B @ B_J_full_WX_I + return W_J_full_WX_I - B_J_full_WX_I = B_J_full_WX_BW = ( # noqa: F841 - B_J_full_WX_B - @ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs())) - ) + def to_body(): + return B_J_full_WX_I - case _: - raise ValueError(data.velocity_representation) + def to_mixed(): + W_R_B = data.base_orientation(dcm=True) + BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B) + BW_X_B = Adjoint.from_transform(transform=BW_H_B) + BW_J_full_WX_I = BW_X_B @ B_J_full_WX_I + return BW_J_full_WX_I # ==================================================================== # Create stacked Jacobian for each link by filtering the full Jacobian # ==================================================================== + # Update the output velocity representation such that `O_v_WL_I = O_J_WL_I @ I_ν`. + O_J_full_WX_I = jax.lax.switch( + index=output_vel_repr, + branches=( + to_body, # VelRepr.Body + to_mixed, # VelRepr.Mixed + to_inertial, # VelRepr.Inertial + ), + ) + κ_bool = model.kin_dyn_parameters.support_body_array_bool # Keep only the columns of the full Jacobian corresponding to the support # body array of each link. B_J_WL_I = jax.vmap( lambda κ: jnp.where( - jnp.hstack([jnp.ones(5), κ]), B_J_full_WX_I, jnp.zeros_like(B_J_full_WX_I) + jnp.hstack([jnp.ones(5), κ]), O_J_full_WX_I, jnp.zeros_like(O_J_full_WX_I) ) )(κ_bool) @@ -1367,7 +1395,7 @@ def total_momentum(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Vec return Jh @ ν -@functools.partial(jax.jit, static_argnames=["output_vel_repr"]) +@jax.jit def total_momentum_jacobian( model: JaxSimModel, data: js.data.JaxSimModelData, @@ -1396,44 +1424,52 @@ def total_momentum_jacobian( with data.switch_velocity_representation(VelRepr.Body): B_Jh_B = free_floating_mass_matrix(model=model, data=data)[0:6] - match data.velocity_representation: - case VelRepr.Body: - B_Jh = B_Jh_B + def to_body() -> jtp.Matrix: + return B_Jh_B - case VelRepr.Inertial: - B_X_W = Adjoint.from_transform( - transform=data.base_transform(), inverse=True - ) - B_Jh = B_Jh_B @ jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs())) + def to_inertial() -> jtp.Matrix: + B_X_W = Adjoint.from_transform(transform=data.base_transform(), inverse=True) + return B_Jh_B @ jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs())) - case VelRepr.Mixed: - BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3)) - B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) - B_Jh = B_Jh_B @ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs())) - - case _: - raise ValueError(data.velocity_representation) - - match output_vel_repr: - case VelRepr.Body: - return B_Jh + def to_mixed() -> jtp.Matrix: + BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3)) + B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) + return B_Jh_B @ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs())) - case VelRepr.Inertial: - W_H_B = data.base_transform() - B_Xv_W = Adjoint.from_transform(transform=W_H_B, inverse=True) - W_Xf_B = B_Xv_W.T - W_Jh = W_Xf_B @ B_Jh - return W_Jh - - case VelRepr.Mixed: - BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3)) - B_Xv_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) - BW_Xf_B = B_Xv_BW.T - BW_Jh = BW_Xf_B @ B_Jh - return BW_Jh + B_Jh = jax.lax.switch( + index=data.velocity_representation, + branches=( + to_body, # VelRepr.Body + to_mixed, # VelRepr.Mixed + to_inertial, # VelRepr.Inertial + ), + ) - case _: - raise ValueError(output_vel_repr) + def to_body() -> jtp.Matrix: + return B_Jh + + def to_inertial() -> jtp.Matrix: + W_H_B = data.base_transform() + B_Xv_W = Adjoint.from_transform(transform=W_H_B, inverse=True) + W_Xf_B = B_Xv_W.T + W_Jh = W_Xf_B @ B_Jh + return W_Jh + + def to_mixed() -> jtp.Matrix: + BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3)) + B_Xv_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) + BW_Xf_B = B_Xv_BW.T + BW_Jh = BW_Xf_B @ B_Jh + return BW_Jh + + return jax.lax.switch( + index=output_vel_repr, + branches=( + to_body, # VelRepr.Body + to_mixed, # VelRepr.Mixed + to_inertial, # VelRepr.Inertial + ), + ) @jax.jit @@ -1456,7 +1492,7 @@ def average_velocity(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.V return J @ ν -@functools.partial(jax.jit, static_argnames=["output_vel_repr"]) +@jax.jit def average_velocity_jacobian( model: JaxSimModel, data: js.data.JaxSimModelData, @@ -1483,40 +1519,47 @@ def average_velocity_jacobian( # Depending on the velocity representation, the frame G is either G[W] or G[B]. G_J = js.com.average_centroidal_velocity_jacobian(model=model, data=data) - match output_vel_repr: + def to_inertial() -> jtp.Matrix: - case VelRepr.Inertial: + GW_J = G_J + W_p_CoM = js.com.com_position(model=model, data=data) - GW_J = G_J - W_p_CoM = js.com.com_position(model=model, data=data) + W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM) + W_X_GW = Adjoint.from_transform(transform=W_H_GW) - W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM) - W_X_GW = Adjoint.from_transform(transform=W_H_GW) + return W_X_GW @ GW_J - return W_X_GW @ GW_J + def to_body() -> jtp.Matrix: - case VelRepr.Body: + GB_J = G_J + W_p_B = data.base_position() + W_p_CoM = js.com.com_position(model=model, data=data) + B_R_W = data.base_orientation(dcm=True).transpose() - GB_J = G_J - W_p_B = data.base_position() - W_p_CoM = js.com.com_position(model=model, data=data) - B_R_W = data.base_orientation(dcm=True).transpose() + B_H_GB = jnp.eye(4).at[0:3, 3].set(B_R_W @ (W_p_CoM - W_p_B)) + B_X_GB = Adjoint.from_transform(transform=B_H_GB) - B_H_GB = jnp.eye(4).at[0:3, 3].set(B_R_W @ (W_p_CoM - W_p_B)) - B_X_GB = Adjoint.from_transform(transform=B_H_GB) + return B_X_GB @ GB_J - return B_X_GB @ GB_J + def to_mixed() -> jtp.Matrix: - case VelRepr.Mixed: + GW_J = G_J + W_p_B = data.base_position() + W_p_CoM = js.com.com_position(model=model, data=data) - GW_J = G_J - W_p_B = data.base_position() - W_p_CoM = js.com.com_position(model=model, data=data) + BW_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM - W_p_B) + BW_X_GW = Adjoint.from_transform(transform=BW_H_GW) - BW_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM - W_p_B) - BW_X_GW = Adjoint.from_transform(transform=BW_H_GW) + return BW_X_GW @ GW_J - return BW_X_GW @ GW_J + return jax.lax.switch( + index=output_vel_repr, + branches=( + to_body, # VelRepr.Body + to_mixed, # VelRepr.Mixed + to_inertial, # VelRepr.Inertial + ), + ) # ======================== From 8230b3ea40e517da876531a3e9ef6358d943cbdb Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 4 Jun 2024 13:00:30 +0200 Subject: [PATCH 12/27] Use `jax.lax.switch` inside `api.model` for `VelRepr` --- src/jaxsim/api/model.py | 378 ++++++++++++++++++++++------------------ 1 file changed, 204 insertions(+), 174 deletions(-) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index d5960b04c..8d1b06fcd 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -488,7 +488,7 @@ def generalized_free_floating_jacobian( # Update the input velocity representation such that v_WL = J_WL_I @ I_ν # ====================================================================== - def to_inertial(): + def from_inertial(): W_H_B = data.base_transform() B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True) @@ -497,10 +497,10 @@ def to_inertial(): ) return B_J_full_WX_W - def to_body(): + def from_body(): return B_J_full_WX_B - def to_mixed(): + def from_mixed(): W_R_B = data.base_orientation(dcm=True) BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B) B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) @@ -514,9 +514,9 @@ def to_mixed(): B_J_full_WX_I = jax.lax.switch( index=data.velocity_representation, branches=( - to_body, # VelRepr.Body - to_mixed, # VelRepr.Mixed - to_inertial, # VelRepr.Inertial + from_body, # VelRepr.Body + from_mixed, # VelRepr.Mixed + from_inertial, # VelRepr.Inertial ), ) @@ -564,45 +564,47 @@ def to_mixed(): # Update the output velocity representation such that O_v_WL = O_J_WL @ ν # ======================================================================= - match output_vel_repr: - - case VelRepr.Inertial: - - W_H_B = data.base_transform() - W_X_B = jaxsim.math.Adjoint.from_transform(W_H_B) - - O_J_WL_I = W_J_WL_I = jax.vmap( # noqa: F841 - lambda B_J_WL_I: W_X_B @ B_J_WL_I - )(B_J_WL_I) - - case VelRepr.Body: - - O_J_WL_I = L_J_WL_I = jax.vmap( # noqa: F841 - lambda B_H_L, B_J_WL_I: jaxsim.math.Adjoint.from_transform( - B_H_L, inverse=True - ) - @ B_J_WL_I - )(B_H_L, B_J_WL_I) - - case VelRepr.Mixed: + def to_inertial() -> jtp.Matrix: + W_H_B = data.base_transform() + W_X_B = jaxsim.math.Adjoint.from_transform(W_H_B) - W_H_B = data.base_transform() + O_J_WL_I = W_J_WL_I = jax.vmap( # noqa: F841 + lambda B_J_WL_I: W_X_B @ B_J_WL_I + )(B_J_WL_I) + return O_J_WL_I - LW_H_L = jax.vmap( - lambda B_H_L: (W_H_B @ B_H_L).at[0:3, 3].set(jnp.zeros(3)) - )(B_H_L) + def to_body() -> jtp.Matrix: + O_J_WL_I = L_J_WL_I = jax.vmap( # noqa: F841 + lambda B_H_L, B_J_WL_I: jaxsim.math.Adjoint.from_transform( + B_H_L, inverse=True + ) + @ B_J_WL_I + )(B_H_L, B_J_WL_I) + return O_J_WL_I - LW_H_B = jax.vmap( - lambda LW_H_L, B_H_L: LW_H_L @ jaxsim.math.Transform.inverse(B_H_L) - )(LW_H_L, B_H_L) + def to_mixed() -> jtp.Matrix: + W_H_B = data.base_transform() + LW_H_L = jax.vmap(lambda B_H_L: (W_H_B @ B_H_L).at[0:3, 3].set(jnp.zeros(3)))( + B_H_L + ) + LW_H_B = jax.vmap( + lambda LW_H_L, B_H_L: LW_H_L @ jaxsim.math.Transform.inverse(B_H_L) + )(LW_H_L, B_H_L) - O_J_WL_I = LW_J_WL_I = jax.vmap( # noqa: F841 - lambda LW_H_B, B_J_WL_I: jaxsim.math.Adjoint.from_transform(LW_H_B) - @ B_J_WL_I - )(LW_H_B, B_J_WL_I) + O_J_WL_I = LW_J_WL_I = jax.vmap( # noqa: F841 + lambda LW_H_B, B_J_WL_I: jaxsim.math.Adjoint.from_transform(LW_H_B) + @ B_J_WL_I + )(LW_H_B, B_J_WL_I) + return O_J_WL_I - case _: - raise ValueError(output_vel_repr) + O_J_WL_I = jax.lax.switch( + index=output_vel_repr, + branches=( + to_body, # VelRepr.Body + to_mixed, # VelRepr.Mixed + to_inertial, # VelRepr.Inertial + ), + ) return O_J_WL_I @@ -784,26 +786,34 @@ def to_active( C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True) return C_X_W @ (W_v̇_WB - Cross.vx(W_v_WC) @ W_v_WB) - match data.velocity_representation: - case VelRepr.Inertial: - # In this case C=W - W_H_C = W_H_W = jnp.eye(4) # noqa: F841 - W_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841 - - case VelRepr.Body: - # In this case C=B - W_H_C = W_H_B = data.base_transform() - W_v_WC = W_v_WB - - case VelRepr.Mixed: - # In this case C=B[W] - W_H_B = data.base_transform() - W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) # noqa: F841 - W_ṗ_B = data.base_velocity()[0:3] - W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) # noqa: F841 + def to_inertial(): + # In this case C=W + W_H_C = W_H_W = jnp.eye(4) # noqa: F841 + W_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841 + return W_v_WC, W_H_C + + def to_body(): + # In this case C=B + W_H_C = W_H_B = data.base_transform() # noqa: F841 + W_v_WC = W_v_WB + return W_v_WC, W_H_C + + def to_mixed(): + # In this case C=B[W] + W_H_B = data.base_transform() + W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) # noqa: F841 + W_ṗ_B = data.base_velocity()[0:3] + W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) # noqa: F841 + return W_v_WC, W_H_C - case _: - raise ValueError(data.velocity_representation) + W_v_WC, W_H_C = jax.lax.switch( + index=data.velocity_representation, + branches=( + to_body, # VelRepr.Body + to_mixed, # VelRepr.Mixed + to_inertial, # VelRepr.Inertial + ), + ) # We need to convert the derivative of the base velocity to the active # representation. In Mixed representation, this conversion is not a plain @@ -940,29 +950,32 @@ def free_floating_mass_matrix( joint_positions=data.state.physics_model.joint_positions, ) - match data.velocity_representation: - case VelRepr.Body: - return M_body + def to_body() -> jtp.Matrix: + return M_body - case VelRepr.Inertial: + def to_inertial() -> jtp.Matrix: - B_X_W = Adjoint.from_transform( - transform=data.base_transform(), inverse=True - ) - invT = jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs())) + B_X_W = Adjoint.from_transform(transform=data.base_transform(), inverse=True) + invT = jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs())) - return invT.T @ M_body @ invT + return invT.T @ M_body @ invT - case VelRepr.Mixed: + def to_mixed() -> jtp.Matrix: - BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3)) - B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) - invT = jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs())) + BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3)) + B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) + invT = jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs())) - return invT.T @ M_body @ invT + return invT.T @ M_body @ invT - case _: - raise ValueError(data.velocity_representation) + return jax.lax.switch( + index=data.velocity_representation, + branches=( + to_body, # VelRepr.Body + to_mixed, # VelRepr.Mixed + to_inertial, # VelRepr.Inertial + ), + ) @jax.jit @@ -1033,56 +1046,58 @@ def compute_link_contribution(M, v, J, J̇) -> jtp.Array: # Adjust the representation of the Coriolis matrix. # Refer to https://github.com/traversaro/traversaro-phd-thesis, Section 3.6. - match data.velocity_representation: - - case VelRepr.Body: - return C_B - - case VelRepr.Inertial: - - n = model.dofs() - W_H_B = data.base_transform() - B_X_W = jaxsim.math.Adjoint.from_transform(W_H_B, inverse=True) - B_T_W = jax.scipy.linalg.block_diag(B_X_W, jnp.eye(n)) + def to_body(): + return C_B - with data.switch_velocity_representation(VelRepr.Inertial): - W_v_WB = data.base_velocity() - B_Ẋ_W = -B_X_W @ jaxsim.math.Cross.vx(W_v_WB) + def to_inertial(): + n = model.dofs() + W_H_B = data.base_transform() + B_X_W = jaxsim.math.Adjoint.from_transform(W_H_B, inverse=True) + B_T_W = jax.scipy.linalg.block_diag(B_X_W, jnp.eye(n)) - B_Ṫ_W = jax.scipy.linalg.block_diag(B_Ẋ_W, jnp.zeros(shape=(n, n))) + with data.switch_velocity_representation(VelRepr.Inertial): + W_v_WB = data.base_velocity() + B_Ẋ_W = -B_X_W @ jaxsim.math.Cross.vx(W_v_WB) - with data.switch_velocity_representation(VelRepr.Body): - M = free_floating_mass_matrix(model=model, data=data) + B_Ṫ_W = jax.scipy.linalg.block_diag(B_Ẋ_W, jnp.zeros(shape=(n, n))) - C = B_T_W.T @ (M @ B_Ṫ_W + C_B @ B_T_W) + with data.switch_velocity_representation(VelRepr.Body): + M = free_floating_mass_matrix(model=model, data=data) - return C + C = B_T_W.T @ (M @ B_Ṫ_W + C_B @ B_T_W) - case VelRepr.Mixed: + return C - n = model.dofs() - BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3)) - B_X_BW = jaxsim.math.Adjoint.from_transform(transform=BW_H_B, inverse=True) - B_T_BW = jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(n)) + def to_mixed(): + n = model.dofs() + BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3)) + B_X_BW = jaxsim.math.Adjoint.from_transform(transform=BW_H_B, inverse=True) + B_T_BW = jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(n)) - with data.switch_velocity_representation(VelRepr.Mixed): - BW_v_WB = data.base_velocity() - BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3)) + with data.switch_velocity_representation(VelRepr.Mixed): + BW_v_WB = data.base_velocity() + BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3)) - BW_v_BW_B = BW_v_WB - BW_v_W_BW - B_Ẋ_BW = -B_X_BW @ jaxsim.math.Cross.vx(BW_v_BW_B) + BW_v_BW_B = BW_v_WB - BW_v_W_BW + B_Ẋ_BW = -B_X_BW @ jaxsim.math.Cross.vx(BW_v_BW_B) - B_Ṫ_BW = jax.scipy.linalg.block_diag(B_Ẋ_BW, jnp.zeros(shape=(n, n))) + B_Ṫ_BW = jax.scipy.linalg.block_diag(B_Ẋ_BW, jnp.zeros(shape=(n, n))) - with data.switch_velocity_representation(VelRepr.Body): - M = free_floating_mass_matrix(model=model, data=data) + with data.switch_velocity_representation(VelRepr.Body): + M = free_floating_mass_matrix(model=model, data=data) - C = B_T_BW.T @ (M @ B_Ṫ_BW + C_B @ B_T_BW) + C = B_T_BW.T @ (M @ B_Ṫ_BW + C_B @ B_T_BW) - return C + return C - case _: - raise ValueError(data.velocity_representation) + return jax.lax.switch( + index=data.velocity_representation, + branches=( + to_body, # VelRepr.Body + to_mixed, # VelRepr.Mixed + to_inertial, # VelRepr.Inertial + ), + ) @jax.jit @@ -1153,25 +1168,32 @@ def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC): # In Inertial and Body representations, the cross product is always zero. return W_X_C @ (C_v̇_WB + Cross.vx(C_v_WC) @ C_v_WB) - match data.velocity_representation: - case VelRepr.Inertial: - W_H_C = W_H_W = jnp.eye(4) # noqa: F841 - W_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841 + def convert_inertial() -> jtp.Vector: + W_H_C = W_H_W = jnp.eye(4) # noqa: F841 + W_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841 + return W_H_C, W_v_WC - case VelRepr.Body: - W_H_C = W_H_B = data.base_transform() - with data.switch_velocity_representation(VelRepr.Inertial): - W_v_WC = W_v_WB = data.base_velocity() - - case VelRepr.Mixed: - W_H_B = data.base_transform() - W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) # noqa: F841 - W_ṗ_B = data.base_velocity()[0:3] - W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) # noqa: F841 + def convert_body() -> jtp.Vector: + W_H_C = W_H_B = data.base_transform() # noqa: F841 + with data.switch_velocity_representation(VelRepr.Inertial): + W_v_WC = W_v_WB = data.base_velocity() # noqa: F841 + return W_H_C, W_v_WC - case _: - raise ValueError(data.velocity_representation) + def convert_mixed() -> jtp.Vector: + W_H_B = data.base_transform() + W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) # noqa: F841 + W_ṗ_B = data.base_velocity()[0:3] + W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) # noqa: F841 + return W_H_C, W_v_WC + W_H_C, W_v_WC = jax.lax.switch( + index=data.velocity_representation, + branches=( + convert_body, # VelRepr.Body + convert_mixed, # VelRepr.Mixed + convert_inertial, # VelRepr.Inertial + ), + ) # We need to convert the derivative of the base acceleration to the Inertial # representation. In Mixed representation, this conversion is not a plain # transformation with just X, but it also involves a cross product in ℝ⁶. @@ -1615,33 +1637,39 @@ def other_representation_to_inertial( # because the apparent acceleration W_v̇_WB is equal to the intrinsic acceleration # W_a_WB, and intrinsic accelerations can be expressed in different frames through # a simple C_X_W 6D transform. - match data.velocity_representation: - case VelRepr.Inertial: - W_H_C = W_H_W = jnp.eye(4) # noqa: F841 - W_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841 - with data.switch_velocity_representation(VelRepr.Inertial): - C_v_WB = W_v_WB = data.base_velocity() - - case VelRepr.Body: - W_H_C = W_H_B - with data.switch_velocity_representation(VelRepr.Inertial): - W_v_WC = W_v_WB = data.base_velocity() # noqa: F841 - with data.switch_velocity_representation(VelRepr.Body): - C_v_WB = B_v_WB = data.base_velocity() - - case VelRepr.Mixed: - W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) - W_H_C = W_H_BW - with data.switch_velocity_representation(VelRepr.Mixed): - W_ṗ_B = data.base_velocity()[0:3] - BW_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) - W_X_BW = jaxsim.math.Adjoint.from_transform(transform=W_H_BW) - W_v_WC = W_v_W_BW = W_X_BW @ BW_v_W_BW # noqa: F841 - with data.switch_velocity_representation(VelRepr.Mixed): - C_v_WB = BW_v_WB = data.base_velocity() # noqa: F841 - - case _: - raise ValueError(data.velocity_representation) + def to_inertial() -> jtp.Matrix: + W_H_C = W_H_W = jnp.eye(4) # noqa: F841 + W_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841 + with data.switch_velocity_representation(VelRepr.Inertial): + C_v_WB = W_v_WB = data.base_velocity() # noqa: F841 + return W_H_C, W_v_WC, C_v_WB + + def to_body() -> jtp.Matrix: + W_H_C = W_H_B + with data.switch_velocity_representation(VelRepr.Inertial): + W_v_WC = W_v_WB = data.base_velocity() # noqa: F841 + with data.switch_velocity_representation(VelRepr.Body): + C_v_WB = B_v_WB = data.base_velocity() # noqa: F841 + return W_H_C, W_v_WC, C_v_WB + + def to_mixed() -> jtp.Matrix: + W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) + W_H_C = W_H_BW + with data.switch_velocity_representation(VelRepr.Mixed): + W_ṗ_B = data.base_velocity()[0:3] + W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) # noqa: F841 + with data.switch_velocity_representation(VelRepr.Mixed): + C_v_WB = BW_v_WB = data.base_velocity() # noqa: F841 + return W_H_C, W_v_WC, C_v_WB + + W_H_C, W_v_WC, C_v_WB = jax.lax.switch( + index=data.velocity_representation, + branches=( + to_body, # VelRepr.Body + to_mixed, # VelRepr.Mixed + to_inertial, # VelRepr.Inertial + ), + ) # Convert a zero 6D acceleration from the active representation to inertial-fixed. W_v̇_WB = other_representation_to_inertial( @@ -1744,29 +1772,31 @@ def body_to_other_representation( C_X_L = jaxsim.math.Adjoint.from_transform(transform=C_H_L) return C_X_L @ (L_v̇_WL + jaxsim.math.Cross.vx(L_v_CL) @ L_v_WL) - match data.velocity_representation: - case VelRepr.Body: - C_H_L = L_H_L = jnp.stack( # noqa: F841 - [jnp.eye(4)] * model.number_of_links() - ) - L_v_CL = L_v_LL = jnp.zeros( # noqa: F841 - shape=(model.number_of_links(), 6) - ) + def to_body() -> jtp.Matrix: + C_H_L = L_H_L = jnp.stack([jnp.eye(4)] * model.number_of_links()) # noqa: F841 + L_v_CL = L_v_LL = jnp.zeros(shape=(model.number_of_links(), 6)) # noqa: F841 + return C_H_L, L_v_CL - case VelRepr.Inertial: - C_H_L = W_H_L = js.model.forward_kinematics(model=model, data=data) - L_v_CL = L_v_WL + def to_inertial() -> jtp.Matrix: + C_H_L = W_H_L = js.model.forward_kinematics(model=model, data=data) # noqa: F841 + L_v_CL = L_v_WL + return C_H_L, L_v_CL - case VelRepr.Mixed: - W_H_L = js.model.forward_kinematics(model=model, data=data) - LW_H_L = jax.vmap(lambda W_H_L: W_H_L.at[0:3, 3].set(jnp.zeros(3)))(W_H_L) - C_H_L = LW_H_L - L_v_CL = L_v_LW_L = jax.vmap( # noqa: F841 - lambda v: v.at[0:3].set(jnp.zeros(3)) - )(L_v_WL) + def to_mixed() -> jtp.Matrix: + W_H_L = js.model.forward_kinematics(model=model, data=data) + LW_H_L = jax.vmap(lambda W_H_L: W_H_L.at[0:3, 3].set(jnp.zeros(3)))(W_H_L) + C_H_L = LW_H_L + L_v_CL = L_v_LW_L = jax.vmap(lambda v: v.at[0:3].set(jnp.zeros(3)))(L_v_WL) # noqa: F841 + return C_H_L, L_v_CL - case _: - raise ValueError(data.velocity_representation) + C_H_L, L_v_CL = jax.lax.switch( + index=data.velocity_representation, + branches=( + to_body, # VelRepr.Body + to_mixed, # VelRepr.Mixed + to_inertial, # VelRepr.Inertial + ), + ) # Convert from body-fixed to the active representation. O_v̇_WL = jax.vmap(body_to_other_representation)( From 1c326399e8456bf2286860d51f637a005a602b10 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 4 Jun 2024 15:51:03 +0200 Subject: [PATCH 13/27] Update error messages in `api.references` --- src/jaxsim/api/references.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/jaxsim/api/references.py b/src/jaxsim/api/references.py index c34aa2240..29d5767ae 100644 --- a/src/jaxsim/api/references.py +++ b/src/jaxsim/api/references.py @@ -184,8 +184,9 @@ def link_forces( # serialization. if model is None: if self.velocity_representation is not VelRepr.Inertial: - msg = "Missing model to use a representation different from {}" - raise ValueError(msg.format(VelRepr.Inertial.name)) + raise ValueError( + "Missing model to use a representation different from `VelRepr.Inertial`" + ) if link_names is not None: raise ValueError("Link names cannot be provided without a model") @@ -201,8 +202,9 @@ def link_forces( return W_f_L[link_idxs, :] if data is None: - msg = "Missing model data to use a representation different from {}" - raise ValueError(msg.format(VelRepr.Inertial.name)) + raise ValueError( + "Missing model data to use a representation different from `VelRepr.Inertial`" + ) if not_tracing(self.input.physics_model.f_ext) and not data.valid(model=model): raise ValueError("The provided data is not valid for the model") @@ -372,8 +374,9 @@ def replace(forces: jtp.MatrixLike) -> JaxSimModelReferences: # using the implicit link serialization. if model is None: if self.velocity_representation is not VelRepr.Inertial: - msg = "Missing model to use a representation different from {}" - raise ValueError(msg.format(VelRepr.Inertial.name)) + raise ValueError( + "Missing model to use a representation different from `VelRepr.Inertial`" + ) if link_names is not None: raise ValueError("Link names cannot be provided without a model") @@ -418,8 +421,9 @@ def replace(forces: jtp.MatrixLike) -> JaxSimModelReferences: ) if data is None: - msg = "Missing model data to use a representation different from {}" - raise ValueError(msg.format(VelRepr.Inertial.name)) + raise ValueError( + "Missing model data to use a representation different from `VelRepr.Inertial`" + ) if not_tracing(forces) and not data.valid(model=model): raise ValueError("The provided data is not valid for the model") From 0a2d9c3d41bc8cf0ef84b0f4534654f461568a67 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 4 Jun 2024 16:36:29 +0200 Subject: [PATCH 14/27] Use `jax.lax.switch` in `api.contact` for `VelRepr` --- src/jaxsim/api/contact.py | 55 +++++++++++++++++++++------------------ 1 file changed, 29 insertions(+), 26 deletions(-) diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 3caa75624..9ddcc3c17 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -372,44 +372,47 @@ def jacobian( jnp.array(model.kin_dyn_parameters.contact_parameters.body, dtype=int) ) - # Adjust the output representation. - match output_vel_repr: + def to_inertial() -> jtp.Matrix: + O_J_WC = W_J_WC + return O_J_WC - case VelRepr.Inertial: - O_J_WC = W_J_WC + def to_body() -> jtp.Matrix: - case VelRepr.Body: + W_H_C = transforms(model=model, data=data) - W_H_C = transforms(model=model, data=data) + def jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix: + C_X_W = jaxsim.math.Adjoint.from_transform(transform=W_H_C, inverse=True) + C_J_WC = C_X_W @ W_J_WC + return C_J_WC - def body_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix: - C_X_W = jaxsim.math.Adjoint.from_transform( - transform=W_H_C, inverse=True - ) - C_J_WC = C_X_W @ W_J_WC - return C_J_WC + O_J_WC = jax.vmap(jacobian)(W_H_C, W_J_WC) + return O_J_WC - O_J_WC = jax.vmap(body_jacobian)(W_H_C, W_J_WC) + def to_mixed() -> jtp.Matrix: - case VelRepr.Mixed: + W_H_C = transforms(model=model, data=data) - W_H_C = transforms(model=model, data=data) + def jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix: - def mixed_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix: + W_H_CW = W_H_C.at[0:3, 0:3].set(jnp.eye(3)) - W_H_CW = W_H_C.at[0:3, 0:3].set(jnp.eye(3)) + CW_X_W = jaxsim.math.Adjoint.from_transform(transform=W_H_CW, inverse=True) - CW_X_W = jaxsim.math.Adjoint.from_transform( - transform=W_H_CW, inverse=True - ) - - CW_J_WC = CW_X_W @ W_J_WC - return CW_J_WC + CW_J_WC = CW_X_W @ W_J_WC + return CW_J_WC - O_J_WC = jax.vmap(mixed_jacobian)(W_H_C, W_J_WC) + O_J_WC = jax.vmap(jacobian)(W_H_C, W_J_WC) + return O_J_WC - case _: - raise ValueError(output_vel_repr) + # Adjust the output representation. + O_J_WC = jax.lax.switch( + index=output_vel_repr, + branches=( + to_body, # VelRepr.Body + to_mixed, # VelRepr.Mixed + to_inertial, # VelRepr.Inertial + ), + ) return O_J_WC From c3940f3fa94632e3c6a100b919ae0ec85eb2572c Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 4 Jun 2024 17:29:00 +0200 Subject: [PATCH 15/27] Fix dimensions in `api.com.bias_acceleration` --- src/jaxsim/api/com.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jaxsim/api/com.py b/src/jaxsim/api/com.py index c9be3c15f..6057edaed 100644 --- a/src/jaxsim/api/com.py +++ b/src/jaxsim/api/com.py @@ -422,7 +422,7 @@ def to_body() -> jtp.Vector: # G := G[B] = (W_p_CoM, [B]) GB_Xf_W = Adjoint.from_transform( - transform=data.base_transform().at[0:3].set(W_p_CoM) + transform=data.base_transform().at[0:3, 3].set(W_p_CoM) ).T GB_ḣ_bias = GB_Xf_W @ W_ḣ_bias From 6c59458e9b5383c555dcddfb924ab340c45eadb1 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 4 Jun 2024 17:33:09 +0200 Subject: [PATCH 16/27] Use equality instead of identity operator for comparing `VelRepr` --- src/jaxsim/api/references.py | 8 ++++---- tests/test_api_com.py | 8 ++++++-- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/jaxsim/api/references.py b/src/jaxsim/api/references.py index 29d5767ae..c9e4c81d5 100644 --- a/src/jaxsim/api/references.py +++ b/src/jaxsim/api/references.py @@ -183,7 +183,7 @@ def link_forces( # Return all link forces in inertial-fixed representation using the implicit # serialization. if model is None: - if self.velocity_representation is not VelRepr.Inertial: + if (self.velocity_representation != VelRepr.Inertial) is True: raise ValueError( "Missing model to use a representation different from `VelRepr.Inertial`" ) @@ -198,7 +198,7 @@ def link_forces( link_idxs = js.link.names_to_idxs(link_names=link_names, model=model) # In inertial-fixed representation, we already have the link forces. - if self.velocity_representation is VelRepr.Inertial: + if (self.velocity_representation == VelRepr.Inertial) is True: return W_f_L[link_idxs, :] if data is None: @@ -373,7 +373,7 @@ def replace(forces: jtp.MatrixLike) -> JaxSimModelReferences: # In this case, we allow only to set the inertial 6D forces to all links # using the implicit link serialization. if model is None: - if self.velocity_representation is not VelRepr.Inertial: + if (self.velocity_representation != VelRepr.Inertial) is True: raise ValueError( "Missing model to use a representation different from `VelRepr.Inertial`" ) @@ -412,7 +412,7 @@ def replace(forces: jtp.MatrixLike) -> JaxSimModelReferences: ) # If inertial-fixed representation, we can directly store the link forces. - if self.velocity_representation is VelRepr.Inertial: + if (self.velocity_representation == VelRepr.Inertial) is True: W_f_L = f_L return replace( forces=self.input.physics_model.f_ext.at[link_idxs, :].set( diff --git a/tests/test_api_com.py b/tests/test_api_com.py index 559689401..81fb6abef 100644 --- a/tests/test_api_com.py +++ b/tests/test_api_com.py @@ -53,7 +53,11 @@ def test_com_properties( assert pytest.approx(v_avg_com_idt) == v_avg_com_js # https://github.com/ami-iit/jaxsim/pull/117#discussion_r1535486123 - if data.velocity_representation is not VelRepr.Body: + with data.switch_velocity_representation( + data.velocity_representation + if data.velocity_representation != VelRepr.Body + else VelRepr.Mixed + ): vl_com_idt = kin_dyn.com_velocity() vl_com_js = js.com.com_linear_velocity(model=model, data=data) assert pytest.approx(vl_com_idt) == vl_com_js @@ -61,7 +65,7 @@ def test_com_properties( # iDynTree provides the bias acceleration in G[W] frame regardless of the velocity # representation. JaxSim, instead, returns the bias acceleration in G[B] when the # active representation is VelRepr.Body. - if data.velocity_representation is not VelRepr.Body: + if data.velocity_representation != VelRepr.Body: G_v̇_bias_WG_idt = kin_dyn.com_bias_acceleration() G_v̇_bias_WG_js = js.com.bias_acceleration(model=model, data=data) assert pytest.approx(G_v̇_bias_WG_idt) == G_v̇_bias_WG_js From 6163ad9e7942659c72f8cd5ed71843796d63cab0 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 4 Jun 2024 17:43:58 +0200 Subject: [PATCH 17/27] Update type-hints to `int` for `VelRepr` --- src/jaxsim/api/common.py | 8 ++++---- src/jaxsim/api/contact.py | 2 +- src/jaxsim/api/data.py | 12 ++++++------ src/jaxsim/api/frame.py | 2 +- src/jaxsim/api/link.py | 6 ++---- src/jaxsim/api/model.py | 6 +++--- src/jaxsim/api/references.py | 4 ++-- tests/test_api_com.py | 2 +- tests/test_api_data.py | 2 +- tests/test_api_frame.py | 2 +- tests/test_api_link.py | 4 ++-- tests/test_api_model.py | 6 +++--- tests/test_automatic_differentiation.py | 2 +- tests/test_contact.py | 2 +- tests/test_simulations.py | 2 +- tests/utils_idyntree.py | 4 ++-- 16 files changed, 32 insertions(+), 34 deletions(-) diff --git a/src/jaxsim/api/common.py b/src/jaxsim/api/common.py index 74eadab48..9ff0bece6 100644 --- a/src/jaxsim/api/common.py +++ b/src/jaxsim/api/common.py @@ -35,13 +35,13 @@ class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC): Base class for model data structures with velocity representation. """ - velocity_representation: VelRepr = dataclasses.field( + velocity_representation: int = dataclasses.field( default=VelRepr.Inertial, kw_only=True ) @contextlib.contextmanager def switch_velocity_representation( - self, velocity_representation: VelRepr + self, velocity_representation: int ) -> ContextManager[Self]: """ Context manager to temporarily switch the velocity representation. @@ -83,7 +83,7 @@ def switch_velocity_representation( @functools.partial(jax.jit, static_argnames=["is_force"]) def inertial_to_other_representation( array: jtp.Array, - other_representation: VelRepr, + other_representation: int, transform: jtp.Matrix, *, is_force: bool, @@ -148,7 +148,7 @@ def to_mixed(): @functools.partial(jax.jit, static_argnames=["is_force"]) def other_representation_to_inertial( array: jtp.Array, - other_representation: VelRepr, + other_representation: int, transform: jtp.Matrix, *, is_force: bool, diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 9ddcc3c17..3f472ec90 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -334,7 +334,7 @@ def jacobian( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, *, - output_vel_repr: VelRepr | None = None, + output_vel_repr: int | None = None, ) -> jtp.Array: r""" Return the free-floating Jacobian of the collidable points. diff --git a/src/jaxsim/api/data.py b/src/jaxsim/api/data.py index 6f28f1ced..70d2774e5 100644 --- a/src/jaxsim/api/data.py +++ b/src/jaxsim/api/data.py @@ -84,7 +84,7 @@ def valid(self, model: js.model.JaxSimModel | None = None) -> bool: @staticmethod def zero( model: js.model.JaxSimModel, - velocity_representation: VelRepr = VelRepr.Inertial, + velocity_representation: int = VelRepr.Inertial, ) -> JaxSimModelData: """ Create a `JaxSimModelData` object with zero state. @@ -113,7 +113,7 @@ def build( standard_gravity: jtp.FloatLike = jaxsim.math.StandardGravity, contact: jaxsim.rbda.ContactsState | None = None, contacts_params: jaxsim.rbda.ContactsParams | None = None, - velocity_representation: VelRepr = VelRepr.Inertial, + velocity_representation: int = VelRepr.Inertial, time: jtp.FloatLike | None = None, ) -> JaxSimModelData: """ @@ -625,7 +625,7 @@ def reset_base_pose(self, base_pose: jtp.MatrixLike) -> Self: def reset_base_linear_velocity( self, linear_velocity: jtp.VectorLike, - velocity_representation: VelRepr | None = None, + velocity_representation: int | None = None, ) -> Self: """ Reset the base linear velocity. @@ -656,7 +656,7 @@ def reset_base_linear_velocity( def reset_base_angular_velocity( self, angular_velocity: jtp.VectorLike, - velocity_representation: VelRepr | None = None, + velocity_representation: int | None = None, ) -> Self: """ Reset the base angular velocity. @@ -687,7 +687,7 @@ def reset_base_angular_velocity( def reset_base_velocity( self, base_velocity: jtp.VectorLike, - velocity_representation: VelRepr | None = None, + velocity_representation: int | None = None, ) -> Self: """ Reset the base 6D velocity. @@ -732,7 +732,7 @@ def random_model_data( model: js.model.JaxSimModel, *, key: jax.Array | None = None, - velocity_representation: VelRepr | None = None, + velocity_representation: int | None = None, base_pos_bounds: tuple[ jtp.FloatLike | Sequence[jtp.FloatLike], jtp.FloatLike | Sequence[jtp.FloatLike], diff --git a/src/jaxsim/api/frame.py b/src/jaxsim/api/frame.py index e243df6dd..698b61202 100644 --- a/src/jaxsim/api/frame.py +++ b/src/jaxsim/api/frame.py @@ -186,7 +186,7 @@ def jacobian( data: js.data.JaxSimModelData, *, frame_index: jtp.IntLike, - output_vel_repr: VelRepr | None = None, + output_vel_repr: int | None = None, ) -> jtp.Matrix: r""" Compute the free-floating jacobian of the frame. diff --git a/src/jaxsim/api/link.py b/src/jaxsim/api/link.py index affc8b361..87db39afb 100644 --- a/src/jaxsim/api/link.py +++ b/src/jaxsim/api/link.py @@ -11,8 +11,6 @@ from jaxsim import exceptions from jaxsim.math import Adjoint, Transform -from .common import VelRepr - # ======================= # Index-related functions # ======================= @@ -241,7 +239,7 @@ def jacobian( data: js.data.JaxSimModelData, *, link_index: jtp.IntLike, - output_vel_repr: VelRepr | None = None, + output_vel_repr: int | None = None, ) -> jtp.Matrix: r""" Compute the free-floating jacobian of the link. @@ -354,7 +352,7 @@ def velocity( data: js.data.JaxSimModelData, *, link_index: jtp.IntLike, - output_vel_repr: VelRepr | None = None, + output_vel_repr: int | None = None, ) -> jtp.Vector: """ Compute the 6D velocity of the link. diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 8d1b06fcd..e607b6b42 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -453,7 +453,7 @@ def generalized_free_floating_jacobian( model: JaxSimModel, data: js.data.JaxSimModelData, *, - output_vel_repr: VelRepr | None = None, + output_vel_repr: int | None = None, ) -> jtp.Matrix: """ Compute the free-floating jacobians of all links. @@ -1422,7 +1422,7 @@ def total_momentum_jacobian( model: JaxSimModel, data: js.data.JaxSimModelData, *, - output_vel_repr: VelRepr | None = None, + output_vel_repr: int | None = None, ) -> jtp.Matrix: """ Compute the jacobian of the total momentum. @@ -1519,7 +1519,7 @@ def average_velocity_jacobian( model: JaxSimModel, data: js.data.JaxSimModelData, *, - output_vel_repr: VelRepr | None = None, + output_vel_repr: int | None = None, ) -> jtp.Matrix: """ Compute the Jacobian of the average velocity of the model. diff --git a/src/jaxsim/api/references.py b/src/jaxsim/api/references.py index c9e4c81d5..bef48ff5c 100644 --- a/src/jaxsim/api/references.py +++ b/src/jaxsim/api/references.py @@ -32,7 +32,7 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation): def zero( model: js.model.JaxSimModel, data: js.data.JaxSimModelData | None = None, - velocity_representation: VelRepr = VelRepr.Inertial, + velocity_representation: int = VelRepr.Inertial, ) -> JaxSimModelReferences: """ Create a `JaxSimModelReferences` object with zero references. @@ -58,7 +58,7 @@ def build( joint_force_references: jtp.Vector | None = None, link_forces: jtp.Matrix | None = None, data: js.data.JaxSimModelData | None = None, - velocity_representation: VelRepr | None = None, + velocity_representation: int | None = None, ) -> JaxSimModelReferences: """ Create a `JaxSimModelReferences` object with the given references. diff --git a/tests/test_api_com.py b/tests/test_api_com.py index 81fb6abef..83a17bf75 100644 --- a/tests/test_api_com.py +++ b/tests/test_api_com.py @@ -9,7 +9,7 @@ def test_com_properties( jaxsim_models_types: js.model.JaxSimModel, - velocity_representation: VelRepr, + velocity_representation: int, prng_key: jax.Array, ): diff --git a/tests/test_api_data.py b/tests/test_api_data.py index 9db5e6561..78e19d153 100644 --- a/tests/test_api_data.py +++ b/tests/test_api_data.py @@ -21,7 +21,7 @@ def test_data_valid( def test_data_joint_indexing( jaxsim_models_types: js.model.JaxSimModel, - velocity_representation: VelRepr, + velocity_representation: int, prng_key: jax.Array, ): diff --git a/tests/test_api_frame.py b/tests/test_api_frame.py index 41965ae14..ee4fc5ec8 100644 --- a/tests/test_api_frame.py +++ b/tests/test_api_frame.py @@ -124,7 +124,7 @@ def test_frame_transforms( def test_frame_jacobians( jaxsim_models_types: js.model.JaxSimModel, - velocity_representation: VelRepr, + velocity_representation: int, prng_key: jax.Array, ): diff --git a/tests/test_api_link.py b/tests/test_api_link.py index 418f78a99..858587409 100644 --- a/tests/test_api_link.py +++ b/tests/test_api_link.py @@ -128,7 +128,7 @@ def test_link_transforms( def test_link_jacobians( jaxsim_models_types: js.model.JaxSimModel, - velocity_representation: VelRepr, + velocity_representation: int, prng_key: jax.Array, ): @@ -197,7 +197,7 @@ def test_link_jacobians( def test_link_bias_acceleration( jaxsim_models_types: js.model.JaxSimModel, - velocity_representation: VelRepr, + velocity_representation: int, prng_key: jax.Array, ): diff --git a/tests/test_api_model.py b/tests/test_api_model.py index a0fe49db5..439e44761 100644 --- a/tests/test_api_model.py +++ b/tests/test_api_model.py @@ -222,7 +222,7 @@ def test_model_creation_and_reduction( def test_model_properties( jaxsim_models_types: js.model.JaxSimModel, - velocity_representation: VelRepr, + velocity_representation: int, prng_key: jax.Array, ): @@ -269,7 +269,7 @@ def test_model_properties( def test_model_rbda( jaxsim_models_types: js.model.JaxSimModel, prng_key: jax.Array, - velocity_representation: VelRepr, + velocity_representation: int, ): model = jaxsim_models_types @@ -480,7 +480,7 @@ def compute_q̇(data: js.data.JaxSimModelData) -> jax.Array: def test_model_fd_id_consistency( jaxsim_models_types: js.model.JaxSimModel, - velocity_representation: VelRepr, + velocity_representation: int, prng_key: jax.Array, ): diff --git a/tests/test_automatic_differentiation.py b/tests/test_automatic_differentiation.py index 132cfcca0..fd075ef0e 100644 --- a/tests/test_automatic_differentiation.py +++ b/tests/test_automatic_differentiation.py @@ -26,7 +26,7 @@ def get_random_data_and_references( model: js.model.JaxSimModel, - velocity_representation: VelRepr, + velocity_representation: int, key: jax.Array, ) -> tuple[js.data.JaxSimModelData, js.references.JaxSimModelReferences]: diff --git a/tests/test_contact.py b/tests/test_contact.py index ff48385a7..20af3dab2 100644 --- a/tests/test_contact.py +++ b/tests/test_contact.py @@ -7,7 +7,7 @@ def test_collidable_point_jacobians( jaxsim_models_types: js.model.JaxSimModel, - velocity_representation: VelRepr, + velocity_representation: int, prng_key: jax.Array, ): diff --git a/tests/test_simulations.py b/tests/test_simulations.py index 41807b8e2..5dfb829fe 100644 --- a/tests/test_simulations.py +++ b/tests/test_simulations.py @@ -11,7 +11,7 @@ def test_box_with_external_forces( jaxsim_model_box: js.model.JaxSimModel, - velocity_representation: VelRepr, + velocity_representation: int, ): """ This test simulates a box falling due to gravity. diff --git a/tests/utils_idyntree.py b/tests/utils_idyntree.py index 5376a82ef..8c235a94f 100644 --- a/tests/utils_idyntree.py +++ b/tests/utils_idyntree.py @@ -118,7 +118,7 @@ def store_jaxsim_data_in_kindyncomputations( class KinDynComputations: """High-level wrapper of the iDynTree KinDynComputations class.""" - vel_repr: VelRepr + vel_repr: int gravity: npt.NDArray kin_dyn: idt.KinDynComputations @@ -126,7 +126,7 @@ class KinDynComputations: def build( urdf: pathlib.Path | str, considered_joints: list[str] | None = None, - vel_repr: VelRepr = VelRepr.Inertial, + vel_repr: int = VelRepr.Inertial, gravity: npt.NDArray = np.array([0, 0, -10.0]), removed_joint_positions: dict[str, npt.NDArray | float | int] | None = None, ) -> KinDynComputations: From 42cd69eb61625e7a5318a53e9a19b309a7a0a87b Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Wed, 5 Jun 2024 10:41:55 +0200 Subject: [PATCH 18/27] Use `jax.lax.cond` to check the velocity representation --- src/jaxsim/api/references.py | 174 ++++++++++++++++++++--------------- 1 file changed, 102 insertions(+), 72 deletions(-) diff --git a/src/jaxsim/api/references.py b/src/jaxsim/api/references.py index bef48ff5c..a8c523784 100644 --- a/src/jaxsim/api/references.py +++ b/src/jaxsim/api/references.py @@ -183,50 +183,63 @@ def link_forces( # Return all link forces in inertial-fixed representation using the implicit # serialization. if model is None: - if (self.velocity_representation != VelRepr.Inertial) is True: - raise ValueError( - "Missing model to use a representation different from `VelRepr.Inertial`" - ) - if link_names is not None: - raise ValueError("Link names cannot be provided without a model") + def not_inertial(): + if link_names is not None: + raise ValueError("Link names cannot be provided without a model") + + return self.input.physics_model.f_ext - return self.input.physics_model.f_ext + jax.lax.cond( + pred=(self.velocity_representation == VelRepr.Inertial), + true_fun=not_inertial, + false_fun=lambda: (_ for _ in (None,)).throw( + ValueError( + "Missing model to use a representation different from `VelRepr.Inertial`" + ) + ), + ) # If we have the model, we can extract the link names, if not provided. link_names = link_names if link_names is not None else model.link_names() link_idxs = js.link.names_to_idxs(link_names=link_names, model=model) - # In inertial-fixed representation, we already have the link forces. - if (self.velocity_representation == VelRepr.Inertial) is True: - return W_f_L[link_idxs, :] - - if data is None: - raise ValueError( - "Missing model data to use a representation different from `VelRepr.Inertial`" - ) - - if not_tracing(self.input.physics_model.f_ext) and not data.valid(model=model): - raise ValueError("The provided data is not valid for the model") + def not_inertial(): + if data is None: + raise ValueError( + "Missing model data to use a representation different from `VelRepr.Inertial`" + ) - # Helper function to convert a single 6D force to the active representation - # considering as body the link (i.e. L_f_L and LW_f_L). - def convert(W_f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike) -> jtp.Matrix: + if not_tracing(self.input.physics_model.f_ext) and not data.valid( + model=model + ): + raise ValueError("The provided data is not valid for the model") + + # Helper function to convert a single 6D force to the active representation + # considering as body the link (i.e. L_f_L and LW_f_L). + def convert(W_f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike) -> jtp.Matrix: + + return jax.vmap( + lambda W_f_L, W_H_L: JaxSimModelReferences.inertial_to_other_representation( + array=W_f_L, + other_representation=self.velocity_representation, + transform=W_H_L, + is_force=True, + ) + )(W_f_L, W_H_L) - return jax.vmap( - lambda W_f_L, W_H_L: JaxSimModelReferences.inertial_to_other_representation( - array=W_f_L, - other_representation=self.velocity_representation, - transform=W_H_L, - is_force=True, - ) - )(W_f_L, W_H_L) + # The f_L output is either L_f_L or LW_f_L, depending on the representation. + W_H_L = js.model.forward_kinematics(model=model, data=data) + f_L = convert(W_f_L=W_f_L[link_idxs, :], W_H_L=W_H_L[link_idxs, :, :]) - # The f_L output is either L_f_L or LW_f_L, depending on the representation. - W_H_L = js.model.forward_kinematics(model=model, data=data) - f_L = convert(W_f_L=W_f_L[link_idxs, :], W_H_L=W_H_L[link_idxs, :, :]) + return f_L - return f_L + # In inertial-fixed representation, we already have the link forces. + return jax.lax.cond( + pred=(self.velocity_representation == VelRepr.Inertial), + true_fun=lambda: W_f_L[link_idxs, :], + false_fun=not_inertial, + ) def joint_force_references( self, @@ -373,23 +386,30 @@ def replace(forces: jtp.MatrixLike) -> JaxSimModelReferences: # In this case, we allow only to set the inertial 6D forces to all links # using the implicit link serialization. if model is None: - if (self.velocity_representation != VelRepr.Inertial) is True: - raise ValueError( - "Missing model to use a representation different from `VelRepr.Inertial`" - ) - if link_names is not None: - raise ValueError("Link names cannot be provided without a model") + def not_inertial(): + if link_names is not None: + raise ValueError("Link names cannot be provided without a model") - W_f_L = f_L + W_f_L = f_L - W_f0_L = ( - jnp.zeros_like(W_f_L) - if not additive - else self.input.physics_model.f_ext - ) + W_f0_L = ( + jnp.zeros_like(W_f_L) + if not additive + else self.input.physics_model.f_ext + ) - return replace(forces=W_f0_L + W_f_L) + return replace(forces=W_f0_L + W_f_L) + + jax.lax.cond( + pred=(self.velocity_representation == VelRepr.Inertial), + true_fun=not_inertial, + false_fun=lambda: (_ for _ in (None,)).throw( + ValueError( + "Missing model to use a representation different from `VelRepr.Inertial`" + ) + ), + ) # If we have the model, we can extract the link names if not provided. link_names = link_names if link_names is not None else model.link_names() @@ -412,7 +432,7 @@ def replace(forces: jtp.MatrixLike) -> JaxSimModelReferences: ) # If inertial-fixed representation, we can directly store the link forces. - if (self.velocity_representation == VelRepr.Inertial) is True: + def inertial(): W_f_L = f_L return replace( forces=self.input.physics_model.f_ext.at[link_idxs, :].set( @@ -420,35 +440,45 @@ def replace(forces: jtp.MatrixLike) -> JaxSimModelReferences: ) ) - if data is None: - raise ValueError( - "Missing model data to use a representation different from `VelRepr.Inertial`" - ) - - if not_tracing(forces) and not data.valid(model=model): - raise ValueError("The provided data is not valid for the model") + def not_inertial(): - # Helper function to convert a single 6D force to the inertial representation - # considering as body the link (i.e. L_f_L and LW_f_L). - def convert_using_link_frame( - f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike - ) -> jtp.Matrix: - - return jax.vmap( - lambda f_L, W_H_L: JaxSimModelReferences.other_representation_to_inertial( - array=f_L, - other_representation=self.velocity_representation, - transform=W_H_L, - is_force=True, + if data is None: + raise ValueError( + "Missing model data to use a representation different from `VelRepr.Inertial`" ) - )(f_L, W_H_L) - # The f_L input is either L_f_L or LW_f_L, depending on the representation. - W_H_L = js.model.forward_kinematics(model=model, data=data) - W_f_L = convert_using_link_frame(f_L=f_L, W_H_L=W_H_L[link_idxs, :, :]) + if not_tracing(forces) and not data.valid(model=model): + raise ValueError("The provided data is not valid for the model") + + # Helper function to convert a single 6D force to the inertial representation + # considering as body the link (i.e. L_f_L and LW_f_L). + def convert_using_link_frame( + f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike + ) -> jtp.Matrix: + + return jax.vmap( + lambda f_L, W_H_L: JaxSimModelReferences.other_representation_to_inertial( + array=f_L, + other_representation=self.velocity_representation, + transform=W_H_L, + is_force=True, + ) + )(f_L, W_H_L) + + # The f_L input is either L_f_L or LW_f_L, depending on the representation. + W_H_L = js.model.forward_kinematics(model=model, data=data) + W_f_L = convert_using_link_frame(f_L=f_L, W_H_L=W_H_L[link_idxs, :, :]) + + return replace( + forces=self.input.physics_model.f_ext.at[link_idxs, :].set( + W_f0_L + W_f_L + ) + ) - return replace( - forces=self.input.physics_model.f_ext.at[link_idxs, :].set(W_f0_L + W_f_L) + return jax.lax.cond( + pred=(self.velocity_representation == VelRepr.Inertial), + true_fun=inertial, + false_fun=not_inertial, ) def apply_frame_forces( From 12a6f8d968a157efdc48e9c77a88261e0ead1bd0 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Wed, 5 Jun 2024 17:02:40 +0200 Subject: [PATCH 19/27] Use `jax.pure_callback` to throw errors --- src/jaxsim/api/references.py | 45 ++++++++++++++++++++++-------------- 1 file changed, 28 insertions(+), 17 deletions(-) diff --git a/src/jaxsim/api/references.py b/src/jaxsim/api/references.py index a8c523784..7f486debd 100644 --- a/src/jaxsim/api/references.py +++ b/src/jaxsim/api/references.py @@ -184,19 +184,22 @@ def link_forces( # serialization. if model is None: - def not_inertial(): + def inertial(): if link_names is not None: raise ValueError("Link names cannot be provided without a model") return self.input.physics_model.f_ext - jax.lax.cond( + return jax.lax.cond( pred=(self.velocity_representation == VelRepr.Inertial), - true_fun=not_inertial, - false_fun=lambda: (_ for _ in (None,)).throw( - ValueError( - "Missing model to use a representation different from `VelRepr.Inertial`" - ) + true_fun=inertial, + false_fun=lambda: jax.pure_callback( + callback=lambda: (_ for _ in ()).throw( + ValueError( + "Missing model to use a representation different from `VelRepr.Inertial`" + ) + ), + result_shape_dtypes=self.input.physics_model.f_ext, ), ) @@ -238,7 +241,10 @@ def convert(W_f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike) -> jtp.Matrix: return jax.lax.cond( pred=(self.velocity_representation == VelRepr.Inertial), true_fun=lambda: W_f_L[link_idxs, :], - false_fun=not_inertial, + false_fun=lambda: jax.pure_callback( + callback=not_inertial, + result_shape_dtypes=W_f_L[link_idxs, :], + ), ) def joint_force_references( @@ -387,7 +393,7 @@ def replace(forces: jtp.MatrixLike) -> JaxSimModelReferences: # using the implicit link serialization. if model is None: - def not_inertial(): + def inertial(): if link_names is not None: raise ValueError("Link names cannot be provided without a model") @@ -403,11 +409,14 @@ def not_inertial(): jax.lax.cond( pred=(self.velocity_representation == VelRepr.Inertial), - true_fun=not_inertial, - false_fun=lambda: (_ for _ in (None,)).throw( - ValueError( - "Missing model to use a representation different from `VelRepr.Inertial`" - ) + true_fun=inertial, + false_fun=lambda: jax.pure_callback( + callback=lambda: (_ for _ in ()).throw( + ValueError( + "Missing model to use a representation different from `VelRepr.Inertial`" + ) + ), + result_shape_dtypes=self, ), ) @@ -440,8 +449,7 @@ def inertial(): ) ) - def not_inertial(): - + def not_inertial(data): if data is None: raise ValueError( "Missing model data to use a representation different from `VelRepr.Inertial`" @@ -478,7 +486,10 @@ def convert_using_link_frame( return jax.lax.cond( pred=(self.velocity_representation == VelRepr.Inertial), true_fun=inertial, - false_fun=not_inertial, + false_fun=lambda: jax.experimental.io_callback( + callback=not_inertial, + result_shape_dtypes=self, + ), ) def apply_frame_forces( From 0bc1d76c42e4c42edf30bc5a6db186e8e29d1d23 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Thu, 6 Jun 2024 09:00:31 +0200 Subject: [PATCH 20/27] Add additional type checks --- src/jaxsim/api/link.py | 2 + src/jaxsim/api/references.py | 76 ++++++++++++++++++++++++------------ 2 files changed, 52 insertions(+), 26 deletions(-) diff --git a/src/jaxsim/api/link.py b/src/jaxsim/api/link.py index 87db39afb..d154b3b0d 100644 --- a/src/jaxsim/api/link.py +++ b/src/jaxsim/api/link.py @@ -11,6 +11,8 @@ from jaxsim import exceptions from jaxsim.math import Adjoint, Transform +from .common import VelRepr + # ======================= # Index-related functions # ======================= diff --git a/src/jaxsim/api/references.py b/src/jaxsim/api/references.py index 7f486debd..30a6c7649 100644 --- a/src/jaxsim/api/references.py +++ b/src/jaxsim/api/references.py @@ -19,6 +19,8 @@ except ImportError: from typing_extensions import Self +from .data import JaxSimModelData + @jax_dataclasses.pytree_dataclass class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation): @@ -184,7 +186,7 @@ def link_forces( # serialization. if model is None: - def inertial(): + def inertial() -> jtp.Array: if link_names is not None: raise ValueError("Link names cannot be provided without a model") @@ -207,7 +209,7 @@ def inertial(): link_names = link_names if link_names is not None else model.link_names() link_idxs = js.link.names_to_idxs(link_names=link_names, model=model) - def not_inertial(): + def check_not_inertial() -> None: if data is None: raise ValueError( "Missing model data to use a representation different from `VelRepr.Inertial`" @@ -218,6 +220,17 @@ def not_inertial(): ): raise ValueError("The provided data is not valid for the model") + # If not inertial-fixed representation, we need the model data. + jax.lax.cond( + pred=(self.velocity_representation != VelRepr.Inertial), + true_fun=lambda: jax.pure_callback( + callback=check_not_inertial, + result_shape_dtypes=None, + ), + false_fun=lambda: None, + ) + + def not_inertial(velocity_representation: int) -> jtp.Matrix: # Helper function to convert a single 6D force to the active representation # considering as body the link (i.e. L_f_L and LW_f_L). def convert(W_f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike) -> jtp.Matrix: @@ -225,14 +238,16 @@ def convert(W_f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike) -> jtp.Matrix: return jax.vmap( lambda W_f_L, W_H_L: JaxSimModelReferences.inertial_to_other_representation( array=W_f_L, - other_representation=self.velocity_representation, + other_representation=velocity_representation, transform=W_H_L, is_force=True, ) )(W_f_L, W_H_L) # The f_L output is either L_f_L or LW_f_L, depending on the representation. - W_H_L = js.model.forward_kinematics(model=model, data=data) + W_H_L = js.model.forward_kinematics( + model=model, data=data or JaxSimModelData.zero(model=model) + ) f_L = convert(W_f_L=W_f_L[link_idxs, :], W_H_L=W_H_L[link_idxs, :, :]) return f_L @@ -240,11 +255,9 @@ def convert(W_f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike) -> jtp.Matrix: # In inertial-fixed representation, we already have the link forces. return jax.lax.cond( pred=(self.velocity_representation == VelRepr.Inertial), - true_fun=lambda: W_f_L[link_idxs, :], - false_fun=lambda: jax.pure_callback( - callback=not_inertial, - result_shape_dtypes=W_f_L[link_idxs, :], - ), + true_fun=lambda _: W_f_L[link_idxs, :], + false_fun=not_inertial, + operand=self.velocity_representation, ) def joint_force_references( @@ -393,7 +406,7 @@ def replace(forces: jtp.MatrixLike) -> JaxSimModelReferences: # using the implicit link serialization. if model is None: - def inertial(): + def inertial() -> JaxSimModelReferences: if link_names is not None: raise ValueError("Link names cannot be provided without a model") @@ -440,16 +453,7 @@ def inertial(): else self.input.physics_model.f_ext[link_idxs, :] ) - # If inertial-fixed representation, we can directly store the link forces. - def inertial(): - W_f_L = f_L - return replace( - forces=self.input.physics_model.f_ext.at[link_idxs, :].set( - W_f0_L + W_f_L - ) - ) - - def not_inertial(data): + def check_not_inertial() -> None: if data is None: raise ValueError( "Missing model data to use a representation different from `VelRepr.Inertial`" @@ -458,6 +462,26 @@ def not_inertial(data): if not_tracing(forces) and not data.valid(model=model): raise ValueError("The provided data is not valid for the model") + # If not inertial-fixed representation, we need the model data. + jax.lax.cond( + pred=(self.velocity_representation != VelRepr.Inertial), + true_fun=lambda: jax.pure_callback( + callback=check_not_inertial, + result_shape_dtypes=None, + ), + false_fun=lambda: None, + ) + + # If inertial-fixed representation, we can directly store the link forces. + def inertial(velocity_representation: int) -> JaxSimModelReferences: + W_f_L = f_L + return replace( + forces=self.input.physics_model.f_ext.at[link_idxs, :].set( + W_f0_L + W_f_L + ) + ) + + def not_inertial(velocity_representation: int) -> JaxSimModelReferences: # Helper function to convert a single 6D force to the inertial representation # considering as body the link (i.e. L_f_L and LW_f_L). def convert_using_link_frame( @@ -467,14 +491,16 @@ def convert_using_link_frame( return jax.vmap( lambda f_L, W_H_L: JaxSimModelReferences.other_representation_to_inertial( array=f_L, - other_representation=self.velocity_representation, + other_representation=velocity_representation, transform=W_H_L, is_force=True, ) )(f_L, W_H_L) # The f_L input is either L_f_L or LW_f_L, depending on the representation. - W_H_L = js.model.forward_kinematics(model=model, data=data) + W_H_L = js.model.forward_kinematics( + model=model, data=data or JaxSimModelData.zero(model=model) + ) W_f_L = convert_using_link_frame(f_L=f_L, W_H_L=W_H_L[link_idxs, :, :]) return replace( @@ -486,10 +512,8 @@ def convert_using_link_frame( return jax.lax.cond( pred=(self.velocity_representation == VelRepr.Inertial), true_fun=inertial, - false_fun=lambda: jax.experimental.io_callback( - callback=not_inertial, - result_shape_dtypes=self, - ), + false_fun=not_inertial, + operand=self.velocity_representation, ) def apply_frame_forces( From 68e2f8e916fb1b3141f40aa83aeb29d68a8fa668 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Fri, 24 May 2024 13:21:48 +0200 Subject: [PATCH 21/27] Fix base acceleration transform to `VelRepr.Mixed` In Articulated Body Algorithm Co-authored-by: Alessandro Croci --- src/jaxsim/api/model.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index e607b6b42..65f05129c 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -568,9 +568,9 @@ def to_inertial() -> jtp.Matrix: W_H_B = data.base_transform() W_X_B = jaxsim.math.Adjoint.from_transform(W_H_B) - O_J_WL_I = W_J_WL_I = jax.vmap( # noqa: F841 - lambda B_J_WL_I: W_X_B @ B_J_WL_I - )(B_J_WL_I) + O_J_WL_I = W_J_WL_I = jax.vmap(lambda B_J_WL_I: W_X_B @ B_J_WL_I)( # noqa: F841 + B_J_WL_I + ) return O_J_WL_I def to_body() -> jtp.Matrix: @@ -784,7 +784,7 @@ def to_active( # In Mixed representation, we need to include a cross product in ℝ⁶. # In Inertial and Body representations, the cross product is always zero. C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True) - return C_X_W @ (W_v̇_WB - Cross.vx(W_v_WC) @ W_v_WB) + return C_X_W @ W_v̇_WB - Cross.vx(W_v_WC) @ W_v_WB def to_inertial(): # In this case C=W @@ -1778,7 +1778,9 @@ def to_body() -> jtp.Matrix: return C_H_L, L_v_CL def to_inertial() -> jtp.Matrix: - C_H_L = W_H_L = js.model.forward_kinematics(model=model, data=data) # noqa: F841 + C_H_L = W_H_L = js.model.forward_kinematics( + model=model, data=data + ) # noqa: F841 L_v_CL = L_v_WL return C_H_L, L_v_CL @@ -1786,7 +1788,9 @@ def to_mixed() -> jtp.Matrix: W_H_L = js.model.forward_kinematics(model=model, data=data) LW_H_L = jax.vmap(lambda W_H_L: W_H_L.at[0:3, 3].set(jnp.zeros(3)))(W_H_L) C_H_L = LW_H_L - L_v_CL = L_v_LW_L = jax.vmap(lambda v: v.at[0:3].set(jnp.zeros(3)))(L_v_WL) # noqa: F841 + L_v_CL = L_v_LW_L = jax.vmap(lambda v: v.at[0:3].set(jnp.zeros(3)))( + L_v_WL + ) # noqa: F841 return C_H_L, L_v_CL C_H_L, L_v_CL = jax.lax.switch( From 0da3a628ed0c8737ec6e84f5cab1dec01b8d9e4e Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Fri, 14 Jun 2024 15:54:06 +0200 Subject: [PATCH 22/27] Define `jaxsim.typing.VelRepr` as `Int` Co-authored-by: Diego Ferigo --- src/jaxsim/api/common.py | 8 +- src/jaxsim/api/contact.py | 2 +- src/jaxsim/api/data.py | 12 +- src/jaxsim/api/frame.py | 2 +- src/jaxsim/api/link.py | 6 +- src/jaxsim/api/model.py | 16 +-- src/jaxsim/api/references.py | 10 +- src/jaxsim/typing.py | 2 + tests/test_api_com.py | 3 +- tests/test_api_data.py | 3 +- tests/test_api_frame.py | 3 +- tests/test_api_link.py | 7 +- tests/test_api_model.py | 7 +- tests/test_api_references.py | 150 ++++++++++++++++++++++++ tests/test_automatic_differentiation.py | 2 +- tests/test_contact.py | 3 +- tests/test_simulations.py | 5 +- tests/utils_idyntree.py | 5 +- 18 files changed, 203 insertions(+), 43 deletions(-) create mode 100644 tests/test_api_references.py diff --git a/src/jaxsim/api/common.py b/src/jaxsim/api/common.py index 9ff0bece6..d9dba423e 100644 --- a/src/jaxsim/api/common.py +++ b/src/jaxsim/api/common.py @@ -35,13 +35,13 @@ class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC): Base class for model data structures with velocity representation. """ - velocity_representation: int = dataclasses.field( + velocity_representation: jtp.VelRepr = dataclasses.field( default=VelRepr.Inertial, kw_only=True ) @contextlib.contextmanager def switch_velocity_representation( - self, velocity_representation: int + self, velocity_representation: jtp.VelRepr ) -> ContextManager[Self]: """ Context manager to temporarily switch the velocity representation. @@ -83,7 +83,7 @@ def switch_velocity_representation( @functools.partial(jax.jit, static_argnames=["is_force"]) def inertial_to_other_representation( array: jtp.Array, - other_representation: int, + other_representation: jtp.VelRepr, transform: jtp.Matrix, *, is_force: bool, @@ -148,7 +148,7 @@ def to_mixed(): @functools.partial(jax.jit, static_argnames=["is_force"]) def other_representation_to_inertial( array: jtp.Array, - other_representation: int, + other_representation: jtp.VelRepr, transform: jtp.Matrix, *, is_force: bool, diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 3f472ec90..c59608ff1 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -334,7 +334,7 @@ def jacobian( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, *, - output_vel_repr: int | None = None, + output_vel_repr: jtp.VelRepr | None = None, ) -> jtp.Array: r""" Return the free-floating Jacobian of the collidable points. diff --git a/src/jaxsim/api/data.py b/src/jaxsim/api/data.py index 70d2774e5..0a8185fee 100644 --- a/src/jaxsim/api/data.py +++ b/src/jaxsim/api/data.py @@ -84,7 +84,7 @@ def valid(self, model: js.model.JaxSimModel | None = None) -> bool: @staticmethod def zero( model: js.model.JaxSimModel, - velocity_representation: int = VelRepr.Inertial, + velocity_representation: jtp.VelRepr = VelRepr.Inertial, ) -> JaxSimModelData: """ Create a `JaxSimModelData` object with zero state. @@ -113,7 +113,7 @@ def build( standard_gravity: jtp.FloatLike = jaxsim.math.StandardGravity, contact: jaxsim.rbda.ContactsState | None = None, contacts_params: jaxsim.rbda.ContactsParams | None = None, - velocity_representation: int = VelRepr.Inertial, + velocity_representation: jtp.VelRepr = VelRepr.Inertial, time: jtp.FloatLike | None = None, ) -> JaxSimModelData: """ @@ -625,7 +625,7 @@ def reset_base_pose(self, base_pose: jtp.MatrixLike) -> Self: def reset_base_linear_velocity( self, linear_velocity: jtp.VectorLike, - velocity_representation: int | None = None, + velocity_representation: jtp.VelRepr | None = None, ) -> Self: """ Reset the base linear velocity. @@ -656,7 +656,7 @@ def reset_base_linear_velocity( def reset_base_angular_velocity( self, angular_velocity: jtp.VectorLike, - velocity_representation: int | None = None, + velocity_representation: jtp.VelRepr | None = None, ) -> Self: """ Reset the base angular velocity. @@ -687,7 +687,7 @@ def reset_base_angular_velocity( def reset_base_velocity( self, base_velocity: jtp.VectorLike, - velocity_representation: int | None = None, + velocity_representation: jtp.VelRepr | None = None, ) -> Self: """ Reset the base 6D velocity. @@ -732,7 +732,7 @@ def random_model_data( model: js.model.JaxSimModel, *, key: jax.Array | None = None, - velocity_representation: int | None = None, + velocity_representation: jtp.VelRepr | None = None, base_pos_bounds: tuple[ jtp.FloatLike | Sequence[jtp.FloatLike], jtp.FloatLike | Sequence[jtp.FloatLike], diff --git a/src/jaxsim/api/frame.py b/src/jaxsim/api/frame.py index 698b61202..468d3be71 100644 --- a/src/jaxsim/api/frame.py +++ b/src/jaxsim/api/frame.py @@ -186,7 +186,7 @@ def jacobian( data: js.data.JaxSimModelData, *, frame_index: jtp.IntLike, - output_vel_repr: int | None = None, + output_vel_repr: jtp.VelRepr | None = None, ) -> jtp.Matrix: r""" Compute the free-floating jacobian of the frame. diff --git a/src/jaxsim/api/link.py b/src/jaxsim/api/link.py index d154b3b0d..5906792a9 100644 --- a/src/jaxsim/api/link.py +++ b/src/jaxsim/api/link.py @@ -241,7 +241,7 @@ def jacobian( data: js.data.JaxSimModelData, *, link_index: jtp.IntLike, - output_vel_repr: int | None = None, + output_vel_repr: jtp.VelRepr | None = None, ) -> jtp.Matrix: r""" Compute the free-floating jacobian of the link. @@ -354,7 +354,7 @@ def velocity( data: js.data.JaxSimModelData, *, link_index: jtp.IntLike, - output_vel_repr: int | None = None, + output_vel_repr: jtp.VelRepr | None = None, ) -> jtp.Vector: """ Compute the 6D velocity of the link. @@ -404,7 +404,7 @@ def jacobian_derivative( data: js.data.JaxSimModelData, *, link_index: jtp.IntLike, - output_vel_repr: int | None = None, + output_vel_repr: jtp.VelRepr | None = None, ) -> jtp.Matrix: r""" Compute the derivative of the free-floating jacobian of the link. diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 65f05129c..8011d3438 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -453,7 +453,7 @@ def generalized_free_floating_jacobian( model: JaxSimModel, data: js.data.JaxSimModelData, *, - output_vel_repr: int | None = None, + output_vel_repr: jtp.VelRepr | None = None, ) -> jtp.Matrix: """ Compute the free-floating jacobians of all links. @@ -1422,7 +1422,7 @@ def total_momentum_jacobian( model: JaxSimModel, data: js.data.JaxSimModelData, *, - output_vel_repr: int | None = None, + output_vel_repr: jtp.VelRepr | None = None, ) -> jtp.Matrix: """ Compute the jacobian of the total momentum. @@ -1519,7 +1519,7 @@ def average_velocity_jacobian( model: JaxSimModel, data: js.data.JaxSimModelData, *, - output_vel_repr: int | None = None, + output_vel_repr: jtp.VelRepr | None = None, ) -> jtp.Matrix: """ Compute the Jacobian of the average velocity of the model. @@ -1778,9 +1778,9 @@ def to_body() -> jtp.Matrix: return C_H_L, L_v_CL def to_inertial() -> jtp.Matrix: - C_H_L = W_H_L = js.model.forward_kinematics( + C_H_L = W_H_L = js.model.forward_kinematics( # noqa: F841 model=model, data=data - ) # noqa: F841 + ) L_v_CL = L_v_WL return C_H_L, L_v_CL @@ -1788,9 +1788,9 @@ def to_mixed() -> jtp.Matrix: W_H_L = js.model.forward_kinematics(model=model, data=data) LW_H_L = jax.vmap(lambda W_H_L: W_H_L.at[0:3, 3].set(jnp.zeros(3)))(W_H_L) C_H_L = LW_H_L - L_v_CL = L_v_LW_L = jax.vmap(lambda v: v.at[0:3].set(jnp.zeros(3)))( - L_v_WL - ) # noqa: F841 + L_v_CL = L_v_LW_L = jax.vmap( # noqa: F841 + lambda v: v.at[0:3].set(jnp.zeros(3)) + )(L_v_WL) return C_H_L, L_v_CL C_H_L, L_v_CL = jax.lax.switch( diff --git a/src/jaxsim/api/references.py b/src/jaxsim/api/references.py index 30a6c7649..5e7197040 100644 --- a/src/jaxsim/api/references.py +++ b/src/jaxsim/api/references.py @@ -34,7 +34,7 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation): def zero( model: js.model.JaxSimModel, data: js.data.JaxSimModelData | None = None, - velocity_representation: int = VelRepr.Inertial, + velocity_representation: jtp.VelRepr = VelRepr.Inertial, ) -> JaxSimModelReferences: """ Create a `JaxSimModelReferences` object with zero references. @@ -60,7 +60,7 @@ def build( joint_force_references: jtp.Vector | None = None, link_forces: jtp.Matrix | None = None, data: js.data.JaxSimModelData | None = None, - velocity_representation: int | None = None, + velocity_representation: jtp.VelRepr | None = None, ) -> JaxSimModelReferences: """ Create a `JaxSimModelReferences` object with the given references. @@ -230,7 +230,7 @@ def check_not_inertial() -> None: false_fun=lambda: None, ) - def not_inertial(velocity_representation: int) -> jtp.Matrix: + def not_inertial(velocity_representation: jtp.VelRepr) -> jtp.Matrix: # Helper function to convert a single 6D force to the active representation # considering as body the link (i.e. L_f_L and LW_f_L). def convert(W_f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike) -> jtp.Matrix: @@ -473,7 +473,7 @@ def check_not_inertial() -> None: ) # If inertial-fixed representation, we can directly store the link forces. - def inertial(velocity_representation: int) -> JaxSimModelReferences: + def inertial(velocity_representation: jtp.VelRepr) -> JaxSimModelReferences: W_f_L = f_L return replace( forces=self.input.physics_model.f_ext.at[link_idxs, :].set( @@ -481,7 +481,7 @@ def inertial(velocity_representation: int) -> JaxSimModelReferences: ) ) - def not_inertial(velocity_representation: int) -> JaxSimModelReferences: + def not_inertial(velocity_representation: jtp.VelRepr) -> JaxSimModelReferences: # Helper function to convert a single 6D force to the inertial representation # considering as body the link (i.e. L_f_L and LW_f_L). def convert_using_link_frame( diff --git a/src/jaxsim/typing.py b/src/jaxsim/typing.py index 5a2c0a54e..4a6ff0474 100644 --- a/src/jaxsim/typing.py +++ b/src/jaxsim/typing.py @@ -37,3 +37,5 @@ IntLike = int | Int | jax.typing.ArrayLike BoolLike = bool | Bool | jax.typing.ArrayLike FloatLike = float | Float | jax.typing.ArrayLike + +VelRepr = int diff --git a/tests/test_api_com.py b/tests/test_api_com.py index 83a17bf75..e019a53fd 100644 --- a/tests/test_api_com.py +++ b/tests/test_api_com.py @@ -2,6 +2,7 @@ import pytest import jaxsim.api as js +import jaxsim.typing as jtp from jaxsim import VelRepr from . import utils_idyntree @@ -9,7 +10,7 @@ def test_com_properties( jaxsim_models_types: js.model.JaxSimModel, - velocity_representation: int, + velocity_representation: jtp.VelRepr, prng_key: jax.Array, ): diff --git a/tests/test_api_data.py b/tests/test_api_data.py index 78e19d153..996356a8b 100644 --- a/tests/test_api_data.py +++ b/tests/test_api_data.py @@ -3,6 +3,7 @@ import pytest import jaxsim.api as js +import jaxsim.typing as jtp from jaxsim import VelRepr from jaxsim.utils import Mutability @@ -21,7 +22,7 @@ def test_data_valid( def test_data_joint_indexing( jaxsim_models_types: js.model.JaxSimModel, - velocity_representation: int, + velocity_representation: jtp.VelRepr, prng_key: jax.Array, ): diff --git a/tests/test_api_frame.py b/tests/test_api_frame.py index ee4fc5ec8..820e8ce31 100644 --- a/tests/test_api_frame.py +++ b/tests/test_api_frame.py @@ -4,6 +4,7 @@ import pytest import jaxsim.api as js +import jaxsim.typing as jtp from jaxsim import VelRepr from jaxsim.math.quaternion import Quaternion @@ -124,7 +125,7 @@ def test_frame_transforms( def test_frame_jacobians( jaxsim_models_types: js.model.JaxSimModel, - velocity_representation: int, + velocity_representation: jtp.VelRepr, prng_key: jax.Array, ): diff --git a/tests/test_api_link.py b/tests/test_api_link.py index 858587409..74cba2def 100644 --- a/tests/test_api_link.py +++ b/tests/test_api_link.py @@ -5,6 +5,7 @@ import jaxsim.api as js import jaxsim.math +import jaxsim.typing as jtp from jaxsim import VelRepr from . import utils_idyntree @@ -128,7 +129,7 @@ def test_link_transforms( def test_link_jacobians( jaxsim_models_types: js.model.JaxSimModel, - velocity_representation: int, + velocity_representation: jtp.VelRepr, prng_key: jax.Array, ): @@ -197,7 +198,7 @@ def test_link_jacobians( def test_link_bias_acceleration( jaxsim_models_types: js.model.JaxSimModel, - velocity_representation: int, + velocity_representation: jtp.VelRepr, prng_key: jax.Array, ): @@ -298,7 +299,7 @@ def test_link_bias_acceleration( def test_link_jacobian_derivative( jaxsim_models_types: js.model.JaxSimModel, - velocity_representation: VelRepr, + velocity_representation: jtp.VelRepr, prng_key: jax.Array, ): diff --git a/tests/test_api_model.py b/tests/test_api_model.py index 439e44761..fe0643c7f 100644 --- a/tests/test_api_model.py +++ b/tests/test_api_model.py @@ -8,6 +8,7 @@ import jaxsim.api as js import jaxsim.math +import jaxsim.typing as jtp from jaxsim import VelRepr from . import utils_idyntree @@ -222,7 +223,7 @@ def test_model_creation_and_reduction( def test_model_properties( jaxsim_models_types: js.model.JaxSimModel, - velocity_representation: int, + velocity_representation: jtp.VelRepr, prng_key: jax.Array, ): @@ -269,7 +270,7 @@ def test_model_properties( def test_model_rbda( jaxsim_models_types: js.model.JaxSimModel, prng_key: jax.Array, - velocity_representation: int, + velocity_representation: jtp.VelRepr, ): model = jaxsim_models_types @@ -480,7 +481,7 @@ def compute_q̇(data: js.data.JaxSimModelData) -> jax.Array: def test_model_fd_id_consistency( jaxsim_models_types: js.model.JaxSimModel, - velocity_representation: int, + velocity_representation: jtp.VelRepr, prng_key: jax.Array, ): diff --git a/tests/test_api_references.py b/tests/test_api_references.py new file mode 100644 index 000000000..b656d2719 --- /dev/null +++ b/tests/test_api_references.py @@ -0,0 +1,150 @@ +import jax +import jax.numpy as jnp +import pytest +from jaxlib.xla_extension import XlaRuntimeError + +import jaxsim.api as js +import jaxsim.typing as jtp +from jaxsim import VelRepr + + +def get_random_references( + model: js.model.JaxSimModel | None = None, + data: js.data.JaxSimModelData | None = None, + *, + velocity_representation: jtp.VelRepr, + key: jax.Array, +) -> tuple[js.data.JaxSimModelData, js.references.JaxSimModelReferences]: + + _, subkey = jax.random.split(key, num=2) + + _, subkey1, subkey2 = jax.random.split(subkey, num=3) + + references = js.references.JaxSimModelReferences.build( + model=model, + joint_force_references=10 * jax.random.uniform(subkey1, shape=(model.dofs(),)), + link_forces=jax.random.uniform(subkey2, shape=(model.number_of_links(), 6)), + data=data, + velocity_representation=velocity_representation, + ) + + # Remove the force applied to the base link if the model is fixed-base. + if not model.floating_base(): + references = references.apply_link_forces( + forces=jnp.atleast_2d(jnp.zeros(6)), + model=model, + data=data, + link_names=(model.base_link(),), + additive=False, + ) + + return references + + +def test_raise_errors_link_forces( + jaxsim_model_box: js.model.JaxSimModel, + prng_key: jax.Array, +): + + model = jaxsim_model_box + + _, subkey1, subkey2 = jax.random.split(prng_key, num=3) + + data = js.data.random_model_data(model=model, key=subkey1) + + # ================ + # VelRepr.Inertial + # ================ + + references_inertial = get_random_references( + model=model, data=None, velocity_representation=VelRepr.Inertial, key=subkey2 + ) + + # `model` is None and `link_names` is not None. + with pytest.raises( + ValueError, match="Link names cannot be provided without a model" + ): + references_inertial.link_forces(model=None, link_names=model.link_names()) + + # ============ + # VelRepr.Body + # ============ + + references_body = get_random_references( + model=model, data=data, velocity_representation=VelRepr.Body, key=subkey2 + ) + + # `model` is None and `link_names` is not None. + with pytest.raises( + ValueError, match="Link names cannot be provided without a model" + ): + references_body.link_forces(model=None, link_names=model.link_names()) + + # `model` is not None and `data` is None. + with pytest.raises( + XlaRuntimeError, + match="Missing model data to use a representation different from `VelRepr.Inertial`", + ): + references_body.link_forces(model=model, data=None) + + +def test_raise_errors_apply_link_forces( + jaxsim_model_box: js.model.JaxSimModel, + prng_key: jax.Array, +): + + model = jaxsim_model_box + + _, subkey1, subkey2 = jax.random.split(prng_key, num=3) + + data = js.data.random_model_data(model=model, key=subkey1) + + # ================ + # VelRepr.Inertial + # ================ + + references_inertial = get_random_references( + model=model, data=None, velocity_representation=VelRepr.Inertial, key=subkey2 + ) + + # `model` is None + with pytest.raises( + ValueError, + match="Link names cannot be provided without a model", + ): + references_inertial.apply_link_forces( + forces=jnp.zeros(6), model=None, data=None, link_names=model.link_names() + ) + + # ============ + # VelRepr.Body + # ============ + + references_body = get_random_references( + model=model, data=data, velocity_representation=VelRepr.Body, key=subkey2 + ) + + # `model` is None + with pytest.raises( + ValueError, + match="Link names cannot be provided without a model", + ): + references_body.apply_link_forces( + forces=jnp.zeros(6), model=None, data=None, link_names=model.link_names() + ) + + # `model` is None + with pytest.raises( + XlaRuntimeError, + match="Missing model to use a representation different from `VelRepr.Inertial`", + ): + references_body.apply_link_forces(forces=jnp.zeros(6), model=None, data=None) + + # `model` is not None and `data` is None. + with pytest.raises( + XlaRuntimeError, + match="Missing model data to use a representation different from `VelRepr.Inertial`", + ): + references_body.apply_link_forces( + forces=jnp.zeros(6), model=model, data=None, link_names=model.link_names() + ) diff --git a/tests/test_automatic_differentiation.py b/tests/test_automatic_differentiation.py index fd075ef0e..7a6294dad 100644 --- a/tests/test_automatic_differentiation.py +++ b/tests/test_automatic_differentiation.py @@ -26,7 +26,7 @@ def get_random_data_and_references( model: js.model.JaxSimModel, - velocity_representation: int, + velocity_representation: jtp.VelRepr, key: jax.Array, ) -> tuple[js.data.JaxSimModelData, js.references.JaxSimModelReferences]: diff --git a/tests/test_contact.py b/tests/test_contact.py index 20af3dab2..5bcdc2cd6 100644 --- a/tests/test_contact.py +++ b/tests/test_contact.py @@ -2,12 +2,13 @@ import pytest import jaxsim.api as js +import jaxsim.typing as jtp from jaxsim import VelRepr def test_collidable_point_jacobians( jaxsim_models_types: js.model.JaxSimModel, - velocity_representation: int, + velocity_representation: jtp.VelRepr, prng_key: jax.Array, ): diff --git a/tests/test_simulations.py b/tests/test_simulations.py index 5dfb829fe..6e5f85bad 100644 --- a/tests/test_simulations.py +++ b/tests/test_simulations.py @@ -5,13 +5,14 @@ import jaxsim.api as js import jaxsim.integrators import jaxsim.rbda +import jaxsim.typing as jtp from jaxsim import VelRepr from jaxsim.utils import Mutability def test_box_with_external_forces( jaxsim_model_box: js.model.JaxSimModel, - velocity_representation: int, + velocity_representation: jtp.VelRepr, ): """ This test simulates a box falling due to gravity. @@ -96,7 +97,7 @@ def test_box_with_external_forces( def test_box_with_zero_gravity( jaxsim_model_box: js.model.JaxSimModel, - velocity_representation: VelRepr, + velocity_representation: jtp.VelRepr, prng_key: jnp.ndarray, ): diff --git a/tests/utils_idyntree.py b/tests/utils_idyntree.py index 8c235a94f..80070a6f5 100644 --- a/tests/utils_idyntree.py +++ b/tests/utils_idyntree.py @@ -8,6 +8,7 @@ import numpy.typing as npt import jaxsim.api as js +import jaxsim.typing as jtp from jaxsim import VelRepr @@ -118,7 +119,7 @@ def store_jaxsim_data_in_kindyncomputations( class KinDynComputations: """High-level wrapper of the iDynTree KinDynComputations class.""" - vel_repr: int + vel_repr: jtp.VelRepr gravity: npt.NDArray kin_dyn: idt.KinDynComputations @@ -126,7 +127,7 @@ class KinDynComputations: def build( urdf: pathlib.Path | str, considered_joints: list[str] | None = None, - vel_repr: int = VelRepr.Inertial, + vel_repr: jtp.VelRepr = VelRepr.Inertial, gravity: npt.NDArray = np.array([0, 0, -10.0]), removed_joint_positions: dict[str, npt.NDArray | float | int] | None = None, ) -> KinDynComputations: From 06dd88e8d79576343fdcfabdea2488059839e6ec Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Fri, 14 Jun 2024 18:17:46 +0200 Subject: [PATCH 23/27] Update JIT checks in `JaxSimModelReferences` --- src/jaxsim/api/references.py | 58 ++++++++++++++++++++---------------- 1 file changed, 33 insertions(+), 25 deletions(-) diff --git a/src/jaxsim/api/references.py b/src/jaxsim/api/references.py index 5e7197040..e85511587 100644 --- a/src/jaxsim/api/references.py +++ b/src/jaxsim/api/references.py @@ -245,8 +245,19 @@ def convert(W_f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike) -> jtp.Matrix: )(W_f_L, W_H_L) # The f_L output is either L_f_L or LW_f_L, depending on the representation. - W_H_L = js.model.forward_kinematics( - model=model, data=data or JaxSimModelData.zero(model=model) + W_H_L = jax.lax.cond( + pred=(data is None), + true_fun=lambda: jax.pure_callback( + callback=lambda: (_ for _ in ()).throw( + ValueError( + "Missing model data to use a representation different from `VelRepr.Inertial`" + ) + ), + result_shape_dtypes=jnp.empty(shape=(1, 4, 4)), + ), + false_fun=lambda: js.model.forward_kinematics( + model=model, data=data or JaxSimModelData.zero(model=model) + ), ) f_L = convert(W_f_L=W_f_L[link_idxs, :], W_H_L=W_H_L[link_idxs, :, :]) @@ -420,7 +431,7 @@ def inertial() -> JaxSimModelReferences: return replace(forces=W_f0_L + W_f_L) - jax.lax.cond( + return jax.lax.cond( pred=(self.velocity_representation == VelRepr.Inertial), true_fun=inertial, false_fun=lambda: jax.pure_callback( @@ -453,25 +464,6 @@ def inertial() -> JaxSimModelReferences: else self.input.physics_model.f_ext[link_idxs, :] ) - def check_not_inertial() -> None: - if data is None: - raise ValueError( - "Missing model data to use a representation different from `VelRepr.Inertial`" - ) - - if not_tracing(forces) and not data.valid(model=model): - raise ValueError("The provided data is not valid for the model") - - # If not inertial-fixed representation, we need the model data. - jax.lax.cond( - pred=(self.velocity_representation != VelRepr.Inertial), - true_fun=lambda: jax.pure_callback( - callback=check_not_inertial, - result_shape_dtypes=None, - ), - false_fun=lambda: None, - ) - # If inertial-fixed representation, we can directly store the link forces. def inertial(velocity_representation: jtp.VelRepr) -> JaxSimModelReferences: W_f_L = f_L @@ -497,10 +489,26 @@ def convert_using_link_frame( ) )(f_L, W_H_L) - # The f_L input is either L_f_L or LW_f_L, depending on the representation. - W_H_L = js.model.forward_kinematics( - model=model, data=data or JaxSimModelData.zero(model=model) + # If not inertial-fixed representation, we need the model data. + W_H_L = jax.lax.cond( + pred=(data is None), + true_fun=lambda: jax.pure_callback( + callback=lambda: (_ for _ in ()).throw( + ValueError( + "Missing model data to use a representation different from `VelRepr.Inertial`" + ) + ), + result_shape_dtypes=jnp.empty(shape=(1, 4, 4)), + ), + false_fun=lambda: js.model.forward_kinematics( + model=model, data=data or JaxSimModelData.zero(model=model) + ), ) + + # The f_L input is either L_f_L or LW_f_L, depending on the representation. + # W_H_L = js.model.forward_kinematics( + # model=model, data=data or JaxSimModelData.zero(model=model) + # ) W_f_L = convert_using_link_frame(f_L=f_L, W_H_L=W_H_L[link_idxs, :, :]) return replace( From 7543660d0afd54374dddd905df13fed23c93750f Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Mon, 17 Jun 2024 10:33:04 +0200 Subject: [PATCH 24/27] Update `result_shape_dtypes` in `JaxSimModelReferences` --- src/jaxsim/api/references.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/jaxsim/api/references.py b/src/jaxsim/api/references.py index e85511587..1ec43eee7 100644 --- a/src/jaxsim/api/references.py +++ b/src/jaxsim/api/references.py @@ -253,7 +253,9 @@ def convert(W_f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike) -> jtp.Matrix: "Missing model data to use a representation different from `VelRepr.Inertial`" ) ), - result_shape_dtypes=jnp.empty(shape=(1, 4, 4)), + result_shape_dtypes=jnp.empty( + shape=(model.number_of_links(), 4, 4) + ), ), false_fun=lambda: js.model.forward_kinematics( model=model, data=data or JaxSimModelData.zero(model=model) @@ -498,7 +500,9 @@ def convert_using_link_frame( "Missing model data to use a representation different from `VelRepr.Inertial`" ) ), - result_shape_dtypes=jnp.empty(shape=(1, 4, 4)), + result_shape_dtypes=jnp.empty( + shape=(model.number_of_links(), 4, 4) + ), ), false_fun=lambda: js.model.forward_kinematics( model=model, data=data or JaxSimModelData.zero(model=model) From 0aae8fcee2a05259825899a9b4d0f06559bca987 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Wed, 31 Jul 2024 11:24:22 +0200 Subject: [PATCH 25/27] Use `jax.lax.switch` for Jacobian derivative computation --- src/jaxsim/api/contact.py | 135 +++++++++++++++++------------- src/jaxsim/api/frame.py | 154 ++++++++++++++++++++--------------- src/jaxsim/api/references.py | 16 ++-- 3 files changed, 172 insertions(+), 133 deletions(-) diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index c59608ff1..6c0785c7c 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -469,40 +469,50 @@ def compute_Ṫ(model: js.model.JaxSimModel, Ẋ: jtp.Matrix) -> jtp.Matrix: # Compute the operator to change the representation of ν, and its # time derivative. - match data.velocity_representation: - case VelRepr.Inertial: - W_H_W = jnp.eye(4) - W_X_W = Adjoint.from_transform(transform=W_H_W) - W_Ẋ_W = jnp.zeros((6, 6)) + def from_inertial(): + W_H_W = jnp.eye(4) + W_X_W = Adjoint.from_transform(transform=W_H_W) + W_Ẋ_W = jnp.zeros((6, 6)) - T = compute_T(model=model, X=W_X_W) - Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_W) + T = compute_T(model=model, X=W_X_W) + Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_W) - case VelRepr.Body: - W_H_B = data.base_transform() - W_X_B = Adjoint.from_transform(transform=W_H_B) - B_v_WB = data.base_velocity() - B_vx_WB = Cross.vx(B_v_WB) - W_Ẋ_B = W_X_B @ B_vx_WB + return T, Ṫ - T = compute_T(model=model, X=W_X_B) - Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_B) + def from_body(): + W_H_B = data.base_transform() + W_X_B = Adjoint.from_transform(transform=W_H_B) + B_v_WB = data.base_velocity() + B_vx_WB = Cross.vx(B_v_WB) + W_Ẋ_B = W_X_B @ B_vx_WB - case VelRepr.Mixed: - W_H_B = data.base_transform() - W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) - W_X_BW = Adjoint.from_transform(transform=W_H_BW) - BW_v_WB = data.base_velocity() - BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3)) - BW_vx_W_BW = Cross.vx(BW_v_W_BW) - W_Ẋ_BW = W_X_BW @ BW_vx_W_BW + T = compute_T(model=model, X=W_X_B) + Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_B) - T = compute_T(model=model, X=W_X_BW) - Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_BW) + return T, Ṫ - case _: - raise ValueError(data.velocity_representation) + def from_mixed(): + W_H_B = data.base_transform() + W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) + W_X_BW = Adjoint.from_transform(transform=W_H_BW) + BW_v_WB = data.base_velocity() + BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3)) + BW_vx_W_BW = Cross.vx(BW_v_W_BW) + W_Ẋ_BW = W_X_BW @ BW_vx_W_BW + + T = compute_T(model=model, X=W_X_BW) + Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_BW) + + return T, Ṫ + T, Ṫ = jax.lax.switch( + index=data.velocity_representation, + branches=( + from_body, # VelRepr.Body + from_mixed, # VelRepr.Mixed + from_inertial, # VelRepr.Inertial + ), + ) # ===================================================== # Compute quantities to adjust the output representation # ===================================================== @@ -538,37 +548,46 @@ def compute_O_J̇_WC_I( parent_link_idx = parent_link_idxs[contact_idx] - match output_vel_repr: - case VelRepr.Inertial: - O_X_W = W_X_W = Adjoint.from_transform( # noqa: F841 - transform=jnp.eye(4) - ) - O_Ẋ_W = W_Ẋ_W = jnp.zeros((6, 6)) # noqa: F841 - - case VelRepr.Body: - L_H_C = Transform.from_rotation_and_translation(translation=L_p_C) - W_H_C = W_H_L[parent_link_idx] @ L_H_C - O_X_W = C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True) - with data.switch_velocity_representation(VelRepr.Inertial): - W_nu = data.generalized_velocity() - W_v_WC = W_J_WL_W[parent_link_idx] @ W_nu - W_vx_WC = Cross.vx(W_v_WC) - O_Ẋ_W = C_Ẋ_W = -C_X_W @ W_vx_WC # noqa: F841 - - case VelRepr.Mixed: - L_H_C = Transform.from_rotation_and_translation(translation=L_p_C) - W_H_C = W_H_L[parent_link_idx] @ L_H_C - W_H_CW = W_H_C.at[0:3, 0:3].set(jnp.eye(3)) - CW_H_W = Transform.inverse(W_H_CW) - O_X_W = CW_X_W = Adjoint.from_transform(transform=CW_H_W) - with data.switch_velocity_representation(VelRepr.Mixed): - CW_v_WC = CW_J_WC_BW @ data.generalized_velocity() - W_v_W_CW = jnp.zeros(6).at[0:3].set(CW_v_WC[0:3]) - W_vx_W_CW = Cross.vx(W_v_W_CW) - O_Ẋ_W = CW_Ẋ_W = -CW_X_W @ W_vx_W_CW # noqa: F841 - - case _: - raise ValueError(output_vel_repr) + def to_inertial(): + W_X_W = Adjoint.from_transform(transform=jnp.eye(4)) + W_Ẋ_W = jnp.zeros((6, 6)) + + return W_X_W, W_Ẋ_W + + def to_body(): + L_H_C = Transform.from_rotation_and_translation(translation=L_p_C) + W_H_C = W_H_L[parent_link_idx] @ L_H_C + C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True) + with data.switch_velocity_representation(VelRepr.Inertial): + W_nu = data.generalized_velocity() + W_v_WC = W_J_WL_W[parent_link_idx] @ W_nu + W_vx_WC = Cross.vx(W_v_WC) + C_Ẋ_W = -C_X_W @ W_vx_WC + + return C_X_W, C_Ẋ_W + + def to_mixed(): + L_H_C = Transform.from_rotation_and_translation(translation=L_p_C) + W_H_C = W_H_L[parent_link_idx] @ L_H_C + W_H_CW = W_H_C.at[0:3, 0:3].set(jnp.eye(3)) + CW_H_W = Transform.inverse(W_H_CW) + CW_X_W = Adjoint.from_transform(transform=CW_H_W) + with data.switch_velocity_representation(VelRepr.Mixed): + CW_v_WC = CW_J_WC_BW @ data.generalized_velocity() + W_v_W_CW = jnp.zeros(6).at[0:3].set(CW_v_WC[0:3]) + W_vx_W_CW = Cross.vx(W_v_W_CW) + CW_Ẋ_W = -CW_X_W @ W_vx_W_CW + + return CW_X_W, CW_Ẋ_W + + O_X_W, O_Ẋ_W = jax.lax.switch( + index=output_vel_repr, + branches=( + to_body, # VelRepr.Body + to_mixed, # VelRepr.Mixed + to_inertial, # VelRepr.Inertial + ), + ) O_J̇_WC_I = jnp.zeros(shape=(6, 6 + model.dofs())) O_J̇_WC_I += O_Ẋ_W @ W_J_WL_W[parent_link_idx] @ T diff --git a/src/jaxsim/api/frame.py b/src/jaxsim/api/frame.py index 468d3be71..0dc2d88a1 100644 --- a/src/jaxsim/api/frame.py +++ b/src/jaxsim/api/frame.py @@ -339,77 +339,99 @@ def compute_Ṫ(model: js.model.JaxSimModel, Ẋ: jtp.Matrix) -> jtp.Matrix: # Compute the operator to change the representation of ν, and its # time derivative. - match data.velocity_representation: - case VelRepr.Inertial: - W_H_W = jnp.eye(4) - W_X_W = Adjoint.from_transform(transform=W_H_W) - W_Ẋ_W = jnp.zeros((6, 6)) - - T = compute_T(model=model, X=W_X_W) - Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_W) - - case VelRepr.Body: - W_H_B = data.base_transform() - W_X_B = Adjoint.from_transform(transform=W_H_B) - B_v_WB = data.base_velocity() - B_vx_WB = Cross.vx(B_v_WB) - W_Ẋ_B = W_X_B @ B_vx_WB - - T = compute_T(model=model, X=W_X_B) - Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_B) - - case VelRepr.Mixed: - W_H_B = data.base_transform() - W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) - W_X_BW = Adjoint.from_transform(transform=W_H_BW) - BW_v_WB = data.base_velocity() - BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3)) - BW_vx_W_BW = Cross.vx(BW_v_W_BW) - W_Ẋ_BW = W_X_BW @ BW_vx_W_BW - - T = compute_T(model=model, X=W_X_BW) - Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_BW) - - case _: - raise ValueError(data.velocity_representation) + def from_inertial(): + W_H_W = jnp.eye(4) + W_X_W = Adjoint.from_transform(transform=W_H_W) + W_Ẋ_W = jnp.zeros((6, 6)) + + T = compute_T(model=model, X=W_X_W) + Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_W) + + return T, Ṫ + + def from_body(): + W_H_B = data.base_transform() + W_X_B = Adjoint.from_transform(transform=W_H_B) + B_v_WB = data.base_velocity() + B_vx_WB = Cross.vx(B_v_WB) + W_Ẋ_B = W_X_B @ B_vx_WB + + T = compute_T(model=model, X=W_X_B) + Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_B) + + return T, Ṫ + + def from_mixed(): + W_H_B = data.base_transform() + W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) + W_X_BW = Adjoint.from_transform(transform=W_H_BW) + BW_v_WB = data.base_velocity() + BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3)) + BW_vx_W_BW = Cross.vx(BW_v_W_BW) + W_Ẋ_BW = W_X_BW @ BW_vx_W_BW + + T = compute_T(model=model, X=W_X_BW) + Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_BW) + + return T, Ṫ + + T, Ṫ = jax.lax.switch( + index=data.velocity_representation, + branches=( + from_body, # VelRepr.Body + from_mixed, # VelRepr.Mixed + from_inertial, # VelRepr.Inertial + ), + ) # ===================================================== # Compute quantities to adjust the output representation # ===================================================== - match output_vel_repr: - case VelRepr.Inertial: - O_X_W = W_X_W = Adjoint.from_transform(transform=jnp.eye(4)) - O_Ẋ_W = W_Ẋ_W = jnp.zeros((6, 6)) - - case VelRepr.Body: - W_H_F = transform(model=model, data=data, frame_index=frame_index) - O_X_W = F_X_W = Adjoint.from_transform(transform=W_H_F, inverse=True) - with data.switch_velocity_representation(VelRepr.Inertial): - W_nu = data.generalized_velocity() - W_v_WF = W_J_WL_W @ W_nu - W_vx_WF = Cross.vx(W_v_WF) - O_Ẋ_W = F_Ẋ_W = -F_X_W @ W_vx_WF # noqa: F841 - - case VelRepr.Mixed: - W_H_F = transform(model=model, data=data, frame_index=frame_index) - W_H_FW = W_H_F.at[0:3, 0:3].set(jnp.eye(3)) - FW_H_W = Transform.inverse(W_H_FW) - O_X_W = FW_X_W = Adjoint.from_transform(transform=FW_H_W) - with data.switch_velocity_representation(VelRepr.Mixed): - FW_J_WF_FW = jacobian( - model=model, - data=data, - frame_index=frame_index, - output_vel_repr=VelRepr.Mixed, - ) - FW_v_WF = FW_J_WF_FW @ data.generalized_velocity() - W_v_W_FW = jnp.zeros(6).at[0:3].set(FW_v_WF[0:3]) - W_vx_W_FW = Cross.vx(W_v_W_FW) - O_Ẋ_W = FW_Ẋ_W = -FW_X_W @ W_vx_W_FW # noqa: F841 - - case _: - raise ValueError(output_vel_repr) + def to_inertial(): + W_X_W = Adjoint.from_transform(transform=jnp.eye(4)) + W_Ẋ_W = jnp.zeros((6, 6)) + + return W_X_W, W_Ẋ_W + + def to_body(): + W_H_F = transform(model=model, data=data, frame_index=frame_index) + F_X_W = Adjoint.from_transform(transform=W_H_F, inverse=True) + with data.switch_velocity_representation(VelRepr.Inertial): + W_nu = data.generalized_velocity() + W_v_WF = W_J_WL_W @ W_nu + W_vx_WF = Cross.vx(W_v_WF) + F_Ẋ_W = -F_X_W @ W_vx_WF + + return F_X_W, F_Ẋ_W + + def to_mixed(): + W_H_F = transform(model=model, data=data, frame_index=frame_index) + W_H_FW = W_H_F.at[0:3, 0:3].set(jnp.eye(3)) + FW_H_W = Transform.inverse(W_H_FW) + FW_X_W = Adjoint.from_transform(transform=FW_H_W) + with data.switch_velocity_representation(VelRepr.Mixed): + FW_J_WF_FW = jacobian( + model=model, + data=data, + frame_index=frame_index, + output_vel_repr=VelRepr.Mixed, + ) + FW_v_WF = FW_J_WF_FW @ data.generalized_velocity() + W_v_W_FW = jnp.zeros(6).at[0:3].set(FW_v_WF[0:3]) + W_vx_W_FW = Cross.vx(W_v_W_FW) + FW_Ẋ_W = -FW_X_W @ W_vx_W_FW + + return FW_X_W, FW_Ẋ_W + + O_X_W, O_Ẋ_W = jax.lax.switch( + index=output_vel_repr, + branches=( + to_body, # VelRepr.Body + to_mixed, # VelRepr.Mixed + to_inertial, # VelRepr.Inertial + ), + ) O_J̇_WF_I = jnp.zeros(shape=(6, 6 + model.dofs())) O_J̇_WF_I += O_Ẋ_W @ W_J_WL_W @ T diff --git a/src/jaxsim/api/references.py b/src/jaxsim/api/references.py index 1ec43eee7..ca21133ba 100644 --- a/src/jaxsim/api/references.py +++ b/src/jaxsim/api/references.py @@ -597,15 +597,13 @@ def to_inertial(f_F: jtp.MatrixLike, W_H_F: jtp.MatrixLike) -> jtp.Matrix: is_force=True, ) - match self.velocity_representation: - case VelRepr.Inertial: - W_f_F = f_F - - case VelRepr.Body | VelRepr.Mixed: - W_f_F = jax.vmap(to_inertial)(f_F, W_H_Fi) - - case _: - raise ValueError("Invalid velocity representation.") + W_f_F = jax.lax.switch( + index=self.velocity_representation, + branches=( + lambda: f_F, + lambda: jax.vmap(to_inertial)(f_F, W_H_Fi), + ), + ) # Sum the forces on the parent links. mask = parent_link_idxs[:, jnp.newaxis] == jnp.arange(model.number_of_links()) From f307618b611ae820546a6df32d30abcc15a96c4b Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Thu, 1 Aug 2024 17:18:16 +0200 Subject: [PATCH 26/27] Adjust output names and return type hints --- src/jaxsim/api/com.py | 37 ++++----- src/jaxsim/api/common.py | 9 ++- src/jaxsim/api/contact.py | 20 ++--- src/jaxsim/api/frame.py | 9 ++- src/jaxsim/api/link.py | 54 +++++++------ src/jaxsim/api/model.py | 164 +++++++++++++++++--------------------- src/jaxsim/typing.py | 2 +- 7 files changed, 146 insertions(+), 149 deletions(-) diff --git a/src/jaxsim/api/com.py b/src/jaxsim/api/com.py index 6057edaed..9701ced92 100644 --- a/src/jaxsim/api/com.py +++ b/src/jaxsim/api/com.py @@ -301,18 +301,17 @@ def other_representation_to_body( def to_body() -> jtp.Vector: L_a_bias_WL = v̇_bias_WL + return L_a_bias_WL def to_inertial() -> jtp.Vector: - C_v̇_WL = W_v̇_bias_WL = v̇_bias_WL # noqa: F841 - C_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841 + W_v̇_bias_WL = v̇_bias_WL + W_v_WW = jnp.zeros(6) - L_H_C = L_H_W = jax.vmap(lambda W_H_L: Transform.inverse(W_H_L))( # noqa: F841 - W_H_L - ) + L_H_W = jax.vmap(lambda W_H_L: Transform.inverse(W_H_L))(W_H_L) - L_v_LC = L_v_LW = jax.vmap( # noqa: F841 + L_v_LW = jax.vmap( lambda i: -js.link.velocity( model=model, data=data, link_index=i, output_vel_repr=VelRepr.Body ) @@ -320,19 +319,20 @@ def to_inertial() -> jtp.Vector: L_a_bias_WL = jax.vmap( lambda i: other_representation_to_body( - C_v̇_WL=C_v̇_WL[i], - C_v_WC=C_v_WC, - L_H_C=L_H_C[i], - L_v_LC=L_v_LC[i], + C_v̇_WL=W_v̇_bias_WL[i], + C_v_WC=W_v_WW, + L_H_C=L_H_W[i], + L_v_LC=L_v_LW[i], ) )(jnp.arange(model.number_of_links())) + return L_a_bias_WL def to_mixed() -> jtp.Vector: - C_v̇_WL = LW_v̇_bias_WL = v̇_bias_WL # noqa: F841 + LW_v̇_bias_WL = v̇_bias_WL - C_v_WC = LW_v_W_LW = jax.vmap( # noqa: F841 + LW_v_W_LW = jax.vmap( lambda i: js.link.velocity( model=model, data=data, link_index=i, output_vel_repr=VelRepr.Mixed ) @@ -340,11 +340,11 @@ def to_mixed() -> jtp.Vector: .set(jnp.zeros(3)) )(jnp.arange(model.number_of_links())) - L_H_C = L_H_LW = jax.vmap( # noqa: F841 + L_H_LW = jax.vmap( lambda W_H_L: Transform.inverse(W_H_L.at[0:3, 3].set(jnp.zeros(3))) )(W_H_L) - L_v_LC = L_v_L_LW = jax.vmap( # noqa: F841 + L_v_L_LW = jax.vmap( lambda i: -js.link.velocity( model=model, data=data, link_index=i, output_vel_repr=VelRepr.Body ) @@ -354,12 +354,13 @@ def to_mixed() -> jtp.Vector: L_a_bias_WL = jax.vmap( lambda i: other_representation_to_body( - C_v̇_WL=C_v̇_WL[i], - C_v_WC=C_v_WC[i], - L_H_C=L_H_C[i], - L_v_LC=L_v_LC[i], + C_v̇_WL=LW_v̇_bias_WL[i], + C_v_WC=LW_v_W_LW[i], + L_H_C=L_H_LW[i], + L_v_LC=L_v_L_LW[i], ) )(jnp.arange(model.number_of_links())) + return L_a_bias_WL # We need here to get the body-fixed bias acceleration of the links. diff --git a/src/jaxsim/api/common.py b/src/jaxsim/api/common.py index d9dba423e..33bec4567 100644 --- a/src/jaxsim/api/common.py +++ b/src/jaxsim/api/common.py @@ -112,19 +112,21 @@ def inertial_to_other_representation( if W_H_O.shape != (4, 4): raise ValueError(W_H_O.shape, (4, 4)) - def to_inertial(): + def to_inertial() -> jtp.Array: + return W_array - def to_body(): + def to_body() -> jtp.Array: if not is_force: O_Xv_W = Adjoint.from_transform(transform=W_H_O, inverse=True) O_array = O_Xv_W @ W_array else: O_Xf_W = Adjoint.from_transform(transform=W_H_O).T O_array = O_Xf_W @ W_array + return O_array - def to_mixed(): + def to_mixed() -> jtp.Array: W_p_O = W_H_O[0:3, 3] W_H_OW = jnp.eye(4).at[0:3, 3].set(W_p_O) if not is_force: @@ -133,6 +135,7 @@ def to_mixed(): else: OW_Xf_W = Adjoint.from_transform(transform=W_H_OW).T OW_array = OW_Xf_W @ W_array + return OW_array return jax.lax.switch( diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 6c0785c7c..96d66282f 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -373,8 +373,8 @@ def jacobian( ) def to_inertial() -> jtp.Matrix: - O_J_WC = W_J_WC - return O_J_WC + + return W_J_WC def to_body() -> jtp.Matrix: @@ -385,8 +385,9 @@ def jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix: C_J_WC = C_X_W @ W_J_WC return C_J_WC - O_J_WC = jax.vmap(jacobian)(W_H_C, W_J_WC) - return O_J_WC + C_J_WC = jax.vmap(jacobian)(W_H_C, W_J_WC) + + return C_J_WC def to_mixed() -> jtp.Matrix: @@ -401,8 +402,9 @@ def jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix: CW_J_WC = CW_X_W @ W_J_WC return CW_J_WC - O_J_WC = jax.vmap(jacobian)(W_H_C, W_J_WC) - return O_J_WC + CW_J_WC = jax.vmap(jacobian)(W_H_C, W_J_WC) + + return CW_J_WC # Adjust the output representation. O_J_WC = jax.lax.switch( @@ -548,13 +550,13 @@ def compute_O_J̇_WC_I( parent_link_idx = parent_link_idxs[contact_idx] - def to_inertial(): + def to_inertial() -> tuple[jtp.Matrix, jtp.Matrix]: W_X_W = Adjoint.from_transform(transform=jnp.eye(4)) W_Ẋ_W = jnp.zeros((6, 6)) return W_X_W, W_Ẋ_W - def to_body(): + def to_body() -> tuple[jtp.Matrix, jtp.Matrix]: L_H_C = Transform.from_rotation_and_translation(translation=L_p_C) W_H_C = W_H_L[parent_link_idx] @ L_H_C C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True) @@ -566,7 +568,7 @@ def to_body(): return C_X_W, C_Ẋ_W - def to_mixed(): + def to_mixed() -> tuple[jtp.Matrix, jtp.Matrix]: L_H_C = Transform.from_rotation_and_translation(translation=L_p_C) W_H_C = W_H_L[parent_link_idx] @ L_H_C W_H_CW = W_H_C.at[0:3, 0:3].set(jnp.eye(3)) diff --git a/src/jaxsim/api/frame.py b/src/jaxsim/api/frame.py index 0dc2d88a1..0d8ae9dcc 100644 --- a/src/jaxsim/api/frame.py +++ b/src/jaxsim/api/frame.py @@ -231,6 +231,7 @@ def to_inertial() -> jtp.Matrix: W_H_L = js.link.transform(model=model, data=data, link_index=L) W_X_L = Adjoint.from_transform(transform=W_H_L) W_J_WL = W_X_L @ L_J_WL + return W_J_WL def to_body() -> jtp.Matrix: @@ -239,6 +240,7 @@ def to_body() -> jtp.Matrix: F_H_L = Transform.inverse(W_H_F) @ W_H_L F_X_L = Adjoint.from_transform(transform=F_H_L) F_J_WL = F_X_L @ L_J_WL + return F_J_WL def to_mixed() -> jtp.Matrix: @@ -249,6 +251,7 @@ def to_mixed() -> jtp.Matrix: FW_H_L = FW_H_F @ F_H_L FW_X_L = Adjoint.from_transform(transform=FW_H_L) FW_J_WL = FW_X_L @ L_J_WL + return FW_J_WL # Adjust the output representation @@ -388,13 +391,13 @@ def from_mixed(): # Compute quantities to adjust the output representation # ===================================================== - def to_inertial(): + def to_inertial() -> tuple[jtp.Matrix, jtp.Matrix]: W_X_W = Adjoint.from_transform(transform=jnp.eye(4)) W_Ẋ_W = jnp.zeros((6, 6)) return W_X_W, W_Ẋ_W - def to_body(): + def to_body() -> tuple[jtp.Matrix, jtp.Matrix]: W_H_F = transform(model=model, data=data, frame_index=frame_index) F_X_W = Adjoint.from_transform(transform=W_H_F, inverse=True) with data.switch_velocity_representation(VelRepr.Inertial): @@ -405,7 +408,7 @@ def to_body(): return F_X_W, F_Ẋ_W - def to_mixed(): + def to_mixed() -> tuple[jtp.Matrix, jtp.Matrix]: W_H_F = transform(model=model, data=data, frame_index=frame_index) W_H_FW = W_H_F.at[0:3, 0:3].set(jnp.eye(3)) FW_H_W = Transform.inverse(W_H_FW) diff --git a/src/jaxsim/api/link.py b/src/jaxsim/api/link.py index 5906792a9..9519f5c14 100644 --- a/src/jaxsim/api/link.py +++ b/src/jaxsim/api/link.py @@ -287,21 +287,22 @@ def jacobian( def to_inertial() -> jtp.Matrix: W_H_B = data.base_transform() B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True) - B_J_WL_I = B_J_WL_W = B_J_WL_B @ jax.scipy.linalg.block_diag( # noqa: F841 - B_X_W, jnp.eye(model.dofs()) - ) + B_J_WL_W = B_J_WL_B @ jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs())) + return B_J_WL_W def to_body() -> jtp.Matrix: + return B_J_WL_B def to_mixed() -> jtp.Matrix: W_R_B = data.base_orientation(dcm=True) BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B) B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) - B_J_WL_I = B_J_WL_BW = B_J_WL_B @ jax.scipy.linalg.block_diag( # noqa: F841 + B_J_WL_BW = B_J_WL_B @ jax.scipy.linalg.block_diag( B_X_BW, jnp.eye(model.dofs()) ) + return B_J_WL_BW B_J_WL_I = jax.lax.switch( @@ -319,11 +320,13 @@ def to_inertial() -> jtp.Matrix: W_H_B = data.base_transform() W_X_B = Adjoint.from_transform(transform=W_H_B) W_J_WL_I = W_X_B @ B_J_WL_I + return W_J_WL_I def to_body() -> jtp.Matrix: L_X_B = Adjoint.from_transform(transform=B_H_L, inverse=True) L_J_WL_I = L_X_B @ B_J_WL_I + return L_J_WL_I def to_mixed() -> jtp.Matrix: @@ -333,6 +336,7 @@ def to_mixed() -> jtp.Matrix: LW_H_B = LW_H_L @ Transform.inverse(B_H_L) LW_X_B = Adjoint.from_transform(transform=LW_H_B) LW_J_WL_I = LW_X_B @ B_J_WL_I + return LW_J_WL_I # Adjust the output representation such that `O_v_WL_I = O_J_WL_I @ I_ν`. @@ -455,7 +459,7 @@ def jacobian_derivative( In = jnp.eye(model.dofs()) On = jnp.zeros(shape=(model.dofs(), model.dofs())) - def from_inertial() -> jtp.Matrix: + def from_inertial() -> tuple[jtp.Matrix, jtp.Matrix]: W_H_B = data.base_transform() B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True) @@ -468,9 +472,10 @@ def from_inertial() -> jtp.Matrix: # time derivative. T = jax.scipy.linalg.block_diag(B_X_W, In) Ṫ = jax.scipy.linalg.block_diag(B_Ẋ_W, On) + return T, Ṫ - def from_body() -> jtp.Matrix: + def from_body() -> tuple[jtp.Matrix, jtp.Matrix]: B_X_B = Adjoint.from_rotation_and_translation( translation=jnp.zeros(3), rotation=jnp.eye(3) @@ -482,9 +487,10 @@ def from_body() -> jtp.Matrix: # time derivative. T = jax.scipy.linalg.block_diag(B_X_B, In) Ṫ = jax.scipy.linalg.block_diag(B_Ẋ_B, On) + return T, Ṫ - def from_mixed() -> jtp.Matrix: + def from_mixed() -> tuple[jtp.Matrix, jtp.Matrix]: BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3)) B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) @@ -500,6 +506,7 @@ def from_mixed() -> jtp.Matrix: # time derivative. T = jax.scipy.linalg.block_diag(B_X_BW, In) Ṫ = jax.scipy.linalg.block_diag(B_Ẋ_BW, On) + return T, Ṫ T, Ṫ = jax.lax.switch( @@ -515,22 +522,21 @@ def from_mixed() -> jtp.Matrix: # Compute quantities to adjust the output representation # ====================================================== - def to_inertial() -> jtp.Matrix: + def to_inertial() -> tuple[jtp.Matrix, jtp.Matrix]: W_H_B = data.base_transform() - O_X_B = W_X_B = Adjoint.from_transform(transform=W_H_B) + W_X_B = Adjoint.from_transform(transform=W_H_B) with data.switch_velocity_representation(VelRepr.Body): B_v_WB = data.base_velocity() - O_Ẋ_B = W_Ẋ_B = W_X_B @ jaxsim.math.Cross.vx(B_v_WB) # noqa: F841 - return O_X_B, O_Ẋ_B + W_Ẋ_B = W_X_B @ jaxsim.math.Cross.vx(B_v_WB) - def to_body() -> jtp.Matrix: + return W_X_B, W_Ẋ_B - O_X_B = L_X_B = Adjoint.from_transform( - transform=B_H_L[link_index, :, :], inverse=True - ) + def to_body() -> tuple[jtp.Matrix, jtp.Matrix]: + + L_X_B = Adjoint.from_transform(transform=B_H_L[link_index, :, :], inverse=True) B_X_L = Adjoint.inverse(adjoint=L_X_B) @@ -538,19 +544,18 @@ def to_body() -> jtp.Matrix: B_v_WB = data.base_velocity() L_v_WL = js.link.velocity(model=model, data=data, link_index=link_index) - O_Ẋ_B = L_Ẋ_B = -L_X_B @ jaxsim.math.Cross.vx( # noqa: F841 - B_X_L @ L_v_WL - B_v_WB - ) - return O_X_B, O_Ẋ_B + L_Ẋ_B = -L_X_B @ jaxsim.math.Cross.vx(B_X_L @ L_v_WL - B_v_WB) - def to_mixed() -> jtp.Matrix: + return L_X_B, L_Ẋ_B + + def to_mixed() -> tuple[jtp.Matrix, jtp.Matrix]: W_H_B = data.base_transform() W_H_L = W_H_B @ B_H_L[link_index, :, :] LW_H_L = W_H_L.at[0:3, 3].set(jnp.zeros(3)) LW_H_B = LW_H_L @ Transform.inverse(B_H_L[link_index, :, :]) - O_X_B = LW_X_B = Adjoint.from_transform(transform=LW_H_B) + LW_X_B = Adjoint.from_transform(transform=LW_H_B) B_X_LW = Adjoint.inverse(adjoint=LW_X_B) @@ -564,10 +569,9 @@ def to_mixed() -> jtp.Matrix: LW_v_LW_L = LW_v_WL - LW_v_W_LW LW_v_B_LW = LW_v_WL - LW_X_B @ B_v_WB - LW_v_LW_L - O_Ẋ_B = LW_Ẋ_B = -LW_X_B @ jaxsim.math.Cross.vx( # noqa: F841 - B_X_LW @ LW_v_B_LW - ) - return O_X_B, O_Ẋ_B + LW_Ẋ_B = -LW_X_B @ jaxsim.math.Cross.vx(B_X_LW @ LW_v_B_LW) + + return LW_X_B, LW_Ẋ_B O_X_B, O_Ẋ_B = jax.lax.switch( index=output_vel_repr, diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 8011d3438..b8f9f5701 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -488,19 +488,21 @@ def generalized_free_floating_jacobian( # Update the input velocity representation such that v_WL = J_WL_I @ I_ν # ====================================================================== - def from_inertial(): + def from_inertial() -> jtp.Matrix: W_H_B = data.base_transform() B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True) B_J_full_WX_W = B_J_full_WX_B @ jax.scipy.linalg.block_diag( B_X_W, jnp.eye(model.dofs()) ) + return B_J_full_WX_W - def from_body(): + def from_body() -> jtp.Matrix: + return B_J_full_WX_B - def from_mixed(): + def from_mixed() -> jtp.Matrix: W_R_B = data.base_orientation(dcm=True) BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B) B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) @@ -508,6 +510,7 @@ def from_mixed(): B_J_full_WX_BW = B_J_full_WX_B @ jax.scipy.linalg.block_diag( B_X_BW, jnp.eye(model.dofs()) ) + return B_J_full_WX_BW # Update the input velocity representation such that `J_WL_I @ I_ν`. @@ -520,43 +523,17 @@ def from_mixed(): ), ) - def to_inertial(): - W_H_B = data.base_transform() - W_X_B = Adjoint.from_transform(transform=W_H_B) - W_J_full_WX_I = W_X_B @ B_J_full_WX_I - return W_J_full_WX_I - - def to_body(): - return B_J_full_WX_I - - def to_mixed(): - W_R_B = data.base_orientation(dcm=True) - BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B) - BW_X_B = Adjoint.from_transform(transform=BW_H_B) - BW_J_full_WX_I = BW_X_B @ B_J_full_WX_I - return BW_J_full_WX_I - # ==================================================================== # Create stacked Jacobian for each link by filtering the full Jacobian # ==================================================================== - # Update the output velocity representation such that `O_v_WL_I = O_J_WL_I @ I_ν`. - O_J_full_WX_I = jax.lax.switch( - index=output_vel_repr, - branches=( - to_body, # VelRepr.Body - to_mixed, # VelRepr.Mixed - to_inertial, # VelRepr.Inertial - ), - ) - κ_bool = model.kin_dyn_parameters.support_body_array_bool # Keep only the columns of the full Jacobian corresponding to the support # body array of each link. B_J_WL_I = jax.vmap( lambda κ: jnp.where( - jnp.hstack([jnp.ones(5), κ]), O_J_full_WX_I, jnp.zeros_like(O_J_full_WX_I) + jnp.hstack([jnp.ones(5), κ]), B_J_full_WX_I, jnp.zeros_like(B_J_full_WX_I) ) )(κ_bool) @@ -568,19 +545,19 @@ def to_inertial() -> jtp.Matrix: W_H_B = data.base_transform() W_X_B = jaxsim.math.Adjoint.from_transform(W_H_B) - O_J_WL_I = W_J_WL_I = jax.vmap(lambda B_J_WL_I: W_X_B @ B_J_WL_I)( # noqa: F841 - B_J_WL_I - ) - return O_J_WL_I + W_J_WL_I = jax.vmap(lambda B_J_WL_I: W_X_B @ B_J_WL_I)(B_J_WL_I) + + return W_J_WL_I def to_body() -> jtp.Matrix: - O_J_WL_I = L_J_WL_I = jax.vmap( # noqa: F841 + L_J_WL_I = jax.vmap( lambda B_H_L, B_J_WL_I: jaxsim.math.Adjoint.from_transform( B_H_L, inverse=True ) @ B_J_WL_I )(B_H_L, B_J_WL_I) - return O_J_WL_I + + return L_J_WL_I def to_mixed() -> jtp.Matrix: W_H_B = data.base_transform() @@ -591,11 +568,12 @@ def to_mixed() -> jtp.Matrix: lambda LW_H_L, B_H_L: LW_H_L @ jaxsim.math.Transform.inverse(B_H_L) )(LW_H_L, B_H_L) - O_J_WL_I = LW_J_WL_I = jax.vmap( # noqa: F841 + LW_J_WL_I = jax.vmap( lambda LW_H_B, B_J_WL_I: jaxsim.math.Adjoint.from_transform(LW_H_B) @ B_J_WL_I )(LW_H_B, B_J_WL_I) - return O_J_WL_I + + return LW_J_WL_I O_J_WL_I = jax.lax.switch( index=output_vel_repr, @@ -786,25 +764,27 @@ def to_active( C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True) return C_X_W @ W_v̇_WB - Cross.vx(W_v_WC) @ W_v_WB - def to_inertial(): + def to_inertial() -> tuple[jtp.Vector, jtp.Matrix]: # In this case C=W - W_H_C = W_H_W = jnp.eye(4) # noqa: F841 - W_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841 - return W_v_WC, W_H_C + W_H_W = jnp.eye(4) + W_v_WW = jnp.zeros(6) + + return W_v_WW, W_H_W - def to_body(): + def to_body() -> tuple[jtp.Vector, jtp.Matrix]: # In this case C=B - W_H_C = W_H_B = data.base_transform() # noqa: F841 - W_v_WC = W_v_WB - return W_v_WC, W_H_C + W_H_B = data.base_transform() - def to_mixed(): + return W_v_WB, W_H_B + + def to_mixed() -> tuple[jtp.Vector, jtp.Matrix]: # In this case C=B[W] W_H_B = data.base_transform() - W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) # noqa: F841 + W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) W_ṗ_B = data.base_velocity()[0:3] - W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) # noqa: F841 - return W_v_WC, W_H_C + W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) + + return W_v_W_BW, W_H_BW W_v_WC, W_H_C = jax.lax.switch( index=data.velocity_representation, @@ -1046,10 +1026,10 @@ def compute_link_contribution(M, v, J, J̇) -> jtp.Array: # Adjust the representation of the Coriolis matrix. # Refer to https://github.com/traversaro/traversaro-phd-thesis, Section 3.6. - def to_body(): + def to_body() -> jtp.Matrix: return C_B - def to_inertial(): + def to_inertial() -> jtp.Matrix: n = model.dofs() W_H_B = data.base_transform() B_X_W = jaxsim.math.Adjoint.from_transform(W_H_B, inverse=True) @@ -1068,7 +1048,7 @@ def to_inertial(): return C - def to_mixed(): + def to_mixed() -> jtp.Matrix: n = model.dofs() BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3)) B_X_BW = jaxsim.math.Adjoint.from_transform(transform=BW_H_B, inverse=True) @@ -1169,22 +1149,25 @@ def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC): return W_X_C @ (C_v̇_WB + Cross.vx(C_v_WC) @ C_v_WB) def convert_inertial() -> jtp.Vector: - W_H_C = W_H_W = jnp.eye(4) # noqa: F841 - W_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841 - return W_H_C, W_v_WC + W_H_W = jnp.eye(4) + W_v_WW = jnp.zeros(6) + + return W_H_W, W_v_WW def convert_body() -> jtp.Vector: - W_H_C = W_H_B = data.base_transform() # noqa: F841 + W_H_B = data.base_transform() with data.switch_velocity_representation(VelRepr.Inertial): - W_v_WC = W_v_WB = data.base_velocity() # noqa: F841 - return W_H_C, W_v_WC + W_v_WB = data.base_velocity() + + return W_H_B, W_v_WB def convert_mixed() -> jtp.Vector: W_H_B = data.base_transform() - W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) # noqa: F841 + W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) W_ṗ_B = data.base_velocity()[0:3] - W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) # noqa: F841 - return W_H_C, W_v_WC + W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) + + return W_H_BW, W_v_W_BW W_H_C, W_v_WC = jax.lax.switch( index=data.velocity_representation, @@ -1447,15 +1430,18 @@ def total_momentum_jacobian( B_Jh_B = free_floating_mass_matrix(model=model, data=data)[0:6] def to_body() -> jtp.Matrix: + return B_Jh_B def to_inertial() -> jtp.Matrix: B_X_W = Adjoint.from_transform(transform=data.base_transform(), inverse=True) + return B_Jh_B @ jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs())) def to_mixed() -> jtp.Matrix: BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3)) B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) + return B_Jh_B @ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs())) B_Jh = jax.lax.switch( @@ -1638,29 +1624,30 @@ def other_representation_to_inertial( # W_a_WB, and intrinsic accelerations can be expressed in different frames through # a simple C_X_W 6D transform. def to_inertial() -> jtp.Matrix: - W_H_C = W_H_W = jnp.eye(4) # noqa: F841 - W_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841 + W_H_W = jnp.eye(4) + W_v_WW = jnp.zeros(6) with data.switch_velocity_representation(VelRepr.Inertial): - C_v_WB = W_v_WB = data.base_velocity() # noqa: F841 - return W_H_C, W_v_WC, C_v_WB + W_v_WB = data.base_velocity() + + return W_H_W, W_v_WW, W_v_WB def to_body() -> jtp.Matrix: - W_H_C = W_H_B with data.switch_velocity_representation(VelRepr.Inertial): - W_v_WC = W_v_WB = data.base_velocity() # noqa: F841 + W_v_WB = data.base_velocity() with data.switch_velocity_representation(VelRepr.Body): - C_v_WB = B_v_WB = data.base_velocity() # noqa: F841 - return W_H_C, W_v_WC, C_v_WB + B_v_WB = data.base_velocity() + + return W_H_B, W_v_WB, B_v_WB def to_mixed() -> jtp.Matrix: W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) - W_H_C = W_H_BW with data.switch_velocity_representation(VelRepr.Mixed): W_ṗ_B = data.base_velocity()[0:3] - W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) # noqa: F841 + W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) with data.switch_velocity_representation(VelRepr.Mixed): - C_v_WB = BW_v_WB = data.base_velocity() # noqa: F841 - return W_H_C, W_v_WC, C_v_WB + BW_v_WB = data.base_velocity() + + return W_H_BW, W_v_W_BW, BW_v_WB W_H_C, W_v_WC, C_v_WB = jax.lax.switch( index=data.velocity_representation, @@ -1772,26 +1759,23 @@ def body_to_other_representation( C_X_L = jaxsim.math.Adjoint.from_transform(transform=C_H_L) return C_X_L @ (L_v̇_WL + jaxsim.math.Cross.vx(L_v_CL) @ L_v_WL) - def to_body() -> jtp.Matrix: - C_H_L = L_H_L = jnp.stack([jnp.eye(4)] * model.number_of_links()) # noqa: F841 - L_v_CL = L_v_LL = jnp.zeros(shape=(model.number_of_links(), 6)) # noqa: F841 - return C_H_L, L_v_CL + def to_body() -> tuple[jtp.Matrix, jtp.Vector]: + L_H_L = jnp.stack([jnp.eye(4)] * model.number_of_links()) + L_v_LL = jnp.zeros(shape=(model.number_of_links(), 6)) - def to_inertial() -> jtp.Matrix: - C_H_L = W_H_L = js.model.forward_kinematics( # noqa: F841 - model=model, data=data - ) - L_v_CL = L_v_WL - return C_H_L, L_v_CL + return L_H_L, L_v_LL - def to_mixed() -> jtp.Matrix: + def to_inertial() -> tuple[jtp.Matrix, jtp.Vector]: + W_H_L = js.model.forward_kinematics(model=model, data=data) + + return W_H_L, L_v_WL + + def to_mixed() -> tuple[jtp.Matrix, jtp.Vector]: W_H_L = js.model.forward_kinematics(model=model, data=data) LW_H_L = jax.vmap(lambda W_H_L: W_H_L.at[0:3, 3].set(jnp.zeros(3)))(W_H_L) - C_H_L = LW_H_L - L_v_CL = L_v_LW_L = jax.vmap( # noqa: F841 - lambda v: v.at[0:3].set(jnp.zeros(3)) - )(L_v_WL) - return C_H_L, L_v_CL + L_v_LW_L = jax.vmap(lambda v: v.at[0:3].set(jnp.zeros(3)))(L_v_WL) + + return LW_H_L, L_v_LW_L C_H_L, L_v_CL = jax.lax.switch( index=data.velocity_representation, diff --git a/src/jaxsim/typing.py b/src/jaxsim/typing.py index 4a6ff0474..1e1ed5da3 100644 --- a/src/jaxsim/typing.py +++ b/src/jaxsim/typing.py @@ -38,4 +38,4 @@ BoolLike = bool | Bool | jax.typing.ArrayLike FloatLike = float | Float | jax.typing.ArrayLike -VelRepr = int +VelRepr = Int From d58062290f7e80faf3980d8b8cfa944b30e8edc8 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Thu, 22 Aug 2024 15:58:50 +0200 Subject: [PATCH 27/27] Use `jaxsim.exceptions` module to handle dynamic checks --- src/jaxsim/api/references.py | 148 +++++++++++++---------------------- tests/test_api_references.py | 11 +-- 2 files changed, 54 insertions(+), 105 deletions(-) diff --git a/src/jaxsim/api/references.py b/src/jaxsim/api/references.py index ca21133ba..1d80d994e 100644 --- a/src/jaxsim/api/references.py +++ b/src/jaxsim/api/references.py @@ -186,48 +186,36 @@ def link_forces( # serialization. if model is None: - def inertial() -> jtp.Array: - if link_names is not None: - raise ValueError("Link names cannot be provided without a model") - - return self.input.physics_model.f_ext - - return jax.lax.cond( - pred=(self.velocity_representation == VelRepr.Inertial), - true_fun=inertial, - false_fun=lambda: jax.pure_callback( - callback=lambda: (_ for _ in ()).throw( - ValueError( - "Missing model to use a representation different from `VelRepr.Inertial`" - ) - ), - result_shape_dtypes=self.input.physics_model.f_ext, - ), + exceptions.raise_value_error_if( + condition=jnp.not_equal(self.velocity_representation, VelRepr.Inertial), + msg="Missing model to use a representation different from `VelRepr.Inertial`", ) + if link_names is not None: + raise ValueError("Link names cannot be provided without a model") + + return self.input.physics_model.f_ext + # If we have the model, we can extract the link names, if not provided. link_names = link_names if link_names is not None else model.link_names() link_idxs = js.link.names_to_idxs(link_names=link_names, model=model) - def check_not_inertial() -> None: - if data is None: - raise ValueError( - "Missing model data to use a representation different from `VelRepr.Inertial`" - ) - - if not_tracing(self.input.physics_model.f_ext) and not data.valid( - model=model - ): - raise ValueError("The provided data is not valid for the model") - # If not inertial-fixed representation, we need the model data. - jax.lax.cond( - pred=(self.velocity_representation != VelRepr.Inertial), - true_fun=lambda: jax.pure_callback( - callback=check_not_inertial, - result_shape_dtypes=None, + exceptions.raise_value_error_if( + condition=jnp.logical_and( + jnp.not_equal(self.velocity_representation, VelRepr.Inertial), + data is None, ), - false_fun=lambda: None, + msg="Missing model data to use a representation different from `VelRepr.Inertial`", + ) + + # The f_L output is either L_f_L or LW_f_L, depending on the representation. + exceptions.raise_value_error_if( + condition=jnp.logical_and( + jnp.not_equal(self.velocity_representation, VelRepr.Inertial), + data is None, + ), + msg="Missing model data to use a representation different from `VelRepr.Inertial`", ) def not_inertial(velocity_representation: jtp.VelRepr) -> jtp.Matrix: @@ -244,22 +232,8 @@ def convert(W_f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike) -> jtp.Matrix: ) )(W_f_L, W_H_L) - # The f_L output is either L_f_L or LW_f_L, depending on the representation. - W_H_L = jax.lax.cond( - pred=(data is None), - true_fun=lambda: jax.pure_callback( - callback=lambda: (_ for _ in ()).throw( - ValueError( - "Missing model data to use a representation different from `VelRepr.Inertial`" - ) - ), - result_shape_dtypes=jnp.empty( - shape=(model.number_of_links(), 4, 4) - ), - ), - false_fun=lambda: js.model.forward_kinematics( - model=model, data=data or JaxSimModelData.zero(model=model) - ), + W_H_L = js.model.forward_kinematics( + model=model, data=data or JaxSimModelData.zero(model=model) ) f_L = convert(W_f_L=W_f_L[link_idxs, :], W_H_L=W_H_L[link_idxs, :, :]) @@ -267,7 +241,7 @@ def convert(W_f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike) -> jtp.Matrix: # In inertial-fixed representation, we already have the link forces. return jax.lax.cond( - pred=(self.velocity_representation == VelRepr.Inertial), + pred=jnp.equal(self.velocity_representation, VelRepr.Inertial), true_fun=lambda _: W_f_L[link_idxs, :], false_fun=not_inertial, operand=self.velocity_representation, @@ -417,35 +391,28 @@ def replace(forces: jtp.MatrixLike) -> JaxSimModelReferences: # In this case, we allow only to set the inertial 6D forces to all links # using the implicit link serialization. - if model is None: - - def inertial() -> JaxSimModelReferences: - if link_names is not None: - raise ValueError("Link names cannot be provided without a model") + exceptions.raise_value_error_if( + condition=jnp.not_equal(self.velocity_representation, VelRepr.Inertial) + & (model is None), + msg="Missing model to use a representation different from `VelRepr.Inertial`", + ) - W_f_L = f_L + exceptions.raise_value_error_if( + condition=jnp.logical_and(link_names is not None, model is None), + msg="Link names cannot be provided without a model", + ) - W_f0_L = ( - jnp.zeros_like(W_f_L) - if not additive - else self.input.physics_model.f_ext - ) + if model is None: + W_f_L = f_L - return replace(forces=W_f0_L + W_f_L) - - return jax.lax.cond( - pred=(self.velocity_representation == VelRepr.Inertial), - true_fun=inertial, - false_fun=lambda: jax.pure_callback( - callback=lambda: (_ for _ in ()).throw( - ValueError( - "Missing model to use a representation different from `VelRepr.Inertial`" - ) - ), - result_shape_dtypes=self, - ), + W_f0_L = ( + jnp.zeros_like(W_f_L) + if not additive + else self.input.physics_model.f_ext ) + return replace(forces=W_f0_L + W_f_L) + # If we have the model, we can extract the link names if not provided. link_names = link_names if link_names is not None else model.link_names() @@ -466,6 +433,14 @@ def inertial() -> JaxSimModelReferences: else self.input.physics_model.f_ext[link_idxs, :] ) + exceptions.raise_value_error_if( + condition=jnp.logical_and( + jnp.not_equal(self.velocity_representation, VelRepr.Inertial), + data is None, + ), + msg="Missing model data to use a representation different from `VelRepr.Inertial`", + ) + # If inertial-fixed representation, we can directly store the link forces. def inertial(velocity_representation: jtp.VelRepr) -> JaxSimModelReferences: W_f_L = f_L @@ -491,28 +466,11 @@ def convert_using_link_frame( ) )(f_L, W_H_L) - # If not inertial-fixed representation, we need the model data. - W_H_L = jax.lax.cond( - pred=(data is None), - true_fun=lambda: jax.pure_callback( - callback=lambda: (_ for _ in ()).throw( - ValueError( - "Missing model data to use a representation different from `VelRepr.Inertial`" - ) - ), - result_shape_dtypes=jnp.empty( - shape=(model.number_of_links(), 4, 4) - ), - ), - false_fun=lambda: js.model.forward_kinematics( - model=model, data=data or JaxSimModelData.zero(model=model) - ), + W_H_L = js.model.forward_kinematics( + model=model, data=data or JaxSimModelData.zero(model=model) ) # The f_L input is either L_f_L or LW_f_L, depending on the representation. - # W_H_L = js.model.forward_kinematics( - # model=model, data=data or JaxSimModelData.zero(model=model) - # ) W_f_L = convert_using_link_frame(f_L=f_L, W_H_L=W_H_L[link_idxs, :, :]) return replace( @@ -522,7 +480,7 @@ def convert_using_link_frame( ) return jax.lax.cond( - pred=(self.velocity_representation == VelRepr.Inertial), + pred=jnp.equal(self.velocity_representation, VelRepr.Inertial), true_fun=inertial, false_fun=not_inertial, operand=self.velocity_representation, diff --git a/tests/test_api_references.py b/tests/test_api_references.py index b656d2719..7bc35823e 100644 --- a/tests/test_api_references.py +++ b/tests/test_api_references.py @@ -109,7 +109,7 @@ def test_raise_errors_apply_link_forces( # `model` is None with pytest.raises( - ValueError, + XlaRuntimeError, match="Link names cannot be provided without a model", ): references_inertial.apply_link_forces( @@ -124,15 +124,6 @@ def test_raise_errors_apply_link_forces( model=model, data=data, velocity_representation=VelRepr.Body, key=subkey2 ) - # `model` is None - with pytest.raises( - ValueError, - match="Link names cannot be provided without a model", - ): - references_body.apply_link_forces( - forces=jnp.zeros(6), model=None, data=None, link_names=model.link_names() - ) - # `model` is None with pytest.raises( XlaRuntimeError,