diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 4995b2953f9..c07e75c5d4c 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -66,6 +66,8 @@ Deprecations Bug fixes ~~~~~~~~~ +- Ensured tuple return types of ``groupby`` calls with sequences, regardless of ``len==1``. + By `Nick Papior `_. - :py:meth:`~xarray.Dataset.to_stacked_array` now uses dimensions in order of appearance. This fixes the issue where using :py:meth:`~xarray.Dataset.transpose` before :py:meth:`~xarray.Dataset.to_stacked_array` had no effect. (Mentioned in :issue:`9921`) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 6d9dd355638..a1a80518fe4 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -383,7 +383,7 @@ def _parse_group_and_groupers( groupers: dict[str, Grouper], *, eagerly_compute_group: bool, -) -> tuple[ResolvedGrouper, ...]: +) -> ResolvedGrouper | tuple[ResolvedGrouper, ...]: from xarray.core.dataarray import DataArray from xarray.core.variable import Variable from xarray.groupers import UniqueGrouper @@ -407,6 +407,7 @@ def _parse_group_and_groupers( rgroupers: tuple[ResolvedGrouper, ...] if isinstance(group, DataArray | Variable): + # TODO add test for this, see gh-10246 rgroupers = ( ResolvedGrouper( UniqueGrouper(), group, obj, eagerly_compute_group=eagerly_compute_group @@ -429,6 +430,8 @@ def _parse_group_and_groupers( ) for group, grouper in grouper_mapping.items() ) + if isinstance(group, str): + rgroupers = rgroupers[0] return rgroupers @@ -453,7 +456,6 @@ def _resolve_group( "match the length of this variable along its " "dimensions" ) - newgroup: T_Group if isinstance(group, DataArray): try: @@ -602,7 +604,7 @@ class GroupBy(Generic[T_Xarray]): "groupers", ) _obj: T_Xarray - groupers: tuple[ResolvedGrouper, ...] + groupers: ResolvedGrouper | tuple[ResolvedGrouper, ...] _restore_coord_dims: bool _original_obj: T_Xarray @@ -626,7 +628,7 @@ class GroupBy(Generic[T_Xarray]): def __init__( self, obj: T_Xarray, - groupers: tuple[ResolvedGrouper, ...], + groupers: ResolvedGrouper | tuple[ResolvedGrouper, ...], restore_coord_dims: bool = True, ) -> None: """Create a GroupBy object @@ -635,7 +637,7 @@ def __init__( ---------- obj : Dataset or DataArray Object to group. - grouper : Grouper + groupers : ResolvedGrouper or tuple[ResolvedGrouper, ...] Grouper object restore_coord_dims : bool, default: True If True, also restore the dimension order of multi-dimensional @@ -645,9 +647,8 @@ def __init__( self._restore_coord_dims = restore_coord_dims self.groupers = groupers - if len(groupers) == 1: - (grouper,) = groupers - self.encoded = grouper.encoded + if isinstance(groupers, ResolvedGrouper): + self.encoded = groupers.encoded else: if any( isinstance(obj._indexes.get(grouper.name, None), PandasMultiIndex) @@ -699,6 +700,12 @@ def sizes(self) -> Mapping[Hashable, int]: self._sizes = self._obj.isel({self._group_dim: index}).sizes return self._sizes + @property + def _groupers_tuple(self) -> tuple[ResolvedGrouper, ...]: + if isinstance(self.groupers, ResolvedGrouper): + return (self.groupers,) + return self.groupers + def shuffle_to_chunks(self, chunks: T_Chunks = None) -> T_Xarray: """ Sort or "shuffle" the underlying object. @@ -756,7 +763,8 @@ def _shuffle_obj(self, chunks: T_Chunks) -> T_Xarray: was_array = isinstance(self._obj, DataArray) as_dataset = self._obj._to_temp_dataset() if was_array else self._obj - for grouper in self.groupers: + groupers = self._groupers_tuple + for grouper in groupers: if grouper.name not in as_dataset._variables: as_dataset.coords[grouper.name] = grouper.group @@ -801,7 +809,9 @@ def _raise_if_by_is_chunked(self): ) def _raise_if_not_single_group(self): - if len(self.groupers) != 1: + # This ensures that one does not do `groupby_bins(..., ["x"])` + # TODO this should be enabled later. + if not isinstance(self.groupers, ResolvedGrouper): raise NotImplementedError( "This method is not supported for grouping by multiple variables yet." ) @@ -836,12 +846,13 @@ def __iter__(self) -> Iterator[tuple[GroupKey, T_Xarray]]: return zip(self.encoded.unique_coord.data, self._iter_grouped(), strict=True) def __repr__(self) -> str: + groupers = self._groupers_tuple text = ( f"<{self.__class__.__name__}, " - f"grouped over {len(self.groupers)} grouper(s)," + f"grouped over {len(groupers)} grouper(s)," f" {self._len} groups in total:" ) - for grouper in self.groupers: + for grouper in groupers: coord = grouper.unique_coord labels = ", ".join(format_array_flat(coord, 30).split()) text += f"\n {grouper.name!r}: {coord.size}/{grouper.full_index.size} groups present with labels {labels}" @@ -871,7 +882,7 @@ def _binary_op(self, other, f, reflexive=False): g = f if not reflexive else lambda x, y: f(y, x) self._raise_if_not_single_group() - (grouper,) = self.groupers + (grouper,) = self._groupers_tuple obj = self._original_obj name = grouper.name group = grouper.group @@ -971,11 +982,12 @@ def _maybe_reindex(self, combined): self.encoded.unique_coord.size != self.encoded.full_index.size ) indexers = {} - for grouper in self.groupers: + groupers = self._groupers_tuple + for grouper in groupers: index = combined._indexes.get(grouper.name, None) if has_missing_groups and index is not None: indexers[grouper.name] = grouper.full_index - elif len(self.groupers) > 1: + elif len(groupers) > 1: if not isinstance( grouper.full_index, pd.RangeIndex ) and not index.index.equals(grouper.full_index): @@ -989,6 +1001,7 @@ def _maybe_unstack(self, obj): multidimensional group.""" from xarray.groupers import UniqueGrouper + groupers = self._groupers_tuple stacked_dim = self._stacked_dim if stacked_dim is not None and stacked_dim in obj.dims: inserted_dims = self._inserted_dims @@ -997,7 +1010,7 @@ def _maybe_unstack(self, obj): if dim in obj.coords: del obj.coords[dim] obj._indexes = filter_indexes_from_coords(obj._indexes, set(obj.coords)) - elif len(self.groupers) > 1: + elif len(groupers) > 1: # TODO: we could clean this up by setting the appropriate `stacked_dim` # and `inserted_dims` # if multiple groupers all share the same single dimension, then @@ -1007,7 +1020,7 @@ def _maybe_unstack(self, obj): obj = obj.unstack(*dims_to_unstack) to_drop = [ grouper.name - for grouper in self.groupers + for grouper in groupers if isinstance(grouper.group, _DummyGroup) and isinstance(grouper.grouper, UniqueGrouper) ] @@ -1044,7 +1057,7 @@ def _flox_reduce( kwargs.setdefault("method", "cohorts") midx_grouping_vars: tuple[Hashable, ...] = () - for grouper in self.groupers: + for grouper in self._groupers_tuple: name = grouper.name maybe_midx = obj._indexes.get(name, None) if isinstance(maybe_midx, PandasMultiIndex): @@ -1082,7 +1095,7 @@ def _flox_reduce( parsed_dim_list = list() # preserve order for dim_ in itertools.chain( - *(grouper.group.dims for grouper in self.groupers) + *(grouper.group.dims for grouper in self._groupers_tuple) ): if dim_ not in parsed_dim_list: parsed_dim_list.append(dim_) @@ -1094,7 +1107,7 @@ def _flox_reduce( # Do this so we raise the same error message whether flox is present or not. # Better to control it here than in flox. - for grouper in self.groupers: + for grouper in self._groupers_tuple: if any( d not in grouper.group.dims and d not in obj.dims for d in parsed_dim ): @@ -1117,10 +1130,10 @@ def _flox_reduce( # pass RangeIndex as a hint to flox that `by` is already factorized expected_groups = tuple( - pd.RangeIndex(len(grouper)) for grouper in self.groupers + pd.RangeIndex(len(grouper)) for grouper in self._groupers_tuple ) - codes = tuple(g.codes for g in self.groupers) + codes = tuple(g.codes for g in self._groupers_tuple) result = xarray_reduce( obj.drop_vars(non_numeric.keys()), *codes, @@ -1137,7 +1150,7 @@ def _flox_reduce( new_coords = [] to_drop = [] if group_dims & set(parsed_dim): - for grouper in self.groupers: + for grouper in self._groupers_tuple: output_index = grouper.full_index if isinstance(output_index, pd.RangeIndex): # flox always assigns an index so we must drop it here if we don't need it. @@ -1529,7 +1542,7 @@ def _concat_shortcut(self, applied, dim, positions=None): def _restore_dim_order(self, stacked: DataArray) -> DataArray: def lookup_order(dimension): - for grouper in self.groupers: + for grouper in self._groupers_tuple: if dimension == grouper.name and grouper.group.ndim == 1: (dimension,) = grouper.group.dims if dimension in self._obj.dims: diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index c8f2b106b07..943460501c4 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -3308,6 +3308,33 @@ def test_groupby_dask_eager_load_warnings() -> None: ds.groupby_bins("x", bins=[1, 2, 3], eagerly_compute_group=False) +def test_groupby_return_group_dataset_type(dataset): + # Checks GH10246 + def group_val(groupers): + ret = next(iter(groupers))[0] + return ret + + assert isinstance(group_val(dataset.groupby("baz")), str) + assert isinstance(group_val(dataset.groupby(["baz"])), tuple) + assert isinstance(group_val(dataset.groupby(["baz"]))[0], str) + + +def test_groupby_return_group_dataarray_type(array): + # Checks GH10246 + def group_val(groupers): + ret = next(iter(groupers))[0] + return ret + + assert isinstance(group_val(array.groupby("x")), str) + assert isinstance(group_val(array.groupby(["x"])), tuple) + assert isinstance(group_val(array.groupby(["x"]))[0], str) + + +def test_groupby_return_group_type_raise(dataset): + with pytest.raises(TypeError, match="xarray variable or dimension"): + dataset.groupby_bins(["y"], [0, 1]) + + # TODO: Possible property tests to add to this module # 1. lambda x: x # 2. grouped-reduce on unique coords is identical to array