diff --git a/docs/modules/api.rst b/docs/modules/api.rst index 9b3833c39..cff2b8b7f 100644 --- a/docs/modules/api.rst +++ b/docs/modules/api.rst @@ -94,7 +94,7 @@ References Common ~~~~~~ -.. autoflag:: jaxsim.api.common.VelRepr +.. autoclass:: jaxsim.api.common.VelRepr :members: .. autoclass:: jaxsim.api.common.ModelDataWithVelocityRepresentation diff --git a/src/jaxsim/api/com.py b/src/jaxsim/api/com.py index 6d25d276d..9701ced92 100644 --- a/src/jaxsim/api/com.py +++ b/src/jaxsim/api/com.py @@ -2,8 +2,8 @@ import jax.numpy as jnp import jaxsim.api as js -import jaxsim.math import jaxsim.typing as jtp +from jaxsim.math import Adjoint, Cross, Transform from .common import VelRepr @@ -27,7 +27,7 @@ def com_position( W_H_L = js.model.forward_kinematics(model=model, data=data) W_H_B = data.base_transform() - B_H_W = jaxsim.math.Transform.inverse(transform=W_H_B) + B_H_W = Transform.inverse(transform=W_H_B) def B_p̃_LCoM(i) -> jtp.Vector: m = js.link.mass(model=model, link_index=i) @@ -131,20 +131,29 @@ def centroidal_momentum_jacobian( ) W_H_B = data.base_transform() - B_H_W = jaxsim.math.Transform.inverse(W_H_B) + B_H_W = Transform.inverse(W_H_B) W_p_CoM = com_position(model=model, data=data) - match data.velocity_representation: - case VelRepr.Inertial | VelRepr.Mixed: - W_H_G = W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM) # noqa: F841 - case VelRepr.Body: - W_H_G = W_H_GB = W_H_B.at[0:3, 3].set(W_p_CoM) # noqa: F841 - case _: - raise ValueError(data.velocity_representation) + def to_inertial_and_mixed(): + W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM) + return W_H_GW + + def to_body(): + W_H_GB = W_H_B.at[0:3, 3].set(W_p_CoM) + return W_H_GB + + W_H_G = jax.lax.switch( + index=data.velocity_representation, + branches=( + to_body, # VelRepr.Body + to_inertial_and_mixed, # VelRepr.Mixed + to_inertial_and_mixed, # VelRepr.Inertial + ), + ) # Compute the transform for 6D forces. - G_Xf_B = jaxsim.math.Adjoint.from_transform(transform=B_H_W @ W_H_G).T + G_Xf_B = Adjoint.from_transform(transform=B_H_W @ W_H_G).T return G_Xf_B @ B_Jh @@ -170,17 +179,26 @@ def locked_centroidal_spatial_inertia( W_H_B = data.base_transform() W_p_CoM = com_position(model=model, data=data) - match data.velocity_representation: - case VelRepr.Inertial | VelRepr.Mixed: - W_H_G = W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM) # noqa: F841 - case VelRepr.Body: - W_H_G = W_H_GB = W_H_B.at[0:3, 3].set(W_p_CoM) # noqa: F841 - case _: - raise ValueError(data.velocity_representation) + def to_inertial_or_mixed() -> jtp.Matrix: + W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM) + return W_H_GW + + def to_body() -> jtp.Matrix: + W_H_GB = W_H_B.at[0:3, 3].set(W_p_CoM) + return W_H_GB + + W_H_G = jax.lax.switch( + index=data.velocity_representation, + branches=( + to_body, # VelRepr.Body + to_inertial_or_mixed, # VelRepr.Mixed + to_inertial_or_mixed, # VelRepr.Inertial + ), + ) - B_H_G = jaxsim.math.Transform.inverse(W_H_B) @ W_H_G + B_H_G = Transform.inverse(W_H_B) @ W_H_G - B_Xv_G = jaxsim.math.Adjoint.from_transform(transform=B_H_G) + B_Xv_G = Adjoint.from_transform(transform=B_H_G) G_Xf_B = B_Xv_G.transpose() return G_Xf_B @ B_Mbb_B @ B_Xv_G @@ -275,80 +293,86 @@ def other_representation_to_body( C_v̇_WL expressed in a generic frame C to the body-fixed representation L_v̇_WL. """ - L_X_C = jaxsim.math.Adjoint.from_transform(transform=L_H_C) - C_X_L = jaxsim.math.Adjoint.inverse(L_X_C) + L_X_C = Adjoint.from_transform(transform=L_H_C) + C_X_L = Adjoint.inverse(L_X_C) - L_v̇_WL = L_X_C @ (C_v̇_WL + jaxsim.math.Cross.vx(C_X_L @ L_v_LC) @ C_v_WC) + L_v̇_WL = L_X_C @ (C_v̇_WL + Cross.vx(C_X_L @ L_v_LC) @ C_v_WC) return L_v̇_WL + def to_body() -> jtp.Vector: + L_a_bias_WL = v̇_bias_WL + + return L_a_bias_WL + + def to_inertial() -> jtp.Vector: + + W_v̇_bias_WL = v̇_bias_WL + W_v_WW = jnp.zeros(6) + + L_H_W = jax.vmap(lambda W_H_L: Transform.inverse(W_H_L))(W_H_L) + + L_v_LW = jax.vmap( + lambda i: -js.link.velocity( + model=model, data=data, link_index=i, output_vel_repr=VelRepr.Body + ) + )(jnp.arange(model.number_of_links())) + + L_a_bias_WL = jax.vmap( + lambda i: other_representation_to_body( + C_v̇_WL=W_v̇_bias_WL[i], + C_v_WC=W_v_WW, + L_H_C=L_H_W[i], + L_v_LC=L_v_LW[i], + ) + )(jnp.arange(model.number_of_links())) + + return L_a_bias_WL + + def to_mixed() -> jtp.Vector: + + LW_v̇_bias_WL = v̇_bias_WL + + LW_v_W_LW = jax.vmap( + lambda i: js.link.velocity( + model=model, data=data, link_index=i, output_vel_repr=VelRepr.Mixed + ) + .at[3:6] + .set(jnp.zeros(3)) + )(jnp.arange(model.number_of_links())) + + L_H_LW = jax.vmap( + lambda W_H_L: Transform.inverse(W_H_L.at[0:3, 3].set(jnp.zeros(3))) + )(W_H_L) + + L_v_L_LW = jax.vmap( + lambda i: -js.link.velocity( + model=model, data=data, link_index=i, output_vel_repr=VelRepr.Body + ) + .at[0:3] + .set(jnp.zeros(3)) + )(jnp.arange(model.number_of_links())) + + L_a_bias_WL = jax.vmap( + lambda i: other_representation_to_body( + C_v̇_WL=LW_v̇_bias_WL[i], + C_v_WC=LW_v_W_LW[i], + L_H_C=L_H_LW[i], + L_v_LC=L_v_L_LW[i], + ) + )(jnp.arange(model.number_of_links())) + + return L_a_bias_WL + # We need here to get the body-fixed bias acceleration of the links. # Since it's computed in the active representation, we need to convert it to body. - match data.velocity_representation: - - case VelRepr.Body: - L_a_bias_WL = v̇_bias_WL - - case VelRepr.Inertial: - - C_v̇_WL = W_v̇_bias_WL = v̇_bias_WL # noqa: F841 - C_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841 - - L_H_C = L_H_W = jax.vmap( # noqa: F841 - lambda W_H_L: jaxsim.math.Transform.inverse(W_H_L) - )(W_H_L) - - L_v_LC = L_v_LW = jax.vmap( # noqa: F841 - lambda i: -js.link.velocity( - model=model, data=data, link_index=i, output_vel_repr=VelRepr.Body - ) - )(jnp.arange(model.number_of_links())) - - L_a_bias_WL = jax.vmap( - lambda i: other_representation_to_body( - C_v̇_WL=C_v̇_WL[i], - C_v_WC=C_v_WC, - L_H_C=L_H_C[i], - L_v_LC=L_v_LC[i], - ) - )(jnp.arange(model.number_of_links())) - - case VelRepr.Mixed: - - C_v̇_WL = LW_v̇_bias_WL = v̇_bias_WL # noqa: F841 - - C_v_WC = LW_v_W_LW = jax.vmap( # noqa: F841 - lambda i: js.link.velocity( - model=model, data=data, link_index=i, output_vel_repr=VelRepr.Mixed - ) - .at[3:6] - .set(jnp.zeros(3)) - )(jnp.arange(model.number_of_links())) - - L_H_C = L_H_LW = jax.vmap( # noqa: F841 - lambda W_H_L: jaxsim.math.Transform.inverse( - W_H_L.at[0:3, 3].set(jnp.zeros(3)) - ) - )(W_H_L) - - L_v_LC = L_v_L_LW = jax.vmap( # noqa: F841 - lambda i: -js.link.velocity( - model=model, data=data, link_index=i, output_vel_repr=VelRepr.Body - ) - .at[0:3] - .set(jnp.zeros(3)) - )(jnp.arange(model.number_of_links())) - - L_a_bias_WL = jax.vmap( - lambda i: other_representation_to_body( - C_v̇_WL=C_v̇_WL[i], - C_v_WC=C_v_WC[i], - L_H_C=L_H_C[i], - L_v_LC=L_v_LC[i], - ) - )(jnp.arange(model.number_of_links())) - - case _: - raise ValueError(data.velocity_representation) + L_a_bias_WL = jax.lax.switch( + index=data.velocity_representation, + branches=( + to_body, # VelRepr.Body + to_mixed, # VelRepr.Mixed + to_inertial, # VelRepr.Inertial + ), + ) # Compute the bias of the 6D momentum derivative. def bias_momentum_derivative_term( @@ -364,13 +388,11 @@ def bias_momentum_derivative_term( ) # Compute the world-to-link transformations for 6D forces. - W_Xf_L = jaxsim.math.Adjoint.from_transform( - transform=W_H_L[link_index], inverse=True - ).T + W_Xf_L = Adjoint.from_transform(transform=W_H_L[link_index], inverse=True).T # Compute the contribution of the link to the bias acceleration of the CoM. W_ḣ_bias_link_contribution = W_Xf_L @ ( - L_M_L @ L_a_bias_WL + jaxsim.math.Cross.vx_star(L_v_WL) @ L_M_L @ L_v_WL + L_M_L @ L_a_bias_WL + Cross.vx_star(L_v_WL) @ L_M_L @ L_v_WL ) return W_ḣ_bias_link_contribution @@ -386,30 +408,36 @@ def bias_momentum_derivative_term( # Compute the position of the CoM. W_p_CoM = com_position(model=model, data=data) - match data.velocity_representation: - + def to_inertial_or_mixed() -> jtp.Vector: # G := G[W] = (W_p_CoM, [W]) - case VelRepr.Inertial | VelRepr.Mixed: - W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM) - GW_Xf_W = jaxsim.math.Adjoint.from_transform(W_H_GW).T + W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM) + GW_Xf_W = Adjoint.from_transform(W_H_GW).T - GW_ḣ_bias = GW_Xf_W @ W_ḣ_bias - GW_v̇l_com_bias = GW_ḣ_bias[0:3] / m + GW_ḣ_bias = GW_Xf_W @ W_ḣ_bias + GW_v̇l_com_bias = GW_ḣ_bias[0:3] / m - return GW_v̇l_com_bias + return GW_v̇l_com_bias + def to_body() -> jtp.Vector: # G := G[B] = (W_p_CoM, [B]) - case VelRepr.Body: - GB_Xf_W = jaxsim.math.Adjoint.from_transform( - transform=data.base_transform().at[0:3].set(W_p_CoM) - ).T + GB_Xf_W = Adjoint.from_transform( + transform=data.base_transform().at[0:3, 3].set(W_p_CoM) + ).T - GB_ḣ_bias = GB_Xf_W @ W_ḣ_bias - GB_v̇l_com_bias = GB_ḣ_bias[0:3] / m + GB_ḣ_bias = GB_Xf_W @ W_ḣ_bias + GB_v̇l_com_bias = GB_ḣ_bias[0:3] / m - return GB_v̇l_com_bias + return GB_v̇l_com_bias + + GB_v̇l_com_bias = jax.lax.switch( + index=data.velocity_representation, + branches=( + to_body, # VelRepr.Body + to_inertial_or_mixed, # VelRepr.Mixed + to_inertial_or_mixed, # VelRepr.Inertial + ), + ) - case _: - raise ValueError(data.velocity_representation) + return GB_v̇l_com_bias diff --git a/src/jaxsim/api/common.py b/src/jaxsim/api/common.py index d8ba8dd7b..33bec4567 100644 --- a/src/jaxsim/api/common.py +++ b/src/jaxsim/api/common.py @@ -1,14 +1,12 @@ import abc import contextlib import dataclasses -import enum import functools -from typing import ContextManager +from typing import ClassVar, ContextManager import jax import jax.numpy as jnp import jax_dataclasses -from jax_dataclasses import Static import jaxsim.typing as jtp from jaxsim.math import Adjoint @@ -20,15 +18,15 @@ from typing_extensions import Self -@enum.unique -class VelRepr(enum.IntEnum): +@dataclasses.dataclass(frozen=True) +class VelRepr: """ Enumeration of all supported 6D velocity representations. """ - Body = enum.auto() - Mixed = enum.auto() - Inertial = enum.auto() + Body: ClassVar[int] = 0 + Mixed: ClassVar[int] = 1 + Inertial: ClassVar[int] = 2 @jax_dataclasses.pytree_dataclass @@ -37,13 +35,13 @@ class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC): Base class for model data structures with velocity representation. """ - velocity_representation: Static[VelRepr] = dataclasses.field( + velocity_representation: jtp.VelRepr = dataclasses.field( default=VelRepr.Inertial, kw_only=True ) @contextlib.contextmanager def switch_velocity_representation( - self, velocity_representation: VelRepr + self, velocity_representation: jtp.VelRepr ) -> ContextManager[Self]: """ Context manager to temporarily switch the velocity representation. @@ -82,10 +80,10 @@ def switch_velocity_representation( self.velocity_representation = original_representation @staticmethod - @functools.partial(jax.jit, static_argnames=["other_representation", "is_force"]) + @functools.partial(jax.jit, static_argnames=["is_force"]) def inertial_to_other_representation( array: jtp.Array, - other_representation: VelRepr, + other_representation: jtp.VelRepr, transform: jtp.Matrix, *, is_force: bool, @@ -114,45 +112,46 @@ def inertial_to_other_representation( if W_H_O.shape != (4, 4): raise ValueError(W_H_O.shape, (4, 4)) - match other_representation: - - case VelRepr.Inertial: - return W_array - - case VelRepr.Body: - - if not is_force: - O_Xv_W = Adjoint.from_transform(transform=W_H_O, inverse=True) - O_array = O_Xv_W @ W_array - - else: - O_Xf_W = Adjoint.from_transform(transform=W_H_O).T - O_array = O_Xf_W @ W_array - - return O_array - - case VelRepr.Mixed: - W_p_O = W_H_O[0:3, 3] - W_H_OW = jnp.eye(4).at[0:3, 3].set(W_p_O) - - if not is_force: - OW_Xv_W = Adjoint.from_transform(transform=W_H_OW, inverse=True) - OW_array = OW_Xv_W @ W_array - - else: - OW_Xf_W = Adjoint.from_transform(transform=W_H_OW).T - OW_array = OW_Xf_W @ W_array - - return OW_array - - case _: - raise ValueError(other_representation) + def to_inertial() -> jtp.Array: + + return W_array + + def to_body() -> jtp.Array: + if not is_force: + O_Xv_W = Adjoint.from_transform(transform=W_H_O, inverse=True) + O_array = O_Xv_W @ W_array + else: + O_Xf_W = Adjoint.from_transform(transform=W_H_O).T + O_array = O_Xf_W @ W_array + + return O_array + + def to_mixed() -> jtp.Array: + W_p_O = W_H_O[0:3, 3] + W_H_OW = jnp.eye(4).at[0:3, 3].set(W_p_O) + if not is_force: + OW_Xv_W = Adjoint.from_transform(transform=W_H_OW, inverse=True) + OW_array = OW_Xv_W @ W_array + else: + OW_Xf_W = Adjoint.from_transform(transform=W_H_OW).T + OW_array = OW_Xf_W @ W_array + + return OW_array + + return jax.lax.switch( + index=other_representation, + branches=( + to_body, # VelRepr.Body + to_mixed, # VelRepr.Mixed + to_inertial, # VelRepr.Inertial + ), + ) @staticmethod - @functools.partial(jax.jit, static_argnames=["other_representation", "is_force"]) + @functools.partial(jax.jit, static_argnames=["is_force"]) def other_representation_to_inertial( array: jtp.Array, - other_representation: VelRepr, + other_representation: jtp.VelRepr, transform: jtp.Matrix, *, is_force: bool, @@ -181,38 +180,43 @@ def other_representation_to_inertial( if W_H_O.shape != (4, 4): raise ValueError(W_H_O.shape, (4, 4)) - match other_representation: - case VelRepr.Inertial: - W_array = array - return W_array + def from_inertial(): + W_array = array + return W_array - case VelRepr.Body: - O_array = array + def from_body(): + O_array = array - if not is_force: - W_Xv_O: jtp.Array = Adjoint.from_transform(W_H_O) - W_array = W_Xv_O @ O_array + if not is_force: + W_Xv_O = Adjoint.from_transform(W_H_O) + W_array = W_Xv_O @ O_array - else: - W_Xf_O = Adjoint.from_transform(transform=W_H_O, inverse=True).T - W_array = W_Xf_O @ O_array + else: + W_Xf_O = Adjoint.from_transform(transform=W_H_O, inverse=True).T + W_array = W_Xf_O @ O_array - return W_array + return W_array - case VelRepr.Mixed: - BW_array = array - W_p_O = W_H_O[0:3, 3] - W_H_OW = jnp.eye(4).at[0:3, 3].set(W_p_O) + def from_mixed(): + BW_array = array + W_p_O = W_H_O[0:3, 3] + W_H_OW = jnp.eye(4).at[0:3, 3].set(W_p_O) - if not is_force: - W_Xv_BW: jtp.Array = Adjoint.from_transform(W_H_OW) - W_array = W_Xv_BW @ BW_array + if not is_force: + W_Xv_BW = Adjoint.from_transform(W_H_OW) + W_array = W_Xv_BW @ BW_array - else: - W_Xf_BW = Adjoint.from_transform(transform=W_H_OW, inverse=True).T - W_array = W_Xf_BW @ BW_array + else: + W_Xf_BW = Adjoint.from_transform(transform=W_H_OW, inverse=True).T + W_array = W_Xf_BW @ BW_array - return W_array + return W_array - case _: - raise ValueError(other_representation) + return jax.lax.switch( + index=other_representation, + branches=( + from_body, # VelRepr.Body + from_mixed, # VelRepr.Mixed + from_inertial, # VelRepr.Inertial + ), + ) diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 3caa75624..96d66282f 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -334,7 +334,7 @@ def jacobian( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, *, - output_vel_repr: VelRepr | None = None, + output_vel_repr: jtp.VelRepr | None = None, ) -> jtp.Array: r""" Return the free-floating Jacobian of the collidable points. @@ -372,44 +372,49 @@ def jacobian( jnp.array(model.kin_dyn_parameters.contact_parameters.body, dtype=int) ) - # Adjust the output representation. - match output_vel_repr: + def to_inertial() -> jtp.Matrix: - case VelRepr.Inertial: - O_J_WC = W_J_WC + return W_J_WC - case VelRepr.Body: + def to_body() -> jtp.Matrix: - W_H_C = transforms(model=model, data=data) + W_H_C = transforms(model=model, data=data) - def body_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix: - C_X_W = jaxsim.math.Adjoint.from_transform( - transform=W_H_C, inverse=True - ) - C_J_WC = C_X_W @ W_J_WC - return C_J_WC + def jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix: + C_X_W = jaxsim.math.Adjoint.from_transform(transform=W_H_C, inverse=True) + C_J_WC = C_X_W @ W_J_WC + return C_J_WC - O_J_WC = jax.vmap(body_jacobian)(W_H_C, W_J_WC) + C_J_WC = jax.vmap(jacobian)(W_H_C, W_J_WC) - case VelRepr.Mixed: + return C_J_WC - W_H_C = transforms(model=model, data=data) + def to_mixed() -> jtp.Matrix: - def mixed_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix: + W_H_C = transforms(model=model, data=data) - W_H_CW = W_H_C.at[0:3, 0:3].set(jnp.eye(3)) + def jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix: - CW_X_W = jaxsim.math.Adjoint.from_transform( - transform=W_H_CW, inverse=True - ) + W_H_CW = W_H_C.at[0:3, 0:3].set(jnp.eye(3)) - CW_J_WC = CW_X_W @ W_J_WC - return CW_J_WC + CW_X_W = jaxsim.math.Adjoint.from_transform(transform=W_H_CW, inverse=True) - O_J_WC = jax.vmap(mixed_jacobian)(W_H_C, W_J_WC) + CW_J_WC = CW_X_W @ W_J_WC + return CW_J_WC - case _: - raise ValueError(output_vel_repr) + CW_J_WC = jax.vmap(jacobian)(W_H_C, W_J_WC) + + return CW_J_WC + + # Adjust the output representation. + O_J_WC = jax.lax.switch( + index=output_vel_repr, + branches=( + to_body, # VelRepr.Body + to_mixed, # VelRepr.Mixed + to_inertial, # VelRepr.Inertial + ), + ) return O_J_WC @@ -466,40 +471,50 @@ def compute_Ṫ(model: js.model.JaxSimModel, Ẋ: jtp.Matrix) -> jtp.Matrix: # Compute the operator to change the representation of ν, and its # time derivative. - match data.velocity_representation: - case VelRepr.Inertial: - W_H_W = jnp.eye(4) - W_X_W = Adjoint.from_transform(transform=W_H_W) - W_Ẋ_W = jnp.zeros((6, 6)) - - T = compute_T(model=model, X=W_X_W) - Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_W) - - case VelRepr.Body: - W_H_B = data.base_transform() - W_X_B = Adjoint.from_transform(transform=W_H_B) - B_v_WB = data.base_velocity() - B_vx_WB = Cross.vx(B_v_WB) - W_Ẋ_B = W_X_B @ B_vx_WB - - T = compute_T(model=model, X=W_X_B) - Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_B) - - case VelRepr.Mixed: - W_H_B = data.base_transform() - W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) - W_X_BW = Adjoint.from_transform(transform=W_H_BW) - BW_v_WB = data.base_velocity() - BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3)) - BW_vx_W_BW = Cross.vx(BW_v_W_BW) - W_Ẋ_BW = W_X_BW @ BW_vx_W_BW - - T = compute_T(model=model, X=W_X_BW) - Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_BW) - - case _: - raise ValueError(data.velocity_representation) - + def from_inertial(): + W_H_W = jnp.eye(4) + W_X_W = Adjoint.from_transform(transform=W_H_W) + W_Ẋ_W = jnp.zeros((6, 6)) + + T = compute_T(model=model, X=W_X_W) + Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_W) + + return T, Ṫ + + def from_body(): + W_H_B = data.base_transform() + W_X_B = Adjoint.from_transform(transform=W_H_B) + B_v_WB = data.base_velocity() + B_vx_WB = Cross.vx(B_v_WB) + W_Ẋ_B = W_X_B @ B_vx_WB + + T = compute_T(model=model, X=W_X_B) + Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_B) + + return T, Ṫ + + def from_mixed(): + W_H_B = data.base_transform() + W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) + W_X_BW = Adjoint.from_transform(transform=W_H_BW) + BW_v_WB = data.base_velocity() + BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3)) + BW_vx_W_BW = Cross.vx(BW_v_W_BW) + W_Ẋ_BW = W_X_BW @ BW_vx_W_BW + + T = compute_T(model=model, X=W_X_BW) + Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_BW) + + return T, Ṫ + + T, Ṫ = jax.lax.switch( + index=data.velocity_representation, + branches=( + from_body, # VelRepr.Body + from_mixed, # VelRepr.Mixed + from_inertial, # VelRepr.Inertial + ), + ) # ===================================================== # Compute quantities to adjust the output representation # ===================================================== @@ -535,37 +550,46 @@ def compute_O_J̇_WC_I( parent_link_idx = parent_link_idxs[contact_idx] - match output_vel_repr: - case VelRepr.Inertial: - O_X_W = W_X_W = Adjoint.from_transform( # noqa: F841 - transform=jnp.eye(4) - ) - O_Ẋ_W = W_Ẋ_W = jnp.zeros((6, 6)) # noqa: F841 - - case VelRepr.Body: - L_H_C = Transform.from_rotation_and_translation(translation=L_p_C) - W_H_C = W_H_L[parent_link_idx] @ L_H_C - O_X_W = C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True) - with data.switch_velocity_representation(VelRepr.Inertial): - W_nu = data.generalized_velocity() - W_v_WC = W_J_WL_W[parent_link_idx] @ W_nu - W_vx_WC = Cross.vx(W_v_WC) - O_Ẋ_W = C_Ẋ_W = -C_X_W @ W_vx_WC # noqa: F841 - - case VelRepr.Mixed: - L_H_C = Transform.from_rotation_and_translation(translation=L_p_C) - W_H_C = W_H_L[parent_link_idx] @ L_H_C - W_H_CW = W_H_C.at[0:3, 0:3].set(jnp.eye(3)) - CW_H_W = Transform.inverse(W_H_CW) - O_X_W = CW_X_W = Adjoint.from_transform(transform=CW_H_W) - with data.switch_velocity_representation(VelRepr.Mixed): - CW_v_WC = CW_J_WC_BW @ data.generalized_velocity() - W_v_W_CW = jnp.zeros(6).at[0:3].set(CW_v_WC[0:3]) - W_vx_W_CW = Cross.vx(W_v_W_CW) - O_Ẋ_W = CW_Ẋ_W = -CW_X_W @ W_vx_W_CW # noqa: F841 - - case _: - raise ValueError(output_vel_repr) + def to_inertial() -> tuple[jtp.Matrix, jtp.Matrix]: + W_X_W = Adjoint.from_transform(transform=jnp.eye(4)) + W_Ẋ_W = jnp.zeros((6, 6)) + + return W_X_W, W_Ẋ_W + + def to_body() -> tuple[jtp.Matrix, jtp.Matrix]: + L_H_C = Transform.from_rotation_and_translation(translation=L_p_C) + W_H_C = W_H_L[parent_link_idx] @ L_H_C + C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True) + with data.switch_velocity_representation(VelRepr.Inertial): + W_nu = data.generalized_velocity() + W_v_WC = W_J_WL_W[parent_link_idx] @ W_nu + W_vx_WC = Cross.vx(W_v_WC) + C_Ẋ_W = -C_X_W @ W_vx_WC + + return C_X_W, C_Ẋ_W + + def to_mixed() -> tuple[jtp.Matrix, jtp.Matrix]: + L_H_C = Transform.from_rotation_and_translation(translation=L_p_C) + W_H_C = W_H_L[parent_link_idx] @ L_H_C + W_H_CW = W_H_C.at[0:3, 0:3].set(jnp.eye(3)) + CW_H_W = Transform.inverse(W_H_CW) + CW_X_W = Adjoint.from_transform(transform=CW_H_W) + with data.switch_velocity_representation(VelRepr.Mixed): + CW_v_WC = CW_J_WC_BW @ data.generalized_velocity() + W_v_W_CW = jnp.zeros(6).at[0:3].set(CW_v_WC[0:3]) + W_vx_W_CW = Cross.vx(W_v_W_CW) + CW_Ẋ_W = -CW_X_W @ W_vx_W_CW + + return CW_X_W, CW_Ẋ_W + + O_X_W, O_Ẋ_W = jax.lax.switch( + index=output_vel_repr, + branches=( + to_body, # VelRepr.Body + to_mixed, # VelRepr.Mixed + to_inertial, # VelRepr.Inertial + ), + ) O_J̇_WC_I = jnp.zeros(shape=(6, 6 + model.dofs())) O_J̇_WC_I += O_Ẋ_W @ W_J_WL_W[parent_link_idx] @ T diff --git a/src/jaxsim/api/data.py b/src/jaxsim/api/data.py index 73883dd5c..0a8185fee 100644 --- a/src/jaxsim/api/data.py +++ b/src/jaxsim/api/data.py @@ -84,7 +84,7 @@ def valid(self, model: js.model.JaxSimModel | None = None) -> bool: @staticmethod def zero( model: js.model.JaxSimModel, - velocity_representation: VelRepr = VelRepr.Inertial, + velocity_representation: jtp.VelRepr = VelRepr.Inertial, ) -> JaxSimModelData: """ Create a `JaxSimModelData` object with zero state. @@ -113,7 +113,7 @@ def build( standard_gravity: jtp.FloatLike = jaxsim.math.StandardGravity, contact: jaxsim.rbda.ContactsState | None = None, contacts_params: jaxsim.rbda.ContactsParams | None = None, - velocity_representation: VelRepr = VelRepr.Inertial, + velocity_representation: jtp.VelRepr = VelRepr.Inertial, time: jtp.FloatLike | None = None, ) -> JaxSimModelData: """ @@ -621,11 +621,11 @@ def reset_base_pose(self, base_pose: jtp.MatrixLike) -> Self: base_quaternion=W_Q_B ) - @functools.partial(jax.jit, static_argnames=["velocity_representation"]) + @jax.jit def reset_base_linear_velocity( self, linear_velocity: jtp.VectorLike, - velocity_representation: VelRepr | None = None, + velocity_representation: jtp.VelRepr | None = None, ) -> Self: """ Reset the base linear velocity. @@ -652,11 +652,11 @@ def reset_base_linear_velocity( velocity_representation=velocity_representation, ) - @functools.partial(jax.jit, static_argnames=["velocity_representation"]) + @jax.jit def reset_base_angular_velocity( self, angular_velocity: jtp.VectorLike, - velocity_representation: VelRepr | None = None, + velocity_representation: jtp.VelRepr | None = None, ) -> Self: """ Reset the base angular velocity. @@ -683,11 +683,11 @@ def reset_base_angular_velocity( velocity_representation=velocity_representation, ) - @functools.partial(jax.jit, static_argnames=["velocity_representation"]) + @jax.jit def reset_base_velocity( self, base_velocity: jtp.VectorLike, - velocity_representation: VelRepr | None = None, + velocity_representation: jtp.VelRepr | None = None, ) -> Self: """ Reset the base 6D velocity. @@ -732,7 +732,7 @@ def random_model_data( model: js.model.JaxSimModel, *, key: jax.Array | None = None, - velocity_representation: VelRepr | None = None, + velocity_representation: jtp.VelRepr | None = None, base_pos_bounds: tuple[ jtp.FloatLike | Sequence[jtp.FloatLike], jtp.FloatLike | Sequence[jtp.FloatLike], diff --git a/src/jaxsim/api/frame.py b/src/jaxsim/api/frame.py index 2767e38dd..0d8ae9dcc 100644 --- a/src/jaxsim/api/frame.py +++ b/src/jaxsim/api/frame.py @@ -180,13 +180,13 @@ def transform( return W_H_L @ L_H_F -@functools.partial(jax.jit, static_argnames=["output_vel_repr"]) +@jax.jit def jacobian( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, *, frame_index: jtp.IntLike, - output_vel_repr: VelRepr | None = None, + output_vel_repr: jtp.VelRepr | None = None, ) -> jtp.Matrix: r""" Compute the free-floating jacobian of the frame. @@ -227,34 +227,42 @@ def jacobian( model=model, data=data, link_index=L, output_vel_repr=VelRepr.Body ) - # Adjust the output representation. - match output_vel_repr: - case VelRepr.Inertial: - W_H_L = js.link.transform(model=model, data=data, link_index=L) - W_X_L = Adjoint.from_transform(transform=W_H_L) - W_J_WL = W_X_L @ L_J_WL - O_J_WL_I = W_J_WL - - case VelRepr.Body: - W_H_L = js.link.transform(model=model, data=data, link_index=L) - W_H_F = transform(model=model, data=data, frame_index=frame_index) - F_H_L = Transform.inverse(W_H_F) @ W_H_L - F_X_L = Adjoint.from_transform(transform=F_H_L) - F_J_WL = F_X_L @ L_J_WL - O_J_WL_I = F_J_WL - - case VelRepr.Mixed: - W_H_L = js.link.transform(model=model, data=data, link_index=L) - W_H_F = transform(model=model, data=data, frame_index=frame_index) - F_H_L = Transform.inverse(W_H_F) @ W_H_L - FW_H_F = W_H_F.at[0:3, 3].set(jnp.zeros(3)) - FW_H_L = FW_H_F @ F_H_L - FW_X_L = Adjoint.from_transform(transform=FW_H_L) - FW_J_WL = FW_X_L @ L_J_WL - O_J_WL_I = FW_J_WL - - case _: - raise ValueError(output_vel_repr) + def to_inertial() -> jtp.Matrix: + W_H_L = js.link.transform(model=model, data=data, link_index=L) + W_X_L = Adjoint.from_transform(transform=W_H_L) + W_J_WL = W_X_L @ L_J_WL + + return W_J_WL + + def to_body() -> jtp.Matrix: + W_H_L = js.link.transform(model=model, data=data, link_index=L) + W_H_F = transform(model=model, data=data, frame_index=frame_index) + F_H_L = Transform.inverse(W_H_F) @ W_H_L + F_X_L = Adjoint.from_transform(transform=F_H_L) + F_J_WL = F_X_L @ L_J_WL + + return F_J_WL + + def to_mixed() -> jtp.Matrix: + W_H_L = js.link.transform(model=model, data=data, link_index=L) + W_H_F = transform(model=model, data=data, frame_index=frame_index) + F_H_L = Transform.inverse(W_H_F) @ W_H_L + FW_H_F = W_H_F.at[0:3, 3].set(jnp.zeros(3)) + FW_H_L = FW_H_F @ F_H_L + FW_X_L = Adjoint.from_transform(transform=FW_H_L) + FW_J_WL = FW_X_L @ L_J_WL + + return FW_J_WL + + # Adjust the output representation + O_J_WL_I = jax.lax.switch( + index=output_vel_repr, + branches=( + to_body, # VelRepr.Body + to_mixed, # VelRepr.Mixed + to_inertial, # VelRepr.Inertial + ), + ) return O_J_WL_I @@ -334,77 +342,99 @@ def compute_Ṫ(model: js.model.JaxSimModel, Ẋ: jtp.Matrix) -> jtp.Matrix: # Compute the operator to change the representation of ν, and its # time derivative. - match data.velocity_representation: - case VelRepr.Inertial: - W_H_W = jnp.eye(4) - W_X_W = Adjoint.from_transform(transform=W_H_W) - W_Ẋ_W = jnp.zeros((6, 6)) - - T = compute_T(model=model, X=W_X_W) - Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_W) - - case VelRepr.Body: - W_H_B = data.base_transform() - W_X_B = Adjoint.from_transform(transform=W_H_B) - B_v_WB = data.base_velocity() - B_vx_WB = Cross.vx(B_v_WB) - W_Ẋ_B = W_X_B @ B_vx_WB - - T = compute_T(model=model, X=W_X_B) - Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_B) - - case VelRepr.Mixed: - W_H_B = data.base_transform() - W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) - W_X_BW = Adjoint.from_transform(transform=W_H_BW) - BW_v_WB = data.base_velocity() - BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3)) - BW_vx_W_BW = Cross.vx(BW_v_W_BW) - W_Ẋ_BW = W_X_BW @ BW_vx_W_BW - - T = compute_T(model=model, X=W_X_BW) - Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_BW) - - case _: - raise ValueError(data.velocity_representation) + def from_inertial(): + W_H_W = jnp.eye(4) + W_X_W = Adjoint.from_transform(transform=W_H_W) + W_Ẋ_W = jnp.zeros((6, 6)) + + T = compute_T(model=model, X=W_X_W) + Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_W) + + return T, Ṫ + + def from_body(): + W_H_B = data.base_transform() + W_X_B = Adjoint.from_transform(transform=W_H_B) + B_v_WB = data.base_velocity() + B_vx_WB = Cross.vx(B_v_WB) + W_Ẋ_B = W_X_B @ B_vx_WB + + T = compute_T(model=model, X=W_X_B) + Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_B) + + return T, Ṫ + + def from_mixed(): + W_H_B = data.base_transform() + W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) + W_X_BW = Adjoint.from_transform(transform=W_H_BW) + BW_v_WB = data.base_velocity() + BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3)) + BW_vx_W_BW = Cross.vx(BW_v_W_BW) + W_Ẋ_BW = W_X_BW @ BW_vx_W_BW + + T = compute_T(model=model, X=W_X_BW) + Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_BW) + + return T, Ṫ + + T, Ṫ = jax.lax.switch( + index=data.velocity_representation, + branches=( + from_body, # VelRepr.Body + from_mixed, # VelRepr.Mixed + from_inertial, # VelRepr.Inertial + ), + ) # ===================================================== # Compute quantities to adjust the output representation # ===================================================== - match output_vel_repr: - case VelRepr.Inertial: - O_X_W = W_X_W = Adjoint.from_transform(transform=jnp.eye(4)) - O_Ẋ_W = W_Ẋ_W = jnp.zeros((6, 6)) - - case VelRepr.Body: - W_H_F = transform(model=model, data=data, frame_index=frame_index) - O_X_W = F_X_W = Adjoint.from_transform(transform=W_H_F, inverse=True) - with data.switch_velocity_representation(VelRepr.Inertial): - W_nu = data.generalized_velocity() - W_v_WF = W_J_WL_W @ W_nu - W_vx_WF = Cross.vx(W_v_WF) - O_Ẋ_W = F_Ẋ_W = -F_X_W @ W_vx_WF # noqa: F841 - - case VelRepr.Mixed: - W_H_F = transform(model=model, data=data, frame_index=frame_index) - W_H_FW = W_H_F.at[0:3, 0:3].set(jnp.eye(3)) - FW_H_W = Transform.inverse(W_H_FW) - O_X_W = FW_X_W = Adjoint.from_transform(transform=FW_H_W) - with data.switch_velocity_representation(VelRepr.Mixed): - FW_J_WF_FW = jacobian( - model=model, - data=data, - frame_index=frame_index, - output_vel_repr=VelRepr.Mixed, - ) - FW_v_WF = FW_J_WF_FW @ data.generalized_velocity() - W_v_W_FW = jnp.zeros(6).at[0:3].set(FW_v_WF[0:3]) - W_vx_W_FW = Cross.vx(W_v_W_FW) - O_Ẋ_W = FW_Ẋ_W = -FW_X_W @ W_vx_W_FW # noqa: F841 - - case _: - raise ValueError(output_vel_repr) + def to_inertial() -> tuple[jtp.Matrix, jtp.Matrix]: + W_X_W = Adjoint.from_transform(transform=jnp.eye(4)) + W_Ẋ_W = jnp.zeros((6, 6)) + + return W_X_W, W_Ẋ_W + + def to_body() -> tuple[jtp.Matrix, jtp.Matrix]: + W_H_F = transform(model=model, data=data, frame_index=frame_index) + F_X_W = Adjoint.from_transform(transform=W_H_F, inverse=True) + with data.switch_velocity_representation(VelRepr.Inertial): + W_nu = data.generalized_velocity() + W_v_WF = W_J_WL_W @ W_nu + W_vx_WF = Cross.vx(W_v_WF) + F_Ẋ_W = -F_X_W @ W_vx_WF + + return F_X_W, F_Ẋ_W + + def to_mixed() -> tuple[jtp.Matrix, jtp.Matrix]: + W_H_F = transform(model=model, data=data, frame_index=frame_index) + W_H_FW = W_H_F.at[0:3, 0:3].set(jnp.eye(3)) + FW_H_W = Transform.inverse(W_H_FW) + FW_X_W = Adjoint.from_transform(transform=FW_H_W) + with data.switch_velocity_representation(VelRepr.Mixed): + FW_J_WF_FW = jacobian( + model=model, + data=data, + frame_index=frame_index, + output_vel_repr=VelRepr.Mixed, + ) + FW_v_WF = FW_J_WF_FW @ data.generalized_velocity() + W_v_W_FW = jnp.zeros(6).at[0:3].set(FW_v_WF[0:3]) + W_vx_W_FW = Cross.vx(W_v_W_FW) + FW_Ẋ_W = -FW_X_W @ W_vx_W_FW + + return FW_X_W, FW_Ẋ_W + + O_X_W, O_Ẋ_W = jax.lax.switch( + index=output_vel_repr, + branches=( + to_body, # VelRepr.Body + to_mixed, # VelRepr.Mixed + to_inertial, # VelRepr.Inertial + ), + ) O_J̇_WF_I = jnp.zeros(shape=(6, 6 + model.dofs())) O_J̇_WF_I += O_Ẋ_W @ W_J_WL_W @ T diff --git a/src/jaxsim/api/link.py b/src/jaxsim/api/link.py index 75f2e89a0..9519f5c14 100644 --- a/src/jaxsim/api/link.py +++ b/src/jaxsim/api/link.py @@ -9,7 +9,7 @@ import jaxsim.rbda import jaxsim.typing as jtp from jaxsim import exceptions -from jaxsim.math import Adjoint +from jaxsim.math import Adjoint, Transform from .common import VelRepr @@ -235,13 +235,13 @@ def com_in_inertial_frame(): ) -@functools.partial(jax.jit, static_argnames=["output_vel_repr"]) +@jax.jit def jacobian( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, *, link_index: jtp.IntLike, - output_vel_repr: VelRepr | None = None, + output_vel_repr: jtp.VelRepr | None = None, ) -> jtp.Matrix: r""" Compute the free-floating jacobian of the link. @@ -284,64 +284,81 @@ def jacobian( B_J_WL_B = jnp.hstack([jnp.ones(5), κb]) * B_J_full_WX_B # Adjust the input representation such that `J_WL_I @ I_ν`. - match data.velocity_representation: - case VelRepr.Inertial: - W_H_B = data.base_transform() - B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True) - B_J_WL_I = B_J_WL_W = B_J_WL_B @ jax.scipy.linalg.block_diag( # noqa: F841 - B_X_W, jnp.eye(model.dofs()) - ) - - case VelRepr.Body: - B_J_WL_I = B_J_WL_B - - case VelRepr.Mixed: - W_R_B = data.base_orientation(dcm=True) - BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B) - B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) - B_J_WL_I = B_J_WL_BW = B_J_WL_B @ jax.scipy.linalg.block_diag( # noqa: F841 - B_X_BW, jnp.eye(model.dofs()) - ) - - case _: - raise ValueError(data.velocity_representation) + def to_inertial() -> jtp.Matrix: + W_H_B = data.base_transform() + B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True) + B_J_WL_W = B_J_WL_B @ jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs())) + + return B_J_WL_W + + def to_body() -> jtp.Matrix: + + return B_J_WL_B + + def to_mixed() -> jtp.Matrix: + W_R_B = data.base_orientation(dcm=True) + BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B) + B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) + B_J_WL_BW = B_J_WL_B @ jax.scipy.linalg.block_diag( + B_X_BW, jnp.eye(model.dofs()) + ) + + return B_J_WL_BW + + B_J_WL_I = jax.lax.switch( + index=data.velocity_representation, + branches=( + to_body, # VelRepr.Body + to_mixed, # VelRepr.Mixed + to_inertial, # VelRepr.Inertial + ), + ) B_H_L = B_H_Li[link_index] + def to_inertial() -> jtp.Matrix: + W_H_B = data.base_transform() + W_X_B = Adjoint.from_transform(transform=W_H_B) + W_J_WL_I = W_X_B @ B_J_WL_I + + return W_J_WL_I + + def to_body() -> jtp.Matrix: + L_X_B = Adjoint.from_transform(transform=B_H_L, inverse=True) + L_J_WL_I = L_X_B @ B_J_WL_I + + return L_J_WL_I + + def to_mixed() -> jtp.Matrix: + W_H_B = data.base_transform() + W_H_L = W_H_B @ B_H_L + LW_H_L = W_H_L.at[0:3, 3].set(jnp.zeros(3)) + LW_H_B = LW_H_L @ Transform.inverse(B_H_L) + LW_X_B = Adjoint.from_transform(transform=LW_H_B) + LW_J_WL_I = LW_X_B @ B_J_WL_I + + return LW_J_WL_I + # Adjust the output representation such that `O_v_WL_I = O_J_WL_I @ I_ν`. - match output_vel_repr: - case VelRepr.Inertial: - W_H_B = data.base_transform() - W_X_B = Adjoint.from_transform(transform=W_H_B) - O_J_WL_I = W_J_WL_I = W_X_B @ B_J_WL_I # noqa: F841 - - case VelRepr.Body: - L_X_B = Adjoint.from_transform(transform=B_H_L, inverse=True) - L_J_WL_I = L_X_B @ B_J_WL_I - O_J_WL_I = L_J_WL_I - - case VelRepr.Mixed: - W_H_B = data.base_transform() - W_H_L = W_H_B @ B_H_L - LW_H_L = W_H_L.at[0:3, 3].set(jnp.zeros(3)) - LW_H_B = LW_H_L @ jaxsim.math.Transform.inverse(B_H_L) - LW_X_B = Adjoint.from_transform(transform=LW_H_B) - LW_J_WL_I = LW_X_B @ B_J_WL_I - O_J_WL_I = LW_J_WL_I - - case _: - raise ValueError(output_vel_repr) + O_J_WL_I = jax.lax.switch( + index=output_vel_repr, + branches=( + to_body, # VelRepr.Body + to_mixed, # VelRepr.Mixed + to_inertial, # VelRepr.Inertial + ), + ) return O_J_WL_I -@functools.partial(jax.jit, static_argnames=["output_vel_repr"]) +@jax.jit def velocity( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, *, link_index: jtp.IntLike, - output_vel_repr: VelRepr | None = None, + output_vel_repr: jtp.VelRepr | None = None, ) -> jtp.Vector: """ Compute the 6D velocity of the link. @@ -385,13 +402,13 @@ def velocity( return O_J_WL_I @ I_ν -@functools.partial(jax.jit, static_argnames=["output_vel_repr"]) +@jax.jit def jacobian_derivative( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, *, link_index: jtp.IntLike, - output_vel_repr: VelRepr | None = None, + output_vel_repr: jtp.VelRepr | None = None, ) -> jtp.Matrix: r""" Compute the derivative of the free-floating jacobian of the link. @@ -442,116 +459,128 @@ def jacobian_derivative( In = jnp.eye(model.dofs()) On = jnp.zeros(shape=(model.dofs(), model.dofs())) - match data.velocity_representation: + def from_inertial() -> tuple[jtp.Matrix, jtp.Matrix]: + + W_H_B = data.base_transform() + B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True) + + with data.switch_velocity_representation(VelRepr.Inertial): + W_v_WB = data.base_velocity() + B_Ẋ_W = -B_X_W @ jaxsim.math.Cross.vx(W_v_WB) - case VelRepr.Inertial: + # Compute the operator to change the representation of ν, and its + # time derivative. + T = jax.scipy.linalg.block_diag(B_X_W, In) + Ṫ = jax.scipy.linalg.block_diag(B_Ẋ_W, On) - W_H_B = data.base_transform() - B_X_W = jaxsim.math.Adjoint.from_transform(transform=W_H_B, inverse=True) + return T, Ṫ - with data.switch_velocity_representation(VelRepr.Inertial): - W_v_WB = data.base_velocity() - B_Ẋ_W = -B_X_W @ jaxsim.math.Cross.vx(W_v_WB) + def from_body() -> tuple[jtp.Matrix, jtp.Matrix]: - # Compute the operator to change the representation of ν, and its - # time derivative. - T = jax.scipy.linalg.block_diag(B_X_W, In) - Ṫ = jax.scipy.linalg.block_diag(B_Ẋ_W, On) + B_X_B = Adjoint.from_rotation_and_translation( + translation=jnp.zeros(3), rotation=jnp.eye(3) + ) - case VelRepr.Body: + B_Ẋ_B = jnp.zeros(shape=(6, 6)) - B_X_B = jaxsim.math.Adjoint.from_rotation_and_translation( - translation=jnp.zeros(3), rotation=jnp.eye(3) - ) + # Compute the operator to change the representation of ν, and its + # time derivative. + T = jax.scipy.linalg.block_diag(B_X_B, In) + Ṫ = jax.scipy.linalg.block_diag(B_Ẋ_B, On) - B_Ẋ_B = jnp.zeros(shape=(6, 6)) + return T, Ṫ - # Compute the operator to change the representation of ν, and its - # time derivative. - T = jax.scipy.linalg.block_diag(B_X_B, In) - Ṫ = jax.scipy.linalg.block_diag(B_Ẋ_B, On) + def from_mixed() -> tuple[jtp.Matrix, jtp.Matrix]: - case VelRepr.Mixed: + BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3)) + B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) - BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3)) - B_X_BW = jaxsim.math.Adjoint.from_transform(transform=BW_H_B, inverse=True) + with data.switch_velocity_representation(VelRepr.Mixed): + BW_v_WB = data.base_velocity() + BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3)) - with data.switch_velocity_representation(VelRepr.Mixed): - BW_v_WB = data.base_velocity() - BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3)) + BW_v_BW_B = BW_v_WB - BW_v_W_BW + B_Ẋ_BW = -B_X_BW @ jaxsim.math.Cross.vx(BW_v_BW_B) - BW_v_BW_B = BW_v_WB - BW_v_W_BW - B_Ẋ_BW = -B_X_BW @ jaxsim.math.Cross.vx(BW_v_BW_B) + # Compute the operator to change the representation of ν, and its + # time derivative. + T = jax.scipy.linalg.block_diag(B_X_BW, In) + Ṫ = jax.scipy.linalg.block_diag(B_Ẋ_BW, On) - # Compute the operator to change the representation of ν, and its - # time derivative. - T = jax.scipy.linalg.block_diag(B_X_BW, In) - Ṫ = jax.scipy.linalg.block_diag(B_Ẋ_BW, On) + return T, Ṫ - case _: - raise ValueError(data.velocity_representation) + T, Ṫ = jax.lax.switch( + index=data.velocity_representation, + branches=( + from_body, # VelRepr.Body + from_mixed, # VelRepr.Mixed + from_inertial, # VelRepr.Inertial + ), + ) # ====================================================== # Compute quantities to adjust the output representation # ====================================================== - match output_vel_repr: + def to_inertial() -> tuple[jtp.Matrix, jtp.Matrix]: + + W_H_B = data.base_transform() + W_X_B = Adjoint.from_transform(transform=W_H_B) - case VelRepr.Inertial: + with data.switch_velocity_representation(VelRepr.Body): + B_v_WB = data.base_velocity() - W_H_B = data.base_transform() - O_X_B = W_X_B = jaxsim.math.Adjoint.from_transform(transform=W_H_B) + W_Ẋ_B = W_X_B @ jaxsim.math.Cross.vx(B_v_WB) - with data.switch_velocity_representation(VelRepr.Body): - B_v_WB = data.base_velocity() + return W_X_B, W_Ẋ_B - O_Ẋ_B = W_Ẋ_B = W_X_B @ jaxsim.math.Cross.vx(B_v_WB) # noqa: F841 + def to_body() -> tuple[jtp.Matrix, jtp.Matrix]: - case VelRepr.Body: + L_X_B = Adjoint.from_transform(transform=B_H_L[link_index, :, :], inverse=True) - O_X_B = L_X_B = jaxsim.math.Adjoint.from_transform( - transform=B_H_L[link_index, :, :], inverse=True - ) + B_X_L = Adjoint.inverse(adjoint=L_X_B) - B_X_L = jaxsim.math.Adjoint.inverse(adjoint=L_X_B) + with data.switch_velocity_representation(VelRepr.Body): + B_v_WB = data.base_velocity() + L_v_WL = js.link.velocity(model=model, data=data, link_index=link_index) - with data.switch_velocity_representation(VelRepr.Body): - B_v_WB = data.base_velocity() - L_v_WL = js.link.velocity(model=model, data=data, link_index=link_index) + L_Ẋ_B = -L_X_B @ jaxsim.math.Cross.vx(B_X_L @ L_v_WL - B_v_WB) - O_Ẋ_B = L_Ẋ_B = -L_X_B @ jaxsim.math.Cross.vx( # noqa: F841 - B_X_L @ L_v_WL - B_v_WB - ) + return L_X_B, L_Ẋ_B - case VelRepr.Mixed: + def to_mixed() -> tuple[jtp.Matrix, jtp.Matrix]: - W_H_B = data.base_transform() - W_H_L = W_H_B @ B_H_L[link_index, :, :] - LW_H_L = W_H_L.at[0:3, 3].set(jnp.zeros(3)) - LW_H_B = LW_H_L @ jaxsim.math.Transform.inverse(B_H_L[link_index, :, :]) + W_H_B = data.base_transform() + W_H_L = W_H_B @ B_H_L[link_index, :, :] + LW_H_L = W_H_L.at[0:3, 3].set(jnp.zeros(3)) + LW_H_B = LW_H_L @ Transform.inverse(B_H_L[link_index, :, :]) - O_X_B = LW_X_B = jaxsim.math.Adjoint.from_transform(transform=LW_H_B) + LW_X_B = Adjoint.from_transform(transform=LW_H_B) - B_X_LW = jaxsim.math.Adjoint.inverse(adjoint=LW_X_B) + B_X_LW = Adjoint.inverse(adjoint=LW_X_B) - with data.switch_velocity_representation(VelRepr.Body): - B_v_WB = data.base_velocity() + with data.switch_velocity_representation(VelRepr.Body): + B_v_WB = data.base_velocity() - with data.switch_velocity_representation(VelRepr.Mixed): - LW_v_WL = js.link.velocity( - model=model, data=data, link_index=link_index - ) - LW_v_W_LW = LW_v_WL.at[3:6].set(jnp.zeros(3)) + with data.switch_velocity_representation(VelRepr.Mixed): + LW_v_WL = js.link.velocity(model=model, data=data, link_index=link_index) + LW_v_W_LW = LW_v_WL.at[3:6].set(jnp.zeros(3)) - LW_v_LW_L = LW_v_WL - LW_v_W_LW - LW_v_B_LW = LW_v_WL - LW_X_B @ B_v_WB - LW_v_LW_L + LW_v_LW_L = LW_v_WL - LW_v_W_LW + LW_v_B_LW = LW_v_WL - LW_X_B @ B_v_WB - LW_v_LW_L - O_Ẋ_B = LW_Ẋ_B = -LW_X_B @ jaxsim.math.Cross.vx( # noqa: F841 - B_X_LW @ LW_v_B_LW - ) - case _: - raise ValueError(output_vel_repr) + LW_Ẋ_B = -LW_X_B @ jaxsim.math.Cross.vx(B_X_LW @ LW_v_B_LW) + return LW_X_B, LW_Ẋ_B + + O_X_B, O_Ẋ_B = jax.lax.switch( + index=output_vel_repr, + branches=( + to_body, # VelRepr.Body + to_mixed, # VelRepr.Mixed + to_inertial, # VelRepr.Inertial + ), + ) # ============================================================= # Express the Jacobian derivative in the target representations # ============================================================= diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 84e2b8229..b8f9f5701 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -448,12 +448,12 @@ def forward_kinematics(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp return jnp.atleast_3d(W_H_LL).astype(float) -@functools.partial(jax.jit, static_argnames=["output_vel_repr"]) +@jax.jit def generalized_free_floating_jacobian( model: JaxSimModel, data: js.data.JaxSimModelData, *, - output_vel_repr: VelRepr | None = None, + output_vel_repr: jtp.VelRepr | None = None, ) -> jtp.Matrix: """ Compute the free-floating jacobians of all links. @@ -488,35 +488,40 @@ def generalized_free_floating_jacobian( # Update the input velocity representation such that v_WL = J_WL_I @ I_ν # ====================================================================== - match data.velocity_representation: + def from_inertial() -> jtp.Matrix: + W_H_B = data.base_transform() + B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True) - case VelRepr.Inertial: + B_J_full_WX_W = B_J_full_WX_B @ jax.scipy.linalg.block_diag( + B_X_W, jnp.eye(model.dofs()) + ) - W_H_B = data.base_transform() - B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True) + return B_J_full_WX_W - B_J_full_WX_I = B_J_full_WX_W = ( # noqa: F841 - B_J_full_WX_B - @ jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs())) - ) + def from_body() -> jtp.Matrix: - case VelRepr.Body: + return B_J_full_WX_B - B_J_full_WX_I = B_J_full_WX_B + def from_mixed() -> jtp.Matrix: + W_R_B = data.base_orientation(dcm=True) + BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B) + B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) - case VelRepr.Mixed: - - W_R_B = data.base_orientation(dcm=True) - BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B) - B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) + B_J_full_WX_BW = B_J_full_WX_B @ jax.scipy.linalg.block_diag( + B_X_BW, jnp.eye(model.dofs()) + ) - B_J_full_WX_I = B_J_full_WX_BW = ( # noqa: F841 - B_J_full_WX_B - @ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs())) - ) + return B_J_full_WX_BW - case _: - raise ValueError(data.velocity_representation) + # Update the input velocity representation such that `J_WL_I @ I_ν`. + B_J_full_WX_I = jax.lax.switch( + index=data.velocity_representation, + branches=( + from_body, # VelRepr.Body + from_mixed, # VelRepr.Mixed + from_inertial, # VelRepr.Inertial + ), + ) # ==================================================================== # Create stacked Jacobian for each link by filtering the full Jacobian @@ -536,45 +541,48 @@ def generalized_free_floating_jacobian( # Update the output velocity representation such that O_v_WL = O_J_WL @ ν # ======================================================================= - match output_vel_repr: - - case VelRepr.Inertial: - - W_H_B = data.base_transform() - W_X_B = jaxsim.math.Adjoint.from_transform(W_H_B) - - O_J_WL_I = W_J_WL_I = jax.vmap( # noqa: F841 - lambda B_J_WL_I: W_X_B @ B_J_WL_I - )(B_J_WL_I) - - case VelRepr.Body: + def to_inertial() -> jtp.Matrix: + W_H_B = data.base_transform() + W_X_B = jaxsim.math.Adjoint.from_transform(W_H_B) - O_J_WL_I = L_J_WL_I = jax.vmap( # noqa: F841 - lambda B_H_L, B_J_WL_I: jaxsim.math.Adjoint.from_transform( - B_H_L, inverse=True - ) - @ B_J_WL_I - )(B_H_L, B_J_WL_I) + W_J_WL_I = jax.vmap(lambda B_J_WL_I: W_X_B @ B_J_WL_I)(B_J_WL_I) - case VelRepr.Mixed: + return W_J_WL_I - W_H_B = data.base_transform() - - LW_H_L = jax.vmap( - lambda B_H_L: (W_H_B @ B_H_L).at[0:3, 3].set(jnp.zeros(3)) - )(B_H_L) - - LW_H_B = jax.vmap( - lambda LW_H_L, B_H_L: LW_H_L @ jaxsim.math.Transform.inverse(B_H_L) - )(LW_H_L, B_H_L) + def to_body() -> jtp.Matrix: + L_J_WL_I = jax.vmap( + lambda B_H_L, B_J_WL_I: jaxsim.math.Adjoint.from_transform( + B_H_L, inverse=True + ) + @ B_J_WL_I + )(B_H_L, B_J_WL_I) - O_J_WL_I = LW_J_WL_I = jax.vmap( # noqa: F841 - lambda LW_H_B, B_J_WL_I: jaxsim.math.Adjoint.from_transform(LW_H_B) - @ B_J_WL_I - )(LW_H_B, B_J_WL_I) + return L_J_WL_I - case _: - raise ValueError(output_vel_repr) + def to_mixed() -> jtp.Matrix: + W_H_B = data.base_transform() + LW_H_L = jax.vmap(lambda B_H_L: (W_H_B @ B_H_L).at[0:3, 3].set(jnp.zeros(3)))( + B_H_L + ) + LW_H_B = jax.vmap( + lambda LW_H_L, B_H_L: LW_H_L @ jaxsim.math.Transform.inverse(B_H_L) + )(LW_H_L, B_H_L) + + LW_J_WL_I = jax.vmap( + lambda LW_H_B, B_J_WL_I: jaxsim.math.Adjoint.from_transform(LW_H_B) + @ B_J_WL_I + )(LW_H_B, B_J_WL_I) + + return LW_J_WL_I + + O_J_WL_I = jax.lax.switch( + index=output_vel_repr, + branches=( + to_body, # VelRepr.Body + to_mixed, # VelRepr.Mixed + to_inertial, # VelRepr.Inertial + ), + ) return O_J_WL_I @@ -754,28 +762,38 @@ def to_active( # In Mixed representation, we need to include a cross product in ℝ⁶. # In Inertial and Body representations, the cross product is always zero. C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True) - return C_X_W @ (W_v̇_WB - Cross.vx(W_v_WC) @ W_v_WB) - - match data.velocity_representation: - case VelRepr.Inertial: - # In this case C=W - W_H_C = W_H_W = jnp.eye(4) # noqa: F841 - W_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841 - - case VelRepr.Body: - # In this case C=B - W_H_C = W_H_B = data.base_transform() - W_v_WC = W_v_WB - - case VelRepr.Mixed: - # In this case C=B[W] - W_H_B = data.base_transform() - W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) # noqa: F841 - W_ṗ_B = data.base_velocity()[0:3] - W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) # noqa: F841 + return C_X_W @ W_v̇_WB - Cross.vx(W_v_WC) @ W_v_WB + + def to_inertial() -> tuple[jtp.Vector, jtp.Matrix]: + # In this case C=W + W_H_W = jnp.eye(4) + W_v_WW = jnp.zeros(6) + + return W_v_WW, W_H_W + + def to_body() -> tuple[jtp.Vector, jtp.Matrix]: + # In this case C=B + W_H_B = data.base_transform() + + return W_v_WB, W_H_B + + def to_mixed() -> tuple[jtp.Vector, jtp.Matrix]: + # In this case C=B[W] + W_H_B = data.base_transform() + W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) + W_ṗ_B = data.base_velocity()[0:3] + W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) + + return W_v_W_BW, W_H_BW - case _: - raise ValueError(data.velocity_representation) + W_v_WC, W_H_C = jax.lax.switch( + index=data.velocity_representation, + branches=( + to_body, # VelRepr.Body + to_mixed, # VelRepr.Mixed + to_inertial, # VelRepr.Inertial + ), + ) # We need to convert the derivative of the base velocity to the active # representation. In Mixed representation, this conversion is not a plain @@ -912,29 +930,32 @@ def free_floating_mass_matrix( joint_positions=data.state.physics_model.joint_positions, ) - match data.velocity_representation: - case VelRepr.Body: - return M_body + def to_body() -> jtp.Matrix: + return M_body - case VelRepr.Inertial: + def to_inertial() -> jtp.Matrix: - B_X_W = Adjoint.from_transform( - transform=data.base_transform(), inverse=True - ) - invT = jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs())) + B_X_W = Adjoint.from_transform(transform=data.base_transform(), inverse=True) + invT = jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs())) - return invT.T @ M_body @ invT + return invT.T @ M_body @ invT - case VelRepr.Mixed: + def to_mixed() -> jtp.Matrix: - BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3)) - B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) - invT = jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs())) + BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3)) + B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) + invT = jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs())) - return invT.T @ M_body @ invT + return invT.T @ M_body @ invT - case _: - raise ValueError(data.velocity_representation) + return jax.lax.switch( + index=data.velocity_representation, + branches=( + to_body, # VelRepr.Body + to_mixed, # VelRepr.Mixed + to_inertial, # VelRepr.Inertial + ), + ) @jax.jit @@ -1005,56 +1026,58 @@ def compute_link_contribution(M, v, J, J̇) -> jtp.Array: # Adjust the representation of the Coriolis matrix. # Refer to https://github.com/traversaro/traversaro-phd-thesis, Section 3.6. - match data.velocity_representation: + def to_body() -> jtp.Matrix: + return C_B - case VelRepr.Body: - return C_B + def to_inertial() -> jtp.Matrix: + n = model.dofs() + W_H_B = data.base_transform() + B_X_W = jaxsim.math.Adjoint.from_transform(W_H_B, inverse=True) + B_T_W = jax.scipy.linalg.block_diag(B_X_W, jnp.eye(n)) - case VelRepr.Inertial: + with data.switch_velocity_representation(VelRepr.Inertial): + W_v_WB = data.base_velocity() + B_Ẋ_W = -B_X_W @ jaxsim.math.Cross.vx(W_v_WB) - n = model.dofs() - W_H_B = data.base_transform() - B_X_W = jaxsim.math.Adjoint.from_transform(W_H_B, inverse=True) - B_T_W = jax.scipy.linalg.block_diag(B_X_W, jnp.eye(n)) + B_Ṫ_W = jax.scipy.linalg.block_diag(B_Ẋ_W, jnp.zeros(shape=(n, n))) - with data.switch_velocity_representation(VelRepr.Inertial): - W_v_WB = data.base_velocity() - B_Ẋ_W = -B_X_W @ jaxsim.math.Cross.vx(W_v_WB) + with data.switch_velocity_representation(VelRepr.Body): + M = free_floating_mass_matrix(model=model, data=data) - B_Ṫ_W = jax.scipy.linalg.block_diag(B_Ẋ_W, jnp.zeros(shape=(n, n))) + C = B_T_W.T @ (M @ B_Ṫ_W + C_B @ B_T_W) - with data.switch_velocity_representation(VelRepr.Body): - M = free_floating_mass_matrix(model=model, data=data) + return C - C = B_T_W.T @ (M @ B_Ṫ_W + C_B @ B_T_W) + def to_mixed() -> jtp.Matrix: + n = model.dofs() + BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3)) + B_X_BW = jaxsim.math.Adjoint.from_transform(transform=BW_H_B, inverse=True) + B_T_BW = jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(n)) - return C + with data.switch_velocity_representation(VelRepr.Mixed): + BW_v_WB = data.base_velocity() + BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3)) - case VelRepr.Mixed: + BW_v_BW_B = BW_v_WB - BW_v_W_BW + B_Ẋ_BW = -B_X_BW @ jaxsim.math.Cross.vx(BW_v_BW_B) - n = model.dofs() - BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3)) - B_X_BW = jaxsim.math.Adjoint.from_transform(transform=BW_H_B, inverse=True) - B_T_BW = jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(n)) + B_Ṫ_BW = jax.scipy.linalg.block_diag(B_Ẋ_BW, jnp.zeros(shape=(n, n))) - with data.switch_velocity_representation(VelRepr.Mixed): - BW_v_WB = data.base_velocity() - BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3)) + with data.switch_velocity_representation(VelRepr.Body): + M = free_floating_mass_matrix(model=model, data=data) - BW_v_BW_B = BW_v_WB - BW_v_W_BW - B_Ẋ_BW = -B_X_BW @ jaxsim.math.Cross.vx(BW_v_BW_B) + C = B_T_BW.T @ (M @ B_Ṫ_BW + C_B @ B_T_BW) - B_Ṫ_BW = jax.scipy.linalg.block_diag(B_Ẋ_BW, jnp.zeros(shape=(n, n))) + return C - with data.switch_velocity_representation(VelRepr.Body): - M = free_floating_mass_matrix(model=model, data=data) - - C = B_T_BW.T @ (M @ B_Ṫ_BW + C_B @ B_T_BW) - - return C - - case _: - raise ValueError(data.velocity_representation) + return jax.lax.switch( + index=data.velocity_representation, + branches=( + to_body, # VelRepr.Body + to_mixed, # VelRepr.Mixed + to_inertial, # VelRepr.Inertial + ), + ) @jax.jit @@ -1125,25 +1148,35 @@ def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC): # In Inertial and Body representations, the cross product is always zero. return W_X_C @ (C_v̇_WB + Cross.vx(C_v_WC) @ C_v_WB) - match data.velocity_representation: - case VelRepr.Inertial: - W_H_C = W_H_W = jnp.eye(4) # noqa: F841 - W_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841 + def convert_inertial() -> jtp.Vector: + W_H_W = jnp.eye(4) + W_v_WW = jnp.zeros(6) - case VelRepr.Body: - W_H_C = W_H_B = data.base_transform() - with data.switch_velocity_representation(VelRepr.Inertial): - W_v_WC = W_v_WB = data.base_velocity() + return W_H_W, W_v_WW - case VelRepr.Mixed: - W_H_B = data.base_transform() - W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) # noqa: F841 - W_ṗ_B = data.base_velocity()[0:3] - W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) # noqa: F841 + def convert_body() -> jtp.Vector: + W_H_B = data.base_transform() + with data.switch_velocity_representation(VelRepr.Inertial): + W_v_WB = data.base_velocity() + + return W_H_B, W_v_WB - case _: - raise ValueError(data.velocity_representation) + def convert_mixed() -> jtp.Vector: + W_H_B = data.base_transform() + W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) + W_ṗ_B = data.base_velocity()[0:3] + W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) + return W_H_BW, W_v_W_BW + + W_H_C, W_v_WC = jax.lax.switch( + index=data.velocity_representation, + branches=( + convert_body, # VelRepr.Body + convert_mixed, # VelRepr.Mixed + convert_inertial, # VelRepr.Inertial + ), + ) # We need to convert the derivative of the base acceleration to the Inertial # representation. In Mixed representation, this conversion is not a plain # transformation with just X, but it also involves a cross product in ℝ⁶. @@ -1367,12 +1400,12 @@ def total_momentum(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Vec return Jh @ ν -@functools.partial(jax.jit, static_argnames=["output_vel_repr"]) +@jax.jit def total_momentum_jacobian( model: JaxSimModel, data: js.data.JaxSimModelData, *, - output_vel_repr: VelRepr | None = None, + output_vel_repr: jtp.VelRepr | None = None, ) -> jtp.Matrix: """ Compute the jacobian of the total momentum. @@ -1396,44 +1429,55 @@ def total_momentum_jacobian( with data.switch_velocity_representation(VelRepr.Body): B_Jh_B = free_floating_mass_matrix(model=model, data=data)[0:6] - match data.velocity_representation: - case VelRepr.Body: - B_Jh = B_Jh_B + def to_body() -> jtp.Matrix: - case VelRepr.Inertial: - B_X_W = Adjoint.from_transform( - transform=data.base_transform(), inverse=True - ) - B_Jh = B_Jh_B @ jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs())) + return B_Jh_B - case VelRepr.Mixed: - BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3)) - B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) - B_Jh = B_Jh_B @ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs())) + def to_inertial() -> jtp.Matrix: + B_X_W = Adjoint.from_transform(transform=data.base_transform(), inverse=True) - case _: - raise ValueError(data.velocity_representation) + return B_Jh_B @ jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs())) - match output_vel_repr: - case VelRepr.Body: - return B_Jh + def to_mixed() -> jtp.Matrix: + BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3)) + B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) - case VelRepr.Inertial: - W_H_B = data.base_transform() - B_Xv_W = Adjoint.from_transform(transform=W_H_B, inverse=True) - W_Xf_B = B_Xv_W.T - W_Jh = W_Xf_B @ B_Jh - return W_Jh + return B_Jh_B @ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs())) - case VelRepr.Mixed: - BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3)) - B_Xv_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) - BW_Xf_B = B_Xv_BW.T - BW_Jh = BW_Xf_B @ B_Jh - return BW_Jh + B_Jh = jax.lax.switch( + index=data.velocity_representation, + branches=( + to_body, # VelRepr.Body + to_mixed, # VelRepr.Mixed + to_inertial, # VelRepr.Inertial + ), + ) - case _: - raise ValueError(output_vel_repr) + def to_body() -> jtp.Matrix: + return B_Jh + + def to_inertial() -> jtp.Matrix: + W_H_B = data.base_transform() + B_Xv_W = Adjoint.from_transform(transform=W_H_B, inverse=True) + W_Xf_B = B_Xv_W.T + W_Jh = W_Xf_B @ B_Jh + return W_Jh + + def to_mixed() -> jtp.Matrix: + BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3)) + B_Xv_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) + BW_Xf_B = B_Xv_BW.T + BW_Jh = BW_Xf_B @ B_Jh + return BW_Jh + + return jax.lax.switch( + index=output_vel_repr, + branches=( + to_body, # VelRepr.Body + to_mixed, # VelRepr.Mixed + to_inertial, # VelRepr.Inertial + ), + ) @jax.jit @@ -1456,12 +1500,12 @@ def average_velocity(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.V return J @ ν -@functools.partial(jax.jit, static_argnames=["output_vel_repr"]) +@jax.jit def average_velocity_jacobian( model: JaxSimModel, data: js.data.JaxSimModelData, *, - output_vel_repr: VelRepr | None = None, + output_vel_repr: jtp.VelRepr | None = None, ) -> jtp.Matrix: """ Compute the Jacobian of the average velocity of the model. @@ -1483,40 +1527,47 @@ def average_velocity_jacobian( # Depending on the velocity representation, the frame G is either G[W] or G[B]. G_J = js.com.average_centroidal_velocity_jacobian(model=model, data=data) - match output_vel_repr: + def to_inertial() -> jtp.Matrix: - case VelRepr.Inertial: + GW_J = G_J + W_p_CoM = js.com.com_position(model=model, data=data) - GW_J = G_J - W_p_CoM = js.com.com_position(model=model, data=data) + W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM) + W_X_GW = Adjoint.from_transform(transform=W_H_GW) - W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM) - W_X_GW = Adjoint.from_transform(transform=W_H_GW) + return W_X_GW @ GW_J - return W_X_GW @ GW_J + def to_body() -> jtp.Matrix: - case VelRepr.Body: + GB_J = G_J + W_p_B = data.base_position() + W_p_CoM = js.com.com_position(model=model, data=data) + B_R_W = data.base_orientation(dcm=True).transpose() - GB_J = G_J - W_p_B = data.base_position() - W_p_CoM = js.com.com_position(model=model, data=data) - B_R_W = data.base_orientation(dcm=True).transpose() + B_H_GB = jnp.eye(4).at[0:3, 3].set(B_R_W @ (W_p_CoM - W_p_B)) + B_X_GB = Adjoint.from_transform(transform=B_H_GB) - B_H_GB = jnp.eye(4).at[0:3, 3].set(B_R_W @ (W_p_CoM - W_p_B)) - B_X_GB = Adjoint.from_transform(transform=B_H_GB) + return B_X_GB @ GB_J - return B_X_GB @ GB_J + def to_mixed() -> jtp.Matrix: - case VelRepr.Mixed: + GW_J = G_J + W_p_B = data.base_position() + W_p_CoM = js.com.com_position(model=model, data=data) - GW_J = G_J - W_p_B = data.base_position() - W_p_CoM = js.com.com_position(model=model, data=data) + BW_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM - W_p_B) + BW_X_GW = Adjoint.from_transform(transform=BW_H_GW) - BW_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM - W_p_B) - BW_X_GW = Adjoint.from_transform(transform=BW_H_GW) + return BW_X_GW @ GW_J - return BW_X_GW @ GW_J + return jax.lax.switch( + index=output_vel_repr, + branches=( + to_body, # VelRepr.Body + to_mixed, # VelRepr.Mixed + to_inertial, # VelRepr.Inertial + ), + ) # ======================== @@ -1572,33 +1623,40 @@ def other_representation_to_inertial( # because the apparent acceleration W_v̇_WB is equal to the intrinsic acceleration # W_a_WB, and intrinsic accelerations can be expressed in different frames through # a simple C_X_W 6D transform. - match data.velocity_representation: - case VelRepr.Inertial: - W_H_C = W_H_W = jnp.eye(4) # noqa: F841 - W_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841 - with data.switch_velocity_representation(VelRepr.Inertial): - C_v_WB = W_v_WB = data.base_velocity() - - case VelRepr.Body: - W_H_C = W_H_B - with data.switch_velocity_representation(VelRepr.Inertial): - W_v_WC = W_v_WB = data.base_velocity() # noqa: F841 - with data.switch_velocity_representation(VelRepr.Body): - C_v_WB = B_v_WB = data.base_velocity() - - case VelRepr.Mixed: - W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) - W_H_C = W_H_BW - with data.switch_velocity_representation(VelRepr.Mixed): - W_ṗ_B = data.base_velocity()[0:3] - BW_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) - W_X_BW = jaxsim.math.Adjoint.from_transform(transform=W_H_BW) - W_v_WC = W_v_W_BW = W_X_BW @ BW_v_W_BW # noqa: F841 - with data.switch_velocity_representation(VelRepr.Mixed): - C_v_WB = BW_v_WB = data.base_velocity() # noqa: F841 - - case _: - raise ValueError(data.velocity_representation) + def to_inertial() -> jtp.Matrix: + W_H_W = jnp.eye(4) + W_v_WW = jnp.zeros(6) + with data.switch_velocity_representation(VelRepr.Inertial): + W_v_WB = data.base_velocity() + + return W_H_W, W_v_WW, W_v_WB + + def to_body() -> jtp.Matrix: + with data.switch_velocity_representation(VelRepr.Inertial): + W_v_WB = data.base_velocity() + with data.switch_velocity_representation(VelRepr.Body): + B_v_WB = data.base_velocity() + + return W_H_B, W_v_WB, B_v_WB + + def to_mixed() -> jtp.Matrix: + W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) + with data.switch_velocity_representation(VelRepr.Mixed): + W_ṗ_B = data.base_velocity()[0:3] + W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) + with data.switch_velocity_representation(VelRepr.Mixed): + BW_v_WB = data.base_velocity() + + return W_H_BW, W_v_W_BW, BW_v_WB + + W_H_C, W_v_WC, C_v_WB = jax.lax.switch( + index=data.velocity_representation, + branches=( + to_body, # VelRepr.Body + to_mixed, # VelRepr.Mixed + to_inertial, # VelRepr.Inertial + ), + ) # Convert a zero 6D acceleration from the active representation to inertial-fixed. W_v̇_WB = other_representation_to_inertial( @@ -1701,29 +1759,32 @@ def body_to_other_representation( C_X_L = jaxsim.math.Adjoint.from_transform(transform=C_H_L) return C_X_L @ (L_v̇_WL + jaxsim.math.Cross.vx(L_v_CL) @ L_v_WL) - match data.velocity_representation: - case VelRepr.Body: - C_H_L = L_H_L = jnp.stack( # noqa: F841 - [jnp.eye(4)] * model.number_of_links() - ) - L_v_CL = L_v_LL = jnp.zeros( # noqa: F841 - shape=(model.number_of_links(), 6) - ) + def to_body() -> tuple[jtp.Matrix, jtp.Vector]: + L_H_L = jnp.stack([jnp.eye(4)] * model.number_of_links()) + L_v_LL = jnp.zeros(shape=(model.number_of_links(), 6)) + + return L_H_L, L_v_LL - case VelRepr.Inertial: - C_H_L = W_H_L = js.model.forward_kinematics(model=model, data=data) - L_v_CL = L_v_WL + def to_inertial() -> tuple[jtp.Matrix, jtp.Vector]: + W_H_L = js.model.forward_kinematics(model=model, data=data) - case VelRepr.Mixed: - W_H_L = js.model.forward_kinematics(model=model, data=data) - LW_H_L = jax.vmap(lambda W_H_L: W_H_L.at[0:3, 3].set(jnp.zeros(3)))(W_H_L) - C_H_L = LW_H_L - L_v_CL = L_v_LW_L = jax.vmap( # noqa: F841 - lambda v: v.at[0:3].set(jnp.zeros(3)) - )(L_v_WL) + return W_H_L, L_v_WL - case _: - raise ValueError(data.velocity_representation) + def to_mixed() -> tuple[jtp.Matrix, jtp.Vector]: + W_H_L = js.model.forward_kinematics(model=model, data=data) + LW_H_L = jax.vmap(lambda W_H_L: W_H_L.at[0:3, 3].set(jnp.zeros(3)))(W_H_L) + L_v_LW_L = jax.vmap(lambda v: v.at[0:3].set(jnp.zeros(3)))(L_v_WL) + + return LW_H_L, L_v_LW_L + + C_H_L, L_v_CL = jax.lax.switch( + index=data.velocity_representation, + branches=( + to_body, # VelRepr.Body + to_mixed, # VelRepr.Mixed + to_inertial, # VelRepr.Inertial + ), + ) # Convert from body-fixed to the active representation. O_v̇_WL = jax.vmap(body_to_other_representation)( diff --git a/src/jaxsim/api/references.py b/src/jaxsim/api/references.py index c34aa2240..1d80d994e 100644 --- a/src/jaxsim/api/references.py +++ b/src/jaxsim/api/references.py @@ -19,6 +19,8 @@ except ImportError: from typing_extensions import Self +from .data import JaxSimModelData + @jax_dataclasses.pytree_dataclass class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation): @@ -32,7 +34,7 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation): def zero( model: js.model.JaxSimModel, data: js.data.JaxSimModelData | None = None, - velocity_representation: VelRepr = VelRepr.Inertial, + velocity_representation: jtp.VelRepr = VelRepr.Inertial, ) -> JaxSimModelReferences: """ Create a `JaxSimModelReferences` object with zero references. @@ -58,7 +60,7 @@ def build( joint_force_references: jtp.Vector | None = None, link_forces: jtp.Matrix | None = None, data: js.data.JaxSimModelData | None = None, - velocity_representation: VelRepr | None = None, + velocity_representation: jtp.VelRepr | None = None, ) -> JaxSimModelReferences: """ Create a `JaxSimModelReferences` object with the given references. @@ -183,9 +185,11 @@ def link_forces( # Return all link forces in inertial-fixed representation using the implicit # serialization. if model is None: - if self.velocity_representation is not VelRepr.Inertial: - msg = "Missing model to use a representation different from {}" - raise ValueError(msg.format(VelRepr.Inertial.name)) + + exceptions.raise_value_error_if( + condition=jnp.not_equal(self.velocity_representation, VelRepr.Inertial), + msg="Missing model to use a representation different from `VelRepr.Inertial`", + ) if link_names is not None: raise ValueError("Link names cannot be provided without a model") @@ -196,35 +200,52 @@ def link_forces( link_names = link_names if link_names is not None else model.link_names() link_idxs = js.link.names_to_idxs(link_names=link_names, model=model) - # In inertial-fixed representation, we already have the link forces. - if self.velocity_representation is VelRepr.Inertial: - return W_f_L[link_idxs, :] - - if data is None: - msg = "Missing model data to use a representation different from {}" - raise ValueError(msg.format(VelRepr.Inertial.name)) - - if not_tracing(self.input.physics_model.f_ext) and not data.valid(model=model): - raise ValueError("The provided data is not valid for the model") - - # Helper function to convert a single 6D force to the active representation - # considering as body the link (i.e. L_f_L and LW_f_L). - def convert(W_f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike) -> jtp.Matrix: - - return jax.vmap( - lambda W_f_L, W_H_L: JaxSimModelReferences.inertial_to_other_representation( - array=W_f_L, - other_representation=self.velocity_representation, - transform=W_H_L, - is_force=True, - ) - )(W_f_L, W_H_L) + # If not inertial-fixed representation, we need the model data. + exceptions.raise_value_error_if( + condition=jnp.logical_and( + jnp.not_equal(self.velocity_representation, VelRepr.Inertial), + data is None, + ), + msg="Missing model data to use a representation different from `VelRepr.Inertial`", + ) # The f_L output is either L_f_L or LW_f_L, depending on the representation. - W_H_L = js.model.forward_kinematics(model=model, data=data) - f_L = convert(W_f_L=W_f_L[link_idxs, :], W_H_L=W_H_L[link_idxs, :, :]) + exceptions.raise_value_error_if( + condition=jnp.logical_and( + jnp.not_equal(self.velocity_representation, VelRepr.Inertial), + data is None, + ), + msg="Missing model data to use a representation different from `VelRepr.Inertial`", + ) + + def not_inertial(velocity_representation: jtp.VelRepr) -> jtp.Matrix: + # Helper function to convert a single 6D force to the active representation + # considering as body the link (i.e. L_f_L and LW_f_L). + def convert(W_f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike) -> jtp.Matrix: + + return jax.vmap( + lambda W_f_L, W_H_L: JaxSimModelReferences.inertial_to_other_representation( + array=W_f_L, + other_representation=velocity_representation, + transform=W_H_L, + is_force=True, + ) + )(W_f_L, W_H_L) + + W_H_L = js.model.forward_kinematics( + model=model, data=data or JaxSimModelData.zero(model=model) + ) + f_L = convert(W_f_L=W_f_L[link_idxs, :], W_H_L=W_H_L[link_idxs, :, :]) - return f_L + return f_L + + # In inertial-fixed representation, we already have the link forces. + return jax.lax.cond( + pred=jnp.equal(self.velocity_representation, VelRepr.Inertial), + true_fun=lambda _: W_f_L[link_idxs, :], + false_fun=not_inertial, + operand=self.velocity_representation, + ) def joint_force_references( self, @@ -370,14 +391,18 @@ def replace(forces: jtp.MatrixLike) -> JaxSimModelReferences: # In this case, we allow only to set the inertial 6D forces to all links # using the implicit link serialization. - if model is None: - if self.velocity_representation is not VelRepr.Inertial: - msg = "Missing model to use a representation different from {}" - raise ValueError(msg.format(VelRepr.Inertial.name)) + exceptions.raise_value_error_if( + condition=jnp.not_equal(self.velocity_representation, VelRepr.Inertial) + & (model is None), + msg="Missing model to use a representation different from `VelRepr.Inertial`", + ) - if link_names is not None: - raise ValueError("Link names cannot be provided without a model") + exceptions.raise_value_error_if( + condition=jnp.logical_and(link_names is not None, model is None), + msg="Link names cannot be provided without a model", + ) + if model is None: W_f_L = f_L W_f0_L = ( @@ -408,8 +433,16 @@ def replace(forces: jtp.MatrixLike) -> JaxSimModelReferences: else self.input.physics_model.f_ext[link_idxs, :] ) + exceptions.raise_value_error_if( + condition=jnp.logical_and( + jnp.not_equal(self.velocity_representation, VelRepr.Inertial), + data is None, + ), + msg="Missing model data to use a representation different from `VelRepr.Inertial`", + ) + # If inertial-fixed representation, we can directly store the link forces. - if self.velocity_representation is VelRepr.Inertial: + def inertial(velocity_representation: jtp.VelRepr) -> JaxSimModelReferences: W_f_L = f_L return replace( forces=self.input.physics_model.f_ext.at[link_idxs, :].set( @@ -417,34 +450,40 @@ def replace(forces: jtp.MatrixLike) -> JaxSimModelReferences: ) ) - if data is None: - msg = "Missing model data to use a representation different from {}" - raise ValueError(msg.format(VelRepr.Inertial.name)) + def not_inertial(velocity_representation: jtp.VelRepr) -> JaxSimModelReferences: + # Helper function to convert a single 6D force to the inertial representation + # considering as body the link (i.e. L_f_L and LW_f_L). + def convert_using_link_frame( + f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike + ) -> jtp.Matrix: + + return jax.vmap( + lambda f_L, W_H_L: JaxSimModelReferences.other_representation_to_inertial( + array=f_L, + other_representation=velocity_representation, + transform=W_H_L, + is_force=True, + ) + )(f_L, W_H_L) - if not_tracing(forces) and not data.valid(model=model): - raise ValueError("The provided data is not valid for the model") + W_H_L = js.model.forward_kinematics( + model=model, data=data or JaxSimModelData.zero(model=model) + ) - # Helper function to convert a single 6D force to the inertial representation - # considering as body the link (i.e. L_f_L and LW_f_L). - def convert_using_link_frame( - f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike - ) -> jtp.Matrix: - - return jax.vmap( - lambda f_L, W_H_L: JaxSimModelReferences.other_representation_to_inertial( - array=f_L, - other_representation=self.velocity_representation, - transform=W_H_L, - is_force=True, - ) - )(f_L, W_H_L) + # The f_L input is either L_f_L or LW_f_L, depending on the representation. + W_f_L = convert_using_link_frame(f_L=f_L, W_H_L=W_H_L[link_idxs, :, :]) - # The f_L input is either L_f_L or LW_f_L, depending on the representation. - W_H_L = js.model.forward_kinematics(model=model, data=data) - W_f_L = convert_using_link_frame(f_L=f_L, W_H_L=W_H_L[link_idxs, :, :]) + return replace( + forces=self.input.physics_model.f_ext.at[link_idxs, :].set( + W_f0_L + W_f_L + ) + ) - return replace( - forces=self.input.physics_model.f_ext.at[link_idxs, :].set(W_f0_L + W_f_L) + return jax.lax.cond( + pred=jnp.equal(self.velocity_representation, VelRepr.Inertial), + true_fun=inertial, + false_fun=not_inertial, + operand=self.velocity_representation, ) def apply_frame_forces( @@ -516,15 +555,13 @@ def to_inertial(f_F: jtp.MatrixLike, W_H_F: jtp.MatrixLike) -> jtp.Matrix: is_force=True, ) - match self.velocity_representation: - case VelRepr.Inertial: - W_f_F = f_F - - case VelRepr.Body | VelRepr.Mixed: - W_f_F = jax.vmap(to_inertial)(f_F, W_H_Fi) - - case _: - raise ValueError("Invalid velocity representation.") + W_f_F = jax.lax.switch( + index=self.velocity_representation, + branches=( + lambda: f_F, + lambda: jax.vmap(to_inertial)(f_F, W_H_Fi), + ), + ) # Sum the forces on the parent links. mask = parent_link_idxs[:, jnp.newaxis] == jnp.arange(model.number_of_links()) diff --git a/src/jaxsim/typing.py b/src/jaxsim/typing.py index 5a2c0a54e..1e1ed5da3 100644 --- a/src/jaxsim/typing.py +++ b/src/jaxsim/typing.py @@ -37,3 +37,5 @@ IntLike = int | Int | jax.typing.ArrayLike BoolLike = bool | Bool | jax.typing.ArrayLike FloatLike = float | Float | jax.typing.ArrayLike + +VelRepr = Int diff --git a/tests/test_api_com.py b/tests/test_api_com.py index 559689401..e019a53fd 100644 --- a/tests/test_api_com.py +++ b/tests/test_api_com.py @@ -2,6 +2,7 @@ import pytest import jaxsim.api as js +import jaxsim.typing as jtp from jaxsim import VelRepr from . import utils_idyntree @@ -9,7 +10,7 @@ def test_com_properties( jaxsim_models_types: js.model.JaxSimModel, - velocity_representation: VelRepr, + velocity_representation: jtp.VelRepr, prng_key: jax.Array, ): @@ -53,7 +54,11 @@ def test_com_properties( assert pytest.approx(v_avg_com_idt) == v_avg_com_js # https://github.com/ami-iit/jaxsim/pull/117#discussion_r1535486123 - if data.velocity_representation is not VelRepr.Body: + with data.switch_velocity_representation( + data.velocity_representation + if data.velocity_representation != VelRepr.Body + else VelRepr.Mixed + ): vl_com_idt = kin_dyn.com_velocity() vl_com_js = js.com.com_linear_velocity(model=model, data=data) assert pytest.approx(vl_com_idt) == vl_com_js @@ -61,7 +66,7 @@ def test_com_properties( # iDynTree provides the bias acceleration in G[W] frame regardless of the velocity # representation. JaxSim, instead, returns the bias acceleration in G[B] when the # active representation is VelRepr.Body. - if data.velocity_representation is not VelRepr.Body: + if data.velocity_representation != VelRepr.Body: G_v̇_bias_WG_idt = kin_dyn.com_bias_acceleration() G_v̇_bias_WG_js = js.com.bias_acceleration(model=model, data=data) assert pytest.approx(G_v̇_bias_WG_idt) == G_v̇_bias_WG_js diff --git a/tests/test_api_data.py b/tests/test_api_data.py index 9db5e6561..996356a8b 100644 --- a/tests/test_api_data.py +++ b/tests/test_api_data.py @@ -3,6 +3,7 @@ import pytest import jaxsim.api as js +import jaxsim.typing as jtp from jaxsim import VelRepr from jaxsim.utils import Mutability @@ -21,7 +22,7 @@ def test_data_valid( def test_data_joint_indexing( jaxsim_models_types: js.model.JaxSimModel, - velocity_representation: VelRepr, + velocity_representation: jtp.VelRepr, prng_key: jax.Array, ): diff --git a/tests/test_api_frame.py b/tests/test_api_frame.py index 41965ae14..820e8ce31 100644 --- a/tests/test_api_frame.py +++ b/tests/test_api_frame.py @@ -4,6 +4,7 @@ import pytest import jaxsim.api as js +import jaxsim.typing as jtp from jaxsim import VelRepr from jaxsim.math.quaternion import Quaternion @@ -124,7 +125,7 @@ def test_frame_transforms( def test_frame_jacobians( jaxsim_models_types: js.model.JaxSimModel, - velocity_representation: VelRepr, + velocity_representation: jtp.VelRepr, prng_key: jax.Array, ): diff --git a/tests/test_api_link.py b/tests/test_api_link.py index 418f78a99..74cba2def 100644 --- a/tests/test_api_link.py +++ b/tests/test_api_link.py @@ -5,6 +5,7 @@ import jaxsim.api as js import jaxsim.math +import jaxsim.typing as jtp from jaxsim import VelRepr from . import utils_idyntree @@ -128,7 +129,7 @@ def test_link_transforms( def test_link_jacobians( jaxsim_models_types: js.model.JaxSimModel, - velocity_representation: VelRepr, + velocity_representation: jtp.VelRepr, prng_key: jax.Array, ): @@ -197,7 +198,7 @@ def test_link_jacobians( def test_link_bias_acceleration( jaxsim_models_types: js.model.JaxSimModel, - velocity_representation: VelRepr, + velocity_representation: jtp.VelRepr, prng_key: jax.Array, ): @@ -298,7 +299,7 @@ def test_link_bias_acceleration( def test_link_jacobian_derivative( jaxsim_models_types: js.model.JaxSimModel, - velocity_representation: VelRepr, + velocity_representation: jtp.VelRepr, prng_key: jax.Array, ): diff --git a/tests/test_api_model.py b/tests/test_api_model.py index a0fe49db5..fe0643c7f 100644 --- a/tests/test_api_model.py +++ b/tests/test_api_model.py @@ -8,6 +8,7 @@ import jaxsim.api as js import jaxsim.math +import jaxsim.typing as jtp from jaxsim import VelRepr from . import utils_idyntree @@ -222,7 +223,7 @@ def test_model_creation_and_reduction( def test_model_properties( jaxsim_models_types: js.model.JaxSimModel, - velocity_representation: VelRepr, + velocity_representation: jtp.VelRepr, prng_key: jax.Array, ): @@ -269,7 +270,7 @@ def test_model_properties( def test_model_rbda( jaxsim_models_types: js.model.JaxSimModel, prng_key: jax.Array, - velocity_representation: VelRepr, + velocity_representation: jtp.VelRepr, ): model = jaxsim_models_types @@ -480,7 +481,7 @@ def compute_q̇(data: js.data.JaxSimModelData) -> jax.Array: def test_model_fd_id_consistency( jaxsim_models_types: js.model.JaxSimModel, - velocity_representation: VelRepr, + velocity_representation: jtp.VelRepr, prng_key: jax.Array, ): diff --git a/tests/test_api_references.py b/tests/test_api_references.py new file mode 100644 index 000000000..7bc35823e --- /dev/null +++ b/tests/test_api_references.py @@ -0,0 +1,141 @@ +import jax +import jax.numpy as jnp +import pytest +from jaxlib.xla_extension import XlaRuntimeError + +import jaxsim.api as js +import jaxsim.typing as jtp +from jaxsim import VelRepr + + +def get_random_references( + model: js.model.JaxSimModel | None = None, + data: js.data.JaxSimModelData | None = None, + *, + velocity_representation: jtp.VelRepr, + key: jax.Array, +) -> tuple[js.data.JaxSimModelData, js.references.JaxSimModelReferences]: + + _, subkey = jax.random.split(key, num=2) + + _, subkey1, subkey2 = jax.random.split(subkey, num=3) + + references = js.references.JaxSimModelReferences.build( + model=model, + joint_force_references=10 * jax.random.uniform(subkey1, shape=(model.dofs(),)), + link_forces=jax.random.uniform(subkey2, shape=(model.number_of_links(), 6)), + data=data, + velocity_representation=velocity_representation, + ) + + # Remove the force applied to the base link if the model is fixed-base. + if not model.floating_base(): + references = references.apply_link_forces( + forces=jnp.atleast_2d(jnp.zeros(6)), + model=model, + data=data, + link_names=(model.base_link(),), + additive=False, + ) + + return references + + +def test_raise_errors_link_forces( + jaxsim_model_box: js.model.JaxSimModel, + prng_key: jax.Array, +): + + model = jaxsim_model_box + + _, subkey1, subkey2 = jax.random.split(prng_key, num=3) + + data = js.data.random_model_data(model=model, key=subkey1) + + # ================ + # VelRepr.Inertial + # ================ + + references_inertial = get_random_references( + model=model, data=None, velocity_representation=VelRepr.Inertial, key=subkey2 + ) + + # `model` is None and `link_names` is not None. + with pytest.raises( + ValueError, match="Link names cannot be provided without a model" + ): + references_inertial.link_forces(model=None, link_names=model.link_names()) + + # ============ + # VelRepr.Body + # ============ + + references_body = get_random_references( + model=model, data=data, velocity_representation=VelRepr.Body, key=subkey2 + ) + + # `model` is None and `link_names` is not None. + with pytest.raises( + ValueError, match="Link names cannot be provided without a model" + ): + references_body.link_forces(model=None, link_names=model.link_names()) + + # `model` is not None and `data` is None. + with pytest.raises( + XlaRuntimeError, + match="Missing model data to use a representation different from `VelRepr.Inertial`", + ): + references_body.link_forces(model=model, data=None) + + +def test_raise_errors_apply_link_forces( + jaxsim_model_box: js.model.JaxSimModel, + prng_key: jax.Array, +): + + model = jaxsim_model_box + + _, subkey1, subkey2 = jax.random.split(prng_key, num=3) + + data = js.data.random_model_data(model=model, key=subkey1) + + # ================ + # VelRepr.Inertial + # ================ + + references_inertial = get_random_references( + model=model, data=None, velocity_representation=VelRepr.Inertial, key=subkey2 + ) + + # `model` is None + with pytest.raises( + XlaRuntimeError, + match="Link names cannot be provided without a model", + ): + references_inertial.apply_link_forces( + forces=jnp.zeros(6), model=None, data=None, link_names=model.link_names() + ) + + # ============ + # VelRepr.Body + # ============ + + references_body = get_random_references( + model=model, data=data, velocity_representation=VelRepr.Body, key=subkey2 + ) + + # `model` is None + with pytest.raises( + XlaRuntimeError, + match="Missing model to use a representation different from `VelRepr.Inertial`", + ): + references_body.apply_link_forces(forces=jnp.zeros(6), model=None, data=None) + + # `model` is not None and `data` is None. + with pytest.raises( + XlaRuntimeError, + match="Missing model data to use a representation different from `VelRepr.Inertial`", + ): + references_body.apply_link_forces( + forces=jnp.zeros(6), model=model, data=None, link_names=model.link_names() + ) diff --git a/tests/test_automatic_differentiation.py b/tests/test_automatic_differentiation.py index 132cfcca0..7a6294dad 100644 --- a/tests/test_automatic_differentiation.py +++ b/tests/test_automatic_differentiation.py @@ -26,7 +26,7 @@ def get_random_data_and_references( model: js.model.JaxSimModel, - velocity_representation: VelRepr, + velocity_representation: jtp.VelRepr, key: jax.Array, ) -> tuple[js.data.JaxSimModelData, js.references.JaxSimModelReferences]: diff --git a/tests/test_contact.py b/tests/test_contact.py index ff48385a7..5bcdc2cd6 100644 --- a/tests/test_contact.py +++ b/tests/test_contact.py @@ -2,12 +2,13 @@ import pytest import jaxsim.api as js +import jaxsim.typing as jtp from jaxsim import VelRepr def test_collidable_point_jacobians( jaxsim_models_types: js.model.JaxSimModel, - velocity_representation: VelRepr, + velocity_representation: jtp.VelRepr, prng_key: jax.Array, ): diff --git a/tests/test_simulations.py b/tests/test_simulations.py index 41807b8e2..6e5f85bad 100644 --- a/tests/test_simulations.py +++ b/tests/test_simulations.py @@ -5,13 +5,14 @@ import jaxsim.api as js import jaxsim.integrators import jaxsim.rbda +import jaxsim.typing as jtp from jaxsim import VelRepr from jaxsim.utils import Mutability def test_box_with_external_forces( jaxsim_model_box: js.model.JaxSimModel, - velocity_representation: VelRepr, + velocity_representation: jtp.VelRepr, ): """ This test simulates a box falling due to gravity. @@ -96,7 +97,7 @@ def test_box_with_external_forces( def test_box_with_zero_gravity( jaxsim_model_box: js.model.JaxSimModel, - velocity_representation: VelRepr, + velocity_representation: jtp.VelRepr, prng_key: jnp.ndarray, ): diff --git a/tests/utils_idyntree.py b/tests/utils_idyntree.py index 5376a82ef..80070a6f5 100644 --- a/tests/utils_idyntree.py +++ b/tests/utils_idyntree.py @@ -8,6 +8,7 @@ import numpy.typing as npt import jaxsim.api as js +import jaxsim.typing as jtp from jaxsim import VelRepr @@ -118,7 +119,7 @@ def store_jaxsim_data_in_kindyncomputations( class KinDynComputations: """High-level wrapper of the iDynTree KinDynComputations class.""" - vel_repr: VelRepr + vel_repr: jtp.VelRepr gravity: npt.NDArray kin_dyn: idt.KinDynComputations @@ -126,7 +127,7 @@ class KinDynComputations: def build( urdf: pathlib.Path | str, considered_joints: list[str] | None = None, - vel_repr: VelRepr = VelRepr.Inertial, + vel_repr: jtp.VelRepr = VelRepr.Inertial, gravity: npt.NDArray = np.array([0, 0, -10.0]), removed_joint_positions: dict[str, npt.NDArray | float | int] | None = None, ) -> KinDynComputations: