Skip to content

Commit

Permalink
Refactor heun2_integration to remove aux_dict
Browse files Browse the repository at this point in the history
  • Loading branch information
xela-95 committed Jan 21, 2025
1 parent bb2afa3 commit d7494de
Showing 1 changed file with 35 additions and 24 deletions.
59 changes: 35 additions & 24 deletions src/jaxsim/api/integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,19 @@
import jaxsim
import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim.math.skew import Skew


def semi_implicit_euler_integration(model, data, link_forces, joint_force_references):
"""Integrate the system state using the semi-implicit Euler method."""
# Step the dynamics forward.

with (data.switch_velocity_representation(jaxsim.api.common.VelRepr.Inertial)):
with data.switch_velocity_representation(jaxsim.api.common.VelRepr.Inertial):
a_b, dds, _ = js.ode.system_velocity_dynamics(
model=model, data=data, link_forces=link_forces, joint_force_references=joint_force_references
model=model,
data=data,
link_forces=link_forces,
joint_force_references=joint_force_references,
)
generalized_acceleration = jnp.hstack(((a_b), (dds)))
new_velocity = (
Expand All @@ -27,7 +31,9 @@ def semi_implicit_euler_integration(model, data, link_forces, joint_force_refere

quat = data.base_orientation(dcm=False)
angular_velocity_norm = jnp.linalg.norm(base_ang_velocity)
axis_angular_velocity = base_ang_velocity / (angular_velocity_norm + 1e-6 * (angular_velocity_norm == 0))
axis_angular_velocity = base_ang_velocity / (
angular_velocity_norm + 1e-6 * (angular_velocity_norm == 0)
)
angle_rotation = model.time_step * angular_velocity_norm
delta_quat = axis_angle_to_quat(axis_angular_velocity, angle_rotation)
new_quaternion = quat_mul(quat, delta_quat)
Expand All @@ -53,10 +59,13 @@ def semi_implicit_euler_integration(model, data, link_forces, joint_force_refere

def heun2_integration(model, data, link_forces, joint_force_references):
"""Integrate the system state using the Heun's method."""
A: jtp.Matrix = jnp.array([
[0, 0],
[1, 0],
], dtype=float)
A: jtp.Matrix = jnp.array(
[
[0, 0],
[1, 0],
],
dtype=float,
)

b: jtp.Matrix = jnp.array([[1 / 2, 1 / 2]], dtype=float).transpose()
c: jtp.Vector = jnp.array([0, 1], dtype=float)
Expand Down Expand Up @@ -85,7 +94,7 @@ def scan_body(carry, i):
with data.editable(validate=True) as data_rw:
data_rw.state = xi

ki, _ = js.ode.system_dynamics(
ki = js.ode.system_dynamics(
model,
data,
link_forces=link_forces,
Expand Down Expand Up @@ -116,18 +125,18 @@ def scan_body(carry, i):


def axis_angle_to_quat(axis: jax.Array, angle: jax.Array) -> jax.Array:
"""
Provides a quaternion that describes rotating around axis by angle.
"""
Provide a quaternion that describes rotating around axis by angle.
Args:
axis: (3,) axis (x,y,z)
angle: () float angle to rotate by
Args:
axis: (3,) axis (x,y,z)
angle: () float angle to rotate by
Returns:
A quaternion that rotates around axis by angle
"""
s, c = jnp.sin(angle * 0.5), jnp.cos(angle * 0.5)
return jnp.insert(axis * s, 0, c)
Returns:
A quaternion that rotates around axis by angle
"""
s, c = jnp.sin(angle * 0.5), jnp.cos(angle * 0.5)
return jnp.insert(axis * s, 0, c)


def quat_mul(u: jax.Array, v: jax.Array) -> jax.Array:
Expand All @@ -141,9 +150,11 @@ def quat_mul(u: jax.Array, v: jax.Array) -> jax.Array:
Returns:
A quaternion u * v.
"""
return jnp.array([
u[0] * v[0] - u[1] * v[1] - u[2] * v[2] - u[3] * v[3],
u[0] * v[1] + u[1] * v[0] + u[2] * v[3] - u[3] * v[2],
u[0] * v[2] - u[1] * v[3] + u[2] * v[0] + u[3] * v[1],
u[0] * v[3] + u[1] * v[2] - u[2] * v[1] + u[3] * v[0],
])
return jnp.array(
[
u[0] * v[0] - u[1] * v[1] - u[2] * v[2] - u[3] * v[3],
u[0] * v[1] + u[1] * v[0] + u[2] * v[3] - u[3] * v[2],
u[0] * v[2] - u[1] * v[3] + u[2] * v[0] + u[3] * v[1],
u[0] * v[3] + u[1] * v[2] - u[2] * v[1] + u[3] * v[0],
]
)

0 comments on commit d7494de

Please sign in to comment.