@@ -140,17 +140,7 @@ def __call__( # type: ignore[override]
140140 """Retrieve ECMWF data. The child class should override this"""
141141 pass
142142
143- @abstractmethod
144- async def fetch ( # type: ignore[override]
145- self ,
146- time : datetime | list [datetime ] | TimeArray ,
147- lead_time : timedelta | list [timedelta ] | LeadTimeArray ,
148- variable : str | list [str ] | VariableArray ,
149- ) -> xr .DataArray :
150- """Async function to get data, the child class should over ride this and call/"""
151- pass
152-
153- def _call ( # type: ignore[override]
143+ def _call (
154144 self ,
155145 time : datetime | list [datetime ] | TimeArray ,
156146 lead_time : timedelta | list [timedelta ] | LeadTimeArray ,
@@ -171,7 +161,7 @@ def _call( # type: ignore[override]
171161 Note
172162 ----
173163 For peturbed data from ensemble models, the returned data array will have an
174- extra `sample ` dimension added to it.
164+ extra `ensemble ` dimension added to it.
175165
176166 Returns
177167 -------
@@ -187,13 +177,14 @@ def _call( # type: ignore[override]
187177
188178 xr_array = loop .run_until_complete (
189179 asyncio .wait_for (
190- self ._fetch (time , lead_time , variable ), timeout = self .async_timeout
180+ self ._ecmwf_fetch (time , lead_time , variable ),
181+ timeout = self .async_timeout ,
191182 )
192183 )
193184
194185 return xr_array
195186
196- async def _fetch ( # type: ignore[override]
187+ async def _ecmwf_fetch (
197188 self ,
198189 time : datetime | list [datetime ] | TimeArray ,
199190 lead_time : timedelta | list [timedelta ] | LeadTimeArray ,
@@ -241,20 +232,20 @@ async def _fetch( # type: ignore[override]
241232 len (self .LON ),
242233 )
243234 ),
244- dims = ["time" , "lead_time" , "variable" , "sample " , "lat" , "lon" ],
235+ dims = ["time" , "lead_time" , "variable" , "ensemble " , "lat" , "lon" ],
245236 coords = {
246237 "time" : time ,
247238 "lead_time" : lead_time ,
248239 "variable" : variable ,
249- "sample " : np .array (self ._members ),
240+ "ensemble " : np .array (self ._members ),
250241 "lat" : self .LAT ,
251242 "lon" : self .LON ,
252243 },
253244 )
254245
255246 async_tasks = await self ._create_tasks (time , lead_time , variable )
256247 func_map = map (
257- functools .partial (self .fetch_wrapper , xr_array = xr_array ), async_tasks
248+ functools .partial (self ._download_wrapper , xr_array = xr_array ), async_tasks
258249 )
259250
260251 await tqdm .gather (
@@ -317,7 +308,7 @@ async def _create_tasks(
317308 )
318309 return tasks
319310
320- async def fetch_wrapper (
311+ async def _download_wrapper (
321312 self ,
322313 task : ECMWFOpenDataAsyncTask ,
323314 xr_array : xr .DataArray ,
@@ -347,7 +338,7 @@ async def fetch_wrapper(
347338 f"No GRIB messages found for ensemble member { m } in { grib_file } "
348339 )
349340 member_arrays .append (msgs [0 ].values )
350- values = np .stack (member_arrays , axis = 0 ) # [sample , y, x]
341+ values = np .stack (member_arrays , axis = 0 ) # [ensemble , y, x]
351342 else :
352343 values = grbs [1 ].values # [y, x]
353344 # Provided [-180, 180], roll to [0, 360] along x dimension
@@ -526,14 +517,14 @@ def __call__( # type: ignore[override]
526517 IFS analysis data array
527518 """
528519 da = self ._call (time , np .array ([0 ], dtype = "datetime64[h]" ), variable )
529- return da .isel (lead_time = 0 )
520+ return da .isel (lead_time = 0 ). drop_vars ( "lead_time" )
530521
531522 async def fetch ( # type: ignore[override]
532523 self ,
533524 time : datetime | list [datetime ] | TimeArray ,
534525 variable : str | list [str ] | VariableArray ,
535526 ) -> xr .DataArray :
536- """Async function to get data.
527+ """Async method to retrieve IFS analysis data.
537528
538529 Parameters
539530 ----------
@@ -548,8 +539,8 @@ async def fetch( # type: ignore[override]
548539 xr.DataArray
549540 IFS analysis data array.
550541 """
551- da = await self . _fetch (time , np . array ([ 0 ], dtype = "datetime64[h]" ), variable )
552- return da .isel (lead_time = 0 )
542+ da = await super (). _ecmwf_fetch (time , timedelta ( hours = 0 ), variable )
543+ return da .isel (lead_time = 0 ). drop_vars ( "lead_time" )
553544
554545 def _validate_time (self , times : list [datetime ]) -> None :
555546 """Verify all times are valid based on offline knowledge.
@@ -641,13 +632,13 @@ def __call__(
641632 """
642633 return self ._call (time , lead_time , variable )
643634
644- async def fetch ( # type: ignore[override]
635+ async def fetch (
645636 self ,
646637 time : datetime | list [datetime ] | TimeArray ,
647638 lead_time : timedelta | list [timedelta ] | LeadTimeArray ,
648639 variable : str | list [str ] | VariableArray ,
649640 ) -> xr .DataArray :
650- """Async function to get data.
641+ """Async method to retrieve IFS forecast data.
651642
652643 Parameters
653644 ----------
@@ -664,7 +655,7 @@ async def fetch( # type: ignore[override]
664655 xr.DataArray
665656 IFS forecast data array.
666657 """
667- return await self . _fetch (time , lead_time , variable )
658+ return await super (). _ecmwf_fetch (time , lead_time , variable )
668659
669660 def _validate_time (self , times : list [datetime ]) -> None :
670661 validate_time (
@@ -778,16 +769,16 @@ def __call__( # type: ignore[override]
778769 IFS ENS initial state data array.
779770 """
780771 da = self ._call (time , np .array ([0 ], dtype = "datetime64[h]" ), variable )
781- if "sample " in da .dims :
782- da = da .isel (sample = 0 )
783- return da .isel (lead_time = 0 )
772+ if "ensemble " in da .dims :
773+ da = da .isel (ensemble = 0 ). drop_vars ( "ensemble" )
774+ return da .isel (lead_time = 0 ). drop_vars ( "lead_time" )
784775
785776 async def fetch ( # type: ignore[override]
786777 self ,
787778 time : datetime | list [datetime ] | TimeArray ,
788779 variable : str | list [str ] | VariableArray ,
789780 ) -> xr .DataArray :
790- """Async function to get data.
781+ """Async method to retrieve IFS ENS initial state data.
791782
792783 Parameters
793784 ----------
@@ -802,10 +793,10 @@ async def fetch( # type: ignore[override]
802793 xr.DataArray
803794 IFS ENS initial state data array.
804795 """
805- da = await self . _fetch (time , np . array ([ 0 ], dtype = "datetime64[h]" ), variable )
806- if "sample " in da .dims :
807- da = da .isel (sample = 0 )
808- return da .isel (lead_time = 0 )
796+ da = await super (). _ecmwf_fetch (time , timedelta ( hours = 0 ), variable )
797+ if "ensemble " in da .dims :
798+ da = da .isel (ensemble = 0 ). drop_vars ( "ensemble" )
799+ return da .isel (lead_time = 0 ). drop_vars ( "lead_time" )
809800
810801 def _validate_time (self , times : list [datetime ]) -> None :
811802 validate_time (
@@ -922,8 +913,8 @@ def __call__(
922913 IFS ENS forecast data array
923914 """
924915 da = self ._call (time , lead_time , variable )
925- if "sample " in da .dims :
926- da = da .isel (sample = 0 )
916+ if "ensemble " in da .dims :
917+ da = da .isel (ensemble = 0 ). drop_vars ( "ensemble" )
927918 return da
928919
929920 async def fetch (
@@ -932,7 +923,7 @@ async def fetch(
932923 lead_time : timedelta | list [timedelta ] | LeadTimeArray ,
933924 variable : str | list [str ] | VariableArray ,
934925 ) -> xr .DataArray :
935- """Async function to get data.
926+ """Async method to retrieve IFS ENS forecast data.
936927
937928 Parameters
938929 ----------
@@ -949,9 +940,9 @@ async def fetch(
949940 xr.DataArray
950941 IFS ENS forecast data array.
951942 """
952- da = await self . _fetch (time , lead_time , variable )
953- if "sample " in da .dims :
954- da = da .isel (sample = 0 )
943+ da = await super (). _ecmwf_fetch (time , lead_time , variable )
944+ if "ensemble " in da .dims :
945+ da = da .isel (ensemble = 0 ). drop_vars ( "ensemble" )
955946 return da
956947
957948 def _validate_time (self , times : list [datetime ]) -> None :
@@ -1057,13 +1048,13 @@ def __call__(
10571048 """
10581049 return self ._call (time , lead_time , variable )
10591050
1060- async def fetch ( # type: ignore[override]
1051+ async def fetch (
10611052 self ,
10621053 time : datetime | list [datetime ] | TimeArray ,
10631054 lead_time : timedelta | list [timedelta ] | LeadTimeArray ,
10641055 variable : str | list [str ] | VariableArray ,
10651056 ) -> xr .DataArray :
1066- """Async function to get data.
1057+ """Async method to retrieve AIFS forecast data.
10671058
10681059 Parameters
10691060 ----------
@@ -1080,7 +1071,7 @@ async def fetch( # type: ignore[override]
10801071 xr.DataArray
10811072 AIFS forecast data array.
10821073 """
1083- return await self . _fetch (time , lead_time , variable )
1074+ return await super (). _ecmwf_fetch (time , lead_time , variable )
10841075
10851076 def _validate_time (self , times : list [datetime ]) -> None :
10861077 validate_time (
@@ -1183,17 +1174,17 @@ def __call__(
11831174 AIFS ENS forecast data array
11841175 """
11851176 da = self ._call (time , lead_time , variable )
1186- if "sample " in da .dims :
1187- da = da .isel (sample = 0 )
1177+ if "ensemble " in da .dims :
1178+ da = da .isel (ensemble = 0 ). drop_vars ( "ensemble" )
11881179 return da
11891180
1190- async def fetch ( # type: ignore[override]
1181+ async def fetch (
11911182 self ,
11921183 time : datetime | list [datetime ] | TimeArray ,
11931184 lead_time : timedelta | list [timedelta ] | LeadTimeArray ,
11941185 variable : str | list [str ] | VariableArray ,
11951186 ) -> xr .DataArray :
1196- """Async function to get data.
1187+ """Async method to retrieve AIFS ENS forecast data.
11971188
11981189 Parameters
11991190 ----------
@@ -1208,10 +1199,12 @@ async def fetch( # type: ignore[override]
12081199 Returns
12091200 -------
12101201 xr.DataArray
1211- ECMWF weather data array.
1202+ AIFS ENS forecast data array
12121203 """
1213- da = await self ._fetch (time , lead_time , variable )
1214- return da .isel (sample = 0 )
1204+ da = await super ()._ecmwf_fetch (time , lead_time , variable )
1205+ if "ensemble" in da .dims :
1206+ da = da .isel (ensemble = 0 ).drop_vars ("ensemble" )
1207+ return da
12151208
12161209 def _validate_time (self , times : list [datetime ]) -> None :
12171210 validate_time (
0 commit comments