Skip to content

Commit

Permalink
Simplify integrators module structure
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Feb 8, 2024
1 parent 9ba607b commit c3808bb
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 268 deletions.
312 changes: 60 additions & 252 deletions src/jaxsim/simulation/integrators.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import enum
from typing import Any, Callable, Dict, Tuple, Union

import jax
Expand All @@ -22,43 +23,39 @@
[State, Time], Tuple[StateDerivative, Dict[str, Any]]
]

Carry = Tuple[State, Time]


class IntegratorType(enum.IntEnum):
RungeKutta4 = enum.auto()
EulerForward = enum.auto()
EulerSemiImplicit = enum.auto()
EulerSemiImplicitManifold = enum.auto()


# =======================
# Single-step integration
# =======================


def odeint_euler_one_step(
def single_step(
dx_dt: StateDerivativeCallable,
x0: State,
t0: Time,
tf: Time,
integrator_type: IntegratorType,
num_sub_steps: int = 1,
) -> Tuple[State, Dict[str, Any]]:
"""
Forward Euler integrator.
Args:
dx_dt: Callable that computes the state derivative.
x0: Initial state.
t0: Initial time.
tf: Final time.
num_sub_steps: Number of sub-steps to break the integration into.
Returns:
The final state and a dictionary including auxiliary data at t0.
"""

""""""
# Compute the sub-step size.
# We break dt in configurable sub-steps.
dt = tf - t0
sub_step_dt = dt / num_sub_steps

# Initialize the carry
Carry = Tuple[State, Time]
carry_init: Carry = (x0, t0)

def body_fun(carry: Carry, xs: None) -> Tuple[Carry, None]:
def forward_euler_body_fun(carry: Carry, xs: None) -> Tuple[Carry, None]:
# Unpack the carry
x_t0, t0 = carry

Expand All @@ -78,48 +75,7 @@ def body_fun(carry: Carry, xs: None) -> Tuple[Carry, None]:

return carry, None

# Integrate over the given horizon
(x_tf, _), _ = jax.lax.scan(
f=body_fun, init=carry_init, xs=None, length=num_sub_steps
)

# Compute the aux dictionary at t0
_, aux_t0 = dx_dt(x0, t0)

return x_tf, aux_t0


def odeint_rk4_one_step(
dx_dt: StateDerivativeCallable,
x0: State,
t0: Time,
tf: Time,
num_sub_steps: int = 1,
) -> Tuple[State, Dict[str, Any]]:
"""
Runge-Kutta 4 integrator.
Args:
dx_dt: Callable that computes the state derivative.
x0: Initial state.
t0: Initial time.
tf: Final time.
num_sub_steps: Number of sub-steps to break the integration into.
Returns:
The final state and a dictionary including auxiliary data at t0.
"""

# Compute the sub-step size.
# We break dt in configurable sub-steps.
dt = tf - t0
sub_step_dt = dt / num_sub_steps

# Initialize the carry
Carry = Tuple[State, Time]
carry_init: Carry = (x0, t0)

def body_fun(carry: Carry, xs: None) -> Tuple[Carry, None]:
def rk4_body_fun(carry: Carry, xs: None) -> Tuple[Carry, None]:
# Unpack the carry
x_t0, t0 = carry

Expand Down Expand Up @@ -148,49 +104,7 @@ def body_fun(carry: Carry, xs: None) -> Tuple[Carry, None]:

return carry, None

# Integrate over the given horizon
(x_tf, _), _ = jax.lax.scan(
f=body_fun, init=carry_init, xs=None, length=num_sub_steps
)

# Compute the aux dictionary at t0
_, aux_t0 = dx_dt(x0, t0)

return x_tf, aux_t0


def odeint_euler_semi_implicit_one_step(
dx_dt: StateDerivativeCallable,
x0: ODEState,
t0: Time,
tf: Time,
num_sub_steps: int = 1,
) -> Tuple[ODEState, Dict[str, Any]]:
"""
Semi-implicit Euler integrator.
Args:
dx_dt: Callable that computes the state derivative.
x0: Initial state as ODEState object.
t0: Initial time.
tf: Final time.
num_sub_steps: Number of sub-steps to break the integration into.
Returns:
A tuple having as first element the final state as ODEState object,
and as second element a dictionary including auxiliary data at t0.
"""

# Compute the sub-step size.
# We break dt in configurable sub-steps.
dt = tf - t0
sub_step_dt = dt / num_sub_steps

# Initialize the carry
Carry = Tuple[ODEState, Time]
carry_init: Carry = (x0, t0)

def body_fun(carry: Carry, xs: None) -> Tuple[Carry, None]:
def semi_implicit_euler_body_fun(carry: Carry, xs: None) -> Tuple[Carry, None]:
# Unpack the carry
x_t0, t0 = carry

Expand Down Expand Up @@ -294,51 +208,9 @@ def body_fun(carry: Carry, xs: None) -> Tuple[Carry, None]:

return carry, None

# Integrate over the given horizon
(x_tf, _), _ = jax.lax.scan(
f=body_fun, init=carry_init, xs=None, length=num_sub_steps
)

# Compute the aux dictionary at t0
_, aux_t0 = dx_dt(x0, t0)

return x_tf, aux_t0


def odeint_euler_semi_implicit_manifold_one_step(
dx_dt: StateDerivativeCallable,
x0: ODEState,
t0: Time,
tf: Time,
num_sub_steps: int = 1,
) -> Tuple[ODEState, Dict[str, Any]]:
"""
Semi-implicit Euler integrator with quaternion integration on SO(3).
Args:
dx_dt: Callable that computes the state derivative.
x0: Initial state as ODEState object.
t0: Initial time.
tf: Final time.
num_sub_steps: Number of sub-steps to break the integration into.
Returns:
A tuple having as first element the final state as ODEState object,
and as second element a dictionary including auxiliary data at t0.
"""

# Compute the sub-step size.
# We break dt in configurable sub-steps.
dt = tf - t0
sub_step_dt = dt / num_sub_steps

# Integrate the quaternion on its manifold using the new angular velocity

# Initialize the carry
Carry = Tuple[ODEState, Time]
carry_init: Carry = (x0, t0)

def body_fun(carry: Carry, xs: None) -> Tuple[Carry, None]:
def semi_implicit_euler_manifold_body_fun(
carry: Carry, xs: None
) -> Tuple[Carry, None]:
# Unpack the carry
x_t0, t0 = carry

Expand Down Expand Up @@ -436,34 +308,43 @@ def body_fun(carry: Carry, xs: None) -> Tuple[Carry, None]:

return carry, None

_integrator_registry = {
IntegratorType.RungeKutta4: rk4_body_fun,
IntegratorType.EulerForward: forward_euler_body_fun,
IntegratorType.EulerSemiImplicit: semi_implicit_euler_body_fun,
IntegratorType.EulerSemiImplicitManifold: semi_implicit_euler_manifold_body_fun,
}

# Get the body function for the selected integrator
body_fun = _integrator_registry[integrator_type]

# Integrate over the given horizon
(x_no_quat_tf, _), _ = jax.lax.scan(
(x_tf, _), _ = jax.lax.scan(
f=body_fun, init=carry_init, xs=None, length=num_sub_steps
)

# ---------------------------------------------
# 5. Integrate the quaternion on SO(3) manifold
# ---------------------------------------------
if integrator_type is IntegratorType.EulerSemiImplicitManifold:
# Indices to convert quaternions between serializations
to_xyzw = np.array([1, 2, 3, 0])
to_wxyz = np.array([3, 0, 1, 2])

# Indices to convert quaternions between serializations
to_xyzw = jnp.array([1, 2, 3, 0])
to_wxyz = jnp.array([3, 0, 1, 2])

# Get the initial quaternion and the implicitly integrated angular velocity
W_ω_WB_tf = x_no_quat_tf.physics_model.base_angular_velocity
W_Q_B_t0 = so3.SO3.from_quaternion_xyzw(x0.physics_model.base_quaternion[to_xyzw])
# Get the initial quaternion and the implicitly integrated angular velocity
W_ω_WB_tf = x_no_quat_tf.physics_model.base_angular_velocity
W_Q_B_t0 = so3.SO3.from_quaternion_xyzw(
x0.physics_model.base_quaternion[to_xyzw]
)

# Integrate the quaternion on its manifold using the implicit angular velocity,
# transformed in body-fixed representation since jaxlie uses this convention
B_R_W = W_Q_B_t0.inverse().as_matrix()
W_Q_B_tf = W_Q_B_t0 @ so3.SO3.exp(tangent=dt * B_R_W @ W_ω_WB_tf)
# Integrate the quaternion on its manifold using the implicit angular velocity,
# transformed in body-fixed representation since jaxlie uses this convention
B_R_W = W_Q_B_t0.inverse().as_matrix()
W_Q_B_tf = W_Q_B_t0 @ so3.SO3.exp(tangent=dt * B_R_W @ W_ω_WB_tf)

# Store the quaternion in the final state
x_tf = x_no_quat_tf.replace(
physics_model=x_no_quat_tf.physics_model.replace(
base_quaternion=W_Q_B_tf.as_quaternion_xyzw()[to_wxyz]
# Store the quaternion in the final state
x_tf = x_no_quat_tf.replace(
physics_model=x_no_quat_tf.physics_model.replace(
base_quaternion=W_Q_B_tf.as_quaternion_xyzw()[to_wxyz]
)
)
)

# Compute the aux dictionary at t0
_, aux_t0 = dx_dt(x0, t0)
Expand Down Expand Up @@ -526,94 +407,16 @@ def body_fun(carry: Tuple, idx: int) -> Tuple[Tuple, jtp.PyTree]:
# ===================================================================


def odeint_euler(
func,
y0: State,
t: TimeHorizon,
*args,
num_sub_steps: int = 1,
return_aux: bool = False
) -> Union[State, Tuple[State, Dict[str, Any]]]:
"""
Integrate a system of ODEs using the Euler method.
Args:
func: A function that computes the time-derivative of the state.
y0: The initial state.
t: The vector of time instants of the integration horizon.
*args: Additional arguments to be passed to the function func.
num_sub_steps: The number of sub-steps to be performed within each integration step.
return_aux: Whether to return the auxiliary data produced by the integrator.
Returns:
The state of the system at the end of the integration horizon, and optionally
the auxiliary data produced by the integrator.
"""

# Close func over additional inputs and parameters
dx_dt_closure_aux = lambda x, ts: func(x, ts, *args)

# Close one-step integration over its arguments
integrator_single_step = lambda t0, tf, x0: odeint_euler_one_step(
dx_dt=dx_dt_closure_aux, x0=x0, t0=t0, tf=tf, num_sub_steps=num_sub_steps
)

# Integrate the state and compute optional auxiliary data over the horizon
out, aux = integrate_single_step_over_horizon(
integrator_single_step=integrator_single_step, t=t, x0=y0
)

return (out, aux) if return_aux else out


def odeint_euler_semi_implicit(
func,
y0: ODEState,
t: TimeHorizon,
*args,
num_sub_steps: int = 1,
return_aux: bool = False
) -> Union[ODEState, Tuple[ODEState, Dict[str, Any]]]:
"""
Integrate a system of ODEs using the Semi-Implicit Euler method.
Args:
func: A function that computes the time-derivative of the state.
y0: The initial state as ODEState object.
t: The vector of time instants of the integration horizon.
*args: Additional arguments to be passed to the function func.
num_sub_steps: The number of sub-steps to be performed within each integration step.
return_aux: Whether to return the auxiliary data produced by the integrator.
Returns:
The state of the system at the end of the integration horizon as ODEState object,
and optionally the auxiliary data produced by the integrator.
"""

# Close func over additional inputs and parameters
dx_dt_closure_aux = lambda x, ts: func(x, ts, *args)

# Close one-step integration over its arguments
integrator_single_step = lambda t0, tf, x0: odeint_euler_semi_implicit_one_step(
dx_dt=dx_dt_closure_aux, x0=x0, t0=t0, tf=tf, num_sub_steps=num_sub_steps
)

# Integrate the state and compute optional auxiliary data over the horizon
out, aux = integrate_single_step_over_horizon(
integrator_single_step=integrator_single_step, t=t, x0=y0
)

return (out, aux) if return_aux else out


def odeint_rk4(
def odeint(
func,
y0: State,
t: TimeHorizon,
*args,
method: str = "euler",
num_sub_steps: int = 1,
return_aux: bool = False
) -> Union[State, Tuple[State, Dict[str, Any]]]:
return_aux: bool = False,
integrator_type: IntegratorType = None,
):
"""
Integrate a system of ODEs using the Runge-Kutta 4 method.
Expand All @@ -634,8 +437,13 @@ def odeint_rk4(
dx_dt_closure = lambda x, ts: func(x, ts, *args)

# Close one-step integration over its arguments
integrator_single_step = lambda t0, tf, x0: odeint_rk4_one_step(
dx_dt=dx_dt_closure, x0=x0, t0=t0, tf=tf, num_sub_steps=num_sub_steps
integrator_single_step = lambda t0, tf, x0: single_step(
dx_dt=dx_dt_closure,
x0=x0,
t0=t0,
tf=tf,
num_sub_steps=num_sub_steps,
integrator_type=integrator_type,
)

# Integrate the state and compute optional auxiliary data over the horizon
Expand Down
Loading

0 comments on commit c3808bb

Please sign in to comment.