Skip to content

Rechunk derived #6516

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 116 additions & 27 deletions lib/iris/aux_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import dask.array as da
import numpy as np

from iris._lazy_data import concatenate
from iris._lazy_data import _optimum_chunksize, concatenate, is_lazy_data
from iris.common import CFVariableMixin, CoordMetadata, metadata_manager_factory
import iris.coords
from iris.warnings import IrisIgnoringBoundsWarning
Expand Down Expand Up @@ -76,6 +76,93 @@ def dependencies(self):

"""

@abstractmethod
def _calculate_array(self, *dep_arrays, **other_args):
"""Make a coordinate array from a complete set of dependency arrays.

Parameters
----------
* dep_arrays : tuple of array-like
Arrays of data for each dependency.
Must match the number of declared dependencies, in the standard order.
All are aligned with the leading result dimensions, but may have fewer
than the full number of dimensions. They can be lazy or real data.

* other_args
Dict of keys providing class-specific additional arguments.

Returns
-------
array-like
The lazy result array.

This is the basic derived calculation, defined by each hybrid class, which
defines how the dependency values are combined to make the derived result.
"""
pass

def _derive_array(self, *dep_arrays, **other_args):
"""Build an array of coordinate values.

Call arguments as for :meth:`_calculate_array`.

This routine calls :meth:`_calculate_array` to construct a derived result array.

It then checks the chunk size of the result and, if this exceeds the current
Dask chunksize, it will then re-chunk some of the input arrays and re-calculate
the result to reduce the memory cost.

This routine is itself usually called once by :meth:`make_coord`, to make a
points array, and then again to make the bounds.
"""
# Make an initial result calculation.
# First make all dependencies lazy, to ensure a lazy calculation and avoid
# potentially spending a lot of time + memory.
lazy_deps = [
# Note: no attempt to make clever chunking choices here. If needed it
# should get fixed later. Plus, single chunks keeps graph overhead small.
dep if is_lazy_data(dep) else da.from_array(dep, chunks=-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this only guarantees a single chunk if the data was initially non-lazy. From what I can tell of the tests it seems like you're only testing the case where there is a single chunk given. I think it would be worth making sure there is testing for the case where lazy_deps contains chunked arrays.

Copy link
Member Author

@pp-mo pp-mo Jun 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK I think the comment is really the problem here :
this "single chunks" statement really only applies to the real arrays which we wrap as lazy.
I will try and fix this ...

Background:

The initial calculation is supposed to produce a result that we can simply use, if its chunksize is OK, but we need it to be definitely lazy so that we can pre-check the chunksize before committing to do the calculation.
So we need to ensure that the initial 'test' calculation is lazy.
I did consider ensuring that just the first, or smallest term was lazy, but I realised that in the calculation, dask itself would then wrap any other real terms, using "auto" chunking by default, which is probably sub-optimal for our purposes.

If we were making our best single effort at producing a usable result array, we might logically use our "optimal_chunksize" scheme here in wrapping the real terms.
But in fact that is not a good approach, because the whole point is that you need to consider the terms (and especially their chunking) in alignment with all dimensions of the calculated result, and not just in their own individual dimensions. That's effectively the whole problem here.

So, I chose to first wrap all real terms as single chunks, and then assess the chunksize of the calculated result.
Only if that simplistic approach produces a chunksize which is too large, does the code then make a bigger effort to re-consider the chunking across all the terms, and re-chunk everything in certain dimensions.
I thought it was probably "safer" not to do that co-optimisation unless it is clearly needed, as the results might be a bit sub-optimal.

for dep in dep_arrays
]
result = self._calculate_array(*lazy_deps, **other_args)

# Now check if we need to improve on the chunking of the result.
adjusted_chunks = _optimum_chunksize(
chunks=result.chunksize,
shape=result.shape,
dtype=result.dtype,
)

# Does optimum_chunksize say we should have smaller chunks in some dimensions?
if not all(a >= b for a, b in zip(adjusted_chunks, result.chunksize)):
# Re-do the result calculation, but first re-chunking each dep in the
# dimensions which it is suggested to reduce.
new_deps = []
for dep, original_dep in zip(lazy_deps, dep_arrays):
# For each dependency, reduce chunksize in each dim to the new result
# chunksize, if smaller.
dep_chunks = dep.chunksize
new_chunks = tuple(
[
min(dep_chunk, adj_chunk)
for dep_chunk, adj_chunk in zip(dep_chunks, adjusted_chunks)
]
)
if new_chunks != dep_chunks:
# When dep chunksize needs to change, produce a rechunked version.
if is_lazy_data(original_dep):
dep = original_dep.rechunk(new_chunks)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be worthwhile ensuring that this line gets test coverage.

else:
# Make new lazy array from real original, rather than re-chunk.
dep = da.from_array(original_dep, chunks=new_chunks)
new_deps.append(dep)

# Finally, re-do the calculation, which hopefully results in a better
# overall chunksize for the result
result = self._calculate_array(*new_deps, **other_args)

return result

@abstractmethod
def make_coord(self, coord_dims_func):
"""Return a new :class:`iris.coords.AuxCoord` as defined by this factory.
Expand Down Expand Up @@ -463,7 +550,7 @@ def dependencies(self):
return dependencies

