Skip to content

POC: PDEP16 default to masked nullable dtypes #61716

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pandas/_testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@
NpDtype,
)

# Alias so we can update old `assert obj.dtype == np_dtype` checks to PDEP16
# behavior.
to_dtype = pd.core.dtypes.common.pandas_dtype

UNSIGNED_INT_NUMPY_DTYPES: list[NpDtype] = ["uint8", "uint16", "uint32", "uint64"]
UNSIGNED_INT_EA_DTYPES: list[Dtype] = ["UInt8", "UInt16", "UInt32", "UInt64"]
Expand Down Expand Up @@ -304,6 +307,8 @@ def box_expected(expected, box_cls, transpose: bool = True):
expected = pd.concat([expected] * 2, ignore_index=True)
elif box_cls is np.ndarray or box_cls is np.array:
expected = np.array(expected)
if expected.dtype.kind in "iufb" and pd.get_option("mode.pdep16_data_types"):
expected = pd.array(expected, copy=False)
elif box_cls is to_array:
expected = to_array(expected)
else:
Expand Down
18 changes: 12 additions & 6 deletions pandas/_testing/asserters.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,13 +323,19 @@ def _check_types(left, right, obj: str = "Index") -> None:
elif check_exact and check_categorical:
if not left.equals(right):
mismatch = left._values != right._values
if isinstance(left, RangeIndex) and not mismatch.any():
# TODO: probably need to fix RangeIndex.equals?
pass
elif isinstance(right, RangeIndex) and not mismatch.any():
# TODO: probably need to fix some other equals method?
pass
else:
if not isinstance(mismatch, np.ndarray):
mismatch = cast("ExtensionArray", mismatch).fillna(True)

if not isinstance(mismatch, np.ndarray):
mismatch = cast("ExtensionArray", mismatch).fillna(True)

diff = np.sum(mismatch.astype(int)) * 100.0 / len(left)
msg = f"{obj} values are different ({np.round(diff, 5)} %)"
raise_assert_detail(obj, msg, left, right)
diff = np.sum(mismatch.astype(int)) * 100.0 / len(left)
msg = f"{obj} values are different ({np.round(diff, 5)} %)"
raise_assert_detail(obj, msg, left, right)
else:
# if we have "equiv", this becomes True
exact_bool = bool(exact)
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/arrays/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def view(self, dtype: Dtype | None = None) -> ArrayLike:
# pass those through to the underlying ndarray
return self._ndarray.view(dtype)

dtype = pandas_dtype(dtype)
dtype = pandas_dtype(dtype, allow_numpy_dtypes=True)
arr = self._ndarray

if isinstance(dtype, PeriodDtype):
Expand Down
13 changes: 12 additions & 1 deletion pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

import numpy as np

from pandas._config import get_option

