From d14b67b823c8f13a98fb23aab0d9df5201a1348f Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Wed, 2 Apr 2025 17:25:39 +0000 Subject: [PATCH 1/4] Remove `omega_in_body_fixed` argument --- src/jaxsim/api/integrators.py | 1 - src/jaxsim/api/ode.py | 1 - src/jaxsim/math/quaternion.py | 1 - tests/test_api_frame.py | 1 - tests/test_api_link.py | 1 - tests/test_api_model.py | 1 - 6 files changed, 6 deletions(-) diff --git a/src/jaxsim/api/integrators.py b/src/jaxsim/api/integrators.py index 71840e03e..494d85c1b 100644 --- a/src/jaxsim/api/integrators.py +++ b/src/jaxsim/api/integrators.py @@ -52,7 +52,6 @@ def semi_implicit_euler_integration( W_Q̇_B = jaxsim.math.Quaternion.derivative( quaternion=data.base_orientation, omega=W_ω_WB, - omega_in_body_fixed=False, ).squeeze() W_p_B = data.base_position + dt * W_ṗ_B diff --git a/src/jaxsim/api/ode.py b/src/jaxsim/api/ode.py index 8c8a949b6..1969933f7 100644 --- a/src/jaxsim/api/ode.py +++ b/src/jaxsim/api/ode.py @@ -128,7 +128,6 @@ def system_position_dynamics( W_Q̇_B = Quaternion.derivative( quaternion=W_Q_B, omega=W_ω_WB, - omega_in_body_fixed=False, K=baumgarte_quaternion_regularization, ).squeeze() diff --git a/src/jaxsim/math/quaternion.py b/src/jaxsim/math/quaternion.py index d5a58869c..d2813a96c 100644 --- a/src/jaxsim/math/quaternion.py +++ b/src/jaxsim/math/quaternion.py @@ -68,7 +68,6 @@ def from_dcm(dcm: jtp.Matrix) -> jtp.Vector: def derivative( quaternion: jtp.Vector, omega: jtp.Vector, - omega_in_body_fixed: bool = False, K: float = 0.1, ) -> jtp.Vector: """ diff --git a/tests/test_api_frame.py b/tests/test_api_frame.py index e97030524..2c480598f 100644 --- a/tests/test_api_frame.py +++ b/tests/test_api_frame.py @@ -267,7 +267,6 @@ def compute_q̇(data: js.data.JaxSimModelData) -> jax.Array: W_Q̇_B = Quaternion.derivative( quaternion=data.base_orientation, omega=B_ω_WB, - omega_in_body_fixed=True, K=0.0, ).squeeze() diff --git a/tests/test_api_link.py b/tests/test_api_link.py index ee7b3c965..0f653abd2 100644 --- a/tests/test_api_link.py +++ b/tests/test_api_link.py @@ -357,7 +357,6 @@ def compute_q̇(data: js.data.JaxSimModelData) -> jax.Array: W_Q̇_B = jaxsim.math.Quaternion.derivative( quaternion=data.base_orientation, omega=B_ω_WB, - omega_in_body_fixed=True, K=0.0, ).squeeze() diff --git a/tests/test_api_model.py b/tests/test_api_model.py index 27a0775d5..41993c64c 100644 --- a/tests/test_api_model.py +++ b/tests/test_api_model.py @@ -458,7 +458,6 @@ def compute_q̇(data: js.data.JaxSimModelData) -> jax.Array: W_Q̇_B = jaxsim.math.Quaternion.derivative( quaternion=data.base_orientation, omega=B_ω_WB, - omega_in_body_fixed=True, K=0.0, ).squeeze() From a0fd04ac276a36602793b2823feb5588413fdecc Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Wed, 2 Apr 2025 17:25:57 +0000 Subject: [PATCH 2/4] Implement quaternion derivative via Hamilton product --- src/jaxsim/math/quaternion.py | 63 ++++++++++++----------------------- 1 file changed, 21 insertions(+), 42 deletions(-) diff --git a/src/jaxsim/math/quaternion.py b/src/jaxsim/math/quaternion.py index d2813a96c..cb8f15406 100644 --- a/src/jaxsim/math/quaternion.py +++ b/src/jaxsim/math/quaternion.py @@ -76,59 +76,38 @@ def derivative( Args: quaternion: Quaternion in XYZW representation. omega: Angular velocity vector. - omega_in_body_fixed (bool): Whether the angular velocity is in the body-fixed frame. K (float): A scaling factor. Returns: The derivative of the quaternion. """ ω = omega.squeeze() - quaternion = quaternion.squeeze() - - def Q_body(q: jtp.Vector) -> jtp.Matrix: - qw, qx, qy, qz = q - - return jnp.array( - [ - [qw, -qx, -qy, -qz], - [qx, qw, -qz, qy], - [qy, qz, qw, -qx], - [qz, -qy, qx, qw], - ] - ) - - def Q_inertial(q: jtp.Vector) -> jtp.Matrix: - qw, qx, qy, qz = q - - return jnp.array( - [ - [qw, -qx, -qy, -qz], - [qx, qw, qz, -qy], - [qy, -qz, qw, qx], - [qz, qy, -qx, qw], - ] - ) - - Q = jax.lax.cond( - pred=omega_in_body_fixed, - true_fun=Q_body, - false_fun=Q_inertial, - operand=quaternion, + q = quaternion.squeeze() + + # Construct pure quaternion: (scalar damping term, angular velocity components) + ω_quat = jnp.hstack([K * safe_norm(ω) * (1 - safe_norm(quaternion)), ω]) + + # Apply quaternion multiplication based on frame representation + i_idx = jnp.array([[0, 1, 2, 3], [0, 1, 2, 3], [0, 2, 3, 1], [0, 3, 1, 2]]) + j_idx = jnp.array([[0, 1, 2, 3], [1, 0, 3, 2], [2, 0, 1, 3], [3, 0, 2, 1]]) + sign_matrix = jnp.array( + [ + [1, -1, -1, -1], + [1, 1, 1, -1], + [1, 1, 1, -1], + [1, 1, 1, -1], + ] ) - norm_ω = safe_norm(ω) + # Compute quaternion derivative via Einstein summation + q_outer = jnp.einsum("...i,...j->...ij", q, ω_quat) - qd = 0.5 * ( - Q - @ jnp.hstack( - [ - K * norm_ω * (1 - safe_norm(quaternion)), - ω, - ] - ) + Qd = jnp.sum( + sign_matrix * q_outer[..., i_idx, j_idx], + axis=-1, ) - return jnp.vstack(qd) + return 0.5 * Qd @staticmethod def integration( From 94115930515783bbca77a55b26d3be3abd2d189f Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Tue, 15 Apr 2025 14:24:19 +0100 Subject: [PATCH 3/4] Update method to compute the outer product --- src/jaxsim/math/quaternion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jaxsim/math/quaternion.py b/src/jaxsim/math/quaternion.py index cb8f15406..a70e990ea 100644 --- a/src/jaxsim/math/quaternion.py +++ b/src/jaxsim/math/quaternion.py @@ -100,7 +100,7 @@ def derivative( ) # Compute quaternion derivative via Einstein summation - q_outer = jnp.einsum("...i,...j->...ij", q, ω_quat) + q_outer = jnp.outer(q, ω_quat) Qd = jnp.sum( sign_matrix * q_outer[..., i_idx, j_idx], From ee4761193cc0ee7f5900af6763ae8a1739a763bb Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti <102977828+flferretti@users.noreply.github.com> Date: Wed, 16 Apr 2025 11:59:30 +0100 Subject: [PATCH 4/4] Add detailed comments for quaternion multiplication --- src/jaxsim/math/quaternion.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/jaxsim/math/quaternion.py b/src/jaxsim/math/quaternion.py index a70e990ea..3c88ebfa5 100644 --- a/src/jaxsim/math/quaternion.py +++ b/src/jaxsim/math/quaternion.py @@ -87,7 +87,17 @@ def derivative( # Construct pure quaternion: (scalar damping term, angular velocity components) ω_quat = jnp.hstack([K * safe_norm(ω) * (1 - safe_norm(quaternion)), ω]) - # Apply quaternion multiplication based on frame representation + # Quaternion multiplication using index tables. + # This approach avoids using the explicit quaternion multiplication formula + # by encoding the necessary element-wise products and signs via indexed operations. + # Given two quaternions q and w, their Hamilton product q ⊗ w can be written + # as a combination of q[i] * w[j] terms with appropriate signs. + # i_idx and j_idx define which elements of the outer product q ⊗ w to select. + # For example, i_idx[1][2] = 2 and j_idx[1][2] = 3 means: take q[2] * w[3] for this term. + # sign_matrix[i][j] gives the sign (+1 or -1) to apply to each q[i] * w[j] term, + # depending on quaternion multiplication rules. + # This indexed summation reproduces the Hamilton product of quaternions in a + # vectorized way, and is suitable for use with JAX. i_idx = jnp.array([[0, 1, 2, 3], [0, 1, 2, 3], [0, 2, 3, 1], [0, 3, 1, 2]]) j_idx = jnp.array([[0, 1, 2, 3], [1, 0, 3, 2], [2, 0, 1, 3], [3, 0, 2, 1]]) sign_matrix = jnp.array(