From 5bfa653839a20a63a7272f3b201592ce6b8386e9 Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Tue, 8 Dec 2020 02:22:10 -0800 Subject: [PATCH] ENH: use correct dtype in groupby cython ops when it is known (without try/except) (#38291) Co-authored-by: Joris Van den Bossche --- pandas/core/dtypes/cast.py | 18 +++-- pandas/core/groupby/ops.py | 15 +++- .../tests/arrays/integer/test_arithmetic.py | 5 +- pandas/tests/groupby/aggregate/test_cython.py | 68 +++++++++++++++++++ pandas/tests/groupby/test_function.py | 2 +- pandas/tests/resample/test_datetime_index.py | 4 +- 6 files changed, 102 insertions(+), 10 deletions(-) diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index 6cd8036da1577..c77991ced3907 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -357,12 +357,18 @@ def maybe_cast_result_dtype(dtype: DtypeObj, how: str) -> DtypeObj: The desired dtype of the result. """ from pandas.core.arrays.boolean import BooleanDtype - from pandas.core.arrays.integer import Int64Dtype - - if how in ["add", "cumsum", "sum"] and (dtype == np.dtype(bool)): - return np.dtype(np.int64) - elif how in ["add", "cumsum", "sum"] and isinstance(dtype, BooleanDtype): - return Int64Dtype() + from pandas.core.arrays.floating import Float64Dtype + from pandas.core.arrays.integer import Int64Dtype, _IntegerDtype + + if how in ["add", "cumsum", "sum", "prod"]: + if dtype == np.dtype(bool): + return np.dtype(np.int64) + elif isinstance(dtype, (BooleanDtype, _IntegerDtype)): + return Int64Dtype() + elif how in ["mean", "median", "var"] and isinstance( + dtype, (BooleanDtype, _IntegerDtype) + ): + return Float64Dtype() return dtype diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index c60a59916affc..7724e3930f7df 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -45,6 +45,7 @@ is_datetime64_any_dtype, is_datetime64tz_dtype, is_extension_array_dtype, + is_float_dtype, is_integer_dtype, is_numeric_dtype, is_period_dtype, @@ -521,7 +522,19 @@ def _ea_wrap_cython_operation( res_values = self._cython_operation( kind, values, how, axis, min_count, **kwargs ) - result = maybe_cast_result(result=res_values, obj=orig_values, how=how) + dtype = maybe_cast_result_dtype(orig_values.dtype, how) + if is_extension_array_dtype(dtype): + cls = dtype.construct_array_type() + return cls._from_sequence(res_values, dtype=dtype) + return res_values + + elif is_float_dtype(values.dtype): + # FloatingArray + values = values.to_numpy(values.dtype.numpy_dtype, na_value=np.nan) + res_values = self._cython_operation( + kind, values, how, axis, min_count, **kwargs + ) + result = type(orig_values)._from_sequence(res_values) return result raise NotImplementedError(values.dtype) diff --git a/pandas/tests/arrays/integer/test_arithmetic.py b/pandas/tests/arrays/integer/test_arithmetic.py index 4b8d95ae83e4f..617cb6407d857 100644 --- a/pandas/tests/arrays/integer/test_arithmetic.py +++ b/pandas/tests/arrays/integer/test_arithmetic.py @@ -277,7 +277,10 @@ def test_reduce_to_float(op): result = getattr(df.groupby("A"), op)() expected = pd.DataFrame( - {"B": np.array([1.0, 3.0]), "C": integer_array([1, 3], dtype="Int64")}, + { + "B": np.array([1.0, 3.0]), + "C": pd.array([1, 3], dtype="Float64"), + }, index=pd.Index(["a", "b"], name="A"), ) tm.assert_frame_equal(result, expected) diff --git a/pandas/tests/groupby/aggregate/test_cython.py b/pandas/tests/groupby/aggregate/test_cython.py index c907391917ca8..8799f6faa775c 100644 --- a/pandas/tests/groupby/aggregate/test_cython.py +++ b/pandas/tests/groupby/aggregate/test_cython.py @@ -5,6 +5,8 @@ import numpy as np import pytest +from pandas.core.dtypes.common import is_float_dtype + import pandas as pd from pandas import DataFrame, Index, NaT, Series, Timedelta, Timestamp, bdate_range import pandas._testing as tm @@ -312,3 +314,69 @@ def test_cython_agg_nullable_int(op_name): # so for now just checking the values by casting to float result = result.astype("float64") tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("with_na", [True, False]) +@pytest.mark.parametrize( + "op_name, action", + [ + # ("count", "always_int"), + ("sum", "large_int"), + # ("std", "always_float"), + ("var", "always_float"), + # ("sem", "always_float"), + ("mean", "always_float"), + ("median", "always_float"), + ("prod", "large_int"), + ("min", "preserve"), + ("max", "preserve"), + ("first", "preserve"), + ("last", "preserve"), + ], +) +@pytest.mark.parametrize( + "data", + [ + pd.array([1, 2, 3, 4], dtype="Int64"), + pd.array([1, 2, 3, 4], dtype="Int8"), + pd.array([0.1, 0.2, 0.3, 0.4], dtype="Float32"), + pd.array([0.1, 0.2, 0.3, 0.4], dtype="Float64"), + pd.array([True, True, False, False], dtype="boolean"), + ], +) +def test_cython_agg_EA_known_dtypes(data, op_name, action, with_na): + if with_na: + data[3] = pd.NA + + df = DataFrame({"key": ["a", "a", "b", "b"], "col": data}) + grouped = df.groupby("key") + + if action == "always_int": + # always Int64 + expected_dtype = pd.Int64Dtype() + elif action == "large_int": + # for any int/bool use Int64, for float preserve dtype + if is_float_dtype(data.dtype): + expected_dtype = data.dtype + else: + expected_dtype = pd.Int64Dtype() + elif action == "always_float": + # for any int/bool use Float64, for float preserve dtype + if is_float_dtype(data.dtype): + expected_dtype = data.dtype + else: + expected_dtype = pd.Float64Dtype() + elif action == "preserve": + expected_dtype = data.dtype + + result = getattr(grouped, op_name)() + assert result["col"].dtype == expected_dtype + + result = grouped.aggregate(op_name) + assert result["col"].dtype == expected_dtype + + result = getattr(grouped["col"], op_name)() + assert result.dtype == expected_dtype + + result = grouped["col"].aggregate(op_name) + assert result.dtype == expected_dtype diff --git a/pandas/tests/groupby/test_function.py b/pandas/tests/groupby/test_function.py index c915c95294ba0..8d7fcbfcfe694 100644 --- a/pandas/tests/groupby/test_function.py +++ b/pandas/tests/groupby/test_function.py @@ -1093,7 +1093,7 @@ def test_apply_to_nullable_integer_returns_float(values, function): output = 0.5 if function == "var" else 1.5 arr = np.array([output] * 3, dtype=float) idx = Index([1, 2, 3], dtype=object, name="a") - expected = DataFrame({"b": arr}, index=idx) + expected = DataFrame({"b": arr}, index=idx).astype("Float64") groups = DataFrame(values, dtype="Int64").groupby("a") diff --git a/pandas/tests/resample/test_datetime_index.py b/pandas/tests/resample/test_datetime_index.py index 3e41dab39e71d..8bf40c924ec86 100644 --- a/pandas/tests/resample/test_datetime_index.py +++ b/pandas/tests/resample/test_datetime_index.py @@ -124,7 +124,9 @@ def test_resample_integerarray(): result = ts.resample("3T").mean() expected = Series( - [1, 4, 7], index=pd.date_range("1/1/2000", periods=3, freq="3T"), dtype="Int64" + [1, 4, 7], + index=pd.date_range("1/1/2000", periods=3, freq="3T"), + dtype="Float64", ) tm.assert_series_equal(result, expected)