-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Alignment of n-dimensional indexes with partially excluded dims #10293
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -2,9 +2,10 @@ | |||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
import collections.abc | ||||||||||||||||||||||||||||||||||||
import copy | ||||||||||||||||||||||||||||||||||||
import inspect | ||||||||||||||||||||||||||||||||||||
from collections import defaultdict | ||||||||||||||||||||||||||||||||||||
from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence | ||||||||||||||||||||||||||||||||||||
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast | ||||||||||||||||||||||||||||||||||||
from collections.abc import Callable, Hashable, Iterable, Iterator, Mapping, Sequence | ||||||||||||||||||||||||||||||||||||
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast, overload | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
import numpy as np | ||||||||||||||||||||||||||||||||||||
import pandas as pd | ||||||||||||||||||||||||||||||||||||
|
@@ -348,7 +349,15 @@ def reindex_like(self, other: Self) -> dict[Hashable, Any]: | |||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||
raise NotImplementedError(f"{self!r} doesn't support re-indexing labels") | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
def equals(self, other: Index) -> bool: | ||||||||||||||||||||||||||||||||||||
@overload | ||||||||||||||||||||||||||||||||||||
def equals(self, other: Index) -> bool: ... | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
@overload | ||||||||||||||||||||||||||||||||||||
def equals( | ||||||||||||||||||||||||||||||||||||
self, other: Index, *, exclude_dims: frozenset[Hashable] | None = None | ||||||||||||||||||||||||||||||||||||
) -> bool: ... | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
def equals(self, other: Index, **kwargs) -> bool: | ||||||||||||||||||||||||||||||||||||
"""Compare this index with another index of the same type. | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
Implementation is optional but required in order to support alignment. | ||||||||||||||||||||||||||||||||||||
|
@@ -357,6 +366,16 @@ def equals(self, other: Index) -> bool: | |||||||||||||||||||||||||||||||||||
---------- | ||||||||||||||||||||||||||||||||||||
other : Index | ||||||||||||||||||||||||||||||||||||
The other Index object to compare with this object. | ||||||||||||||||||||||||||||||||||||
exclude_dims : frozenset of hashable, optional | ||||||||||||||||||||||||||||||||||||
All the dimensions that are excluded from alignment, or None by default | ||||||||||||||||||||||||||||||||||||
(i.e., when this method is not called in the context of alignment). | ||||||||||||||||||||||||||||||||||||
For a n-dimensional index it allows ignoring any relevant dimension found | ||||||||||||||||||||||||||||||||||||
in ``exclude_dims`` when comparing this index with the other index. | ||||||||||||||||||||||||||||||||||||
For a 1-dimensional index it can be always safely ignored as this | ||||||||||||||||||||||||||||||||||||
method is not called when all of the index's dimensions are also excluded | ||||||||||||||||||||||||||||||||||||
from alignment | ||||||||||||||||||||||||||||||||||||
(note: the index's dimensions correspond to the union of the dimensions | ||||||||||||||||||||||||||||||||||||
of all coordinate variables associated with this index). | ||||||||||||||||||||||||||||||||||||
Comment on lines
+369
to
+378
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could probably be improved a bit.
Comment on lines
+370
to
+378
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
Returns | ||||||||||||||||||||||||||||||||||||
------- | ||||||||||||||||||||||||||||||||||||
|
@@ -863,7 +882,7 @@ def sel( | |||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
return IndexSelResult({self.dim: indexer}) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
def equals(self, other: Index): | ||||||||||||||||||||||||||||||||||||
def equals(self, other: Index, *, exclude_dims: frozenset[Hashable] | None = None): | ||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||
if not isinstance(other, PandasIndex): | ||||||||||||||||||||||||||||||||||||
return False | ||||||||||||||||||||||||||||||||||||
return self.index.equals(other.index) and self.dim == other.dim | ||||||||||||||||||||||||||||||||||||
|
@@ -1542,7 +1561,9 @@ def sel( | |||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
return IndexSelResult(results) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
def equals(self, other: Index) -> bool: | ||||||||||||||||||||||||||||||||||||
def equals( | ||||||||||||||||||||||||||||||||||||
self, other: Index, *, exclude_dims: frozenset[Hashable] | None = None | ||||||||||||||||||||||||||||||||||||
) -> bool: | ||||||||||||||||||||||||||||||||||||
if not isinstance(other, CoordinateTransformIndex): | ||||||||||||||||||||||||||||||||||||
return False | ||||||||||||||||||||||||||||||||||||
return self.transform.equals(other.transform) | ||||||||||||||||||||||||||||||||||||
|
@@ -1925,6 +1946,36 @@ def default_indexes( | |||||||||||||||||||||||||||||||||||
return indexes | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
def _wrap_index_equals( | ||||||||||||||||||||||||||||||||||||
index: Index, | ||||||||||||||||||||||||||||||||||||
) -> Callable[[Index, frozenset[Hashable]], bool]: | ||||||||||||||||||||||||||||||||||||
# TODO: remove this Index.equals() wrapper (backward compatibility) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
sig = inspect.signature(index.equals) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
if len(sig.parameters) == 1: | ||||||||||||||||||||||||||||||||||||
index_cls_name = type(index).__module__ + "." + type(index).__qualname__ | ||||||||||||||||||||||||||||||||||||
emit_user_level_warning( | ||||||||||||||||||||||||||||||||||||
f"the signature ``{index_cls_name}.equals(self, other)`` is deprecated. " | ||||||||||||||||||||||||||||||||||||
f"Please update it to " | ||||||||||||||||||||||||||||||||||||
f"``{index_cls_name}.equals(self, other, *, exclude_dims=None)`` " | ||||||||||||||||||||||||||||||||||||
"or kindly ask the maintainers doing it. " | ||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
otherwise they might come to us ;) |
||||||||||||||||||||||||||||||||||||
"See documentation of xarray.Index.equals() for more info.", | ||||||||||||||||||||||||||||||||||||
FutureWarning, | ||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||
exclude_dims_kwarg = False | ||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||
exclude_dims_kwarg = True | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
def equals_wrapper(other: Index, exclude_dims: frozenset[Hashable]) -> bool: | ||||||||||||||||||||||||||||||||||||
if exclude_dims_kwarg: | ||||||||||||||||||||||||||||||||||||
return index.equals(other, exclude_dims=exclude_dims) | ||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||
return index.equals(other) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
return equals_wrapper | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
def indexes_equal( | ||||||||||||||||||||||||||||||||||||
index: Index, | ||||||||||||||||||||||||||||||||||||
other_index: Index, | ||||||||||||||||||||||||||||||||||||
|
@@ -1966,6 +2017,7 @@ def indexes_equal( | |||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
def indexes_all_equal( | ||||||||||||||||||||||||||||||||||||
elements: Sequence[tuple[Index, dict[Hashable, Variable]]], | ||||||||||||||||||||||||||||||||||||
exclude_dims: frozenset[Hashable], | ||||||||||||||||||||||||||||||||||||
) -> bool: | ||||||||||||||||||||||||||||||||||||
"""Check if indexes are all equal. | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
|
@@ -1990,9 +2042,11 @@ def check_variables(): | |||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
same_type = all(type(indexes[0]) is type(other_idx) for other_idx in indexes[1:]) | ||||||||||||||||||||||||||||||||||||
if same_type: | ||||||||||||||||||||||||||||||||||||
index_equals_func = _wrap_index_equals(indexes[0]) | ||||||||||||||||||||||||||||||||||||
try: | ||||||||||||||||||||||||||||||||||||
not_equal = any( | ||||||||||||||||||||||||||||||||||||
not indexes[0].equals(other_idx) for other_idx in indexes[1:] | ||||||||||||||||||||||||||||||||||||
not index_equals_func(other_idx, exclude_dims) | ||||||||||||||||||||||||||||||||||||
for other_idx in indexes[1:] | ||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||
except NotImplementedError: | ||||||||||||||||||||||||||||||||||||
not_equal = check_variables() | ||||||||||||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -216,30 +216,37 @@ def _normalize_indexes( | |
|
||
normalized_indexes = {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add a comment describing what "normalize" means here please? it would also be good to describe the purpose of each loop with a block comment. I can't figure out what's going on here |
||
normalized_index_vars = {} | ||
for idx, index_vars in Indexes(xr_indexes, xr_variables).group_by_index(): | ||
coord_names_and_dims = [] | ||
all_dims: set[Hashable] = set() | ||
|
||
for name, var in index_vars.items(): | ||
for idx, idx_vars in Indexes(xr_indexes, xr_variables).group_by_index(): | ||
idx_coord_names_and_dims = [] | ||
idx_all_dims: set[Hashable] = set() | ||
|
||
for name, var in idx_vars.items(): | ||
dims = var.dims | ||
coord_names_and_dims.append((name, dims)) | ||
all_dims.update(dims) | ||
|
||
exclude_dims = all_dims & self.exclude_dims | ||
if exclude_dims == all_dims: | ||
continue | ||
elif exclude_dims: | ||
excl_dims_str = ", ".join(str(d) for d in exclude_dims) | ||
incl_dims_str = ", ".join(str(d) for d in all_dims - exclude_dims) | ||
raise AlignmentError( | ||
f"cannot exclude dimension(s) {excl_dims_str} from alignment because " | ||
"these are used by an index together with non-excluded dimensions " | ||
f"{incl_dims_str}" | ||
) | ||
idx_coord_names_and_dims.append((name, dims)) | ||
idx_all_dims.update(dims) | ||
|
||
# We can ignore an index if all the dimensions it uses are also excluded | ||
# from the alignment (do not ignore the index if it has no related dimension, i.e., | ||
# it is associated with one or more scalar coordinates). | ||
if idx_all_dims: | ||
exclude_dims = idx_all_dims & self.exclude_dims | ||
if exclude_dims == idx_all_dims: | ||
continue | ||
elif exclude_dims and self.join != "exact": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm having trouble deciding if any of the other cases make any sense. Seems OK to remove anything that requires actual reindexing, which leaves |
||
excl_dims_str = ", ".join(str(d) for d in exclude_dims) | ||
incl_dims_str = ", ".join( | ||
str(d) for d in idx_all_dims - exclude_dims | ||
) | ||
raise AlignmentError( | ||
f"cannot exclude dimension(s) {excl_dims_str} from non-exact alignment " | ||
"because these are used by an index together with non-excluded dimensions " | ||
f"{incl_dims_str}" | ||
) | ||
|
||
key = (tuple(coord_names_and_dims), type(idx)) | ||
key = (tuple(idx_coord_names_and_dims), type(idx)) | ||
normalized_indexes[key] = idx | ||
normalized_index_vars[key] = index_vars | ||
normalized_index_vars[key] = idx_vars | ||
|
||
return normalized_indexes, normalized_index_vars | ||
|
||
|
@@ -298,7 +305,7 @@ def _need_reindex(self, dim, cmp_indexes) -> bool: | |
pandas). This is useful, e.g., for overwriting such duplicate indexes. | ||
|
||
""" | ||
if not indexes_all_equal(cmp_indexes): | ||
if not indexes_all_equal(cmp_indexes, self.exclude_dims): | ||
# always reindex when matching indexes are not equal | ||
return True | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added those overloads so Mypy doesn't fail for 3rd-party Xarray indexes. But maybe we want Mypy to fail and thereby encourage maintainers updating their code?