Skip to content

Commit

Permalink
Fix running AD through jaxsim.math.Rotation.from_axis_angle
Browse files Browse the repository at this point in the history
  • Loading branch information
diegoferigo committed Oct 30, 2024
1 parent 78c55cc commit be9f3bf
Showing 1 changed file with 22 additions and 11 deletions.
33 changes: 22 additions & 11 deletions src/jaxsim/math/rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,18 @@ def from_axis_angle(vector: jtp.Vector) -> jtp.Matrix:
Generate a 3D rotation matrix from an axis-angle representation.
Args:
vector (jtp.Vector): Axis-angle representation as a 3D vector.
vector: Axis-angle representation or the rotation as a 3D vector.
Returns:
jtp.Matrix: 3D rotation matrix.
The SO(3) rotation matrix.
"""

vector = vector.squeeze()
theta = jnp.linalg.norm(vector)

def theta_is_not_zero(theta_and_v: tuple[jtp.Float, jtp.Vector]) -> jtp.Matrix:
theta, v = theta_and_v
def theta_is_not_zero(axis: jtp.Vector) -> jtp.Matrix:

v = axis
theta = jnp.linalg.norm(v)

s = jnp.sin(theta)
c = jnp.cos(theta)
Expand All @@ -77,9 +78,19 @@ def theta_is_not_zero(theta_and_v: tuple[jtp.Float, jtp.Vector]) -> jtp.Matrix:

return R.transpose()

return jax.lax.cond(
pred=(theta == 0.0),
true_fun=lambda operand: jnp.eye(3),
false_fun=theta_is_not_zero,
operand=(theta, vector),
# Use the double-where trick to prevent JAX problems when the
# jax.jit and jax.grad transforms are applied.
return jnp.where(
jnp.linalg.norm(vector) > 0,
theta_is_not_zero(
axis=jnp.where(
jnp.linalg.norm(vector) > 0,
vector,
# The following line is a workaround to prevent division by 0.
# Considering the outer where, this branch is never executed.
jnp.ones(3),
)
),
# Return an identity rotation matrix when the input vector is zero.
jnp.eye(3),
)

0 comments on commit be9f3bf

Please sign in to comment.