Skip to content

Commit

Permalink
ENH: use correct dtype in groupby cython ops when it is known (withou…
Browse files Browse the repository at this point in the history
…t try/except) (#38291)

Co-authored-by: Joris Van den Bossche <[email protected]>
  • Loading branch information
jbrockmendel and jorisvandenbossche authored Dec 8, 2020
1 parent 37f7bdc commit 5bfa653
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 10 deletions.
18 changes: 12 additions & 6 deletions pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,12 +357,18 @@ def maybe_cast_result_dtype(dtype: DtypeObj, how: str) -> DtypeObj:
The desired dtype of the result.
"""
from pandas.core.arrays.boolean import BooleanDtype
from pandas.core.arrays.integer import Int64Dtype

if how in ["add", "cumsum", "sum"] and (dtype == np.dtype(bool)):
return np.dtype(np.int64)
elif how in ["add", "cumsum", "sum"] and isinstance(dtype, BooleanDtype):
return Int64Dtype()
from pandas.core.arrays.floating import Float64Dtype
from pandas.core.arrays.integer import Int64Dtype, _IntegerDtype

if how in ["add", "cumsum", "sum", "prod"]:
if dtype == np.dtype(bool):
return np.dtype(np.int64)
elif isinstance(dtype, (BooleanDtype, _IntegerDtype)):
return Int64Dtype()
elif how in ["mean", "median", "var"] and isinstance(
dtype, (BooleanDtype, _IntegerDtype)
):
return Float64Dtype()
return dtype


Expand Down
15 changes: 14 additions & 1 deletion pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
is_datetime64_any_dtype,
is_datetime64tz_dtype,
is_extension_array_dtype,
is_float_dtype,
is_integer_dtype,
is_numeric_dtype,
is_period_dtype,
Expand Down Expand Up @@ -521,7 +522,19 @@ def _ea_wrap_cython_operation(
res_values = self._cython_operation(
kind, values, how, axis, min_count, **kwargs
)
result = maybe_cast_result(result=res_values, obj=orig_values, how=how)
dtype = maybe_cast_result_dtype(orig_values.dtype, how)
if is_extension_array_dtype(dtype):
cls = dtype.construct_array_type()
return cls._from_sequence(res_values, dtype=dtype)
return res_values

elif is_float_dtype(values.dtype):
# FloatingArray
values = values.to_numpy(values.dtype.numpy_dtype, na_value=np.nan)
res_values = self._cython_operation(
kind, values, how, axis, min_count, **kwargs
)
result = type(orig_values)._from_sequence(res_values)
return result

raise NotImplementedError(values.dtype)
Expand Down
5 changes: 4 additions & 1 deletion pandas/tests/arrays/integer/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,10 @@ def test_reduce_to_float(op):
result = getattr(df.groupby("A"), op)()

expected = pd.DataFrame(
{"B": np.array([1.0, 3.0]), "C": integer_array([1, 3], dtype="Int64")},
{
"B": np.array([1.0, 3.0]),
"C": pd.array([1, 3], dtype="Float64"),
},
index=pd.Index(["a", "b"], name="A"),
)
tm.assert_frame_equal(result, expected)
Expand Down
68 changes: 68 additions & 0 deletions pandas/tests/groupby/aggregate/test_cython.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import numpy as np
import pytest

from pandas.core.dtypes.common import is_float_dtype

import pandas as pd
from pandas import DataFrame, Index, NaT, Series, Timedelta, Timestamp, bdate_range
import pandas._testing as tm
Expand Down Expand Up @@ -312,3 +314,69 @@ def test_cython_agg_nullable_int(op_name):
# so for now just checking the values by casting to float
result = result.astype("float64")
tm.assert_series_equal(result, expected)


@pytest.mark.parametrize("with_na", [True, False])
@pytest.mark.parametrize(
"op_name, action",
[
# ("count", "always_int"),
("sum", "large_int"),
# ("std", "always_float"),
("var", "always_float"),
# ("sem", "always_float"),
("mean", "always_float"),
("median", "always_float"),
("prod", "large_int"),
("min", "preserve"),
("max", "preserve"),
("first", "preserve"),
("last", "preserve"),
],
)
@pytest.mark.parametrize(
"data",
[
pd.array([1, 2, 3, 4], dtype="Int64"),
pd.array([1, 2, 3, 4], dtype="Int8"),
pd.array([0.1, 0.2, 0.3, 0.4], dtype="Float32"),
pd.array([0.1, 0.2, 0.3, 0.4], dtype="Float64"),
pd.array([True, True, False, False], dtype="boolean"),
],
)
def test_cython_agg_EA_known_dtypes(data, op_name, action, with_na):
if with_na:
data[3] = pd.NA

df = DataFrame({"key": ["a", "a", "b", "b"], "col": data})
grouped = df.groupby("key")

if action == "always_int":
# always Int64
expected_dtype = pd.Int64Dtype()
elif action == "large_int":
# for any int/bool use Int64, for float preserve dtype
if is_float_dtype(data.dtype):
expected_dtype = data.dtype
else:
expected_dtype = pd.Int64Dtype()
elif action == "always_float":
# for any int/bool use Float64, for float preserve dtype
if is_float_dtype(data.dtype):
expected_dtype = data.dtype
else:
expected_dtype = pd.Float64Dtype()
elif action == "preserve":
expected_dtype = data.dtype

result = getattr(grouped, op_name)()
assert result["col"].dtype == expected_dtype

result = grouped.aggregate(op_name)
assert result["col"].dtype == expected_dtype

result = getattr(grouped["col"], op_name)()
assert result.dtype == expected_dtype

result = grouped["col"].aggregate(op_name)
assert result.dtype == expected_dtype
2 changes: 1 addition & 1 deletion pandas/tests/groupby/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,7 +1093,7 @@ def test_apply_to_nullable_integer_returns_float(values, function):
output = 0.5 if function == "var" else 1.5
arr = np.array([output] * 3, dtype=float)
idx = Index([1, 2, 3], dtype=object, name="a")
expected = DataFrame({"b": arr}, index=idx)
expected = DataFrame({"b": arr}, index=idx).astype("Float64")

groups = DataFrame(values, dtype="Int64").groupby("a")

Expand Down
4 changes: 3 additions & 1 deletion pandas/tests/resample/test_datetime_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,9 @@ def test_resample_integerarray():

result = ts.resample("3T").mean()
expected = Series(
[1, 4, 7], index=pd.date_range("1/1/2000", periods=3, freq="3T"), dtype="Int64"
[1, 4, 7],
index=pd.date_range("1/1/2000", periods=3, freq="3T"),
dtype="Float64",
)
tm.assert_series_equal(result, expected)

Expand Down

0 comments on commit 5bfa653

Please sign in to comment.