From d9e4fd242cec742a3a818aeed9d752951d4f1472 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 7 May 2025 13:54:52 +0200 Subject: [PATCH] alignment of n-dimensional indexes vs. exclude dim Support checking exact alignment of indexes that use multiple dimensions in the cases where some of those dimensions are included in alignment while others are excluded. Added `exclude_dims` keyword argument to `Index.equals()` (and still support old signature with future warning). Also fixed bug: indexes associated with scalar coordinates were ignored during alignment. Added tests as well. --- xarray/core/indexes.py | 66 +++++++++++++++++++++++--- xarray/structure/alignment.py | 49 +++++++++++--------- xarray/tests/test_dataset.py | 87 +++++++++++++++++++++++++++++++++++ 3 files changed, 175 insertions(+), 27 deletions(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 8babb885a5e..f290ef2bfd1 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -2,9 +2,10 @@ import collections.abc import copy +import inspect from collections import defaultdict -from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast +from collections.abc import Callable, Hashable, Iterable, Iterator, Mapping, Sequence +from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast, overload import numpy as np import pandas as pd @@ -348,7 +349,15 @@ def reindex_like(self, other: Self) -> dict[Hashable, Any]: """ raise NotImplementedError(f"{self!r} doesn't support re-indexing labels") - def equals(self, other: Index) -> bool: + @overload + def equals(self, other: Index) -> bool: ... + + @overload + def equals( + self, other: Index, *, exclude_dims: frozenset[Hashable] | None = None + ) -> bool: ... + + def equals(self, other: Index, **kwargs) -> bool: """Compare this index with another index of the same type. Implementation is optional but required in order to support alignment. @@ -357,6 +366,16 @@ def equals(self, other: Index) -> bool: ---------- other : Index The other Index object to compare with this object. + exclude_dims : frozenset of hashable, optional + All the dimensions that are excluded from alignment, or None by default + (i.e., when this method is not called in the context of alignment). + For a n-dimensional index it allows ignoring any relevant dimension found + in ``exclude_dims`` when comparing this index with the other index. + For a 1-dimensional index it can be always safely ignored as this + method is not called when all of the index's dimensions are also excluded + from alignment + (note: the index's dimensions correspond to the union of the dimensions + of all coordinate variables associated with this index). Returns ------- @@ -863,7 +882,7 @@ def sel( return IndexSelResult({self.dim: indexer}) - def equals(self, other: Index): + def equals(self, other: Index, *, exclude_dims: frozenset[Hashable] | None = None): if not isinstance(other, PandasIndex): return False return self.index.equals(other.index) and self.dim == other.dim @@ -1542,7 +1561,9 @@ def sel( return IndexSelResult(results) - def equals(self, other: Index) -> bool: + def equals( + self, other: Index, *, exclude_dims: frozenset[Hashable] | None = None + ) -> bool: if not isinstance(other, CoordinateTransformIndex): return False return self.transform.equals(other.transform) @@ -1925,6 +1946,36 @@ def default_indexes( return indexes +def _wrap_index_equals( + index: Index, +) -> Callable[[Index, frozenset[Hashable]], bool]: + # TODO: remove this Index.equals() wrapper (backward compatibility) + + sig = inspect.signature(index.equals) + + if len(sig.parameters) == 1: + index_cls_name = type(index).__module__ + "." + type(index).__qualname__ + emit_user_level_warning( + f"the signature ``{index_cls_name}.equals(self, other)`` is deprecated. " + f"Please update it to " + f"``{index_cls_name}.equals(self, other, *, exclude_dims=None)`` " + "or kindly ask the maintainers doing it. " + "See documentation of xarray.Index.equals() for more info.", + FutureWarning, + ) + exclude_dims_kwarg = False + else: + exclude_dims_kwarg = True + + def equals_wrapper(other: Index, exclude_dims: frozenset[Hashable]) -> bool: + if exclude_dims_kwarg: + return index.equals(other, exclude_dims=exclude_dims) + else: + return index.equals(other) + + return equals_wrapper + + def indexes_equal( index: Index, other_index: Index, @@ -1966,6 +2017,7 @@ def indexes_equal( def indexes_all_equal( elements: Sequence[tuple[Index, dict[Hashable, Variable]]], + exclude_dims: frozenset[Hashable], ) -> bool: """Check if indexes are all equal. @@ -1990,9 +2042,11 @@ def check_variables(): same_type = all(type(indexes[0]) is type(other_idx) for other_idx in indexes[1:]) if same_type: + index_equals_func = _wrap_index_equals(indexes[0]) try: not_equal = any( - not indexes[0].equals(other_idx) for other_idx in indexes[1:] + not index_equals_func(other_idx, exclude_dims) + for other_idx in indexes[1:] ) except NotImplementedError: not_equal = check_variables() diff --git a/xarray/structure/alignment.py b/xarray/structure/alignment.py index ea90519143c..49d2709343e 100644 --- a/xarray/structure/alignment.py +++ b/xarray/structure/alignment.py @@ -216,30 +216,37 @@ def _normalize_indexes( normalized_indexes = {} normalized_index_vars = {} - for idx, index_vars in Indexes(xr_indexes, xr_variables).group_by_index(): - coord_names_and_dims = [] - all_dims: set[Hashable] = set() - for name, var in index_vars.items(): + for idx, idx_vars in Indexes(xr_indexes, xr_variables).group_by_index(): + idx_coord_names_and_dims = [] + idx_all_dims: set[Hashable] = set() + + for name, var in idx_vars.items(): dims = var.dims - coord_names_and_dims.append((name, dims)) - all_dims.update(dims) - - exclude_dims = all_dims & self.exclude_dims - if exclude_dims == all_dims: - continue - elif exclude_dims: - excl_dims_str = ", ".join(str(d) for d in exclude_dims) - incl_dims_str = ", ".join(str(d) for d in all_dims - exclude_dims) - raise AlignmentError( - f"cannot exclude dimension(s) {excl_dims_str} from alignment because " - "these are used by an index together with non-excluded dimensions " - f"{incl_dims_str}" - ) + idx_coord_names_and_dims.append((name, dims)) + idx_all_dims.update(dims) + + # We can ignore an index if all the dimensions it uses are also excluded + # from the alignment (do not ignore the index if it has no related dimension, i.e., + # it is associated with one or more scalar coordinates). + if idx_all_dims: + exclude_dims = idx_all_dims & self.exclude_dims + if exclude_dims == idx_all_dims: + continue + elif exclude_dims and self.join != "exact": + excl_dims_str = ", ".join(str(d) for d in exclude_dims) + incl_dims_str = ", ".join( + str(d) for d in idx_all_dims - exclude_dims + ) + raise AlignmentError( + f"cannot exclude dimension(s) {excl_dims_str} from non-exact alignment " + "because these are used by an index together with non-excluded dimensions " + f"{incl_dims_str}" + ) - key = (tuple(coord_names_and_dims), type(idx)) + key = (tuple(idx_coord_names_and_dims), type(idx)) normalized_indexes[key] = idx - normalized_index_vars[key] = index_vars + normalized_index_vars[key] = idx_vars return normalized_indexes, normalized_index_vars @@ -298,7 +305,7 @@ def _need_reindex(self, dim, cmp_indexes) -> bool: pandas). This is useful, e.g., for overwriting such duplicate indexes. """ - if not indexes_all_equal(cmp_indexes): + if not indexes_all_equal(cmp_indexes, self.exclude_dims): # always reindex when matching indexes are not equal return True diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index ac186a7d351..c001b3b69fc 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -2615,6 +2615,93 @@ def test_align_index_var_attrs(self, join) -> None: assert ds.x.attrs == {"units": "m"} assert ds_noattr.x.attrs == {} + def test_align_scalar_index(self) -> None: + # ensure that indexes associated with scalar coordinates are not ignored + # during alignment + class ScalarIndex(Index): + def __init__(self, value: int): + self.value = value + + @classmethod + def from_variables(cls, variables, *, options): + var = next(iter(variables.values())) + return cls(int(var.values)) + + def equals(self, other, *, exclude_dims=None): + return isinstance(other, ScalarIndex) and other.value == self.value + + ds1 = Dataset(coords={"x": 0}).set_xindex("x", ScalarIndex) + ds2 = Dataset(coords={"x": 0}).set_xindex("x", ScalarIndex) + + actual = xr.align(ds1, ds2, join="exact") + assert_identical(actual[0], ds1, check_default_indexes=False) + assert_identical(actual[1], ds2, check_default_indexes=False) + + ds3 = Dataset(coords={"x": 1}).set_xindex("x", ScalarIndex) + + with pytest.raises(AlignmentError, match="cannot align objects"): + xr.align(ds1, ds3, join="exact") + + def test_align_multi_dim_index_exclude_dims(self) -> None: + class XYIndex(Index): + def __init__(self, x: PandasIndex, y: PandasIndex): + self.x: PandasIndex = x + self.y: PandasIndex = y + + @classmethod + def from_variables(cls, variables, *, options): + return cls( + x=PandasIndex.from_variables( + {"x": variables["x"]}, options=options + ), + y=PandasIndex.from_variables( + {"y": variables["y"]}, options=options + ), + ) + + def equals(self, other, exclude_dims=None): + x_eq = True if self.x.dim in exclude_dims else self.x.equals(other.x) + y_eq = True if self.y.dim in exclude_dims else self.y.equals(other.y) + return x_eq and y_eq + + ds1 = ( + Dataset(coords={"x": [1, 2], "y": [3, 4]}) + .drop_indexes(["x", "y"]) + .set_xindex(["x", "y"], XYIndex) + ) + ds2 = ( + Dataset(coords={"x": [1, 2], "y": [5, 6]}) + .drop_indexes(["x", "y"]) + .set_xindex(["x", "y"], XYIndex) + ) + + actual = xr.align(ds1, ds2, join="exact", exclude="y") + assert_identical(actual[0], ds1, check_default_indexes=False) + assert_identical(actual[1], ds2, check_default_indexes=False) + + with pytest.raises( + AlignmentError, match="cannot align objects.*index.*not equal" + ): + xr.align(ds1, ds2, join="exact") + + with pytest.raises(AlignmentError, match="cannot exclude dimension"): + xr.align(ds1, ds2, join="outer", exclude="y") + + def test_align_index_equals_future_warning(self) -> None: + # TODO: remove this test once the deprecation cycle is completed + class DeprecatedEqualsSignatureIndex(PandasIndex): + def equals(self, other: Index) -> bool: # type: ignore[override] + return super().equals(other, exclude_dims=None) + + ds = ( + Dataset(coords={"x": [1, 2]}) + .drop_indexes("x") + .set_xindex("x", DeprecatedEqualsSignatureIndex) + ) + + with pytest.warns(FutureWarning, match="signature.*deprecated"): + xr.align(ds, ds.copy(), join="exact") + def test_broadcast(self) -> None: ds = Dataset( {"foo": 0, "bar": ("x", [1]), "baz": ("y", [2, 3])}, {"c": ("x", [4])}