diff --git a/.claude/commands/test.md b/.claude/commands/test.md index 115cca8d3..d665d1ab0 100644 --- a/.claude/commands/test.md +++ b/.claude/commands/test.md @@ -48,4 +48,16 @@ uv run pytest test/path/to/test_file.py -v uv run pytest test/path/to/test_file.py -k "test_name" -v ``` +## Slow tests + +Many tests are marked `@pytest.mark.slow` and skipped by default (network calls, large +downloads). To include them, pass `--slow`: + +```bash +uv run pytest test/path/to/test_file.py -v --slow +``` + +Slow tests are often also marked `@pytest.mark.xfail` — an `XPASS` result (unexpected +pass) is fine and means the test succeeded against live data. + Report the test results, including any failures with their tracebacks. diff --git a/earth2studio/data/ecmwf.py b/earth2studio/data/ecmwf.py index 290c0c370..65942c5ed 100644 --- a/earth2studio/data/ecmwf.py +++ b/earth2studio/data/ecmwf.py @@ -140,17 +140,7 @@ def __call__( # type: ignore[override] """Retrieve ECMWF data. The child class should override this""" pass - @abstractmethod - async def fetch( # type: ignore[override] - self, - time: datetime | list[datetime] | TimeArray, - lead_time: timedelta | list[timedelta] | LeadTimeArray, - variable: str | list[str] | VariableArray, - ) -> xr.DataArray: - """Async function to get data, the child class should over ride this and call/""" - pass - - def _call( # type: ignore[override] + def _call( self, time: datetime | list[datetime] | TimeArray, lead_time: timedelta | list[timedelta] | LeadTimeArray, @@ -171,7 +161,7 @@ def _call( # type: ignore[override] Note ---- For peturbed data from ensemble models, the returned data array will have an - extra `sample` dimension added to it. + extra `ensemble` dimension added to it. Returns ------- @@ -187,13 +177,14 @@ def _call( # type: ignore[override] xr_array = loop.run_until_complete( asyncio.wait_for( - self._fetch(time, lead_time, variable), timeout=self.async_timeout + self._ecmwf_fetch(time, lead_time, variable), + timeout=self.async_timeout, ) ) return xr_array - async def _fetch( # type: ignore[override] + async def _ecmwf_fetch( self, time: datetime | list[datetime] | TimeArray, lead_time: timedelta | list[timedelta] | LeadTimeArray, @@ -241,12 +232,12 @@ async def _fetch( # type: ignore[override] len(self.LON), ) ), - dims=["time", "lead_time", "variable", "sample", "lat", "lon"], + dims=["time", "lead_time", "variable", "ensemble", "lat", "lon"], coords={ "time": time, "lead_time": lead_time, "variable": variable, - "sample": np.array(self._members), + "ensemble": np.array(self._members), "lat": self.LAT, "lon": self.LON, }, @@ -254,7 +245,7 @@ async def _fetch( # type: ignore[override] async_tasks = await self._create_tasks(time, lead_time, variable) func_map = map( - functools.partial(self.fetch_wrapper, xr_array=xr_array), async_tasks + functools.partial(self._download_wrapper, xr_array=xr_array), async_tasks ) await tqdm.gather( @@ -317,7 +308,7 @@ async def _create_tasks( ) return tasks - async def fetch_wrapper( + async def _download_wrapper( self, task: ECMWFOpenDataAsyncTask, xr_array: xr.DataArray, @@ -347,7 +338,7 @@ async def fetch_wrapper( f"No GRIB messages found for ensemble member {m} in {grib_file}" ) member_arrays.append(msgs[0].values) - values = np.stack(member_arrays, axis=0) # [sample, y, x] + values = np.stack(member_arrays, axis=0) # [ensemble, y, x] else: values = grbs[1].values # [y, x] # Provided [-180, 180], roll to [0, 360] along x dimension @@ -526,14 +517,14 @@ def __call__( # type: ignore[override] IFS analysis data array """ da = self._call(time, np.array([0], dtype="datetime64[h]"), variable) - return da.isel(lead_time=0) + return da.isel(lead_time=0).drop_vars("lead_time") async def fetch( # type: ignore[override] self, time: datetime | list[datetime] | TimeArray, variable: str | list[str] | VariableArray, ) -> xr.DataArray: - """Async function to get data. + """Async method to retrieve IFS analysis data. Parameters ---------- @@ -548,8 +539,8 @@ async def fetch( # type: ignore[override] xr.DataArray IFS analysis data array. """ - da = await self._fetch(time, np.array([0], dtype="datetime64[h]"), variable) - return da.isel(lead_time=0) + da = await super()._ecmwf_fetch(time, timedelta(hours=0), variable) + return da.isel(lead_time=0).drop_vars("lead_time") def _validate_time(self, times: list[datetime]) -> None: """Verify all times are valid based on offline knowledge. @@ -641,13 +632,13 @@ def __call__( """ return self._call(time, lead_time, variable) - async def fetch( # type: ignore[override] + async def fetch( self, time: datetime | list[datetime] | TimeArray, lead_time: timedelta | list[timedelta] | LeadTimeArray, variable: str | list[str] | VariableArray, ) -> xr.DataArray: - """Async function to get data. + """Async method to retrieve IFS forecast data. Parameters ---------- @@ -664,7 +655,7 @@ async def fetch( # type: ignore[override] xr.DataArray IFS forecast data array. """ - return await self._fetch(time, lead_time, variable) + return await super()._ecmwf_fetch(time, lead_time, variable) def _validate_time(self, times: list[datetime]) -> None: validate_time( @@ -778,16 +769,16 @@ def __call__( # type: ignore[override] IFS ENS initial state data array. """ da = self._call(time, np.array([0], dtype="datetime64[h]"), variable) - if "sample" in da.dims: - da = da.isel(sample=0) - return da.isel(lead_time=0) + if "ensemble" in da.dims: + da = da.isel(ensemble=0).drop_vars("ensemble") + return da.isel(lead_time=0).drop_vars("lead_time") async def fetch( # type: ignore[override] self, time: datetime | list[datetime] | TimeArray, variable: str | list[str] | VariableArray, ) -> xr.DataArray: - """Async function to get data. + """Async method to retrieve IFS ENS initial state data. Parameters ---------- @@ -802,10 +793,10 @@ async def fetch( # type: ignore[override] xr.DataArray IFS ENS initial state data array. """ - da = await self._fetch(time, np.array([0], dtype="datetime64[h]"), variable) - if "sample" in da.dims: - da = da.isel(sample=0) - return da.isel(lead_time=0) + da = await super()._ecmwf_fetch(time, timedelta(hours=0), variable) + if "ensemble" in da.dims: + da = da.isel(ensemble=0).drop_vars("ensemble") + return da.isel(lead_time=0).drop_vars("lead_time") def _validate_time(self, times: list[datetime]) -> None: validate_time( @@ -922,8 +913,8 @@ def __call__( IFS ENS forecast data array """ da = self._call(time, lead_time, variable) - if "sample" in da.dims: - da = da.isel(sample=0) + if "ensemble" in da.dims: + da = da.isel(ensemble=0).drop_vars("ensemble") return da async def fetch( @@ -932,7 +923,7 @@ async def fetch( lead_time: timedelta | list[timedelta] | LeadTimeArray, variable: str | list[str] | VariableArray, ) -> xr.DataArray: - """Async function to get data. + """Async method to retrieve IFS ENS forecast data. Parameters ---------- @@ -949,9 +940,9 @@ async def fetch( xr.DataArray IFS ENS forecast data array. """ - da = await self._fetch(time, lead_time, variable) - if "sample" in da.dims: - da = da.isel(sample=0) + da = await super()._ecmwf_fetch(time, lead_time, variable) + if "ensemble" in da.dims: + da = da.isel(ensemble=0).drop_vars("ensemble") return da def _validate_time(self, times: list[datetime]) -> None: @@ -1057,13 +1048,13 @@ def __call__( """ return self._call(time, lead_time, variable) - async def fetch( # type: ignore[override] + async def fetch( self, time: datetime | list[datetime] | TimeArray, lead_time: timedelta | list[timedelta] | LeadTimeArray, variable: str | list[str] | VariableArray, ) -> xr.DataArray: - """Async function to get data. + """Async method to retrieve AIFS forecast data. Parameters ---------- @@ -1080,7 +1071,7 @@ async def fetch( # type: ignore[override] xr.DataArray AIFS forecast data array. """ - return await self._fetch(time, lead_time, variable) + return await super()._ecmwf_fetch(time, lead_time, variable) def _validate_time(self, times: list[datetime]) -> None: validate_time( @@ -1183,17 +1174,17 @@ def __call__( AIFS ENS forecast data array """ da = self._call(time, lead_time, variable) - if "sample" in da.dims: - da = da.isel(sample=0) + if "ensemble" in da.dims: + da = da.isel(ensemble=0).drop_vars("ensemble") return da - async def fetch( # type: ignore[override] + async def fetch( self, time: datetime | list[datetime] | TimeArray, lead_time: timedelta | list[timedelta] | LeadTimeArray, variable: str | list[str] | VariableArray, ) -> xr.DataArray: - """Async function to get data. + """Async method to retrieve AIFS ENS forecast data. Parameters ---------- @@ -1208,10 +1199,12 @@ async def fetch( # type: ignore[override] Returns ------- xr.DataArray - ECMWF weather data array. + AIFS ENS forecast data array """ - da = await self._fetch(time, lead_time, variable) - return da.isel(sample=0) + da = await super()._ecmwf_fetch(time, lead_time, variable) + if "ensemble" in da.dims: + da = da.isel(ensemble=0).drop_vars("ensemble") + return da def _validate_time(self, times: list[datetime]) -> None: validate_time( diff --git a/test/data/test_ecmwf.py b/test/data/test_ecmwf.py index 5c1ca59b6..3b4208a4f 100644 --- a/test/data/test_ecmwf.py +++ b/test/data/test_ecmwf.py @@ -262,12 +262,16 @@ async def test_ifs_async_fetch(): ds_ifs_fx = IFS_FX(cache=False) ds_ifs_ens = IFS_ENS(cache=False, member=1) ds_ifs_ens_fx = IFS_ENS_FX(cache=False, member=1) + ds_aifs_fx = AIFS_FX(cache=False) + ds_aifs_ens_fx = AIFS_ENS_FX(cache=False, member=1) - da_ifs, da_fx, da_ens, da_ens_fx = await asyncio.gather( + da_ifs, da_fx, da_ens, da_ens_fx, da_aifs_fx, da_aifs_ens_fx = await asyncio.gather( ds_ifs.fetch(t, variable), ds_ifs_fx.fetch(t, lt, variable), ds_ifs_ens.fetch(t, variable), ds_ifs_ens_fx.fetch(t, lt, variable), + ds_aifs_fx.fetch(t, lt, variable), + ds_aifs_ens_fx.fetch(t, lt, variable), ) # IFS (analysis): [time, variable, lat, lon] @@ -300,6 +304,22 @@ async def test_ifs_async_fetch(): assert da_ens_fx.shape[4] == 1440 assert not np.isnan(da_ens_fx.values).any() + # AIFS_FX (forecast): [time, lead_time, variable, lat, lon] + assert da_aifs_fx.shape[0] == 1 + assert da_aifs_fx.shape[1] == 1 + assert da_aifs_fx.shape[2] == 1 + assert da_aifs_fx.shape[3] == 721 + assert da_aifs_fx.shape[4] == 1440 + assert not np.isnan(da_aifs_fx.values).any() + + # AIFS_ENS_FX (forecast): [time, lead_time, variable, lat, lon] + assert da_aifs_ens_fx.shape[0] == 1 + assert da_aifs_ens_fx.shape[1] == 1 + assert da_aifs_ens_fx.shape[2] == 1 + assert da_aifs_ens_fx.shape[3] == 721 + assert da_aifs_ens_fx.shape[4] == 1440 + assert not np.isnan(da_aifs_ens_fx.values).any() + @pytest.mark.timeout(30) @pytest.mark.parametrize("data_source", [IFS, IFS_ENS])