diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 8babb885a5e..f290ef2bfd1 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -2,9 +2,10 @@ import collections.abc import copy +import inspect from collections import defaultdict -from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast +from collections.abc import Callable, Hashable, Iterable, Iterator, Mapping, Sequence +from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast, overload import numpy as np import pandas as pd @@ -348,7 +349,15 @@ def reindex_like(self, other: Self) -> dict[Hashable, Any]: """ raise NotImplementedError(f"{self!r} doesn't support re-indexing labels") - def equals(self, other: Index) -> bool: + @overload + def equals(self, other: Index) -> bool: ... + + @overload + def equals( + self, other: Index, *, exclude_dims: frozenset[Hashable] | None = None + ) -> bool: ... + + def equals(self, other: Index, **kwargs) -> bool: """Compare this index with another index of the same type. Implementation is optional but required in order to support alignment. @@ -357,6 +366,16 @@ def equals(self, other: Index) -> bool: ---------- other : Index The other Index object to compare with this object. + exclude_dims : frozenset of hashable, optional + All the dimensions that are excluded from alignment, or None by default + (i.e., when this method is not called in the context of alignment). + For a n-dimensional index it allows ignoring any relevant dimension found + in ``exclude_dims`` when comparing this index with the other index. + For a 1-dimensional index it can be always safely ignored as this + method is not called when all of the index's dimensions are also excluded + from alignment + (note: the index's dimensions correspond to the union of the dimensions + of all coordinate variables associated with this index). Returns ------- @@ -863,7 +882,7 @@ def sel( return IndexSelResult({self.dim: indexer}) - def equals(self, other: Index): + def equals(self, other: Index, *, exclude_dims: frozenset[Hashable] | None = None): if not isinstance(other, PandasIndex): return False return self.index.equals(other.index) and self.dim == other.dim @@ -1542,7 +1561,9 @@ def sel( return IndexSelResult(results) - def equals(self, other: Index) -> bool: + def equals( + self, other: Index, *, exclude_dims: frozenset[Hashable] | None = None + ) -> bool: if not isinstance(other, CoordinateTransformIndex): return False return self.transform.equals(other.transform) @@ -1925,6 +1946,36 @@ def default_indexes( return indexes +def _wrap_index_equals( + index: Index, +) -> Callable[[Index, frozenset[Hashable]], bool]: + # TODO: remove this Index.equals() wrapper (backward compatibility) + + sig = inspect.signature(index.equals) + + if len(sig.parameters) == 1: + index_cls_name = type(index).__module__ + "." + type(index).__qualname__ + emit_user_level_warning( + f"the signature ``{index_cls_name}.equals(self, other)`` is deprecated. " + f"Please update it to " + f"``{index_cls_name}.equals(self, other, *, exclude_dims=None)`` " + "or kindly ask the maintainers doing it. " + "See documentation of xarray.Index.equals() for more info.", + FutureWarning, + ) + exclude_dims_kwarg = False + else: + exclude_dims_kwarg = True + + def equals_wrapper(other: Index, exclude_dims: frozenset[Hashable]) -> bool: + if exclude_dims_kwarg: + return index.equals(other, exclude_dims=exclude_dims) + else: + return index.equals(other) + + return equals_wrapper + + def indexes_equal( index: Index, other_index: Index, @@ -1966,6 +2017,7 @@ def indexes_equal( def indexes_all_equal( elements: Sequence[tuple[Index, dict[Hashable, Variable]]], + exclude_dims: frozenset[Hashable], ) -> bool: """Check if indexes are all equal. @@ -1990,9 +2042,11 @@ def check_variables(): same_type = all(type(indexes[0]) is type(other_idx) for other_idx in indexes[1:]) if same_type: + index_equals_func = _wrap_index_equals(indexes[0]) try: not_equal = any( - not indexes[0].equals(other_idx) for other_idx in indexes[1:] + not index_equals_func(other_idx, exclude_dims) + for other_idx in indexes[1:] ) except NotImplementedError: not_equal = check_variables() diff --git a/xarray/structure/alignment.py b/xarray/structure/alignment.py index ea90519143c..49d2709343e 100644 --- a/xarray/structure/alignment.py +++ b/xarray/structure/alignment.py @@ -216,30 +216,37 @@ def _normalize_indexes( normalized_indexes = {} normalized_index_vars = {} - for idx, index_vars in Indexes(xr_indexes, xr_variables).group_by_index(): - coord_names_and_dims = [] - all_dims: set[Hashable] = set() - for name, var in index_vars.items(): + for idx, idx_vars in Indexes(xr_indexes, xr_variables).group_by_index(): + idx_coord_names_and_dims = [] + idx_all_dims: set[Hashable] = set() + + for name, var in idx_vars.items(): dims = var.dims - coord_names_and_dims.append((name, dims)) - all_dims.update(dims) - - exclude_dims = all_dims & self.exclude_dims - if exclude_dims == all_dims: - continue - elif exclude_dims: - excl_dims_str = ", ".join(str(d) for d in exclude_dims) - incl_dims_str = ", ".join(str(d) for d in all_dims - exclude_dims) - raise AlignmentError( - f"cannot exclude dimension(s) {excl_dims_str} from alignment because " - "these are used by an index together with non-excluded dimensions " - f"{incl_dims_str}" - ) + idx_coord_names_and_dims.append((name, dims)) + idx_all_dims.update(dims) + + # We can ignore an index if all the dimensions it uses are also excluded + # from the alignment (do not ignore the index if it has no related dimension, i.e., + # it is associated with one or more scalar coordinates). + if idx_all_dims: + exclude_dims = idx_all_dims & self.exclude_dims + if exclude_dims == idx_all_dims: + continue + elif exclude_dims and self.join != "exact": + excl_dims_str = ", ".join(str(d) for d in exclude_dims) + incl_dims_str = ", ".join( + str(d) for d in idx_all_dims - exclude_dims + ) + raise AlignmentError( + f"cannot exclude dimension(s) {excl_dims_str} from non-exact alignment " + "because these are used by an index together with non-excluded dimensions " + f"{incl_dims_str}" + ) - key = (tuple(coord_names_and_dims), type(idx)) + key = (tuple(idx_coord_names_and_dims), type(idx)) normalized_indexes[key] = idx - normalized_index_vars[key] = index_vars + normalized_index_vars[key] = idx_vars return normalized_indexes, normalized_index_vars @@ -298,7 +305,7 @@ def _need_reindex(self, dim, cmp_indexes) -> bool: pandas). This is useful, e.g., for overwriting such duplicate indexes. """ - if not indexes_all_equal(cmp_indexes): + if not indexes_all_equal(cmp_indexes, self.exclude_dims): # always reindex when matching indexes are not equal return True diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index ac186a7d351..c001b3b69fc 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -2615,6 +2615,93 @@ def test_align_index_var_attrs(self, join) -> None: assert ds.x.attrs == {"units": "m"} assert ds_noattr.x.attrs == {} + def test_align_scalar_index(self) -> None: + # ensure that indexes associated with scalar coordinates are not ignored + # during alignment + class ScalarIndex(Index): + def __init__(self, value: int): + self.value = value + + @classmethod + def from_variables(cls, variables, *, options): + var = next(iter(variables.values())) + return cls(int(var.values)) + + def equals(self, other, *, exclude_dims=None): + return isinstance(other, ScalarIndex) and other.value == self.value + + ds1 = Dataset(coords={"x": 0}).set_xindex("x", ScalarIndex) + ds2 = Dataset(coords={"x": 0}).set_xindex("x", ScalarIndex) + + actual = xr.align(ds1, ds2, join="exact") + assert_identical(actual[0], ds1, check_default_indexes=False) + assert_identical(actual[1], ds2, check_default_indexes=False) + + ds3 = Dataset(coords={"x": 1}).set_xindex("x", ScalarIndex) + + with pytest.raises(AlignmentError, match="cannot align objects"): + xr.align(ds1, ds3, join="exact") + + def test_align_multi_dim_index_exclude_dims(self) -> None: + class XYIndex(Index): + def __init__(self, x: PandasIndex, y: PandasIndex): + self.x: PandasIndex = x + self.y: PandasIndex = y + + @classmethod + def from_variables(cls, variables, *, options): + return cls( + x=PandasIndex.from_variables( + {"x": variables["x"]}, options=options + ), + y=PandasIndex.from_variables( + {"y": variables["y"]}, options=options + ), + ) + + def equals(self, other, exclude_dims=None): + x_eq = True if self.x.dim in exclude_dims else self.x.equals(other.x) + y_eq = True if self.y.dim in exclude_dims else self.y.equals(other.y) + return x_eq and y_eq + + ds1 = ( + Dataset(coords={"x": [1, 2], "y": [3, 4]}) + .drop_indexes(["x", "y"]) + .set_xindex(["x", "y"], XYIndex) + ) + ds2 = ( + Dataset(coords={"x": [1, 2], "y": [5, 6]}) + .drop_indexes(["x", "y"]) + .set_xindex(["x", "y"], XYIndex) + ) + + actual = xr.align(ds1, ds2, join="exact", exclude="y") + assert_identical(actual[0], ds1, check_default_indexes=False) + assert_identical(actual[1], ds2, check_default_indexes=False) + + with pytest.raises( + AlignmentError, match="cannot align objects.*index.*not equal" + ): + xr.align(ds1, ds2, join="exact") + + with pytest.raises(AlignmentError, match="cannot exclude dimension"): + xr.align(ds1, ds2, join="outer", exclude="y") + + def test_align_index_equals_future_warning(self) -> None: + # TODO: remove this test once the deprecation cycle is completed + class DeprecatedEqualsSignatureIndex(PandasIndex): + def equals(self, other: Index) -> bool: # type: ignore[override] + return super().equals(other, exclude_dims=None) + + ds = ( + Dataset(coords={"x": [1, 2]}) + .drop_indexes("x") + .set_xindex("x", DeprecatedEqualsSignatureIndex) + ) + + with pytest.warns(FutureWarning, match="signature.*deprecated"): + xr.align(ds, ds.copy(), join="exact") + def test_broadcast(self) -> None: ds = Dataset( {"foo": 0, "bar": ("x", [1]), "baz": ("y", [2, 3])}, {"c": ("x", [4])}