-
Notifications
You must be signed in to change notification settings - Fork 157
Description
Version
0.11.0
On which installation method(s) does this occur?
source
Describe the issue
I am attempting to use the run.deterministic workflow on prognostic models, for single time but batched input data. I have created my own datasource which simply wraps a xarray.dataset in memory. Batched input data can for example correspond to different perturbations, although I am not using the Earth2Studio perturbation API.
Using an xarray without a batch dimension works without issue for many prognostic models.
I am now considering a xarray.dataarray with dimensions and coordinates batch, time, variable, lat, lon as requested by the api. I get a ValueError that the "lead_time" dimension could not be found at the right index (index 2). Upon further inquiry, I find that fetch_data returns coords in the order batch, lead_time, time, lat, lon, which then gets passed down to the iterator and fails the handshake_dim.
I have tried to patch fetch_data and return the expected ordering of batch, time, lead_time, lat, lon, but I then get further errors from io.write. I see that while the model has generated outputs for the first step, they are all nan.
Do you have any tips on how to operate with batched inputs? Is my io handling the issue or does the batch.py module only work in junction with the perturbation api?
Thanks!