Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add generic Ephemeris implementation #115

Merged
merged 11 commits into from
Aug 1, 2024
1 change: 1 addition & 0 deletions src/adam_core/_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__version__ = "0.2.1.dev5+g4d3da0c.d20240731"
14 changes: 14 additions & 0 deletions src/adam_core/coordinates/origin.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,20 @@ def SOLAR_SYSTEM_BARYCENTER(cls) -> float:
class Origin(qv.Table):
code = qv.LargeStringColumn()

def as_OriginCodes(self) -> OriginCodes:
"""
Convert the origin codes to an `~adam_core.coordinates.origin.OriginCodes` object.

Returns
-------
OriginCodes
Origin codes as an `~adam_core.coordinates.origin.OriginCodes` object.
"""
assert (
len(self.code.unique()) == 1
), "Only one origin code can be converted at a time."
return OriginCodes[self.code.unique()[0].as_py()]

def __eq__(self, other: object) -> np.ndarray:
if isinstance(other, (str, np.ndarray)):
codes = self.code.to_numpy(zero_copy_only=False)
Expand Down
2 changes: 1 addition & 1 deletion src/adam_core/coordinates/residuals.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def calculate(
raise TypeError(
f"Predicted coordinates must be one of {SUPPORTED_COORDINATES}, not {type(predicted)}."
)
if type(observed) != type(predicted):
if type(observed) is not type(predicted):
raise TypeError(
"Observed and predicted coordinates must be the same type, "
f"not {type(observed)} and {type(predicted)}."
Expand Down
2 changes: 1 addition & 1 deletion src/adam_core/coordinates/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1368,7 +1368,7 @@ def transform_coordinates(
# `~adam_core.coordinates.origin.OriginCodes` so we can compare them directly.
# If its not an OriginCodes enum then origin_out will be an array of strings which
# also can be checked for equality.
if type(coords) == representation_out_:
if type(coords) is representation_out_:
if coord_frame == frame_out and np.all(coord_origin == origin_out):
return coords

Expand Down
36 changes: 28 additions & 8 deletions src/adam_core/orbits/query/horizons.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@

import numpy.typing as npt
import pandas as pd
import pyarrow as pa
from astroquery.jplhorizons import Horizons

from ...coordinates.cartesian import CartesianCoordinates
from ...coordinates.cometary import CometaryCoordinates
from ...coordinates.keplerian import KeplerianCoordinates
from ...coordinates.origin import Origin
from ...coordinates.spherical import SphericalCoordinates
from ...observers import Observers
from ...time import Timestamp
from ..ephemeris import Ephemeris
from ..orbits import Orbits


Expand Down Expand Up @@ -53,7 +56,7 @@ def _get_horizons_vectors(
for i, obj_id in enumerate(object_ids):
obj = Horizons(
id=obj_id,
epochs=times.rescale("tdb").mjd().to_numpy(zero_copy_only=False),
epochs=times.rescale("tdb").jd().to_numpy(zero_copy_only=False),
location=location,
id_type=id_type,
)
Expand Down Expand Up @@ -157,10 +160,11 @@ def _get_horizons_ephemeris(
as seen from the observer location at the given times.
"""
dfs = []
jd_utc = times.rescale("utc").jd().to_numpy(zero_copy_only=False)
for i, obj_id in enumerate(object_ids):
obj = Horizons(
id=obj_id,
epochs=times.rescale("utc").mjd().to_numpy(zero_copy_only=False),
epochs=jd_utc,
location=location,
id_type=id_type,
)
Expand All @@ -171,7 +175,7 @@ def _get_horizons_ephemeris(
cache=False,
).to_pandas()
ephemeris.insert(0, "orbit_id", f"{i:05d}")
ephemeris.insert(2, "mjd_utc", times.utc.mjd)
ephemeris.insert(2, "jd_utc", jd_utc)
ephemeris.insert(3, "observatory_code", location)

dfs.append(ephemeris)
Expand All @@ -187,7 +191,7 @@ def _get_horizons_ephemeris(

def query_horizons_ephemeris(
object_ids: Union[List, npt.ArrayLike], observers: Observers
) -> pd.DataFrame:
) -> Ephemeris:
"""
Query JPL Horizons (through astroquery) for an object's predicted ephemeris
as seen from a given location at the given times.
Expand All @@ -208,19 +212,35 @@ def query_horizons_ephemeris(
"""
dfs = []
for observatory_code, observers_i in observers.iterate_codes():
ephemeris = _get_horizons_ephemeris(
_ephemeris = _get_horizons_ephemeris(
object_ids,
observers_i.coordinates.time,
observatory_code,
)
dfs.append(ephemeris)
dfs.append(_ephemeris)

ephemeris = pd.concat(dfs, ignore_index=True)
ephemeris.sort_values(
dfs = pd.concat(dfs, ignore_index=True)
dfs.sort_values(
by=["orbit_id", "datetime_jd", "observatory_code"],
inplace=True,
ignore_index=True,
)

ephemeris = Ephemeris.from_kwargs(
orbit_id=dfs["orbit_id"],
object_id=dfs["targetname"],
# Convert from minutes to days
light_time=dfs["lighttime"] / 1440,
alpha=dfs["alpha"],
coordinates=SphericalCoordinates.from_kwargs(
time=Timestamp.from_jd(pa.array(dfs["datetime_jd"]), scale="utc"),
lon=dfs["RA"],
lat=dfs["DEC"],
origin=Origin.from_kwargs(code=dfs["observatory_code"]),
frame="ecliptic",
),
)

return ephemeris


Expand Down
171 changes: 161 additions & 10 deletions src/adam_core/propagator/propagator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,23 @@
import numpy.typing as npt
import quivr as qv

from adam_core.ray_cluster import initialize_use_ray

from ..constants import Constants as c
from ..coordinates.cartesian import CartesianCoordinates
from ..coordinates.origin import Origin, OriginCodes
from ..coordinates.spherical import SphericalCoordinates
from ..coordinates.transform import transform_coordinates
from ..observers.observers import Observers
from ..orbits.ephemeris import Ephemeris
from ..orbits.orbits import Orbits
from ..orbits.variants import VariantEphemeris, VariantOrbits
from ..ray_cluster import initialize_use_ray
from ..time import Timestamp
from .utils import _iterate_chunks

logger = logging.getLogger(__name__)

C = c.C

RAY_INSTALLED = False
try:
import ray
Expand Down Expand Up @@ -89,17 +95,162 @@ class EphemerisMixin:
Subclasses should implement the _generate_ephemeris method.
"""

@abstractmethod
def _add_light_time(
self,
orbits,
observers,
lt_tol: float = 1e-12,
max_iter: int = 10,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had set this to 10 just because its a nice number but feel free to change this. Our examples from yesterday seemed to complete in about 4 iterations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really have a sense of the long tail curve for convergence (if there even is one). In my mind, it's perfectly okay if the max_iter is literally never reached, it's only there to keep something that won't converge from running on forever.

):
orbits_aberrated = orbits.empty()
lts = np.zeros(len(orbits))
for i, (orbit, observer) in enumerate(zip(orbits, observers)):
# Set the running variables
lt_prev = 0
dlt = float("inf")
orbit_i = orbit
lt = 0

# Extract the observer's position which remains
# constant for all iterations
observer_position = observer.coordinates.r

# Calculate the orbit's current epoch (the epoch from which
# the light travel time will be calculated)
t0 = orbit_i.coordinates.time.rescale("tdb").mjd()[0].as_py()

iterations = 0
while dlt > lt_tol and iterations < max_iter:
iterations += 1

# Calculate the topocentric distance
rho = np.linalg.norm(orbit_i.coordinates.r - observer_position)

# Calculate the light travel time
lt = rho / C

# Calculate the change in light travel time since the previous iteration
dlt = np.abs(lt - lt_prev)

# Calculate the new epoch and propagate the initial orbit to that epoch
orbit_i = self.propagate_orbits(
orbit, Timestamp.from_mjd([t0 - lt], scale="tdb")
)

# Update the previous light travel time to this iteration's light travel time
lt_prev = lt

orbits_aberrated = qv.concatenate([orbits_aberrated, orbit_i])
lts[i] = lt

return orbits_aberrated, lts

def _generate_ephemeris(
self, orbits: EphemerisType, observers: ObserverType
self, orbits: OrbitType, observers: ObserverType
) -> EphemerisType:
"""
Generate ephemerides for the given orbits as observed by
the observers.

THIS FUNCTION SHOULD BE DEFINED BY THE USER.
A generic ephemeris implementation, which can be used or overridden by subclasses.
"""
pass

if isinstance(orbits, Orbits):
ephemeris_total = Ephemeris.empty()
elif isinstance(orbits, VariantOrbits):
ephemeris_total = VariantEphemeris.empty()

for orbit in orbits:
propagated_orbits = self.propagate_orbits(orbit, observers.coordinates.time)

# Transform both the orbits and observers to the barycenter if they are not already.
propagated_orbits_barycentric = propagated_orbits.set_column(
"coordinates",
transform_coordinates(
propagated_orbits.coordinates,
CartesianCoordinates,
frame_out="ecliptic",
origin_out=OriginCodes.SOLAR_SYSTEM_BARYCENTER,
),
)
observers_barycentric = observers.set_column(
"coordinates",
transform_coordinates(
observers.coordinates,
CartesianCoordinates,
frame_out="ecliptic",
origin_out=OriginCodes.SOLAR_SYSTEM_BARYCENTER,
),
)
num_orbits = len(propagated_orbits_barycentric.orbit_id.unique())

observer_codes = np.tile(
observers.code.to_numpy(zero_copy_only=False), num_orbits
)

propagated_orbits_aberrated, light_time = self._add_light_time(
propagated_orbits_barycentric,
observers_barycentric,
lt_tol=1e-12,
)

topocentric_state = (
propagated_orbits_aberrated.coordinates.values
- observers_barycentric.coordinates.values
)
topocentric_coordinates = CartesianCoordinates.from_kwargs(
x=topocentric_state[:, 0],
y=topocentric_state[:, 1],
z=topocentric_state[:, 2],
vx=topocentric_state[:, 3],
vy=topocentric_state[:, 4],
vz=topocentric_state[:, 5],
covariance=None,
# The ephemeris times are at the point of the observer,
# not the aberrated orbit
time=observers.coordinates.time,
origin=Origin.from_kwargs(code=observer_codes),
frame="ecliptic",
)

spherical_coordinates = SphericalCoordinates.from_cartesian(
topocentric_coordinates
)

light_time = np.array(light_time)

spherical_coordinates = transform_coordinates(
spherical_coordinates, SphericalCoordinates, frame_out="equatorial"
)

# Ephemeris are generally compared in UTC, so rescale the time
spherical_coordinates = spherical_coordinates.set_column(
"time",
spherical_coordinates.time.rescale("utc"),
)

if isinstance(orbits, Orbits):

ephemeris = Ephemeris.from_kwargs(
orbit_id=propagated_orbits_barycentric.orbit_id,
object_id=propagated_orbits_barycentric.object_id,
coordinates=spherical_coordinates,
light_time=light_time,
aberrated_coordinates=propagated_orbits_aberrated.coordinates,
)

elif isinstance(orbits, VariantOrbits):
weights = orbits.weights
weights_cov = orbits.weights_cov

ephemeris = VariantEphemeris.from_kwargs(
orbit_id=propagated_orbits_barycentric.orbit_id,
object_id=propagated_orbits_barycentric.object_id,
coordinates=spherical_coordinates,
weights=weights,
weights_cov=weights_cov,
)

ephemeris_total = qv.concatenate([ephemeris_total, ephemeris])

return ephemeris_total

def generate_ephemeris(
self,
Expand Down Expand Up @@ -261,7 +412,7 @@ def generate_ephemeris(
)


class Propagator(ABC):
class Propagator(ABC, EphemerisMixin):
"""
Abstract class for propagating orbits and related functions.

Expand Down
6 changes: 4 additions & 2 deletions src/adam_core/time/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,8 +392,10 @@ def add_fractional_days(
nano_part = pc.subtract(fractional_days, day_part)

days = pc.cast(day_part, pa.int64())
nanos = pc.cast(pc.multiply(nano_part, 86400 * 1e9), pa.int64())

nanos = pc.cast(
pc.multiply(nano_part, 86400 * 1e9),
options=pc.CastOptions(target_type=pa.int64(), allow_float_truncate=True),
)
return self.add_days(days).add_nanos(nanos)

def difference_scalar(
Expand Down
Loading