From 8ee74884dd9225c9b3cbbfb1cc3843545889b8e4 Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 24 Jan 2023 16:49:12 -0700 Subject: [PATCH 01/12] Pint array: strip and reattach appropriate units Closes #163 --- flox/aggregations.py | 73 +++++++++++++++++++++++++++++++++++++------- flox/core.py | 10 +++++- tests/__init__.py | 16 ++++++++++ tests/test_core.py | 37 ++++++++++++++++++++++ 4 files changed, 124 insertions(+), 12 deletions(-) diff --git a/flox/aggregations.py b/flox/aggregations.py index 1a52cbee0..699df05b6 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -3,6 +3,7 @@ import copy import warnings from functools import partial +from typing import Callable import numpy as np import numpy_groupies as npg @@ -114,6 +115,7 @@ def __init__( dtypes=None, final_dtype=None, reduction_type="reduce", + units_func: Callable | None = None, ): """ Blueprint for computing grouped aggregations. @@ -156,6 +158,8 @@ def __init__( per reduction in ``chunk`` as a tuple. final_dtype : DType, optional DType for output. By default, uses dtype of array being reduced. + units_func : pint.Unit + units for the output """ self.name = name # preprocess before blockwise @@ -187,6 +191,8 @@ def __init__( # The following are set by _initialize_aggregation self.finalize_kwargs = {} self.min_count = None + self.units_func = units_func + self.units = None def _normalize_dtype_fill_value(self, value, name): value = _atleast_1d(value) @@ -235,17 +241,44 @@ def __repr__(self): final_dtype=np.intp, ) + +def identity(x): + return x + + +def square(x): + return x**2 + + +def raise_units_error(x): + raise ValueError( + "Units cannot supported for prod in general. " + "We can only attach units when there are " + "equal number of members in each group. " + "Please strip units and then reattach units " + "to the output manually." + ) + + # note that the fill values are the result of np.func([np.nan, np.nan]) # final_fill_value is used for groups that don't exist. This is usually np.nan -sum_ = Aggregation("sum", chunk="sum", combine="sum", fill_value=0) -nansum = Aggregation("nansum", chunk="nansum", combine="sum", fill_value=0) -prod = Aggregation("prod", chunk="prod", combine="prod", fill_value=1, final_fill_value=1) +sum_ = Aggregation("sum", chunk="sum", combine="sum", fill_value=0, units_func=identity) +nansum = Aggregation("nansum", chunk="nansum", combine="sum", fill_value=0, units_func=identity) +prod = Aggregation( + "prod", + chunk="prod", + combine="prod", + fill_value=1, + final_fill_value=1, + units_func=raise_units_error, +) nanprod = Aggregation( "nanprod", chunk="nanprod", combine="prod", fill_value=1, final_fill_value=dtypes.NA, + units_func=raise_units_error, ) @@ -262,6 +295,7 @@ def _mean_finalize(sum_, count): fill_value=(0, 0), dtypes=(None, np.intp), final_dtype=np.floating, + units_func=identity, ) nanmean = Aggregation( "nanmean", @@ -271,6 +305,7 @@ def _mean_finalize(sum_, count): fill_value=(0, 0), dtypes=(None, np.intp), final_dtype=np.floating, + units_func=identity, ) @@ -296,6 +331,7 @@ def _std_finalize(sumsq, sum_, count, ddof=0): final_fill_value=np.nan, dtypes=(None, None, np.intp), final_dtype=np.floating, + units_func=square, ) nanvar = Aggregation( "nanvar", @@ -306,6 +342,7 @@ def _std_finalize(sumsq, sum_, count, ddof=0): final_fill_value=np.nan, dtypes=(None, None, np.intp), final_dtype=np.floating, + units_func=square, ) std = Aggregation( "std", @@ -316,6 +353,7 @@ def _std_finalize(sumsq, sum_, count, ddof=0): final_fill_value=np.nan, dtypes=(None, None, np.intp), final_dtype=np.floating, + units_func=identity, ) nanstd = Aggregation( "nanstd", @@ -329,10 +367,14 @@ def _std_finalize(sumsq, sum_, count, ddof=0): ) -min_ = Aggregation("min", chunk="min", combine="min", fill_value=dtypes.INF) -nanmin = Aggregation("nanmin", chunk="nanmin", combine="nanmin", fill_value=np.nan) -max_ = Aggregation("max", chunk="max", combine="max", fill_value=dtypes.NINF) -nanmax = Aggregation("nanmax", chunk="nanmax", combine="nanmax", fill_value=np.nan) +min_ = Aggregation("min", chunk="min", combine="min", fill_value=dtypes.INF, units_func=identity) +nanmin = Aggregation( + "nanmin", chunk="nanmin", combine="nanmin", fill_value=np.nan, units_func=identity +) +max_ = Aggregation("max", chunk="max", combine="max", fill_value=dtypes.NINF, units_func=identity) +nanmax = Aggregation( + "nanmax", chunk="nanmax", combine="nanmax", fill_value=np.nan, units_func=identity +) def argreduce_preprocess(array, axis): @@ -420,10 +462,14 @@ def _pick_second(*x): final_dtype=np.intp, ) -first = Aggregation("first", chunk=None, combine=None, fill_value=0) -last = Aggregation("last", chunk=None, combine=None, fill_value=0) -nanfirst = Aggregation("nanfirst", chunk="nanfirst", combine="nanfirst", fill_value=np.nan) -nanlast = Aggregation("nanlast", chunk="nanlast", combine="nanlast", fill_value=np.nan) +first = Aggregation("first", chunk=None, combine=None, fill_value=0, units_func=identity) +last = Aggregation("last", chunk=None, combine=None, fill_value=0, units_func=identity) +nanfirst = Aggregation( + "nanfirst", chunk="nanfirst", combine="nanfirst", fill_value=np.nan, units_func=identity +) +nanlast = Aggregation( + "nanlast", chunk="nanlast", combine="nanlast", fill_value=np.nan, units_func=identity +) all_ = Aggregation( "all", @@ -483,6 +529,7 @@ def _initialize_aggregation( dtype, array_dtype, fill_value, + array_units, min_count: int | None, finalize_kwargs, ) -> Aggregation: @@ -547,4 +594,8 @@ def _initialize_aggregation( agg.dtype["intermediate"] += (np.intp,) agg.dtype["numpy"] += (np.intp,) + if array_units is not None and agg.units_func is not None: + import pint + + agg.units = agg.units_func(pint.Quantity([1], units=array_units)) return agg diff --git a/flox/core.py b/flox/core.py index 022f29582..9874d3a56 100644 --- a/flox/core.py +++ b/flox/core.py @@ -24,6 +24,7 @@ generic_aggregate, ) from .cache import memoize +from .pint_compat import _reattach_units, _strip_units from .xrutils import is_duck_array, is_duck_dask_array, isnull if TYPE_CHECKING: @@ -1702,6 +1703,8 @@ def groupby_reduce( by_is_dask = tuple(is_duck_dask_array(b) for b in bys) any_by_dask = any(by_is_dask) + array, *bys, units = _strip_units(array, *bys) + if method in ["split-reduce", "cohorts"] and any_by_dask: raise ValueError(f"method={method!r} can only be used when grouping by numpy arrays.") @@ -1803,7 +1806,9 @@ def groupby_reduce( fill_value = np.nan kwargs = dict(axis=axis_, fill_value=fill_value, engine=engine) - agg = _initialize_aggregation(func, dtype, array.dtype, fill_value, min_count, finalize_kwargs) + agg = _initialize_aggregation( + func, dtype, array.dtype, fill_value, units[0], min_count, finalize_kwargs + ) if not has_dask: results = _reduce_blockwise( @@ -1862,4 +1867,7 @@ def groupby_reduce( if _is_minmax_reduction(func) and is_bool_array: result = result.astype(bool) + + units[0] = agg.units + result, *groups = _reattach_units(result, *groups, units=units) return (result, *groups) diff --git a/tests/__init__.py b/tests/__init__.py index b1a266652..75a8342a0 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -24,6 +24,13 @@ except ImportError: xr_types = () # type: ignore +try: + import pint + + pint_types = pint.Quantity +except ImportError: + pint_types = () # type: ignore + def _importorskip(modname, minversion=None): try: @@ -46,6 +53,7 @@ def LooseVersion(vstring): has_dask, requires_dask = _importorskip("dask") +has_pint, requires_pint = _importorskip("pint") has_xarray, requires_xarray = _importorskip("xarray") @@ -95,6 +103,14 @@ def assert_equal(a, b, tolerance=None): xr.testing.assert_identical(a, b) return + if has_pint and isinstance(a, pint_types) or isinstance(b, pint_types): + assert isinstance(a, pint_types) + assert isinstance(b, pint_types) + assert a.units == b.units + + a = a.magnitude + b = b.magnitude + if tolerance is None and ( np.issubdtype(a.dtype, np.float64) | np.issubdtype(b.dtype, np.float64) ): diff --git a/tests/test_core.py b/tests/test_core.py index 0841e531a..3177cf946 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -31,6 +31,7 @@ has_dask, raise_if_dask_computes, requires_dask, + requires_pint, ) labels = np.array([0, 0, 2, 2, 2, 1, 1, 2, 2, 1, 1, 0]) @@ -1321,3 +1322,39 @@ def test_negative_index_factorize_race_condition(): for f in func ] [dask.compute(out, scheduler="threads") for _ in range(5)] + + +@requires_pint +@pytest.mark.parametrize("func", ["all", "count", "sum", "var"]) +@pytest.mark.parametrize("chunk", [True, False]) +def test_pint(chunk, func): + import pint + + if chunk: + d = dask.array.array([1, 2, 3]) + else: + d = np.array([1, 2, 3]) + q = pint.Quantity(d, units="m") + + actual, _ = groupby_reduce(q, [0, 0, 1], func=func) + expected, _ = groupby_reduce(q.magnitude, [0, 0, 1], func=func) + + units = None if func in ["count", "all"] else getattr(np, func)(q).units + if units is not None: + expected = pint.Quantity(expected, units=units) + assert_equal(expected, actual) + + +@requires_pint +@pytest.mark.parametrize("chunk", [True, False]) +def test_pint_prod_error(chunk): + import pint + + if chunk: + d = dask.array.array([1, 2, 3]) + else: + d = np.array([1, 2, 3]) + q = pint.Quantity(d, units="m") + + with pytest.raises(ValueError): + groupby_reduce(q, [0, 0, 1], func="prod") From c2e173cc60c3a2f4ac4a8c0b1eebdf814fe2e88a Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 24 Jan 2023 16:50:52 -0700 Subject: [PATCH 02/12] Update flox/aggregations.py --- flox/aggregations.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flox/aggregations.py b/flox/aggregations.py index 699df05b6..41c7bbb03 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -158,8 +158,8 @@ def __init__( per reduction in ``chunk`` as a tuple. final_dtype : DType, optional DType for output. By default, uses dtype of array being reduced. - units_func : pint.Unit - units for the output + units_func : callable + function whose output will be used to infer units. """ self.name = name # preprocess before blockwise From 0eb9e7e9b96dcc246da2be9c9a4e7101b6d5e1e8 Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 24 Jan 2023 16:51:33 -0700 Subject: [PATCH 03/12] Add pint to env --- ci/environment.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/ci/environment.yml b/ci/environment.yml index 0510e4b20..d847e56a5 100644 --- a/ci/environment.yml +++ b/ci/environment.yml @@ -11,6 +11,7 @@ dependencies: - numpy>=1.20 - lxml # for mypy coverage report - matplotlib + - pint - pip - pytest - pytest-cov From 600e6cbc18e2491d556c79212b955afec0a0d74f Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 24 Jan 2023 16:55:33 -0700 Subject: [PATCH 04/12] Add pint_compat --- flox/pint_compat.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 flox/pint_compat.py diff --git a/flox/pint_compat.py b/flox/pint_compat.py new file mode 100644 index 000000000..c7212a807 --- /dev/null +++ b/flox/pint_compat.py @@ -0,0 +1,25 @@ +def _strip_units(*arrays): + try: + import pint + + pint_quantity = pint.Quantity + + except ImportError: + pint_quantity = None + + bare = [array.magnitude if isinstance(array, pint_quantity) else array for array in arrays] + units = [array.units if isinstance(array, pint_quantity) else None for array in arrays] + + return *bare, units + + +def _reattach_units(*arrays, units): + try: + import pint + + return [ + pint.Quantity(array, unit) if unit is not None else array + for array, unit in zip(arrays, units) + ] + except ImportError: + return arrays From 0a8aa703f3b46217556e821d96e06ad7ffe2e213 Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 24 Jan 2023 20:24:58 -0700 Subject: [PATCH 05/12] More comprehensive tests --- flox/aggregations.py | 1 + tests/test_core.py | 9 ++++++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/flox/aggregations.py b/flox/aggregations.py index 41c7bbb03..2ff67fdf1 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -364,6 +364,7 @@ def _std_finalize(sumsq, sum_, count, ddof=0): final_fill_value=np.nan, dtypes=(None, None, np.intp), final_dtype=np.floating, + units_func=identity, ) diff --git a/tests/test_core.py b/tests/test_core.py index 3177cf946..f064537ff 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1325,11 +1325,14 @@ def test_negative_index_factorize_race_condition(): @requires_pint -@pytest.mark.parametrize("func", ["all", "count", "sum", "var"]) +@pytest.mark.parametrize("func", ALL_FUNCS) @pytest.mark.parametrize("chunk", [True, False]) -def test_pint(chunk, func): +def test_pint(chunk, func, engine): import pint + if func in ["prod", "nanprod"]: + pytest.skip() + if chunk: d = dask.array.array([1, 2, 3]) else: @@ -1339,7 +1342,7 @@ def test_pint(chunk, func): actual, _ = groupby_reduce(q, [0, 0, 1], func=func) expected, _ = groupby_reduce(q.magnitude, [0, 0, 1], func=func) - units = None if func in ["count", "all"] else getattr(np, func)(q).units + units = None if func in ["count", "all", "any"] or "arg" in func else getattr(np, func)(q).units if units is not None: expected = pint.Quantity(expected, units=units) assert_equal(expected, actual) From 70aa03a8065d0113ddef69eb87935642b3914385 Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 24 Jan 2023 20:26:39 -0700 Subject: [PATCH 06/12] Fix test --- flox/pint_compat.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flox/pint_compat.py b/flox/pint_compat.py index c7212a807..6396bbd23 100644 --- a/flox/pint_compat.py +++ b/flox/pint_compat.py @@ -7,7 +7,7 @@ def _strip_units(*arrays): except ImportError: pint_quantity = None - bare = [array.magnitude if isinstance(array, pint_quantity) else array for array in arrays] + bare = tuple(array.magnitude if isinstance(array, pint_quantity) else array for array in arrays) units = [array.units if isinstance(array, pint_quantity) else None for array in arrays] return *bare, units @@ -17,9 +17,9 @@ def _reattach_units(*arrays, units): try: import pint - return [ + return tuple( pint.Quantity(array, unit) if unit is not None else array for array, unit in zip(arrays, units) - ] + ) except ImportError: return arrays From 8204f712616d91ac3de8c6b9ac6660cbac61d76f Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 24 Jan 2023 20:31:05 -0700 Subject: [PATCH 07/12] Fix test --- flox/core.py | 2 +- flox/pint_compat.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/flox/core.py b/flox/core.py index 9874d3a56..0f24ad78a 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1703,7 +1703,7 @@ def groupby_reduce( by_is_dask = tuple(is_duck_dask_array(b) for b in bys) any_by_dask = any(by_is_dask) - array, *bys, units = _strip_units(array, *bys) + array, bys, units = _strip_units(array, *bys) if method in ["split-reduce", "cohorts"] and any_by_dask: raise ValueError(f"method={method!r} can only be used when grouping by numpy arrays.") diff --git a/flox/pint_compat.py b/flox/pint_compat.py index 6396bbd23..ccd66db08 100644 --- a/flox/pint_compat.py +++ b/flox/pint_compat.py @@ -10,7 +10,7 @@ def _strip_units(*arrays): bare = tuple(array.magnitude if isinstance(array, pint_quantity) else array for array in arrays) units = [array.units if isinstance(array, pint_quantity) else None for array in arrays] - return *bare, units + return bare[0], bare[1:], units def _reattach_units(*arrays, units): From 79763910ae65dc4cfabe74f0923cddef90bbf9b1 Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 24 Jan 2023 20:55:05 -0700 Subject: [PATCH 08/12] Avoid converting group_idx to same array type --- flox/aggregations.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/flox/aggregations.py b/flox/aggregations.py index 2ff67fdf1..128053fdc 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -48,8 +48,6 @@ def generic_aggregate( f"Expected engine to be one of ['flox', 'numpy', 'numba']. Received {engine} instead." ) - group_idx = np.asarray(group_idx, like=array) - with warnings.catch_warnings(): warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered") result = method( From dfe5a8e32942e088e62418115efd10c72f1e0e90 Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 24 Jan 2023 21:17:13 -0700 Subject: [PATCH 09/12] Fix test --- flox/pint_compat.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/flox/pint_compat.py b/flox/pint_compat.py index ccd66db08..8a4b26150 100644 --- a/flox/pint_compat.py +++ b/flox/pint_compat.py @@ -1,8 +1,9 @@ def _strip_units(*arrays): + pint_quantity: tuple | None try: import pint - pint_quantity = pint.Quantity + pint_quantity = (pint.Quantity,) except ImportError: pint_quantity = None From 1806a5c59b21912dd3c9b8f4c72b9338d995370e Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 24 Jan 2023 21:23:41 -0700 Subject: [PATCH 10/12] proper fix? --- flox/pint_compat.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/flox/pint_compat.py b/flox/pint_compat.py index 8a4b26150..eddf58654 100644 --- a/flox/pint_compat.py +++ b/flox/pint_compat.py @@ -1,12 +1,11 @@ def _strip_units(*arrays): - pint_quantity: tuple | None try: import pint pint_quantity = (pint.Quantity,) except ImportError: - pint_quantity = None + pint_quantity = () bare = tuple(array.magnitude if isinstance(array, pint_quantity) else array for array in arrays) units = [array.units if isinstance(array, pint_quantity) else None for array in arrays] From 46b5b9396ed2e306815191a2a3f2aa3a045495ba Mon Sep 17 00:00:00 2001 From: dcherian Date: Sun, 29 Jan 2023 14:43:25 -0700 Subject: [PATCH 11/12] Fix bad merge. --- flox/aggregations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flox/aggregations.py b/flox/aggregations.py index c9e78be6e..b130ea05d 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -3,7 +3,7 @@ import copy import warnings from functools import partial -from typing import Callable +from typing import TYPE_CHECKING, Any, Callable, TypedDict import numpy as np import numpy_groupies as npg From 1f56e6eea60a4e3affbd22d4a178cfeafd107a43 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 27 Mar 2023 16:53:08 +0000 Subject: [PATCH 12/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_core.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_core.py b/tests/test_core.py index c34365443..88eaa13cd 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1377,6 +1377,8 @@ def test_pint_prod_error(chunk): with pytest.raises(ValueError): groupby_reduce(q, [0, 0, 1], func="prod") + + @pytest.mark.parametrize("sort", [True, False]) def test_expected_index_conversion_passthrough_range_index(sort): index = pd.RangeIndex(100)