diff --git a/src/adam_core/dynamics/impacts.py b/src/adam_core/dynamics/impacts.py index f5502517..5d3d6632 100644 --- a/src/adam_core/dynamics/impacts.py +++ b/src/adam_core/dynamics/impacts.py @@ -1,4 +1,3 @@ -import importlib.util import logging from abc import abstractmethod from typing import List, Optional, Tuple @@ -21,13 +20,6 @@ logger = logging.getLogger(__name__) -# Test to see that at least one impact-enabled propagator is -# installed and if not print a warning -if importlib.util.find_spec("adam_core.propagator.adam_assist") is None: - logger.warning( - "No impact-enabled propagator installed. Impact calculations will not be possible." - ) - RAY_INSTALLED = False try: import ray diff --git a/src/adam_core/propagator/propagator.py b/src/adam_core/propagator/propagator.py index 713438b4..c0b1f671 100644 --- a/src/adam_core/propagator/propagator.py +++ b/src/adam_core/propagator/propagator.py @@ -4,6 +4,7 @@ import numpy as np import numpy.typing as npt +import pyarrow.compute as pc import quivr as qv from ..constants import Constants as c @@ -479,6 +480,7 @@ def propagate_orbits( propagated : `~adam_core.orbits.orbits.Orbits` Propagated orbits. """ + if max_processes is None or max_processes > 1: propagated_list: List[Orbits] = [] variants_list: List[VariantOrbits] = [] @@ -575,6 +577,29 @@ def propagate_orbits( if propagated_variants is not None: propagated = propagated_variants.collapse(propagated) - return propagated.sort_by( + # Return the results with the original origin and frame + # Preserve the original output origin for the input orbits + # by orbit id + final_results = None + unique_origins = pc.unique(orbits.coordinates.origin.code) + for origin_code in unique_origins: + origin_orbits = orbits.select("coordinates.origin.code", origin_code) + result_origin_orbits = propagated.apply_mask( + pc.is_in(propagated.orbit_id, origin_orbits.orbit_id) + ) + partial_results = result_origin_orbits.set_column( + "coordinates", + transform_coordinates( + result_origin_orbits.coordinates, + origin_out=OriginCodes[origin_code.as_py()], + frame_out=orbits.coordinates.frame, + ), + ) + if final_results is None: + final_results = partial_results + else: + final_results = qv.concatenate([final_results, partial_results]) + + return final_results.sort_by( ["orbit_id", "coordinates.time.days", "coordinates.time.nanos"] ) diff --git a/src/adam_core/propagator/tests/test_propagator.py b/src/adam_core/propagator/tests/test_propagator.py index 248772ac..c3192e88 100644 --- a/src/adam_core/propagator/tests/test_propagator.py +++ b/src/adam_core/propagator/tests/test_propagator.py @@ -4,7 +4,8 @@ import quivr as qv from ...coordinates.cartesian import CartesianCoordinates -from ...coordinates.origin import Origin +from ...coordinates.origin import Origin, OriginCodes +from ...coordinates.transform import transform_coordinates from ...observers.observers import Observers from ...orbits.ephemeris import Ephemeris from ...orbits.orbits import Orbits @@ -21,8 +22,19 @@ def _propagate_orbits(self, orbits: Orbits, times: Timestamp) -> Orbits: repeated_time = qv.concatenate([t] * len(orbits)) orbits.coordinates.time = repeated_time all_times.append(orbits) + all_times = qv.concatenate(all_times) - return qv.concatenate(all_times) + # Artifically change origin to test that it is preserved in the final output + output = all_times.set_column( + "coordinates", + transform_coordinates( + all_times.coordinates, + origin_out=OriginCodes["SATURN_BARYCENTER"], + frame_out="equatorial", + ), + ) + + return output # MockPropagator generated ephemeris by just subtracting the state from # the state of the observers @@ -105,3 +117,45 @@ def test_propagator_multiple_workers_ray(): have = prop.generate_ephemeris(orbits_ref, observers_ref, max_processes=4) assert len(have) == len(orbits) * len(times) + + +def test_propagate_different_origins(): + """ + Test that we are returning propagated orbits with their original origins + """ + orbits = Orbits.from_kwargs( + orbit_id=["1", "2"], + object_id=["1", "2"], + coordinates=CartesianCoordinates.from_kwargs( + x=[1, 1], + y=[1, 1], + z=[1, 1], + vx=[1, 1], + vy=[1, 1], + vz=[1, 1], + time=Timestamp.from_mjd([60000, 60000], scale="tdb"), + frame="ecliptic", + origin=Origin.from_kwargs( + code=["SOLAR_SYSTEM_BARYCENTER", "EARTH_MOON_BARYCENTER"] + ), + ), + ) + + prop = MockPropagator() + propagated_orbits = prop.propagate_orbits( + orbits, Timestamp.from_mjd([60001, 60002, 60003], scale="tdb") + ) + orbit_one_results = propagated_orbits.select("orbit_id", "1") + orbit_two_results = propagated_orbits.select("orbit_id", "2") + # Assert that the origin codes for each set of results is unique + # and that it matches the original input + assert len(orbit_one_results.coordinates.origin.code.unique()) == 1 + assert ( + orbit_one_results.coordinates.origin.code.unique()[0].as_py() + == "SOLAR_SYSTEM_BARYCENTER" + ) + assert len(orbit_two_results.coordinates.origin.code.unique()) == 1 + assert ( + orbit_two_results.coordinates.origin.code.unique()[0].as_py() + == "EARTH_MOON_BARYCENTER" + )