Skip to content

Alignment of n-dimensional indexes with partially excluded dims #10293

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 60 additions & 6 deletions xarray/core/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import collections.abc
import copy
import inspect
from collections import defaultdict
from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
from collections.abc import Callable, Hashable, Iterable, Iterator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast, overload

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -348,7 +349,15 @@ def reindex_like(self, other: Self) -> dict[Hashable, Any]:
"""
raise NotImplementedError(f"{self!r} doesn't support re-indexing labels")

def equals(self, other: Index) -> bool:
@overload
def equals(self, other: Index) -> bool: ...

@overload
def equals(
self, other: Index, *, exclude_dims: frozenset[Hashable] | None = None
) -> bool: ...

def equals(self, other: Index, **kwargs) -> bool:
Comment on lines +352 to +360
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added those overloads so Mypy doesn't fail for 3rd-party Xarray indexes. But maybe we want Mypy to fail and thereby encourage maintainers updating their code?

"""Compare this index with another index of the same type.

Implementation is optional but required in order to support alignment.
Expand All @@ -357,6 +366,16 @@ def equals(self, other: Index) -> bool:
----------
other : Index
The other Index object to compare with this object.
exclude_dims : frozenset of hashable, optional
All the dimensions that are excluded from alignment, or None by default
(i.e., when this method is not called in the context of alignment).
For a n-dimensional index it allows ignoring any relevant dimension found
in ``exclude_dims`` when comparing this index with the other index.
For a 1-dimensional index it can be always safely ignored as this
method is not called when all of the index's dimensions are also excluded
from alignment
(note: the index's dimensions correspond to the union of the dimensions
of all coordinate variables associated with this index).
Comment on lines +369 to +378
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could probably be improved a bit.

Comment on lines +370 to +378
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
All the dimensions that are excluded from alignment, or None by default
(i.e., when this method is not called in the context of alignment).
For a n-dimensional index it allows ignoring any relevant dimension found
in ``exclude_dims`` when comparing this index with the other index.
For a 1-dimensional index it can be always safely ignored as this
method is not called when all of the index's dimensions are also excluded
from alignment
(note: the index's dimensions correspond to the union of the dimensions
of all coordinate variables associated with this index).
Dimensions excluded from checking. It is None by default,
(i.e., when this method is not called in the context of alignment).
For a n-dimensional index this option allows an Index to optionally,
ignore any dimension in ``exclude_dims`` when comparing
``self`` with ``other``. For a 1-dimensional index this kwarg can be safely ignored ,
as this method is not called when all of the index's dimensions are also excluded
from alignment (note: the index's dimensions correspond to the union of the dimensions
of all coordinate variables associated with this index).


