Skip to content

Commit 4484ae3

Browse files
swbgNickGeneva
andauthored
Clean up ECMWF data sources (#747)
* fix output coords * rename fetch_wrapper --------- Co-authored-by: Nicholas Geneva <ngeneva@nvidia.com>
1 parent bc52726 commit 4484ae3

3 files changed

Lines changed: 76 additions & 51 deletions

File tree

.claude/commands/test.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,16 @@ uv run pytest test/path/to/test_file.py -v
4848
uv run pytest test/path/to/test_file.py -k "test_name" -v
4949
```
5050

51+
## Slow tests
52+
53+
Many tests are marked `@pytest.mark.slow` and skipped by default (network calls, large
54+
downloads). To include them, pass `--slow`:
55+
56+
```bash
57+
uv run pytest test/path/to/test_file.py -v --slow
58+
```
59+
60+
Slow tests are often also marked `@pytest.mark.xfail` — an `XPASS` result (unexpected
61+
pass) is fine and means the test succeeded against live data.
62+
5163
Report the test results, including any failures with their tracebacks.

earth2studio/data/ecmwf.py

Lines changed: 43 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -140,17 +140,7 @@ def __call__( # type: ignore[override]
140140
"""Retrieve ECMWF data. The child class should override this"""
141141
pass
142142

143-
@abstractmethod
144-
async def fetch( # type: ignore[override]
145-
self,
146-
time: datetime | list[datetime] | TimeArray,
147-
lead_time: timedelta | list[timedelta] | LeadTimeArray,
148-
variable: str | list[str] | VariableArray,
149-
) -> xr.DataArray:
150-
"""Async function to get data, the child class should over ride this and call/"""
151-
pass
152-
153-
def _call( # type: ignore[override]
143+
def _call(
154144
self,
155145
time: datetime | list[datetime] | TimeArray,
156146
lead_time: timedelta | list[timedelta] | LeadTimeArray,
@@ -171,7 +161,7 @@ def _call( # type: ignore[override]
171161
Note
172162
----
173163
For peturbed data from ensemble models, the returned data array will have an
174-
extra `sample` dimension added to it.
164+
extra `ensemble` dimension added to it.
175165
176166
Returns
177167
-------
@@ -187,13 +177,14 @@ def _call( # type: ignore[override]
187177

188178
xr_array = loop.run_until_complete(
189179
asyncio.wait_for(
190-
self._fetch(time, lead_time, variable), timeout=self.async_timeout
180+
self._ecmwf_fetch(time, lead_time, variable),
181+
timeout=self.async_timeout,
191182
)
192183
)
193184

194185
return xr_array
195186

196-
async def _fetch( # type: ignore[override]
187+
async def _ecmwf_fetch(
197188
self,
198189
time: datetime | list[datetime] | TimeArray,
199190
lead_time: timedelta | list[timedelta] | LeadTimeArray,
@@ -241,20 +232,20 @@ async def _fetch( # type: ignore[override]
241232
len(self.LON),
242233
)
243234
),
244-
dims=["time", "lead_time", "variable", "sample", "lat", "lon"],
235+
dims=["time", "lead_time", "variable", "ensemble", "lat", "lon"],
245236
coords={
246237
"time": time,
247238
"lead_time": lead_time,
248239
"variable": variable,
249-
"sample": np.array(self._members),
240+
"ensemble": np.array(self._members),
250241
"lat": self.LAT,
251242
"lon": self.LON,
252243
},
253244
)
254245

255246
async_tasks = await self._create_tasks(time, lead_time, variable)
256247
func_map = map(
257-
functools.partial(self.fetch_wrapper, xr_array=xr_array), async_tasks
248+
functools.partial(self._download_wrapper, xr_array=xr_array), async_tasks
258249
)
259250

260251
await tqdm.gather(
@@ -317,7 +308,7 @@ async def _create_tasks(
317308
)
318309
return tasks
319310

320-
async def fetch_wrapper(
311+
async def _download_wrapper(
321312
self,
322313
task: ECMWFOpenDataAsyncTask,
323314
xr_array: xr.DataArray,
@@ -347,7 +338,7 @@ async def fetch_wrapper(
347338
f"No GRIB messages found for ensemble member {m} in {grib_file}"
348339
)
349340
member_arrays.append(msgs[0].values)
350-
values = np.stack(member_arrays, axis=0) # [sample, y, x]
341+
values = np.stack(member_arrays, axis=0) # [ensemble, y, x]
351342
else:
352343
values = grbs[1].values # [y, x]
353344
# Provided [-180, 180], roll to [0, 360] along x dimension
@@ -526,14 +517,14 @@ def __call__( # type: ignore[override]
526517
IFS analysis data array
527518
"""
528519
da = self._call(time, np.array([0], dtype="datetime64[h]"), variable)
529-
return da.isel(lead_time=0)
520+
return da.isel(lead_time=0).drop_vars("lead_time")
530521

531522
async def fetch( # type: ignore[override]
532523
self,
533524
time: datetime | list[datetime] | TimeArray,
534525
variable: str | list[str] | VariableArray,
535526
) -> xr.DataArray:
536-
"""Async function to get data.
527+
"""Async method to retrieve IFS analysis data.
537528
538529
Parameters
539530
----------
@@ -548,8 +539,8 @@ async def fetch( # type: ignore[override]
548539
xr.DataArray
549540
IFS analysis data array.
550541
"""
551-
da = await self._fetch(time, np.array([0], dtype="datetime64[h]"), variable)
552-
return da.isel(lead_time=0)
542+
da = await super()._ecmwf_fetch(time, timedelta(hours=0), variable)
543+
return da.isel(lead_time=0).drop_vars("lead_time")
553544

554545
def _validate_time(self, times: list[datetime]) -> None:
555546
"""Verify all times are valid based on offline knowledge.
@@ -641,13 +632,13 @@ def __call__(
641632
"""
642633
return self._call(time, lead_time, variable)
643634

644-
async def fetch( # type: ignore[override]
635+
async def fetch(
645636
self,
646637
time: datetime | list[datetime] | TimeArray,
647638
lead_time: timedelta | list[timedelta] | LeadTimeArray,
648639
variable: str | list[str] | VariableArray,
649640
) -> xr.DataArray:
650-
"""Async function to get data.
641+
"""Async method to retrieve IFS forecast data.
651642
652643
Parameters
653644
----------
@@ -664,7 +655,7 @@ async def fetch( # type: ignore[override]
664655
xr.DataArray
665656
IFS forecast data array.
666657
"""
667-
return await self._fetch(time, lead_time, variable)
658+
return await super()._ecmwf_fetch(time, lead_time, variable)
668659

669660
def _validate_time(self, times: list[datetime]) -> None:
670661
validate_time(
@@ -778,16 +769,16 @@ def __call__( # type: ignore[override]
778769
IFS ENS initial state data array.
779770
"""
780771
da = self._call(time, np.array([0], dtype="datetime64[h]"), variable)
781-
if "sample" in da.dims:
782-
da = da.isel(sample=0)
783-
return da.isel(lead_time=0)
772+
if "ensemble" in da.dims:
773+
da = da.isel(ensemble=0).drop_vars("ensemble")
774+
return da.isel(lead_time=0).drop_vars("lead_time")
784775

785776
async def fetch( # type: ignore[override]
786777
self,
787778
time: datetime | list[datetime] | TimeArray,
788779
variable: str | list[str] | VariableArray,
789780
) -> xr.DataArray:
790-
"""Async function to get data.
781+
"""Async method to retrieve IFS ENS initial state data.
791782
792783
Parameters
793784
----------
@@ -802,10 +793,10 @@ async def fetch( # type: ignore[override]
802793
xr.DataArray
803794
IFS ENS initial state data array.
804795
"""
805-
da = await self._fetch(time, np.array([0], dtype="datetime64[h]"), variable)
806-
if "sample" in da.dims:
807-
da = da.isel(sample=0)
808-
return da.isel(lead_time=0)
796+
da = await super()._ecmwf_fetch(time, timedelta(hours=0), variable)
797+
if "ensemble" in da.dims:
798+
da = da.isel(ensemble=0).drop_vars("ensemble")
799+
return da.isel(lead_time=0).drop_vars("lead_time")
809800

810801
def _validate_time(self, times: list[datetime]) -> None:
811802
validate_time(
@@ -922,8 +913,8 @@ def __call__(
922913
IFS ENS forecast data array
923914
"""
924915
da = self._call(time, lead_time, variable)
925-
if "sample" in da.dims:
926-
da = da.isel(sample=0)
916+
if "ensemble" in da.dims:
917+
da = da.isel(ensemble=0).drop_vars("ensemble")
927918
return da
928919

929920
async def fetch(
@@ -932,7 +923,7 @@ async def fetch(
932923
lead_time: timedelta | list[timedelta] | LeadTimeArray,
933924
variable: str | list[str] | VariableArray,
934925
) -> xr.DataArray:
935-
"""Async function to get data.
926+
"""Async method to retrieve IFS ENS forecast data.
936927
937928
Parameters
938929
----------
@@ -949,9 +940,9 @@ async def fetch(
949940
xr.DataArray
950941
IFS ENS forecast data array.
951942
"""
952-
da = await self._fetch(time, lead_time, variable)
953-
if "sample" in da.dims:
954-
da = da.isel(sample=0)
943+
da = await super()._ecmwf_fetch(time, lead_time, variable)
944+
if "ensemble" in da.dims:
945+
da = da.isel(ensemble=0).drop_vars("ensemble")
955946
return da
956947

957948
def _validate_time(self, times: list[datetime]) -> None:
@@ -1057,13 +1048,13 @@ def __call__(
10571048
"""
10581049
return self._call(time, lead_time, variable)
10591050

1060-
async def fetch( # type: ignore[override]
1051+
async def fetch(
10611052
self,
10621053
time: datetime | list[datetime] | TimeArray,
10631054
lead_time: timedelta | list[timedelta] | LeadTimeArray,
10641055
variable: str | list[str] | VariableArray,
10651056
) -> xr.DataArray:
1066-
"""Async function to get data.
1057+
"""Async method to retrieve AIFS forecast data.
10671058
10681059
Parameters
10691060
----------
@@ -1080,7 +1071,7 @@ async def fetch( # type: ignore[override]
10801071
xr.DataArray
10811072
AIFS forecast data array.
10821073
"""
1083-
return await self._fetch(time, lead_time, variable)
1074+
return await super()._ecmwf_fetch(time, lead_time, variable)
10841075

10851076
def _validate_time(self, times: list[datetime]) -> None:
10861077
validate_time(
@@ -1183,17 +1174,17 @@ def __call__(
11831174
AIFS ENS forecast data array
11841175
"""
11851176
da = self._call(time, lead_time, variable)
1186-
if "sample" in da.dims:
1187-
da = da.isel(sample=0)
1177+
if "ensemble" in da.dims:
1178+
da = da.isel(ensemble=0).drop_vars("ensemble")
11881179
return da
11891180

1190-
async def fetch( # type: ignore[override]
1181+
async def fetch(
11911182
self,
11921183
time: datetime | list[datetime] | TimeArray,
11931184
lead_time: timedelta | list[timedelta] | LeadTimeArray,
11941185
variable: str | list[str] | VariableArray,
11951186
) -> xr.DataArray:
1196-
"""Async function to get data.
1187+
"""Async method to retrieve AIFS ENS forecast data.
11971188
11981189
Parameters
11991190
----------
@@ -1208,10 +1199,12 @@ async def fetch( # type: ignore[override]
12081199
Returns
12091200
-------
12101201
xr.DataArray
1211-
ECMWF weather data array.
1202+
AIFS ENS forecast data array
12121203
"""
1213-
da = await self._fetch(time, lead_time, variable)
1214-
return da.isel(sample=0)
1204+
da = await super()._ecmwf_fetch(time, lead_time, variable)
1205+
if "ensemble" in da.dims:
1206+
da = da.isel(ensemble=0).drop_vars("ensemble")
1207+
return da
12151208

12161209
def _validate_time(self, times: list[datetime]) -> None:
12171210
validate_time(

test/data/test_ecmwf.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,12 +262,16 @@ async def test_ifs_async_fetch():
262262
ds_ifs_fx = IFS_FX(cache=False)
263263
ds_ifs_ens = IFS_ENS(cache=False, member=1)
264264
ds_ifs_ens_fx = IFS_ENS_FX(cache=False, member=1)
265+
ds_aifs_fx = AIFS_FX(cache=False)
266+
ds_aifs_ens_fx = AIFS_ENS_FX(cache=False, member=1)
265267

266-
da_ifs, da_fx, da_ens, da_ens_fx = await asyncio.gather(
268+
da_ifs, da_fx, da_ens, da_ens_fx, da_aifs_fx, da_aifs_ens_fx = await asyncio.gather(
267269
ds_ifs.fetch(t, variable),
268270
ds_ifs_fx.fetch(t, lt, variable),
269271
ds_ifs_ens.fetch(t, variable),
270272
ds_ifs_ens_fx.fetch(t, lt, variable),
273+
ds_aifs_fx.fetch(t, lt, variable),
274+
ds_aifs_ens_fx.fetch(t, lt, variable),
271275
)
272276

273277
# IFS (analysis): [time, variable, lat, lon]
@@ -300,6 +304,22 @@ async def test_ifs_async_fetch():
300304
assert da_ens_fx.shape[4] == 1440
301305
assert not np.isnan(da_ens_fx.values).any()
302306

307+
# AIFS_FX (forecast): [time, lead_time, variable, lat, lon]
308+
assert da_aifs_fx.shape[0] == 1
309+
assert da_aifs_fx.shape[1] == 1
310+
assert da_aifs_fx.shape[2] == 1
311+
assert da_aifs_fx.shape[3] == 721
312+
assert da_aifs_fx.shape[4] == 1440
313+
assert not np.isnan(da_aifs_fx.values).any()
314+
315+
# AIFS_ENS_FX (forecast): [time, lead_time, variable, lat, lon]
316+
assert da_aifs_ens_fx.shape[0] == 1
317+
assert da_aifs_ens_fx.shape[1] == 1
318+
assert da_aifs_ens_fx.shape[2] == 1
319+
assert da_aifs_ens_fx.shape[3] == 721
320+
assert da_aifs_ens_fx.shape[4] == 1440
321+
assert not np.isnan(da_aifs_ens_fx.values).any()
322+
303323

304324
@pytest.mark.timeout(30)
305325
@pytest.mark.parametrize("data_source", [IFS, IFS_ENS])

0 commit comments

Comments
 (0)