@staticmethod
def _derive(pressure_at_top, sigma, surface_air_pressure):
def _calculate_array(pressure_at_top, sigma, surface_air_pressure):
"""Derive coordinate."""
return pressure_at_top + sigma * (surface_air_pressure - pressure_at_top)

Expand All @@ -485,7 +572,7 @@ def make_coord(self, coord_dims_func):

# Build the points array
nd_points_by_key = self._remap(dependency_dims, derived_dims)
points = self._derive(
points = self._derive_array(
nd_points_by_key["pressure_at_top"],
nd_points_by_key["sigma"],
nd_points_by_key["surface_air_pressure"],
Expand Down Expand Up @@ -519,7 +606,7 @@ def make_coord(self, coord_dims_func):
surface_air_pressure_pts = nd_points_by_key["surface_air_pressure"]
bds_shape = list(surface_air_pressure_pts.shape) + [1]
surface_air_pressure = surface_air_pressure_pts.reshape(bds_shape)
bounds = self._derive(pressure_at_top, sigma, surface_air_pressure)
bounds = self._derive_array(pressure_at_top, sigma, surface_air_pressure)

# Create coordinate
return iris.coords.AuxCoord(
Expand Down Expand Up @@ -608,7 +695,7 @@ def dependencies(self):
"orography": self.orography,
}

def _derive(self, delta, sigma, orography):
def _calculate_array(self, delta, sigma, orography):
return delta + sigma * orography

def make_coord(self, coord_dims_func):
Expand All @@ -629,7 +716,7 @@ def make_coord(self, coord_dims_func):

# Build the points array.
nd_points_by_key = self._remap(dependency_dims, derived_dims)
points = self._derive(
points = self._derive_array(
nd_points_by_key["delta"],
nd_points_by_key["sigma"],
nd_points_by_key["orography"],
Expand Down Expand Up @@ -657,7 +744,7 @@ def make_coord(self, coord_dims_func):
bds_shape = list(orography_pts.shape) + [1]
orography = orography_pts.reshape(bds_shape)

bounds = self._derive(delta, sigma, orography)
bounds = self._derive_array(delta, sigma, orography)

hybrid_height = iris.coords.AuxCoord(
points,
Expand Down Expand Up @@ -814,7 +901,7 @@ def dependencies(self):
"surface_air_pressure": self.surface_air_pressure,
}

def _derive(self, delta, sigma, surface_air_pressure):
def _calculate_array(self, delta, sigma, surface_air_pressure):
return delta + sigma * surface_air_pressure

def make_coord(self, coord_dims_func):
Expand All @@ -835,7 +922,7 @@ def make_coord(self, coord_dims_func):

# Build the points array.
nd_points_by_key = self._remap(dependency_dims, derived_dims)
points = self._derive(
points = self._derive_array(
nd_points_by_key["delta"],
nd_points_by_key["sigma"],
nd_points_by_key["surface_air_pressure"],
Expand Down Expand Up @@ -863,7 +950,7 @@ def make_coord(self, coord_dims_func):
bds_shape = list(surface_air_pressure_pts.shape) + [1]
surface_air_pressure = surface_air_pressure_pts.reshape(bds_shape)

bounds = self._derive(delta, sigma, surface_air_pressure)
bounds = self._derive_array(delta, sigma, surface_air_pressure)

hybrid_pressure = iris.coords.AuxCoord(
points,
Expand Down Expand Up @@ -1022,7 +1109,9 @@ def dependencies(self):
zlev=self.zlev,
)

def _derive(self, sigma, eta, depth, depth_c, zlev, nsigma, coord_dims_func):
def _calculate_array(
self, sigma, eta, depth, depth_c, zlev, nsigma, coord_dims_func
):
# Calculate the index of the 'z' dimension in the input arrays.
# First find the cube 'z' dimension ...
[cube_z_dim] = coord_dims_func(self.dependencies["zlev"])
Expand Down Expand Up @@ -1097,14 +1186,14 @@ def make_coord(self, coord_dims_func):
nd_points_by_key = self._remap(dependency_dims, derived_dims)

[nsigma] = nd_points_by_key["nsigma"]
points = self._derive(
points = self._derive_array(
nd_points_by_key["sigma"],
nd_points_by_key["eta"],
nd_points_by_key["depth"],
nd_points_by_key["depth_c"],
nd_points_by_key["zlev"],
nsigma,
coord_dims_func,
coord_dims_func=coord_dims_func,
)

bounds = None
Expand All @@ -1131,14 +1220,14 @@ def make_coord(self, coord_dims_func):
bounds = nd_points_by_key[key].reshape(bds_shape)
nd_values_by_key[key] = bounds

bounds = self._derive(
bounds = self._derive_array(
nd_values_by_key["sigma"],
nd_values_by_key["eta"],
nd_values_by_key["depth"],
nd_values_by_key["depth_c"],
nd_values_by_key["zlev"],
nsigma,
coord_dims_func,
coord_dims_func=coord_dims_func,
)

coord = iris.coords.AuxCoord(
Expand Down Expand Up @@ -1238,7 +1327,7 @@ def dependencies(self):
"""
return dict(sigma=self.sigma, eta=self.eta, depth=self.depth)

def _derive(self, sigma, eta, depth):
def _calculate_array(self, sigma, eta, depth):
return eta + sigma * (depth + eta)

def make_coord(self, coord_dims_func):
Expand All @@ -1257,7 +1346,7 @@ def make_coord(self, coord_dims_func):

# Build the points array.
nd_points_by_key = self._remap(dependency_dims, derived_dims)
points = self._derive(
points = self._derive_array(
nd_points_by_key["sigma"],
nd_points_by_key["eta"],
nd_points_by_key["depth"],
Expand Down Expand Up @@ -1287,7 +1376,7 @@ def make_coord(self, coord_dims_func):
bounds = nd_points_by_key[key].reshape(bds_shape)
nd_values_by_key[key] = bounds

bounds = self._derive(
bounds = self._derive_array(
nd_values_by_key["sigma"],
nd_values_by_key["eta"],
nd_values_by_key["depth"],
Expand Down Expand Up @@ -1419,7 +1508,7 @@ def dependencies(self):
depth_c=self.depth_c,
)

def _derive(self, s, c, eta, depth, depth_c):
def _calculate_array(self, s, c, eta, depth, depth_c):
S = depth_c * s + (depth - depth_c) * c
return S + eta * (1 + S / depth)

Expand All @@ -1439,7 +1528,7 @@ def make_coord(self, coord_dims_func):

# Build the points array.
nd_points_by_key = self._remap(dependency_dims, derived_dims)
points = self._derive(
points = self._derive_array(
nd_points_by_key["s"],
nd_points_by_key["c"],
nd_points_by_key["eta"],
Expand Down Expand Up @@ -1471,7 +1560,7 @@ def make_coord(self, coord_dims_func):
bounds = nd_points_by_key[key].reshape(bds_shape)
nd_values_by_key[key] = bounds

bounds = self._derive(
bounds = self._derive_array(
nd_values_by_key["s"],
nd_values_by_key["c"],
nd_values_by_key["eta"],
Expand Down Expand Up @@ -1608,7 +1697,7 @@ def dependencies(self):
depth_c=self.depth_c,
)

def _derive(self, s, eta, depth, a, b, depth_c):
def _calculate_array(self, s, eta, depth, a, b, depth_c):
c = (1 - b) * da.sinh(a * s) / da.sinh(a) + b * (
da.tanh(a * (s + 0.5)) / (2 * da.tanh(0.5 * a)) - 0.5
)
Expand All @@ -1630,7 +1719,7 @@ def make_coord(self, coord_dims_func):

# Build the points array.
nd_points_by_key = self._remap(dependency_dims, derived_dims)
points = self._derive(
points = self._derive_array(
nd_points_by_key["s"],
nd_points_by_key["eta"],
nd_points_by_key["depth"],
Expand Down Expand Up @@ -1663,7 +1752,7 @@ def make_coord(self, coord_dims_func):
bounds = nd_points_by_key[key].reshape(bds_shape)
nd_values_by_key[key] = bounds

bounds = self._derive(
bounds = self._derive_array(
nd_values_by_key["s"],
nd_values_by_key["eta"],
nd_values_by_key["depth"],
Expand Down Expand Up @@ -1799,7 +1888,7 @@ def dependencies(self):
depth_c=self.depth_c,
)

def _derive(self, s, c, eta, depth, depth_c):
def _calculate_array(self, s, c, eta, depth, depth_c):
S = (depth_c * s + depth * c) / (depth_c + depth)
return eta + (eta + depth) * S

Expand All @@ -1819,7 +1908,7 @@ def make_coord(self, coord_dims_func):

# Build the points array.
nd_points_by_key = self._remap(dependency_dims, derived_dims)
points = self._derive(
points = self._derive_array(
nd_points_by_key["s"],
nd_points_by_key["c"],
nd_points_by_key["eta"],
Expand Down Expand Up @@ -1851,7 +1940,7 @@ def make_coord(self, coord_dims_func):
bounds = nd_points_by_key[key].reshape(bds_shape)
nd_values_by_key[key] = bounds

bounds = self._derive(
bounds = self._derive_array(
nd_values_by_key["s"],
nd_values_by_key["c"],
nd_values_by_key["eta"],
Expand Down
18 changes: 9 additions & 9 deletions lib/iris/tests/unit/aux_factory/test_AtmosphereSigmaFactory.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,21 +135,21 @@ def test_values(self, sample_kwargs):

class Test__derive:
def test_function_scalar(self):
assert AtmosphereSigmaFactory._derive(0, 0, 0) == 0
assert AtmosphereSigmaFactory._derive(3, 0, 0) == 3
assert AtmosphereSigmaFactory._derive(0, 5, 0) == 0
assert AtmosphereSigmaFactory._derive(0, 0, 7) == 0
assert AtmosphereSigmaFactory._derive(3, 5, 0) == -12
assert AtmosphereSigmaFactory._derive(3, 0, 7) == 3
assert AtmosphereSigmaFactory._derive(0, 5, 7) == 35
assert AtmosphereSigmaFactory._derive(3, 5, 7) == 23
assert AtmosphereSigmaFactory._calculate_array(0, 0, 0) == 0
assert AtmosphereSigmaFactory._calculate_array(3, 0, 0) == 3
assert AtmosphereSigmaFactory._calculate_array(0, 5, 0) == 0
assert AtmosphereSigmaFactory._calculate_array(0, 0, 7) == 0
assert AtmosphereSigmaFactory._calculate_array(3, 5, 0) == -12
assert AtmosphereSigmaFactory._calculate_array(3, 0, 7) == 3
assert AtmosphereSigmaFactory._calculate_array(0, 5, 7) == 35
assert AtmosphereSigmaFactory._calculate_array(3, 5, 7) == 23

def test_function_array(self):
ptop = 3
sigma = np.array([2, 4])
ps = np.arange(4).reshape(2, 2)
np.testing.assert_equal(
AtmosphereSigmaFactory._derive(ptop, sigma, ps),
AtmosphereSigmaFactory._calculate_array(ptop, sigma, ps),
[[-3, -5], [1, 3]],
)

Expand Down
Loading
Loading