Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .claude/commands/test.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,16 @@ uv run pytest test/path/to/test_file.py -v
uv run pytest test/path/to/test_file.py -k "test_name" -v
```

## Slow tests

Many tests are marked `@pytest.mark.slow` and skipped by default (network calls, large
downloads). To include them, pass `--slow`:

```bash
uv run pytest test/path/to/test_file.py -v --slow
```

Slow tests are often also marked `@pytest.mark.xfail` — an `XPASS` result (unexpected
pass) is fine and means the test succeeded against live data.

Report the test results, including any failures with their tracebacks.
93 changes: 43 additions & 50 deletions earth2studio/data/ecmwf.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,17 +140,7 @@ def __call__( # type: ignore[override]
"""Retrieve ECMWF data. The child class should override this"""
pass

@abstractmethod
async def fetch( # type: ignore[override]
self,
time: datetime | list[datetime] | TimeArray,
lead_time: timedelta | list[timedelta] | LeadTimeArray,
variable: str | list[str] | VariableArray,
) -> xr.DataArray:
"""Async function to get data, the child class should over ride this and call/"""
pass

def _call( # type: ignore[override]
def _call(
self,
time: datetime | list[datetime] | TimeArray,
lead_time: timedelta | list[timedelta] | LeadTimeArray,
Expand All @@ -171,7 +161,7 @@ def _call( # type: ignore[override]
Note
----
For peturbed data from ensemble models, the returned data array will have an
extra `sample` dimension added to it.
extra `ensemble` dimension added to it.

Returns
-------
Expand All @@ -187,13 +177,14 @@ def _call( # type: ignore[override]

xr_array = loop.run_until_complete(
asyncio.wait_for(
self._fetch(time, lead_time, variable), timeout=self.async_timeout
self._ecmwf_fetch(time, lead_time, variable),
timeout=self.async_timeout,
)
)

return xr_array

async def _fetch( # type: ignore[override]
async def _ecmwf_fetch(
self,
time: datetime | list[datetime] | TimeArray,
lead_time: timedelta | list[timedelta] | LeadTimeArray,
Expand Down Expand Up @@ -241,20 +232,20 @@ async def _fetch( # type: ignore[override]
len(self.LON),
)
),
dims=["time", "lead_time", "variable", "sample", "lat", "lon"],
dims=["time", "lead_time", "variable", "ensemble", "lat", "lon"],
coords={
"time": time,
"lead_time": lead_time,
"variable": variable,
"sample": np.array(self._members),
"ensemble": np.array(self._members),
"lat": self.LAT,
"lon": self.LON,
},
)

async_tasks = await self._create_tasks(time, lead_time, variable)
func_map = map(
functools.partial(self.fetch_wrapper, xr_array=xr_array), async_tasks
functools.partial(self._download_wrapper, xr_array=xr_array), async_tasks
)

await tqdm.gather(
Expand Down Expand Up @@ -317,7 +308,7 @@ async def _create_tasks(
)
return tasks

async def fetch_wrapper(
async def _download_wrapper(
self,
task: ECMWFOpenDataAsyncTask,
xr_array: xr.DataArray,
Expand Down Expand Up @@ -347,7 +338,7 @@ async def fetch_wrapper(
f"No GRIB messages found for ensemble member {m} in {grib_file}"
)
member_arrays.append(msgs[0].values)
values = np.stack(member_arrays, axis=0) # [sample, y, x]
values = np.stack(member_arrays, axis=0) # [ensemble, y, x]
else:
values = grbs[1].values # [y, x]
# Provided [-180, 180], roll to [0, 360] along x dimension
Expand Down Expand Up @@ -526,14 +517,14 @@ def __call__( # type: ignore[override]
IFS analysis data array
"""
da = self._call(time, np.array([0], dtype="datetime64[h]"), variable)
return da.isel(lead_time=0)
return da.isel(lead_time=0).drop_vars("lead_time")

async def fetch( # type: ignore[override]
self,
time: datetime | list[datetime] | TimeArray,
variable: str | list[str] | VariableArray,
) -> xr.DataArray:
"""Async function to get data.
"""Async method to retrieve IFS analysis data.

Parameters
----------
Expand All @@ -548,8 +539,8 @@ async def fetch( # type: ignore[override]
xr.DataArray
IFS analysis data array.
"""
da = await self._fetch(time, np.array([0], dtype="datetime64[h]"), variable)
return da.isel(lead_time=0)
da = await super()._ecmwf_fetch(time, timedelta(hours=0), variable)
return da.isel(lead_time=0).drop_vars("lead_time")

def _validate_time(self, times: list[datetime]) -> None:
"""Verify all times are valid based on offline knowledge.
Expand Down Expand Up @@ -641,13 +632,13 @@ def __call__(
"""
return self._call(time, lead_time, variable)

async def fetch( # type: ignore[override]
async def fetch(
self,
time: datetime | list[datetime] | TimeArray,
lead_time: timedelta | list[timedelta] | LeadTimeArray,
variable: str | list[str] | VariableArray,
) -> xr.DataArray:
"""Async function to get data.
"""Async method to retrieve IFS forecast data.

Parameters
----------
Expand All @@ -664,7 +655,7 @@ async def fetch( # type: ignore[override]
xr.DataArray
IFS forecast data array.
"""
return await self._fetch(time, lead_time, variable)
return await super()._ecmwf_fetch(time, lead_time, variable)

def _validate_time(self, times: list[datetime]) -> None:
validate_time(
Expand Down Expand Up @@ -778,16 +769,16 @@ def __call__( # type: ignore[override]
IFS ENS initial state data array.
"""
da = self._call(time, np.array([0], dtype="datetime64[h]"), variable)
if "sample" in da.dims:
da = da.isel(sample=0)
return da.isel(lead_time=0)
if "ensemble" in da.dims:
da = da.isel(ensemble=0).drop_vars("ensemble")
return da.isel(lead_time=0).drop_vars("lead_time")

async def fetch( # type: ignore[override]
self,
time: datetime | list[datetime] | TimeArray,
variable: str | list[str] | VariableArray,
) -> xr.DataArray:
"""Async function to get data.
"""Async method to retrieve IFS ENS initial state data.

Parameters
----------
Expand All @@ -802,10 +793,10 @@ async def fetch( # type: ignore[override]
xr.DataArray
IFS ENS initial state data array.
"""
da = await self._fetch(time, np.array([0], dtype="datetime64[h]"), variable)
if "sample" in da.dims:
da = da.isel(sample=0)
return da.isel(lead_time=0)
da = await super()._ecmwf_fetch(time, timedelta(hours=0), variable)
if "ensemble" in da.dims:
da = da.isel(ensemble=0).drop_vars("ensemble")
return da.isel(lead_time=0).drop_vars("lead_time")

def _validate_time(self, times: list[datetime]) -> None:
validate_time(
Expand Down Expand Up @@ -922,8 +913,8 @@ def __call__(
IFS ENS forecast data array
"""
da = self._call(time, lead_time, variable)
if "sample" in da.dims:
da = da.isel(sample=0)
if "ensemble" in da.dims:
da = da.isel(ensemble=0).drop_vars("ensemble")
return da

async def fetch(
Expand All @@ -932,7 +923,7 @@ async def fetch(
lead_time: timedelta | list[timedelta] | LeadTimeArray,
variable: str | list[str] | VariableArray,
) -> xr.DataArray:
"""Async function to get data.
"""Async method to retrieve IFS ENS forecast data.

Parameters
----------
Expand All @@ -949,9 +940,9 @@ async def fetch(
xr.DataArray
IFS ENS forecast data array.
"""
da = await self._fetch(time, lead_time, variable)
if "sample" in da.dims:
da = da.isel(sample=0)
da = await super()._ecmwf_fetch(time, lead_time, variable)
if "ensemble" in da.dims:
da = da.isel(ensemble=0).drop_vars("ensemble")
return da

def _validate_time(self, times: list[datetime]) -> None:
Expand Down Expand Up @@ -1057,13 +1048,13 @@ def __call__(
"""
return self._call(time, lead_time, variable)

async def fetch( # type: ignore[override]
async def fetch(
self,
time: datetime | list[datetime] | TimeArray,
lead_time: timedelta | list[timedelta] | LeadTimeArray,
variable: str | list[str] | VariableArray,
) -> xr.DataArray:
"""Async function to get data.
"""Async method to retrieve AIFS forecast data.

Parameters
----------
Expand All @@ -1080,7 +1071,7 @@ async def fetch( # type: ignore[override]
xr.DataArray
AIFS forecast data array.
"""
return await self._fetch(time, lead_time, variable)
return await super()._ecmwf_fetch(time, lead_time, variable)

def _validate_time(self, times: list[datetime]) -> None:
validate_time(
Expand Down Expand Up @@ -1183,17 +1174,17 @@ def __call__(
AIFS ENS forecast data array
"""
da = self._call(time, lead_time, variable)
if "sample" in da.dims:
da = da.isel(sample=0)
if "ensemble" in da.dims:
da = da.isel(ensemble=0).drop_vars("ensemble")
return da

async def fetch( # type: ignore[override]
async def fetch(
self,
time: datetime | list[datetime] | TimeArray,
lead_time: timedelta | list[timedelta] | LeadTimeArray,
variable: str | list[str] | VariableArray,
) -> xr.DataArray:
"""Async function to get data.
"""Async method to retrieve AIFS ENS forecast data.

Parameters
----------
Expand All @@ -1208,10 +1199,12 @@ async def fetch( # type: ignore[override]
Returns
-------
xr.DataArray
ECMWF weather data array.
AIFS ENS forecast data array
"""
da = await self._fetch(time, lead_time, variable)
return da.isel(sample=0)
da = await super()._ecmwf_fetch(time, lead_time, variable)
if "ensemble" in da.dims:
da = da.isel(ensemble=0).drop_vars("ensemble")
return da

def _validate_time(self, times: list[datetime]) -> None:
validate_time(
Expand Down
22 changes: 21 additions & 1 deletion test/data/test_ecmwf.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,12 +262,16 @@ async def test_ifs_async_fetch():
ds_ifs_fx = IFS_FX(cache=False)
ds_ifs_ens = IFS_ENS(cache=False, member=1)
ds_ifs_ens_fx = IFS_ENS_FX(cache=False, member=1)
ds_aifs_fx = AIFS_FX(cache=False)
ds_aifs_ens_fx = AIFS_ENS_FX(cache=False, member=1)

da_ifs, da_fx, da_ens, da_ens_fx = await asyncio.gather(
da_ifs, da_fx, da_ens, da_ens_fx, da_aifs_fx, da_aifs_ens_fx = await asyncio.gather(
ds_ifs.fetch(t, variable),
ds_ifs_fx.fetch(t, lt, variable),
ds_ifs_ens.fetch(t, variable),
ds_ifs_ens_fx.fetch(t, lt, variable),
ds_aifs_fx.fetch(t, lt, variable),
ds_aifs_ens_fx.fetch(t, lt, variable),
)

# IFS (analysis): [time, variable, lat, lon]
Expand Down Expand Up @@ -300,6 +304,22 @@ async def test_ifs_async_fetch():
assert da_ens_fx.shape[4] == 1440
assert not np.isnan(da_ens_fx.values).any()

# AIFS_FX (forecast): [time, lead_time, variable, lat, lon]
assert da_aifs_fx.shape[0] == 1
assert da_aifs_fx.shape[1] == 1
assert da_aifs_fx.shape[2] == 1
assert da_aifs_fx.shape[3] == 721
assert da_aifs_fx.shape[4] == 1440
assert not np.isnan(da_aifs_fx.values).any()

# AIFS_ENS_FX (forecast): [time, lead_time, variable, lat, lon]
assert da_aifs_ens_fx.shape[0] == 1
assert da_aifs_ens_fx.shape[1] == 1
assert da_aifs_ens_fx.shape[2] == 1
assert da_aifs_ens_fx.shape[3] == 721
assert da_aifs_ens_fx.shape[4] == 1440
assert not np.isnan(da_aifs_ens_fx.values).any()


@pytest.mark.timeout(30)
@pytest.mark.parametrize("data_source", [IFS, IFS_ENS])
Expand Down