Description
What happened?
Computing the max
or the isnull
on a DataArray with bfloat16 values raises a:
TypeError: dtype argument must be a NumPy dtype, but it is a <class 'numpy.dtype[bfloat16]'>.
This worked fine before updating numpy to version 2. The main difference in the code seems to be that with numpy < 2, xarray
uses its own implementation of isdtype
, while for numpy >= 2 it relies on np.isdtype
. This is confirmed by checking that doing import numpy as np; del np.isdtype
fixes the problem.
What did you expect to happen?
I expected the computation to be successful, just as prior to numpy 2.
Minimal Complete Verifiable Example
import numpy as np
# del np.isdtype # Uncommenting this line fixes it.
import xarray
import ml_dtypes
da = xarray.DataArray(np.zeros([2], dtype=ml_dtypes.bfloat16), dims=("dim",))
da.isnull() # Or da.max("dim")
MVCE confirmation
- Minimal example — the example is as focused as reasonably possible to demonstrate the underlying issue in xarray.
- Complete example — the example is self-contained, including all data and the text of any traceback.
- Verifiable example — the example copy & pastes into an IPython prompt or Binder notebook, returning the result.
- New issue — a search of GitHub Issues suggests this is not a duplicate.
- Recent environment — the issue occurs with the latest version of xarray and its dependencies.
Relevant log output
TypeError Traceback (most recent call last)
Cell In[1], line 5
3 import numpy as np
4 da = xarray.DataArray(np.zeros([2], dtype=jnp.bfloat16), dims=("dim",))
----> 5 da.isnull()
File ~/dev/xarray/xarray/core/common.py:1293, in DataWithCoords.isnull(self, keep_attrs)
1290 if keep_attrs is None:
1291 keep_attrs = _get_keep_attrs(default=False)
-> 1293 return apply_ufunc(
1294 duck_array_ops.isnull,
1295 self,
1296 dask="allowed",
1297 keep_attrs=keep_attrs,
1298 )
File ~/dev/xarray/xarray/core/computation.py:1278, in apply_ufunc(func, input_core_dims, output_core_dims, exclude_dims, vectorize, join, dataset_join, dataset_fill_value, keep_attrs, kwargs, dask, output_dtypes, output_sizes, meta, dask_gufunc_kwargs, on_missing_core_dim, *args)
1276 # feed DataArray apply_variable_ufunc through apply_dataarray_vfunc
1277 elif any(isinstance(a, DataArray) for a in args):
-> 1278 return apply_dataarray_vfunc(
1279 variables_vfunc,
1280 *args,
1281 signature=signature,
1282 join=join,
1283 exclude_dims=exclude_dims,
1284 keep_attrs=keep_attrs,
1285 )
1286 # feed Variables directly through apply_variable_ufunc
1287 elif any(isinstance(a, Variable) for a in args):
File ~/dev/xarray/xarray/core/computation.py:320, in apply_dataarray_vfunc(func, signature, join, exclude_dims, keep_attrs, *args)
315 result_coords, result_indexes = build_output_coords_and_indexes(
316 args, signature, exclude_dims, combine_attrs=keep_attrs
317 )
319 data_vars = [getattr(a, "variable", a) for a in args]
--> 320 result_var = func(*data_vars)
322 out: tuple[DataArray, ...] | DataArray
323 if signature.num_outputs > 1:
File ~/dev/xarray/xarray/core/computation.py:831, in apply_variable_ufunc(func, signature, exclude_dims, dask, output_dtypes, vectorize, keep_attrs, dask_gufunc_kwargs, *args)
826 if vectorize:
827 func = _vectorize(
828 func, signature, output_dtypes=output_dtypes, exclude_dims=exclude_dims
829 )
--> 831 result_data = func(*input_data)
833 if signature.num_outputs == 1:
834 result_data = (result_data,)
File ~/dev/xarray/xarray/core/duck_array_ops.py:144, in isnull(data)
139 if dtypes.is_datetime_like(scalar_type):
140 # datetime types use NaT for null
141 # note: must check timedelta64 before integers, because currently
142 # timedelta64 inherits from np.integer
143 return isnat(data)
--> 144 elif dtypes.isdtype(scalar_type, ("real floating", "complex floating"), xp=xp):
145 # float types use NaN for null
146 xp = get_array_namespace(data)
147 return xp.isnan(data)
File ~/dev/xarray/xarray/core/dtypes.py:208, in isdtype(dtype, kind, xp)
205 raise TypeError(f"kind must be a string or a tuple of strings: {repr(kind)}")
207 if isinstance(dtype, np.dtype):
--> 208 return npcompat.isdtype(dtype, kind)
209 elif is_extension_array_dtype(dtype):
210 # we never want to match pandas extension array dtypes
211 return False
File ~/miniconda3/envs/xarray-py312/lib/python3.12/site-packages/numpy/_core/numerictypes.py:425, in isdtype(dtype, kind)
423 dtype = _preprocess_dtype(dtype)
424 except _PreprocessDTypeError:
--> 425 raise TypeError(
426 "dtype argument must be a NumPy dtype, "
427 f"but it is a {type(dtype)}."
428 ) from None
430 input_kinds = kind if isinstance(kind, tuple) else (kind,)
432 processed_kinds = set()
TypeError: dtype argument must be a NumPy dtype, but it is a <class 'numpy.dtype[bfloat16]'>.
Anything else we need to know?
Here's a a different reproducer showing the inconsistency between np.isdtype
and npcompat.isdtype
import importlib
from xarray.core import npcompat
import ml_dtypes
import numpy as np
try:
npcompat.isdtype(ml_dtypes.bfloat16.dtype, 'real floating') # `AttributeError: 'module' object has no attribute 'isdtype'`
except Exception as e:
print(e)
numpy_is_dytype = np.isdtype
del np.isdtype
importlib.reload(npcompat)
np.isdtype = numpy_is_dytype
npcompat.isdtype(ml_dtypes.bfloat16.dtype, 'real floating') # No error, but returns False.
Environment
In [5]: xarray.show_versions()
INSTALLED VERSIONS
commit: 03d3e0b
python: 3.12.3 | packaged by conda-forge | (main, Apr 15 2024, 18:35:20) [Clang 16.0.6 ]
python-bits: 64
OS: Darwin
OS-release: 23.6.0
machine: arm64
processor: arm
byteorder: little
LC_ALL: None
LANG: en_US.UTF-8
LOCALE: ('en_US', 'UTF-8')
libhdf5: 1.14.3
libnetcdf: 4.9.2
xarray: 2024.7.1.dev73+g781877cb
pandas: 2.2.2
numpy: 2.1.1
scipy: 1.13.1
netCDF4: 1.6.5
pydap: None
h5netcdf: None
h5py: None
zarr: 2.18.2
cftime: 1.6.4
nc_time_axis: None
iris: None
bottleneck: None
dask: 2024.8.2
distributed: 2024.5.2
matplotlib: 3.9.0
cartopy: None
seaborn: None
numbagg: None
fsspec: 2024.6.0
cupy: None
pint: None
sparse: None
flox: None
numpy_groupies: 0.11.1
setuptools: 70.0.0
pip: 24.0
conda: 24.7.1
pytest: 8.2.2
mypy: 1.10.0
IPython: 8.25.0