Skip to content

Commit

Permalink
Set static_argnames for relevant jax-jitted functions
Browse files Browse the repository at this point in the history
  • Loading branch information
moeyensj committed Sep 17, 2024
1 parent 941e3fb commit 7c232b1
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 11 deletions.
9 changes: 5 additions & 4 deletions src/adam_core/coordinates/transform.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
from functools import partial
from typing import Literal, Optional, Union

import jax.numpy as jnp
Expand Down Expand Up @@ -563,7 +564,7 @@ def cartesian_to_keplerian(
return coords_keplerian


@jit
@partial(jit, static_argnames=("max_iter", "tol"))
def _keplerian_to_cartesian_p(
coords_keplerian: Union[np.ndarray, jnp.ndarray],
mu: float,
Expand Down Expand Up @@ -699,7 +700,7 @@ def _keplerian_to_cartesian_p(
)


@jit
@partial(jit, static_argnames=("max_iter", "tol"))
def _keplerian_to_cartesian_a(
coords_keplerian: Union[np.ndarray, jnp.ndarray],
mu: float,
Expand Down Expand Up @@ -782,7 +783,7 @@ def _keplerian_to_cartesian_a(
)


@jit
@partial(jit, static_argnames=("max_iter", "tol"))
def _keplerian_to_cartesian_q(
coords_keplerian: Union[np.ndarray, jnp.ndarray],
mu: float,
Expand Down Expand Up @@ -1039,7 +1040,7 @@ def cartesian_to_cometary(
return coords_cometary


@jit
@partial(jit, static_argnames=("max_iter", "tol"))
def _cometary_to_cartesian(
coords_cometary: Union[np.ndarray, jnp.ndarray],
t0: float,
Expand Down
5 changes: 3 additions & 2 deletions src/adam_core/dynamics/aberrations.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
from typing import Tuple

import jax.numpy as jnp
Expand All @@ -12,7 +13,7 @@
C = c.C


@jit
@partial(jit, static_argnames=("lt_tol", "mu", "tol", "max_iter"))
def _add_light_time(
orbit: jnp.ndarray,
t0: float,
Expand Down Expand Up @@ -108,7 +109,7 @@ def _while_condition(p):
)


@jit
@partial(jit, static_argnames=("lt_tol", "mu", "tol", "max_iter"))
def add_light_time(
orbits: jnp.ndarray,
t0: jnp.ndarray,
Expand Down
3 changes: 2 additions & 1 deletion src/adam_core/dynamics/chi.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
from typing import Tuple

import jax.numpy as jnp
Expand All @@ -12,7 +13,7 @@
MU = c.MU


@jit
@partial(jit, static_argnames=("mu", "max_iter", "tol"))
def calc_chi(
r: jnp.ndarray,
v: jnp.ndarray,
Expand Down
3 changes: 2 additions & 1 deletion src/adam_core/dynamics/ephemeris.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
from typing import Tuple

import jax.numpy as jnp
Expand All @@ -18,7 +19,7 @@
from .aberrations import _add_light_time, add_stellar_aberration


@jit
@partial(jit, static_argnames=("lt_tol", "max_iter", "tol", "stellar_aberration"))
def _generate_ephemeris_2body(
propagated_orbit: np.ndarray,
observation_time: float,
Expand Down
3 changes: 2 additions & 1 deletion src/adam_core/dynamics/kepler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
from typing import Tuple

import jax.numpy as jnp
Expand Down Expand Up @@ -194,7 +195,7 @@ def _calc_parabolic_anomalies(nu: float, e: float) -> Tuple[float, float]:
return D, M


@jit
@partial(jit, static_argnames=("max_iter", "tol"))
def solve_kepler(e: float, M: float, max_iter: int = 100, tol: float = 1e-15) -> float:
"""
Solve Kepler's equation for true anomaly (nu) given eccentricity
Expand Down
3 changes: 2 additions & 1 deletion src/adam_core/dynamics/lagrange.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
from typing import Tuple

import jax.numpy as jnp
Expand All @@ -14,7 +15,7 @@
LAGRANGE_TYPES = Tuple[jnp.float64, jnp.float64, jnp.float64, jnp.float64]


@jit
@partial(jit, static_argnames=("mu", "max_iter", "tol"))
def calc_lagrange_coefficients(
r: jnp.ndarray,
v: jnp.ndarray,
Expand Down
4 changes: 3 additions & 1 deletion src/adam_core/dynamics/propagation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import partial

import jax.numpy as jnp
import numpy as np
from jax import config, jit, vmap
Expand All @@ -15,7 +17,7 @@
config.update("jax_enable_x64", True)


@jit
@partial(jit, static_argnames=("max_iter", "tol"))
def _propagate_2body(
orbit: jnp.ndarray,
t0: float,
Expand Down

0 comments on commit 7c232b1

Please sign in to comment.