Skip to content

Commit

Permalink
Refactor typing
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Feb 12, 2024
1 parent 28c8a69 commit 29a6ba3
Showing 1 changed file with 10 additions and 11 deletions.
21 changes: 10 additions & 11 deletions src/jaxsim/simulation/integrators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import enum
from typing import Any, Callable, Dict, Tuple
from typing import Any, Callable

import jax
import jax.numpy as jnp
Expand All @@ -20,7 +20,7 @@
StateDerivative = jtp.PyTree

StateDerivativeCallable = Callable[
[State, Time], Tuple[StateDerivative, Dict[str, Any]]
[State, Time], tuple[StateDerivative, dict[str, Any]]
]


Expand All @@ -43,7 +43,7 @@ def integrator_fixed_single_step(
tf: Time,
integrator_type: IntegratorType,
num_sub_steps: int = 1,
) -> Tuple[State | ODEState, Dict[str, Any]]:
) -> tuple[State | ODEState, dict[str, Any]]:
"""
Advance a state vector by integrating a sytem dynamics with a fixed-step integrator.
Expand All @@ -65,10 +65,10 @@ def integrator_fixed_single_step(
sub_step_dt = dt / num_sub_steps

# Initialize the carry
Carry = Tuple[State | ODEState, Time]
Carry = tuple[State | ODEState, Time]
carry_init: Carry = (x0, t0)

def forward_euler_body_fun(carry: Carry, xs: None) -> Tuple[Carry, None]:
def forward_euler_body_fun(carry: Carry, xs: None) -> tuple[Carry, None]:
"""
Forward Euler integrator.
"""
Expand All @@ -92,7 +92,7 @@ def forward_euler_body_fun(carry: Carry, xs: None) -> Tuple[Carry, None]:

return carry, None

def rk4_body_fun(carry: Carry, xs: None) -> Tuple[Carry, None]:
def rk4_body_fun(carry: Carry, xs: None) -> tuple[Carry, None]:
"""
Runge-Kutta 4 integrator.
"""
Expand Down Expand Up @@ -125,7 +125,7 @@ def rk4_body_fun(carry: Carry, xs: None) -> Tuple[Carry, None]:

return carry, None

def semi_implicit_euler_body_fun(carry: Carry, xs: None) -> Tuple[Carry, None]:
def semi_implicit_euler_body_fun(carry: Carry, xs: None) -> tuple[Carry, None]:
"""
Semi-implicit Euler integrator.
"""
Expand Down Expand Up @@ -170,7 +170,6 @@ def semi_implicit_euler_body_fun(carry: Carry, xs: None) -> Tuple[Carry, None]:
# -----------------------------------------------------

# Extract the implicit angular velocity and the initial base quaternion

W_ω_WB = vel_tf[3:6]
W_Q_B = x_t0.physics_model.base_quaternion

Expand Down Expand Up @@ -301,10 +300,10 @@ def semi_implicit_euler_body_fun(carry: Carry, xs: None) -> Tuple[Carry, None]:


def integrate_single_step_over_horizon(
integrator_single_step: Callable[[Time, Time, State], Tuple[State, Dict[str, Any]]],
integrator_single_step: Callable[[Time, Time, State], tuple[State, dict[str, Any]]],
t: TimeHorizon,
x0: State,
) -> Tuple[State, Dict[str, Any]]:
) -> tuple[State, dict[str, Any]]:
"""
Integrate a single-step integrator over a given horizon.
Expand All @@ -320,7 +319,7 @@ def integrate_single_step_over_horizon(
# Initialize the carry
carry_init = (x0, t)

def body_fun(carry: Tuple, idx: int) -> Tuple[Tuple, jtp.PyTree]:
def body_fun(carry: tuple, idx: int) -> tuple[tuple, jtp.PyTree]:
# Unpack the carry
x_t0, horizon = carry

Expand Down

0 comments on commit 29a6ba3

Please sign in to comment.