Returns
-------
Expand Down Expand Up @@ -863,7 +882,7 @@ def sel(

return IndexSelResult({self.dim: indexer})

def equals(self, other: Index):
def equals(self, other: Index, *, exclude_dims: frozenset[Hashable] | None = None):
Copy link
Contributor

@dcherian dcherian May 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def equals(self, other: Index, *, exclude_dims: frozenset[Hashable] | None = None):
def equals(self, other: Index, *, exclude: frozenset[Hashable] | None = None):

align calls this exclude. Shall we keep that name for consistency?

if not isinstance(other, PandasIndex):
return False
return self.index.equals(other.index) and self.dim == other.dim
Expand Down Expand Up @@ -1542,7 +1561,9 @@ def sel(

return IndexSelResult(results)

def equals(self, other: Index) -> bool:
def equals(
self, other: Index, *, exclude_dims: frozenset[Hashable] | None = None
) -> bool:
if not isinstance(other, CoordinateTransformIndex):
return False
return self.transform.equals(other.transform)
Expand Down Expand Up @@ -1925,6 +1946,36 @@ def default_indexes(
return indexes


def _wrap_index_equals(
index: Index,
) -> Callable[[Index, frozenset[Hashable]], bool]:
# TODO: remove this Index.equals() wrapper (backward compatibility)

sig = inspect.signature(index.equals)

if len(sig.parameters) == 1:
index_cls_name = type(index).__module__ + "." + type(index).__qualname__
emit_user_level_warning(
f"the signature ``{index_cls_name}.equals(self, other)`` is deprecated. "
f"Please update it to "
f"``{index_cls_name}.equals(self, other, *, exclude_dims=None)`` "
"or kindly ask the maintainers doing it. "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"or kindly ask the maintainers doing it. "
"or kindly ask the maintainers of ``{index_cls_name}`` to do it. "

otherwise they might come to us ;)

"See documentation of xarray.Index.equals() for more info.",
FutureWarning,
)
exclude_dims_kwarg = False
else:
exclude_dims_kwarg = True

def equals_wrapper(other: Index, exclude_dims: frozenset[Hashable]) -> bool:
if exclude_dims_kwarg:
return index.equals(other, exclude_dims=exclude_dims)
else:
return index.equals(other)

return equals_wrapper


def indexes_equal(
index: Index,
other_index: Index,
Expand Down Expand Up @@ -1966,6 +2017,7 @@ def indexes_equal(

def indexes_all_equal(
elements: Sequence[tuple[Index, dict[Hashable, Variable]]],
exclude_dims: frozenset[Hashable],
) -> bool:
"""Check if indexes are all equal.

Expand All @@ -1990,9 +2042,11 @@ def check_variables():

same_type = all(type(indexes[0]) is type(other_idx) for other_idx in indexes[1:])
if same_type:
index_equals_func = _wrap_index_equals(indexes[0])
try:
not_equal = any(
not indexes[0].equals(other_idx) for other_idx in indexes[1:]
not index_equals_func(other_idx, exclude_dims)
for other_idx in indexes[1:]
)
except NotImplementedError:
not_equal = check_variables()
Expand Down
49 changes: 28 additions & 21 deletions xarray/structure/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,30 +216,37 @@ def _normalize_indexes(

normalized_indexes = {}
Copy link
Contributor

@dcherian dcherian May 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a comment describing what "normalize" means here please? it would also be good to describe the purpose of each loop with a block comment. I can't figure out what's going on here

normalized_index_vars = {}
for idx, index_vars in Indexes(xr_indexes, xr_variables).group_by_index():
coord_names_and_dims = []
all_dims: set[Hashable] = set()

for name, var in index_vars.items():
for idx, idx_vars in Indexes(xr_indexes, xr_variables).group_by_index():
idx_coord_names_and_dims = []
idx_all_dims: set[Hashable] = set()

for name, var in idx_vars.items():
dims = var.dims
coord_names_and_dims.append((name, dims))
all_dims.update(dims)

exclude_dims = all_dims & self.exclude_dims
if exclude_dims == all_dims:
continue
elif exclude_dims:
excl_dims_str = ", ".join(str(d) for d in exclude_dims)
incl_dims_str = ", ".join(str(d) for d in all_dims - exclude_dims)
raise AlignmentError(
f"cannot exclude dimension(s) {excl_dims_str} from alignment because "
"these are used by an index together with non-excluded dimensions "
f"{incl_dims_str}"
)
idx_coord_names_and_dims.append((name, dims))
idx_all_dims.update(dims)

# We can ignore an index if all the dimensions it uses are also excluded
# from the alignment (do not ignore the index if it has no related dimension, i.e.,
# it is associated with one or more scalar coordinates).
if idx_all_dims:
exclude_dims = idx_all_dims & self.exclude_dims
if exclude_dims == idx_all_dims:
continue
elif exclude_dims and self.join != "exact":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm having trouble deciding if any of the other cases make any sense. Seems OK to remove anything that requires actual reindexing, which leaves "override" and I'm not sure about that. Let's have the error message recommend the user to open an issue if they feel it should work.

excl_dims_str = ", ".join(str(d) for d in exclude_dims)
incl_dims_str = ", ".join(
str(d) for d in idx_all_dims - exclude_dims
)
raise AlignmentError(
f"cannot exclude dimension(s) {excl_dims_str} from non-exact alignment "
"because these are used by an index together with non-excluded dimensions "
f"{incl_dims_str}"
)

key = (tuple(coord_names_and_dims), type(idx))
key = (tuple(idx_coord_names_and_dims), type(idx))
normalized_indexes[key] = idx
normalized_index_vars[key] = index_vars
normalized_index_vars[key] = idx_vars

return normalized_indexes, normalized_index_vars

Expand Down Expand Up @@ -298,7 +305,7 @@ def _need_reindex(self, dim, cmp_indexes) -> bool:
pandas). This is useful, e.g., for overwriting such duplicate indexes.

"""
if not indexes_all_equal(cmp_indexes):
if not indexes_all_equal(cmp_indexes, self.exclude_dims):
# always reindex when matching indexes are not equal
return True

Expand Down
87 changes: 87 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2615,6 +2615,93 @@ def test_align_index_var_attrs(self, join) -> None:
assert ds.x.attrs == {"units": "m"}
assert ds_noattr.x.attrs == {}

def test_align_scalar_index(self) -> None:
# ensure that indexes associated with scalar coordinates are not ignored
# during alignment
class ScalarIndex(Index):
def __init__(self, value: int):
self.value = value

@classmethod
def from_variables(cls, variables, *, options):
var = next(iter(variables.values()))
return cls(int(var.values))

def equals(self, other, *, exclude_dims=None):
return isinstance(other, ScalarIndex) and other.value == self.value

ds1 = Dataset(coords={"x": 0}).set_xindex("x", ScalarIndex)
ds2 = Dataset(coords={"x": 0}).set_xindex("x", ScalarIndex)

actual = xr.align(ds1, ds2, join="exact")
assert_identical(actual[0], ds1, check_default_indexes=False)
assert_identical(actual[1], ds2, check_default_indexes=False)

ds3 = Dataset(coords={"x": 1}).set_xindex("x", ScalarIndex)

with pytest.raises(AlignmentError, match="cannot align objects"):
xr.align(ds1, ds3, join="exact")

def test_align_multi_dim_index_exclude_dims(self) -> None:
class XYIndex(Index):
def __init__(self, x: PandasIndex, y: PandasIndex):
self.x: PandasIndex = x
self.y: PandasIndex = y

@classmethod
def from_variables(cls, variables, *, options):
return cls(
x=PandasIndex.from_variables(
{"x": variables["x"]}, options=options
),
y=PandasIndex.from_variables(
{"y": variables["y"]}, options=options
),
)

def equals(self, other, exclude_dims=None):
x_eq = True if self.x.dim in exclude_dims else self.x.equals(other.x)
y_eq = True if self.y.dim in exclude_dims else self.y.equals(other.y)
return x_eq and y_eq

ds1 = (
Dataset(coords={"x": [1, 2], "y": [3, 4]})
.drop_indexes(["x", "y"])
.set_xindex(["x", "y"], XYIndex)
)
ds2 = (
Dataset(coords={"x": [1, 2], "y": [5, 6]})
.drop_indexes(["x", "y"])
.set_xindex(["x", "y"], XYIndex)
)

actual = xr.align(ds1, ds2, join="exact", exclude="y")
assert_identical(actual[0], ds1, check_default_indexes=False)
assert_identical(actual[1], ds2, check_default_indexes=False)

with pytest.raises(
AlignmentError, match="cannot align objects.*index.*not equal"
):
xr.align(ds1, ds2, join="exact")

with pytest.raises(AlignmentError, match="cannot exclude dimension"):
xr.align(ds1, ds2, join="outer", exclude="y")

def test_align_index_equals_future_warning(self) -> None:
# TODO: remove this test once the deprecation cycle is completed
class DeprecatedEqualsSignatureIndex(PandasIndex):
def equals(self, other: Index) -> bool: # type: ignore[override]
return super().equals(other, exclude_dims=None)

ds = (
Dataset(coords={"x": [1, 2]})
.drop_indexes("x")
.set_xindex("x", DeprecatedEqualsSignatureIndex)
)

with pytest.warns(FutureWarning, match="signature.*deprecated"):
xr.align(ds, ds.copy(), join="exact")

def test_broadcast(self) -> None:
ds = Dataset(
{"foo": 0, "bar": ("x", [1]), "baz": ("y", [2, 3])}, {"c": ("x", [4])}
Expand Down
Loading