Skip to content

fix: Filter out StringDType even when the backing array is not NumpyExtensionArray #10559

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Jul 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
946dd78
fix: keep dtype as `object` for `pd.StringDtype` in `safe_cast_to_index`
ilan-gold Jul 22, 2025
147e3a7
chore: comment
ilan-gold Jul 22, 2025
3cf92dd
fix: broader fix
ilan-gold Jul 23, 2025
287e4be
feat: ban use of `pd.api.types.is_extension_array_dtype`
ilan-gold Jul 23, 2025
ffc905d
fix: type ignore
ilan-gold Jul 23, 2025
b97a64e
Merge branch 'main' into ig/fix_string_dtype
ilan-gold Jul 24, 2025
ed86c09
Update xarray/core/extension_array.py
ilan-gold Jul 24, 2025
9ea9d8d
? mypy
ilan-gold Jul 25, 2025
d131442
Merge branch 'ig/fix_string_dtype' of github.com:ilan-gold/xarray int…
ilan-gold Jul 25, 2025
7dc9662
fix: remove comment
ilan-gold Jul 25, 2025
7343f6d
try blanket ignore
ilan-gold Jul 28, 2025
cc1776f
try blanket ignore again
ilan-gold Jul 28, 2025
2e8ed67
fix: mypy
ilan-gold Jul 28, 2025
ac77713
use Any
ilan-gold Jul 28, 2025
3ecc62f
Merge branch 'main' into ig/fix_string_dtype
ilan-gold Jul 29, 2025
c0b7b4c
Merge branch 'main' into ig/fix_string_dtype
ilan-gold Jul 29, 2025
bfbe244
chore: add comment
ilan-gold Jul 29, 2025
594c164
Update xarray/core/utils.py
ilan-gold Jul 29, 2025
0c6ba04
Merge branch 'ig/fix_string_dtype' of github.com:ilan-gold/xarray int…
ilan-gold Jul 29, 2025
a436d78
refactor: use is_allowed_extension_array more
ilan-gold Jul 29, 2025
5e603a7
fix: remove one of the loops
ilan-gold Jul 29, 2025
a8d1aa1
Update xarray/computation/ops.py
ilan-gold Jul 29, 2025
bc27b15
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 29, 2025
8172ae7
try again
dcherian Jul 29, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ accel = [
"numba>=0.59",
"flox>=0.9",
"opt_einsum",
"numpy<2.3", # numba has not updated yet: https://github.com/numba/numba/issues/10105
]
complete = ["xarray[accel,etc,io,parallel,viz]"]
io = [
Expand Down Expand Up @@ -324,6 +325,8 @@ known-first-party = ["xarray"]
[tool.ruff.lint.flake8-tidy-imports]
# Disallow all relative imports.
ban-relative-imports = "all"
[tool.ruff.lint.flake8-tidy-imports.banned-api]
"pandas.api.types.is_extension_array_dtype".msg = "Use xarray.core.utils.is_allowed_extension_array{_dtype} instead. Only use the banend API if the incoming data has already been sanitized by xarray"

[tool.pytest.ini_options]
addopts = [
Expand Down
9 changes: 6 additions & 3 deletions xarray/computation/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@
from __future__ import annotations

import operator
from typing import Literal
from typing import TYPE_CHECKING, Literal

import numpy as np

from xarray.core import dtypes, duck_array_ops

if TYPE_CHECKING:
pass

try:
import bottleneck as bn

Expand Down Expand Up @@ -158,8 +161,8 @@ def fillna(data, other, join="left", dataset_join="left"):
)


# Unsure why we get a mypy error here
def where_method(self, cond, other=dtypes.NA): # type: ignore[has-type]
# TODO: type this properly
def where_method(self, cond, other=dtypes.NA): # type: ignore[unused-ignore,has-type]
"""Return elements from `self` or `other` depending on `cond`.

Parameters
Expand Down
10 changes: 5 additions & 5 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

import numpy as np
import pandas as pd
from pandas.api.types import is_extension_array_dtype

from xarray.coding.calendar_ops import convert_calendar, interp_calendar
from xarray.coding.cftimeindex import CFTimeIndex, _parse_array_of_cftime_strings
Expand Down Expand Up @@ -91,6 +90,7 @@
either_dict_or_kwargs,
emit_user_level_warning,
infix_dims,
is_allowed_extension_array,
is_dict_like,
is_duck_array,
is_duck_dask_array,
Expand Down Expand Up @@ -6780,7 +6780,7 @@ def reduce(
elif (
# Some reduction functions (e.g. std, var) need to run on variables
# that don't have the reduce dims: PR5393
not is_extension_array_dtype(var.dtype)
not pd.api.types.is_extension_array_dtype(var.dtype) # noqa: TID251
and (
not reduce_dims
or not numeric_only
Expand Down Expand Up @@ -7105,12 +7105,12 @@ def _to_dataframe(self, ordered_dims: Mapping[Any, int]):
non_extension_array_columns = [
k
for k in columns_in_order
if not is_extension_array_dtype(self.variables[k].data)
if not pd.api.types.is_extension_array_dtype(self.variables[k].data) # noqa: TID251
]
extension_array_columns = [
k
for k in columns_in_order
if is_extension_array_dtype(self.variables[k].data)
if pd.api.types.is_extension_array_dtype(self.variables[k].data) # noqa: TID251
]
extension_array_columns_different_index = [
k
Expand Down Expand Up @@ -7302,7 +7302,7 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self:
arrays = []
extension_arrays = []
for k, v in dataframe.items():
if not is_extension_array_dtype(v) or isinstance(
if not is_allowed_extension_array(v) or isinstance(
v.array, UNSUPPORTED_EXTENSION_ARRAY_TYPES
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't add UNSUPPORTED_EXTENSION_ARRAY_TYPES to the new is_allowed_extension_array function because we do allow them as backing arrays to Index object, I think. Maybe should get a test?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should merge these checks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to account for internal duck array support i.e., that which allows preserving the dtype of extension array indices that are in this UNSUPPORTED_EXTENSION_ARRAY_TYPES whitelist. See the note here:

# This does not use the UNSUPPORTED_EXTENSION_ARRAY_TYPES whitelist because
# we do support extension arrays from datetime, for example, that need
# duck array support internally via this class.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was that for DatetimeIndex? If so, clarifying that comment would be helpful.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be for any index whose underlying array is one of the ones in UNSUPPORTED_EXTENSION_ARRAY_TYPES, so possible DatetimeIndex but also likely other ones as well

):
arrays.append((k, np.asarray(v)))
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any

import numpy as np
from pandas.api.types import is_extension_array_dtype
import pandas as pd

from xarray.compat import array_api_compat, npcompat
from xarray.compat.npcompat import HAS_STRING_DTYPE
Expand Down Expand Up @@ -213,7 +213,7 @@ def isdtype(dtype, kind: str | tuple[str, ...], xp=None) -> bool:

if isinstance(dtype, np.dtype):
return npcompat.isdtype(dtype, kind)
elif is_extension_array_dtype(dtype):
elif pd.api.types.is_extension_array_dtype(dtype): # noqa: TID251
# we never want to match pandas extension array dtypes
return False
else:
Expand Down
13 changes: 7 additions & 6 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
take,
unravel_index, # noqa: F401
)
from pandas.api.types import is_extension_array_dtype

from xarray.compat import dask_array_compat, dask_array_ops
from xarray.compat.array_api_compat import get_array_namespace
Expand Down Expand Up @@ -184,7 +183,7 @@ def isnull(data):
dtype = xp.bool_ if hasattr(xp, "bool_") else xp.bool
return full_like(data, dtype=dtype, fill_value=False)
# at this point, array should have dtype=object
elif isinstance(data, np.ndarray) or is_extension_array_dtype(data):
elif isinstance(data, np.ndarray) or pd.api.types.is_extension_array_dtype(data): # noqa: TID251
return pandas_isnull(data)
else:
# Not reachable yet, but intended for use with other duck array
Expand Down Expand Up @@ -266,10 +265,12 @@ def asarray(data, xp=np, dtype=None):

def as_shared_dtype(scalars_or_arrays, xp=None):
"""Cast arrays to a shared dtype using xarray's type promotion rules."""
if any(is_extension_array_dtype(x) for x in scalars_or_arrays):
extension_array_types = [
x.dtype for x in scalars_or_arrays if is_extension_array_dtype(x)
]
extension_array_types = [
x.dtype
for x in scalars_or_arrays
if pd.api.types.is_extension_array_dtype(x) # noqa: TID251
]
if len(extension_array_types) >= 1:
non_nans = [x for x in scalars_or_arrays if not isna(x)]
if len(extension_array_types) == len(non_nans) and all(
isinstance(x, type(extension_array_types[0])) for x in extension_array_types
Expand Down
14 changes: 7 additions & 7 deletions xarray/core/extension_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
import numpy as np
import pandas as pd
from packaging.version import Version
from pandas.api.types import is_extension_array_dtype

from xarray.core.types import DTypeLikeSave, T_ExtensionArray
from xarray.core.utils import NDArrayMixin
from xarray.core.utils import NDArrayMixin, is_allowed_extension_array

HANDLED_EXTENSION_ARRAY_FUNCTIONS: dict[Callable, Callable] = {}

Expand Down Expand Up @@ -100,10 +99,11 @@ def __post_init__(self):
raise TypeError(f"{self.array} is not an pandas ExtensionArray.")
# This does not use the UNSUPPORTED_EXTENSION_ARRAY_TYPES whitelist because
# we do support extension arrays from datetime, for example, that need
# duck array support internally via this class.
if isinstance(self.array, pd.arrays.NumpyExtensionArray):
# duck array support internally via this class. These can appear from `DatetimeIndex`
# wrapped by `PandasIndex` internally, for example.
if not is_allowed_extension_array(self.array):
raise TypeError(
"`NumpyExtensionArray` should be converted to a numpy array in `xarray` internally."
f"{self.array.dtype!r} should be converted to a numpy array in `xarray` internally."
)

def __array_function__(self, func, types, args, kwargs):
Expand All @@ -126,7 +126,7 @@ def replace_duck_with_extension_array(args) -> list:
if func not in HANDLED_EXTENSION_ARRAY_FUNCTIONS:
raise KeyError("Function not registered for pandas extension arrays.")
res = HANDLED_EXTENSION_ARRAY_FUNCTIONS[func](*args, **kwargs)
if is_extension_array_dtype(res):
if is_allowed_extension_array(res):
return PandasExtensionArray(res)
return res

Expand All @@ -135,7 +135,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):

def __getitem__(self, key) -> PandasExtensionArray[T_ExtensionArray]:
item = self.array[key]
if is_extension_array_dtype(item):
if is_allowed_extension_array(item):
return PandasExtensionArray(item)
if np.isscalar(item) or isinstance(key, int):
return PandasExtensionArray(type(self.array)._from_sequence([item])) # type: ignore[call-arg,attr-defined,unused-ignore]
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Frozen,
emit_user_level_warning,
get_valid_numpy_dtype,
is_allowed_extension_array_dtype,
is_dict_like,
is_scalar,
)
Expand Down Expand Up @@ -666,9 +667,8 @@ def __init__(

self.index = index
self.dim = dim

if coord_dtype is None:
if pd.api.types.is_extension_array_dtype(index.dtype):
if is_allowed_extension_array_dtype(index.dtype):
cast(pd.api.extensions.ExtensionDtype, index.dtype)
coord_dtype = index.dtype
else:
Expand Down
13 changes: 6 additions & 7 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
NDArrayMixin,
either_dict_or_kwargs,
get_valid_numpy_dtype,
is_allowed_extension_array,
is_allowed_extension_array_dtype,
is_duck_array,
is_duck_dask_array,
is_scalar,
Expand Down Expand Up @@ -1763,12 +1765,12 @@ def __init__(
self.array = safe_cast_to_index(array)

if dtype is None:
if pd.api.types.is_extension_array_dtype(array.dtype):
if is_allowed_extension_array(array):
cast(pd.api.extensions.ExtensionDtype, array.dtype)
self._dtype = array.dtype
else:
self._dtype = get_valid_numpy_dtype(array)
elif pd.api.types.is_extension_array_dtype(dtype):
elif is_allowed_extension_array_dtype(dtype):
self._dtype = cast(pd.api.extensions.ExtensionDtype, dtype)
else:
self._dtype = np.dtype(cast(DTypeLike, dtype))
Expand Down Expand Up @@ -1816,10 +1818,7 @@ def get_duck_array(self) -> np.ndarray | PandasExtensionArray:
# We return an PandasExtensionArray wrapper type that satisfies
# duck array protocols.
# `NumpyExtensionArray` is excluded
if pd.api.types.is_extension_array_dtype(self.array) and not isinstance(
self.array.array,
pd.arrays.NumpyExtensionArray, # type: ignore[attr-defined]
):
if is_allowed_extension_array(self.array):
from xarray.core.extension_array import PandasExtensionArray

return PandasExtensionArray(self.array.array)
Expand Down Expand Up @@ -1916,7 +1915,7 @@ def copy(self, deep: bool = True) -> Self:

@property
def nbytes(self) -> int:
if pd.api.types.is_extension_array_dtype(self.dtype):
if is_allowed_extension_array(self.array):
return self.array.nbytes

dtype = self._get_numpy_dtype()
Expand Down
14 changes: 14 additions & 0 deletions xarray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,20 @@
T = TypeVar("T")


def is_allowed_extension_array_dtype(dtype: Any):
return pd.api.types.is_extension_array_dtype(dtype) and not isinstance( # noqa: TID251
dtype, pd.StringDtype
)


def is_allowed_extension_array(array: Any) -> bool:
return (
hasattr(array, "dtype")
and is_allowed_extension_array_dtype(array.dtype)
and not isinstance(array, pd.arrays.NumpyExtensionArray) # type: ignore[attr-defined]
)


def alias_message(old_name: str, new_name: str) -> str:
return f"{old_name} has been deprecated. Use {new_name} instead."

Expand Down
8 changes: 6 additions & 2 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
emit_user_level_warning,
ensure_us_time_resolution,
infix_dims,
is_allowed_extension_array,
is_dict_like,
is_duck_array,
is_duck_dask_array,
Expand Down Expand Up @@ -198,7 +199,9 @@ def _maybe_wrap_data(data):
return PandasIndexingAdapter(data)
if isinstance(data, UNSUPPORTED_EXTENSION_ARRAY_TYPES):
return data.to_numpy()
if isinstance(data, pd.api.extensions.ExtensionArray):
if isinstance(
data, pd.api.extensions.ExtensionArray
) and is_allowed_extension_array(data):
return PandasExtensionArray(data)
return data

Expand Down Expand Up @@ -261,7 +264,8 @@ def convert_non_numpy_type(data):
if isinstance(data, pd.Series | pd.DataFrame):
if (
isinstance(data, pd.Series)
and pd.api.types.is_extension_array_dtype(data)
and is_allowed_extension_array(data.array)
# Some datetime types are not allowed as well as backing Variable types
and not isinstance(data.array, UNSUPPORTED_EXTENSION_ARRAY_TYPES)
):
pandas_data = data.array
Expand Down
5 changes: 4 additions & 1 deletion xarray/tests/test_pandas_to_xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import pandas as pd
import pandas._testing as tm
import pytest
from packaging.version import Version
from pandas import (
Categorical,
CategoricalIndex,
Expand Down Expand Up @@ -171,7 +172,9 @@ def test_to_xarray_with_multiindex(self, df):

result = result.to_dataframe()
expected = df.copy()
expected["f"] = expected["f"].astype(object)
expected["f"] = expected["f"].astype(
object if Version(pd.__version__) < Version("3.0.0dev0") else str
)
expected.columns.name = None
tm.assert_frame_equal(result, expected)

Expand Down
2 changes: 1 addition & 1 deletion xarray/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1594,7 +1594,7 @@ def test_pandas_categorical_dtype(self):
data = pd.Categorical(np.arange(10, dtype="int64"))
v = self.cls("x", data)
print(v) # should not error
assert pd.api.types.is_extension_array_dtype(v.dtype)
assert isinstance(v.dtype, pd.CategoricalDtype)

def test_squeeze(self):
v = Variable(["x", "y"], [[1]])
Expand Down
Loading