diff --git a/examples/22_healda.py b/examples/22_healda.py index bfe39e75a..eaff6c7ea 100644 --- a/examples/22_healda.py +++ b/examples/22_healda.py @@ -227,7 +227,7 @@ for row, (title, da) in enumerate(zip(titles, results)): for col, var in enumerate(plot_vars): ax = axes[row, col] - field = da.sel(variable=var).data[0].get() # [nlat, nlon] cupy -> numpy + field = da.sel(variable=var).values[0] # [nlat, nlon] cupy -> numpy im = ax.pcolormesh( lon, lat, @@ -276,10 +276,8 @@ diff_results = [result_both, result_sat, result_conv] for title, da_pred in zip(diff_titles, diff_results): for var in plot_vars: - field_pred = da_pred.sel(variable=var).data[0] - if hasattr(field_pred, "get"): - field_pred = field_pred.get() - field_era5 = era5_interp.sel(variable=var).data[0] + field_pred = da_pred.sel(variable=var).values[0] + field_era5 = era5_interp.sel(variable=var).values[0] mae = float(np.abs(field_pred - field_era5).mean()) logger.info(f"{title} | {var} MAE: {mae:.4f}") @@ -298,10 +296,8 @@ for row, (title, da_pred) in enumerate(zip(diff_titles, diff_results)): for col, var in enumerate(plot_vars): ax = axes[row, col] - field_pred = ( - da_pred.sel(variable=var).data[0].get() - ) # [nlat, nlon] cupy -> numpy - field_era5 = era5_interp.sel(variable=var).data[0] # [nlat, nlon] + field_pred = da_pred.sel(variable=var).values[0] # [nlat, nlon] cupy -> numpy + field_era5 = era5_interp.sel(variable=var).values[0] # [nlat, nlon] diff = field_pred - field_era5 im = ax.pcolormesh( lon,