-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
enforced return types of groupby sequence arguments #10271
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
zerothi
wants to merge
2
commits into
pydata:main
Choose a base branch
from
zerothi:10246-groupby-return-type
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -383,7 +383,7 @@ def _parse_group_and_groupers( | |||||||
groupers: dict[str, Grouper], | ||||||||
*, | ||||||||
eagerly_compute_group: bool, | ||||||||
) -> tuple[ResolvedGrouper, ...]: | ||||||||
) -> ResolvedGrouper | tuple[ResolvedGrouper, ...]: | ||||||||
from xarray.core.dataarray import DataArray | ||||||||
from xarray.core.variable import Variable | ||||||||
from xarray.groupers import UniqueGrouper | ||||||||
|
@@ -407,6 +407,7 @@ def _parse_group_and_groupers( | |||||||
|
||||||||
rgroupers: tuple[ResolvedGrouper, ...] | ||||||||
if isinstance(group, DataArray | Variable): | ||||||||
# TODO add test for this, see gh-10246 | ||||||||
rgroupers = ( | ||||||||
ResolvedGrouper( | ||||||||
UniqueGrouper(), group, obj, eagerly_compute_group=eagerly_compute_group | ||||||||
|
@@ -429,6 +430,8 @@ def _parse_group_and_groupers( | |||||||
) | ||||||||
for group, grouper in grouper_mapping.items() | ||||||||
) | ||||||||
if isinstance(group, str): | ||||||||
rgroupers = rgroupers[0] | ||||||||
return rgroupers | ||||||||
|
||||||||
|
||||||||
|
@@ -453,7 +456,6 @@ def _resolve_group( | |||||||
"match the length of this variable along its " | ||||||||
"dimensions" | ||||||||
) | ||||||||
|
||||||||
newgroup: T_Group | ||||||||
if isinstance(group, DataArray): | ||||||||
try: | ||||||||
|
@@ -602,7 +604,7 @@ class GroupBy(Generic[T_Xarray]): | |||||||
"groupers", | ||||||||
) | ||||||||
_obj: T_Xarray | ||||||||
groupers: tuple[ResolvedGrouper, ...] | ||||||||
groupers: ResolvedGrouper | tuple[ResolvedGrouper, ...] | ||||||||
_restore_coord_dims: bool | ||||||||
|
||||||||
_original_obj: T_Xarray | ||||||||
|
@@ -626,7 +628,7 @@ class GroupBy(Generic[T_Xarray]): | |||||||
def __init__( | ||||||||
self, | ||||||||
obj: T_Xarray, | ||||||||
groupers: tuple[ResolvedGrouper, ...], | ||||||||
groupers: ResolvedGrouper | tuple[ResolvedGrouper, ...], | ||||||||
restore_coord_dims: bool = True, | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and then I'd handle this inside
|
||||||||
) -> None: | ||||||||
"""Create a GroupBy object | ||||||||
|
@@ -635,7 +637,7 @@ def __init__( | |||||||
---------- | ||||||||
obj : Dataset or DataArray | ||||||||
Object to group. | ||||||||
grouper : Grouper | ||||||||
groupers : ResolvedGrouper or tuple[ResolvedGrouper, ...] | ||||||||
Grouper object | ||||||||
restore_coord_dims : bool, default: True | ||||||||
If True, also restore the dimension order of multi-dimensional | ||||||||
|
@@ -645,9 +647,8 @@ def __init__( | |||||||
self._restore_coord_dims = restore_coord_dims | ||||||||
self.groupers = groupers | ||||||||
|
||||||||
if len(groupers) == 1: | ||||||||
(grouper,) = groupers | ||||||||
self.encoded = grouper.encoded | ||||||||
if isinstance(groupers, ResolvedGrouper): | ||||||||
self.encoded = groupers.encoded | ||||||||
else: | ||||||||
if any( | ||||||||
isinstance(obj._indexes.get(grouper.name, None), PandasMultiIndex) | ||||||||
|
@@ -699,6 +700,12 @@ def sizes(self) -> Mapping[Hashable, int]: | |||||||
self._sizes = self._obj.isel({self._group_dim: index}).sizes | ||||||||
return self._sizes | ||||||||
|
||||||||
@property | ||||||||
def _groupers_tuple(self) -> tuple[ResolvedGrouper, ...]: | ||||||||
if isinstance(self.groupers, ResolvedGrouper): | ||||||||
return (self.groupers,) | ||||||||
return self.groupers | ||||||||
|
||||||||
def shuffle_to_chunks(self, chunks: T_Chunks = None) -> T_Xarray: | ||||||||
""" | ||||||||
Sort or "shuffle" the underlying object. | ||||||||
|
@@ -756,7 +763,8 @@ def _shuffle_obj(self, chunks: T_Chunks) -> T_Xarray: | |||||||
was_array = isinstance(self._obj, DataArray) | ||||||||
as_dataset = self._obj._to_temp_dataset() if was_array else self._obj | ||||||||
|
||||||||
for grouper in self.groupers: | ||||||||
groupers = self._groupers_tuple | ||||||||
for grouper in groupers: | ||||||||
if grouper.name not in as_dataset._variables: | ||||||||
as_dataset.coords[grouper.name] = grouper.group | ||||||||
|
||||||||
|
@@ -801,7 +809,9 @@ def _raise_if_by_is_chunked(self): | |||||||
) | ||||||||
|
||||||||
def _raise_if_not_single_group(self): | ||||||||
if len(self.groupers) != 1: | ||||||||
# This ensures that one does not do `groupby_bins(..., ["x"])` | ||||||||
# TODO this should be enabled later. | ||||||||
if not isinstance(self.groupers, ResolvedGrouper): | ||||||||
raise NotImplementedError( | ||||||||
"This method is not supported for grouping by multiple variables yet." | ||||||||
) | ||||||||
|
@@ -836,12 +846,13 @@ def __iter__(self) -> Iterator[tuple[GroupKey, T_Xarray]]: | |||||||
return zip(self.encoded.unique_coord.data, self._iter_grouped(), strict=True) | ||||||||
|
||||||||
def __repr__(self) -> str: | ||||||||
groupers = self._groupers_tuple | ||||||||
text = ( | ||||||||
f"<{self.__class__.__name__}, " | ||||||||
f"grouped over {len(self.groupers)} grouper(s)," | ||||||||
f"grouped over {len(groupers)} grouper(s)," | ||||||||
f" {self._len} groups in total:" | ||||||||
) | ||||||||
for grouper in self.groupers: | ||||||||
for grouper in groupers: | ||||||||
coord = grouper.unique_coord | ||||||||
labels = ", ".join(format_array_flat(coord, 30).split()) | ||||||||
text += f"\n {grouper.name!r}: {coord.size}/{grouper.full_index.size} groups present with labels {labels}" | ||||||||
|
@@ -871,7 +882,7 @@ def _binary_op(self, other, f, reflexive=False): | |||||||
g = f if not reflexive else lambda x, y: f(y, x) | ||||||||
|
||||||||
self._raise_if_not_single_group() | ||||||||
(grouper,) = self.groupers | ||||||||
(grouper,) = self._groupers_tuple | ||||||||
obj = self._original_obj | ||||||||
name = grouper.name | ||||||||
group = grouper.group | ||||||||
|
@@ -971,11 +982,12 @@ def _maybe_reindex(self, combined): | |||||||
self.encoded.unique_coord.size != self.encoded.full_index.size | ||||||||
) | ||||||||
indexers = {} | ||||||||
for grouper in self.groupers: | ||||||||
groupers = self._groupers_tuple | ||||||||
for grouper in groupers: | ||||||||
index = combined._indexes.get(grouper.name, None) | ||||||||
if has_missing_groups and index is not None: | ||||||||
indexers[grouper.name] = grouper.full_index | ||||||||
elif len(self.groupers) > 1: | ||||||||
elif len(groupers) > 1: | ||||||||
if not isinstance( | ||||||||
grouper.full_index, pd.RangeIndex | ||||||||
) and not index.index.equals(grouper.full_index): | ||||||||
|
@@ -989,6 +1001,7 @@ def _maybe_unstack(self, obj): | |||||||
multidimensional group.""" | ||||||||
from xarray.groupers import UniqueGrouper | ||||||||
|
||||||||
groupers = self._groupers_tuple | ||||||||
stacked_dim = self._stacked_dim | ||||||||
if stacked_dim is not None and stacked_dim in obj.dims: | ||||||||
inserted_dims = self._inserted_dims | ||||||||
|
@@ -997,7 +1010,7 @@ def _maybe_unstack(self, obj): | |||||||
if dim in obj.coords: | ||||||||
del obj.coords[dim] | ||||||||
obj._indexes = filter_indexes_from_coords(obj._indexes, set(obj.coords)) | ||||||||
elif len(self.groupers) > 1: | ||||||||
elif len(groupers) > 1: | ||||||||
# TODO: we could clean this up by setting the appropriate `stacked_dim` | ||||||||
# and `inserted_dims` | ||||||||
# if multiple groupers all share the same single dimension, then | ||||||||
|
@@ -1007,7 +1020,7 @@ def _maybe_unstack(self, obj): | |||||||
obj = obj.unstack(*dims_to_unstack) | ||||||||
to_drop = [ | ||||||||
grouper.name | ||||||||
for grouper in self.groupers | ||||||||
for grouper in groupers | ||||||||
if isinstance(grouper.group, _DummyGroup) | ||||||||
and isinstance(grouper.grouper, UniqueGrouper) | ||||||||
] | ||||||||
|
@@ -1044,7 +1057,7 @@ def _flox_reduce( | |||||||
kwargs.setdefault("method", "cohorts") | ||||||||
|
||||||||
midx_grouping_vars: tuple[Hashable, ...] = () | ||||||||
for grouper in self.groupers: | ||||||||
for grouper in self._groupers_tuple: | ||||||||
name = grouper.name | ||||||||
maybe_midx = obj._indexes.get(name, None) | ||||||||
if isinstance(maybe_midx, PandasMultiIndex): | ||||||||
|
@@ -1082,7 +1095,7 @@ def _flox_reduce( | |||||||
parsed_dim_list = list() | ||||||||
# preserve order | ||||||||
for dim_ in itertools.chain( | ||||||||
*(grouper.group.dims for grouper in self.groupers) | ||||||||
*(grouper.group.dims for grouper in self._groupers_tuple) | ||||||||
): | ||||||||
if dim_ not in parsed_dim_list: | ||||||||
parsed_dim_list.append(dim_) | ||||||||
|
@@ -1094,7 +1107,7 @@ def _flox_reduce( | |||||||
|
||||||||
# Do this so we raise the same error message whether flox is present or not. | ||||||||
# Better to control it here than in flox. | ||||||||
for grouper in self.groupers: | ||||||||
for grouper in self._groupers_tuple: | ||||||||
if any( | ||||||||
d not in grouper.group.dims and d not in obj.dims for d in parsed_dim | ||||||||
): | ||||||||
|
@@ -1117,10 +1130,10 @@ def _flox_reduce( | |||||||
|
||||||||
# pass RangeIndex as a hint to flox that `by` is already factorized | ||||||||
expected_groups = tuple( | ||||||||
pd.RangeIndex(len(grouper)) for grouper in self.groupers | ||||||||
pd.RangeIndex(len(grouper)) for grouper in self._groupers_tuple | ||||||||
) | ||||||||
|
||||||||
codes = tuple(g.codes for g in self.groupers) | ||||||||
codes = tuple(g.codes for g in self._groupers_tuple) | ||||||||
result = xarray_reduce( | ||||||||
obj.drop_vars(non_numeric.keys()), | ||||||||
*codes, | ||||||||
|
@@ -1137,7 +1150,7 @@ def _flox_reduce( | |||||||
new_coords = [] | ||||||||
to_drop = [] | ||||||||
if group_dims & set(parsed_dim): | ||||||||
for grouper in self.groupers: | ||||||||
for grouper in self._groupers_tuple: | ||||||||
output_index = grouper.full_index | ||||||||
if isinstance(output_index, pd.RangeIndex): | ||||||||
# flox always assigns an index so we must drop it here if we don't need it. | ||||||||
|
@@ -1529,7 +1542,7 @@ def _concat_shortcut(self, applied, dim, positions=None): | |||||||
|
||||||||
def _restore_dim_order(self, stacked: DataArray) -> DataArray: | ||||||||
def lookup_order(dimension): | ||||||||
for grouper in self.groupers: | ||||||||
for grouper in self._groupers_tuple: | ||||||||
if dimension == grouper.name and grouper.group.ndim == 1: | ||||||||
(dimension,) = grouper.group.dims | ||||||||
if dimension in self._obj.dims: | ||||||||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead I'd return a flag
squeeze_group_label_on_iter
that is True if group isstr