diff --git a/src/adam_core/_version.py b/src/adam_core/_version.py new file mode 100644 index 00000000..d7bee229 --- /dev/null +++ b/src/adam_core/_version.py @@ -0,0 +1 @@ +__version__ = "0.2.1.dev5+g4d3da0c.d20240731" diff --git a/src/adam_core/coordinates/origin.py b/src/adam_core/coordinates/origin.py index 38775339..fa7231c1 100644 --- a/src/adam_core/coordinates/origin.py +++ b/src/adam_core/coordinates/origin.py @@ -86,6 +86,20 @@ def SOLAR_SYSTEM_BARYCENTER(cls) -> float: class Origin(qv.Table): code = qv.LargeStringColumn() + def as_OriginCodes(self) -> OriginCodes: + """ + Convert the origin codes to an `~adam_core.coordinates.origin.OriginCodes` object. + + Returns + ------- + OriginCodes + Origin codes as an `~adam_core.coordinates.origin.OriginCodes` object. + """ + assert ( + len(self.code.unique()) == 1 + ), "Only one origin code can be converted at a time." + return OriginCodes[self.code.unique()[0].as_py()] + def __eq__(self, other: object) -> np.ndarray: if isinstance(other, (str, np.ndarray)): codes = self.code.to_numpy(zero_copy_only=False) diff --git a/src/adam_core/coordinates/residuals.py b/src/adam_core/coordinates/residuals.py index 71490b83..e8c0cbec 100644 --- a/src/adam_core/coordinates/residuals.py +++ b/src/adam_core/coordinates/residuals.py @@ -194,7 +194,7 @@ def calculate( raise TypeError( f"Predicted coordinates must be one of {SUPPORTED_COORDINATES}, not {type(predicted)}." ) - if type(observed) != type(predicted): + if type(observed) is not type(predicted): raise TypeError( "Observed and predicted coordinates must be the same type, " f"not {type(observed)} and {type(predicted)}." diff --git a/src/adam_core/coordinates/transform.py b/src/adam_core/coordinates/transform.py index d4265f47..f531d5c9 100644 --- a/src/adam_core/coordinates/transform.py +++ b/src/adam_core/coordinates/transform.py @@ -1368,7 +1368,7 @@ def transform_coordinates( # `~adam_core.coordinates.origin.OriginCodes` so we can compare them directly. # If its not an OriginCodes enum then origin_out will be an array of strings which # also can be checked for equality. - if type(coords) == representation_out_: + if type(coords) is representation_out_: if coord_frame == frame_out and np.all(coord_origin == origin_out): return coords diff --git a/src/adam_core/orbits/query/horizons.py b/src/adam_core/orbits/query/horizons.py index ec9d02ce..c4091d0e 100644 --- a/src/adam_core/orbits/query/horizons.py +++ b/src/adam_core/orbits/query/horizons.py @@ -2,14 +2,17 @@ import numpy.typing as npt import pandas as pd +import pyarrow as pa from astroquery.jplhorizons import Horizons from ...coordinates.cartesian import CartesianCoordinates from ...coordinates.cometary import CometaryCoordinates from ...coordinates.keplerian import KeplerianCoordinates from ...coordinates.origin import Origin +from ...coordinates.spherical import SphericalCoordinates from ...observers import Observers from ...time import Timestamp +from ..ephemeris import Ephemeris from ..orbits import Orbits @@ -53,7 +56,7 @@ def _get_horizons_vectors( for i, obj_id in enumerate(object_ids): obj = Horizons( id=obj_id, - epochs=times.rescale("tdb").mjd().to_numpy(zero_copy_only=False), + epochs=times.rescale("tdb").jd().to_numpy(zero_copy_only=False), location=location, id_type=id_type, ) @@ -157,10 +160,11 @@ def _get_horizons_ephemeris( as seen from the observer location at the given times. """ dfs = [] + jd_utc = times.rescale("utc").jd().to_numpy(zero_copy_only=False) for i, obj_id in enumerate(object_ids): obj = Horizons( id=obj_id, - epochs=times.rescale("utc").mjd().to_numpy(zero_copy_only=False), + epochs=jd_utc, location=location, id_type=id_type, ) @@ -171,7 +175,7 @@ def _get_horizons_ephemeris( cache=False, ).to_pandas() ephemeris.insert(0, "orbit_id", f"{i:05d}") - ephemeris.insert(2, "mjd_utc", times.utc.mjd) + ephemeris.insert(2, "jd_utc", jd_utc) ephemeris.insert(3, "observatory_code", location) dfs.append(ephemeris) @@ -187,7 +191,7 @@ def _get_horizons_ephemeris( def query_horizons_ephemeris( object_ids: Union[List, npt.ArrayLike], observers: Observers -) -> pd.DataFrame: +) -> Ephemeris: """ Query JPL Horizons (through astroquery) for an object's predicted ephemeris as seen from a given location at the given times. @@ -208,19 +212,35 @@ def query_horizons_ephemeris( """ dfs = [] for observatory_code, observers_i in observers.iterate_codes(): - ephemeris = _get_horizons_ephemeris( + _ephemeris = _get_horizons_ephemeris( object_ids, observers_i.coordinates.time, observatory_code, ) - dfs.append(ephemeris) + dfs.append(_ephemeris) - ephemeris = pd.concat(dfs, ignore_index=True) - ephemeris.sort_values( + dfs = pd.concat(dfs, ignore_index=True) + dfs.sort_values( by=["orbit_id", "datetime_jd", "observatory_code"], inplace=True, ignore_index=True, ) + + ephemeris = Ephemeris.from_kwargs( + orbit_id=dfs["orbit_id"], + object_id=dfs["targetname"], + # Convert from minutes to days + light_time=dfs["lighttime"] / 1440, + alpha=dfs["alpha"], + coordinates=SphericalCoordinates.from_kwargs( + time=Timestamp.from_jd(pa.array(dfs["datetime_jd"]), scale="utc"), + lon=dfs["RA"], + lat=dfs["DEC"], + origin=Origin.from_kwargs(code=dfs["observatory_code"]), + frame="ecliptic", + ), + ) + return ephemeris diff --git a/src/adam_core/propagator/propagator.py b/src/adam_core/propagator/propagator.py index 0f23da35..713438b4 100644 --- a/src/adam_core/propagator/propagator.py +++ b/src/adam_core/propagator/propagator.py @@ -6,17 +6,23 @@ import numpy.typing as npt import quivr as qv -from adam_core.ray_cluster import initialize_use_ray - +from ..constants import Constants as c +from ..coordinates.cartesian import CartesianCoordinates +from ..coordinates.origin import Origin, OriginCodes +from ..coordinates.spherical import SphericalCoordinates +from ..coordinates.transform import transform_coordinates from ..observers.observers import Observers from ..orbits.ephemeris import Ephemeris from ..orbits.orbits import Orbits from ..orbits.variants import VariantEphemeris, VariantOrbits +from ..ray_cluster import initialize_use_ray from ..time import Timestamp from .utils import _iterate_chunks logger = logging.getLogger(__name__) +C = c.C + RAY_INSTALLED = False try: import ray @@ -89,17 +95,162 @@ class EphemerisMixin: Subclasses should implement the _generate_ephemeris method. """ - @abstractmethod + def _add_light_time( + self, + orbits, + observers, + lt_tol: float = 1e-12, + max_iter: int = 10, + ): + orbits_aberrated = orbits.empty() + lts = np.zeros(len(orbits)) + for i, (orbit, observer) in enumerate(zip(orbits, observers)): + # Set the running variables + lt_prev = 0 + dlt = float("inf") + orbit_i = orbit + lt = 0 + + # Extract the observer's position which remains + # constant for all iterations + observer_position = observer.coordinates.r + + # Calculate the orbit's current epoch (the epoch from which + # the light travel time will be calculated) + t0 = orbit_i.coordinates.time.rescale("tdb").mjd()[0].as_py() + + iterations = 0 + while dlt > lt_tol and iterations < max_iter: + iterations += 1 + + # Calculate the topocentric distance + rho = np.linalg.norm(orbit_i.coordinates.r - observer_position) + + # Calculate the light travel time + lt = rho / C + + # Calculate the change in light travel time since the previous iteration + dlt = np.abs(lt - lt_prev) + + # Calculate the new epoch and propagate the initial orbit to that epoch + orbit_i = self.propagate_orbits( + orbit, Timestamp.from_mjd([t0 - lt], scale="tdb") + ) + + # Update the previous light travel time to this iteration's light travel time + lt_prev = lt + + orbits_aberrated = qv.concatenate([orbits_aberrated, orbit_i]) + lts[i] = lt + + return orbits_aberrated, lts + def _generate_ephemeris( - self, orbits: EphemerisType, observers: ObserverType + self, orbits: OrbitType, observers: ObserverType ) -> EphemerisType: """ - Generate ephemerides for the given orbits as observed by - the observers. - - THIS FUNCTION SHOULD BE DEFINED BY THE USER. + A generic ephemeris implementation, which can be used or overridden by subclasses. """ - pass + + if isinstance(orbits, Orbits): + ephemeris_total = Ephemeris.empty() + elif isinstance(orbits, VariantOrbits): + ephemeris_total = VariantEphemeris.empty() + + for orbit in orbits: + propagated_orbits = self.propagate_orbits(orbit, observers.coordinates.time) + + # Transform both the orbits and observers to the barycenter if they are not already. + propagated_orbits_barycentric = propagated_orbits.set_column( + "coordinates", + transform_coordinates( + propagated_orbits.coordinates, + CartesianCoordinates, + frame_out="ecliptic", + origin_out=OriginCodes.SOLAR_SYSTEM_BARYCENTER, + ), + ) + observers_barycentric = observers.set_column( + "coordinates", + transform_coordinates( + observers.coordinates, + CartesianCoordinates, + frame_out="ecliptic", + origin_out=OriginCodes.SOLAR_SYSTEM_BARYCENTER, + ), + ) + num_orbits = len(propagated_orbits_barycentric.orbit_id.unique()) + + observer_codes = np.tile( + observers.code.to_numpy(zero_copy_only=False), num_orbits + ) + + propagated_orbits_aberrated, light_time = self._add_light_time( + propagated_orbits_barycentric, + observers_barycentric, + lt_tol=1e-12, + ) + + topocentric_state = ( + propagated_orbits_aberrated.coordinates.values + - observers_barycentric.coordinates.values + ) + topocentric_coordinates = CartesianCoordinates.from_kwargs( + x=topocentric_state[:, 0], + y=topocentric_state[:, 1], + z=topocentric_state[:, 2], + vx=topocentric_state[:, 3], + vy=topocentric_state[:, 4], + vz=topocentric_state[:, 5], + covariance=None, + # The ephemeris times are at the point of the observer, + # not the aberrated orbit + time=observers.coordinates.time, + origin=Origin.from_kwargs(code=observer_codes), + frame="ecliptic", + ) + + spherical_coordinates = SphericalCoordinates.from_cartesian( + topocentric_coordinates + ) + + light_time = np.array(light_time) + + spherical_coordinates = transform_coordinates( + spherical_coordinates, SphericalCoordinates, frame_out="equatorial" + ) + + # Ephemeris are generally compared in UTC, so rescale the time + spherical_coordinates = spherical_coordinates.set_column( + "time", + spherical_coordinates.time.rescale("utc"), + ) + + if isinstance(orbits, Orbits): + + ephemeris = Ephemeris.from_kwargs( + orbit_id=propagated_orbits_barycentric.orbit_id, + object_id=propagated_orbits_barycentric.object_id, + coordinates=spherical_coordinates, + light_time=light_time, + aberrated_coordinates=propagated_orbits_aberrated.coordinates, + ) + + elif isinstance(orbits, VariantOrbits): + weights = orbits.weights + weights_cov = orbits.weights_cov + + ephemeris = VariantEphemeris.from_kwargs( + orbit_id=propagated_orbits_barycentric.orbit_id, + object_id=propagated_orbits_barycentric.object_id, + coordinates=spherical_coordinates, + weights=weights, + weights_cov=weights_cov, + ) + + ephemeris_total = qv.concatenate([ephemeris_total, ephemeris]) + + return ephemeris_total def generate_ephemeris( self, @@ -261,7 +412,7 @@ def generate_ephemeris( ) -class Propagator(ABC): +class Propagator(ABC, EphemerisMixin): """ Abstract class for propagating orbits and related functions. diff --git a/src/adam_core/time/time.py b/src/adam_core/time/time.py index c5d57f03..beabc283 100644 --- a/src/adam_core/time/time.py +++ b/src/adam_core/time/time.py @@ -392,8 +392,10 @@ def add_fractional_days( nano_part = pc.subtract(fractional_days, day_part) days = pc.cast(day_part, pa.int64()) - nanos = pc.cast(pc.multiply(nano_part, 86400 * 1e9), pa.int64()) - + nanos = pc.cast( + pc.multiply(nano_part, 86400 * 1e9), + options=pc.CastOptions(target_type=pa.int64(), allow_float_truncate=True), + ) return self.add_days(days).add_nanos(nanos) def difference_scalar(