from pandas._libs import (
algos as libalgos,
lib,
Expand Down Expand Up @@ -2420,7 +2422,12 @@ def _where(self, mask: npt.NDArray[np.bool_], value) -> Self:
result = self.copy()

if is_list_like(value):
val = value[~mask]
if np.ndim(value) == 1 and len(value) == 1:
# test_where.test_broadcast if we change to use nullable...
# maybe this should be handled at a higher level?
val = value[0]
else:
val = value[~mask]
else:
val = value

Expand Down Expand Up @@ -2655,6 +2662,10 @@ def _groupby_op(
if op.how in op.cast_blocklist:
# i.e. how in ["rank"], since other cast_blocklist methods don't go
# through cython_operation
if get_option("mode.pdep16_data_types"):
from pandas import array as pd_array

return pd_array(res_values)
return res_values

if isinstance(self.dtype, StringDtype):
Expand Down
2 changes: 2 additions & 0 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -1701,6 +1701,8 @@ def _groupby_op(
if op.how in op.cast_blocklist:
# i.e. how in ["rank"], since other cast_blocklist methods don't go
# through cython_operation
# if get_option("mode.pdep16_data_types"):
# return pd_array(res_values) # breaks bc they dont support 2D
return res_values

# We did a view to M8[ns] above, now we go the other direction
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,7 +830,7 @@ def astype(self, dtype, copy: bool = True):
arr_ea = self.copy()
mask = self.isna()
arr_ea[mask] = "0"
values = arr_ea.astype(dtype.numpy_dtype)
values = arr_ea.to_numpy(dtype=dtype.numpy_dtype)
return FloatingArray(values, mask, copy=False)
elif isinstance(dtype, ExtensionDtype):
# Skip the NumpyExtensionArray.astype method
Expand Down
4 changes: 4 additions & 0 deletions pandas/core/arrays/timedeltas.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,10 @@ def __mul__(self, other) -> Self:
# numpy >= 2.1 may not raise a TypeError
# and seems to dispatch to others.__rmul__?
raise TypeError(f"Cannot multiply with {type(other).__name__}")
if isinstance(result, type(self)):
# e.g. if other is IntegerArray
assert result.dtype == self.dtype
return result
return type(self)._simple_new(result, dtype=result.dtype)

__rmul__ = __mul__
Expand Down
9 changes: 9 additions & 0 deletions pandas/core/config_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,15 @@ def is_terminal() -> bool:
validator=is_one_of_factory([True, False, "warn"]),
)

with cf.config_prefix("mode"):
cf.register_option(
"pdep16_data_types",
True,
"Whether to default to numpy-nullable dtypes for integer, float, "
"and bool dtypes",
validator=is_one_of_factory([True, False]),
)


# user warnings
chained_assignment = """
Expand Down
14 changes: 11 additions & 3 deletions pandas/core/construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
import numpy as np
from numpy import ma

from pandas._config import using_string_dtype
from pandas._config import (
get_option,
using_string_dtype,
)

from pandas._libs import lib
from pandas._libs.tslibs import (
Expand Down Expand Up @@ -612,7 +615,9 @@ def sanitize_array(
if dtype is None:
subarr = data
if data.dtype == object and infer_object:
subarr = maybe_infer_to_datetimelike(data)
subarr = maybe_infer_to_datetimelike(
data, convert_to_nullable_dtype=get_option("mode.pdep16_data_types")
)
elif data.dtype.kind == "U" and using_string_dtype():
from pandas.core.arrays.string_ import StringDtype

Expand Down Expand Up @@ -659,7 +664,10 @@ def sanitize_array(
subarr = maybe_convert_platform(data)
if subarr.dtype == object:
subarr = cast(np.ndarray, subarr)
subarr = maybe_infer_to_datetimelike(subarr)
subarr = maybe_infer_to_datetimelike(
subarr,
convert_to_nullable_dtype=get_option("mode.pdep16_data_types"),
)

subarr = _sanitize_ndim(subarr, data, dtype, index, allow_2d=allow_2d)

Expand Down
9 changes: 7 additions & 2 deletions pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@

import numpy as np

from pandas._config import using_string_dtype
from pandas._config import (
get_option,
using_string_dtype,
)

from pandas._libs import (
Interval,
Expand Down Expand Up @@ -135,7 +138,9 @@ def maybe_convert_platform(

if arr.dtype == _dtype_obj:
arr = cast(np.ndarray, arr)
arr = lib.maybe_convert_objects(arr)
arr = lib.maybe_convert_objects(
arr, convert_to_nullable_dtype=get_option("mode.pdep16_data_types")
)

return arr

Expand Down
45 changes: 40 additions & 5 deletions pandas/core/dtypes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@

import numpy as np

from pandas._config import using_string_dtype
from pandas._config import (
get_option,
using_string_dtype,
)

from pandas._libs import (
Interval,
Expand Down Expand Up @@ -1793,14 +1796,36 @@ def validate_all_hashable(*args, error_name: str | None = None) -> None:
raise TypeError("All elements must be hashable")


def pandas_dtype(dtype) -> DtypeObj:
def _map_np_dtype(dtype: np.dtype) -> DtypeObj:
if dtype.kind in "iu":
from pandas.core.arrays.integer import NUMPY_INT_TO_DTYPE

return NUMPY_INT_TO_DTYPE[dtype]
elif dtype.kind == "f":
from pandas.core.arrays.floating import NUMPY_FLOAT_TO_DTYPE

if dtype.itemsize != 2:
# TODO: What do we do for float16? float128?
return NUMPY_FLOAT_TO_DTYPE[dtype]

elif dtype.kind == "b":
from pandas import BooleanDtype

return BooleanDtype()

return dtype


def pandas_dtype(dtype, allow_numpy_dtypes: bool = False) -> DtypeObj:
"""
Convert input into a pandas only dtype object or a numpy dtype object.

Parameters
----------
dtype : object
The object to be converted into a dtype.
allow_numpy_dtypes : bool, default False
Whether to return pre-PDEP16 numpy dtypes for ints, floats, and bools.

Returns
-------
Expand All @@ -1820,10 +1845,18 @@ def pandas_dtype(dtype) -> DtypeObj:
>>> pd.api.types.pandas_dtype(int)
dtype('int64')
"""
allow_numpy_dtypes = allow_numpy_dtypes or not get_option("mode.pdep16_data_types")

# short-circuit
if isinstance(dtype, np.ndarray):
return dtype.dtype
elif isinstance(dtype, (np.dtype, ExtensionDtype)):
if allow_numpy_dtypes:
return dtype.dtype
return _map_np_dtype(dtype.dtype)
elif isinstance(dtype, np.dtype):
if allow_numpy_dtypes:
return dtype
return _map_np_dtype(dtype)
elif isinstance(dtype, ExtensionDtype):
return dtype

# builtin aliases
Expand Down Expand Up @@ -1879,7 +1912,9 @@ def pandas_dtype(dtype) -> DtypeObj:
elif npdtype.kind == "O":
raise TypeError(f"dtype '{dtype}' not understood")

return npdtype
if allow_numpy_dtypes:
return npdtype
return _map_np_dtype(npdtype)


def is_all_strings(value: ArrayLike) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/dtypes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1774,7 +1774,7 @@ def __init__(self, dtype: Dtype = np.float64, fill_value: Any = None) -> None:
)
from pandas.core.dtypes.missing import na_value_for_dtype

dtype = pandas_dtype(dtype)
dtype = pandas_dtype(dtype, allow_numpy_dtypes=True)
if is_string_dtype(dtype):
dtype = np.dtype("object")
if not isinstance(dtype, np.dtype):
Expand Down
9 changes: 9 additions & 0 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@
)
from pandas.core.arrays.sparse import SparseFrameAccessor
from pandas.core.construction import (
array as pd_array,
ensure_wrapped_if_datetimelike,
sanitize_array,
sanitize_masked_array,
Expand Down Expand Up @@ -4411,6 +4412,14 @@ def _iset_item_mgr(
def _set_item_mgr(
self, key, value: ArrayLike, refs: BlockValuesRefs | None = None
) -> None:
if get_option("mode.pdep16_data_types"):
# TODO: possibly handle this at a lower level?
if (
isinstance(value, np.ndarray)
and value.dtype.kind in "iufb"
and value.dtype != np.float16
):
value = pd_array(value, copy=False)
try:
loc = self._info_axis.get_loc(key)
except KeyError:
Expand Down
3 changes: 2 additions & 1 deletion pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2166,9 +2166,10 @@ def _cython_transform(
)

def arr_func(bvalues: ArrayLike) -> ArrayLike:
return self._grouper._cython_operation(
blk_res = self._grouper._cython_operation(
"transform", bvalues, how, 1, **kwargs
)
return blk_res

res_mgr = mgr.apply(arr_func)

Expand Down
12 changes: 11 additions & 1 deletion pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@
)
import pandas.core.common as com
from pandas.core.construction import (
array as pd_array,
ensure_wrapped_if_datetimelike,
extract_array,
sanitize_array,
Expand Down Expand Up @@ -576,6 +577,12 @@ def __new__(
raise ValueError("Index data must be 1-dimensional") from err
raise
arr = ensure_wrapped_if_datetimelike(arr)
if (
arr.dtype.kind in "iufb"
and arr.dtype != np.float16
and get_option("mode.pdep16_data_types")
):
arr = pd_array(arr, copy=False)

klass = cls._dtype_to_subclass(arr.dtype)

Expand Down Expand Up @@ -5391,6 +5398,8 @@ def putmask(self, mask, value) -> Index:

# See also: Block.coerce_to_target_dtype
dtype = self._find_common_type_compat(value)
assert self.dtype != dtype, (self.dtype, value)
# FIXME: should raise with useful message to report a bug!
return self.astype(dtype).putmask(mask, value)

values = self._values.copy()
Expand Down Expand Up @@ -6932,7 +6941,8 @@ def insert(self, loc: int, item) -> Index:
return self.astype(dtype).insert(loc, item)

try:
if isinstance(arr, ExtensionArray):
if isinstance(arr, ExtensionArray) and not isinstance(self, ABCRangeIndex):
# RangeIndex's _simple_new expects a range object
res_values = arr.insert(loc, item)
return type(self)._simple_new(res_values, name=self.name)
else:
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/indexes/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ def _wrap_range_setop(self, other, res_i8) -> Self:
# because test_setops_preserve_freq fails with _validate_frequency raising.
# This raising is incorrect, as 'on_freq' is incorrect. This will
# be fixed by GH#41493
res_values = res_i8.values.view(self._data._ndarray.dtype)
res_values = np.asarray(res_i8.values).view(self._data._ndarray.dtype)
result = type(self._data)._simple_new(
# error: Argument "dtype" to "_simple_new" of "DatetimeArray" has
# incompatible type "Union[dtype[Any], ExtensionDtype]"; expected
Expand Down
Loading
Loading