Skip to content

Incompatibilities with bloat16 after update to numpy 2 #9568

Open
@alvarosg

Description

@alvarosg

What happened?

Computing the max or the isnull on a DataArray with bfloat16 values raises a:
TypeError: dtype argument must be a NumPy dtype, but it is a <class 'numpy.dtype[bfloat16]'>.

This worked fine before updating numpy to version 2. The main difference in the code seems to be that with numpy < 2, xarray uses its own implementation of isdtype, while for numpy >= 2 it relies on np.isdtype. This is confirmed by checking that doing import numpy as np; del np.isdtype fixes the problem.

What did you expect to happen?

I expected the computation to be successful, just as prior to numpy 2.

Minimal Complete Verifiable Example

import numpy as np
# del np.isdtype  # Uncommenting this line fixes it.

import xarray
import ml_dtypes

da = xarray.DataArray(np.zeros([2], dtype=ml_dtypes.bfloat16), dims=("dim",))
da.isnull() # Or da.max("dim")

MVCE confirmation

  • Minimal example — the example is as focused as reasonably possible to demonstrate the underlying issue in xarray.
  • Complete example — the example is self-contained, including all data and the text of any traceback.
  • Verifiable example — the example copy & pastes into an IPython prompt or Binder notebook, returning the result.
  • New issue — a search of GitHub Issues suggests this is not a duplicate.
  • Recent environment — the issue occurs with the latest version of xarray and its dependencies.

Relevant log output

TypeError                                 Traceback (most recent call last)
Cell In[1], line 5
      3 import numpy as np
      4 da = xarray.DataArray(np.zeros([2], dtype=jnp.bfloat16), dims=("dim",))
----> 5 da.isnull()

File ~/dev/xarray/xarray/core/common.py:1293, in DataWithCoords.isnull(self, keep_attrs)
   1290 if keep_attrs is None:
   1291     keep_attrs = _get_keep_attrs(default=False)
-> 1293 return apply_ufunc(
   1294     duck_array_ops.isnull,
   1295     self,
   1296     dask="allowed",
   1297     keep_attrs=keep_attrs,
   1298 )

File ~/dev/xarray/xarray/core/computation.py:1278, in apply_ufunc(func, input_core_dims, output_core_dims, exclude_dims, vectorize, join, dataset_join, dataset_fill_value, keep_attrs, kwargs, dask, output_dtypes, output_sizes, meta, dask_gufunc_kwargs, on_missing_core_dim, *args)
   1276 # feed DataArray apply_variable_ufunc through apply_dataarray_vfunc
   1277 elif any(isinstance(a, DataArray) for a in args):
-> 1278     return apply_dataarray_vfunc(
   1279         variables_vfunc,
   1280         *args,
   1281         signature=signature,
   1282         join=join,
   1283         exclude_dims=exclude_dims,
   1284         keep_attrs=keep_attrs,
   1285     )
   1286 # feed Variables directly through apply_variable_ufunc
   1287 elif any(isinstance(a, Variable) for a in args):

File ~/dev/xarray/xarray/core/computation.py:320, in apply_dataarray_vfunc(func, signature, join, exclude_dims, keep_attrs, *args)
    315 result_coords, result_indexes = build_output_coords_and_indexes(
    316     args, signature, exclude_dims, combine_attrs=keep_attrs
    317 )
    319 data_vars = [getattr(a, "variable", a) for a in args]
--> 320 result_var = func(*data_vars)
    322 out: tuple[DataArray, ...] | DataArray
    323 if signature.num_outputs > 1:

File ~/dev/xarray/xarray/core/computation.py:831, in apply_variable_ufunc(func, signature, exclude_dims, dask, output_dtypes, vectorize, keep_attrs, dask_gufunc_kwargs, *args)
    826     if vectorize:
    827         func = _vectorize(
    828             func, signature, output_dtypes=output_dtypes, exclude_dims=exclude_dims
    829         )
--> 831 result_data = func(*input_data)
    833 if signature.num_outputs == 1:
    834     result_data = (result_data,)

File ~/dev/xarray/xarray/core/duck_array_ops.py:144, in isnull(data)
    139 if dtypes.is_datetime_like(scalar_type):
    140     # datetime types use NaT for null
    141     # note: must check timedelta64 before integers, because currently
    142     # timedelta64 inherits from np.integer
    143     return isnat(data)
--> 144 elif dtypes.isdtype(scalar_type, ("real floating", "complex floating"), xp=xp):
    145     # float types use NaN for null
    146     xp = get_array_namespace(data)
    147     return xp.isnan(data)

File ~/dev/xarray/xarray/core/dtypes.py:208, in isdtype(dtype, kind, xp)
    205     raise TypeError(f"kind must be a string or a tuple of strings: {repr(kind)}")
    207 if isinstance(dtype, np.dtype):
--> 208     return npcompat.isdtype(dtype, kind)
    209 elif is_extension_array_dtype(dtype):
    210     # we never want to match pandas extension array dtypes
    211     return False

File ~/miniconda3/envs/xarray-py312/lib/python3.12/site-packages/numpy/_core/numerictypes.py:425, in isdtype(dtype, kind)
    423     dtype = _preprocess_dtype(dtype)
    424 except _PreprocessDTypeError:
--> 425     raise TypeError(
    426         "dtype argument must be a NumPy dtype, "
    427         f"but it is a {type(dtype)}."
    428     ) from None
    430 input_kinds = kind if isinstance(kind, tuple) else (kind,)
    432 processed_kinds = set()

TypeError: dtype argument must be a NumPy dtype, but it is a <class 'numpy.dtype[bfloat16]'>.

Anything else we need to know?

Here's a a different reproducer showing the inconsistency between np.isdtype and npcompat.isdtype

import importlib
from xarray.core import npcompat
import ml_dtypes
import numpy as np
try:
  npcompat.isdtype(ml_dtypes.bfloat16.dtype, 'real floating')  # `AttributeError: 'module' object has no attribute 'isdtype'`
except Exception as e:
  print(e)

numpy_is_dytype = np.isdtype
del np.isdtype
importlib.reload(npcompat)
np.isdtype = numpy_is_dytype

npcompat.isdtype(ml_dtypes.bfloat16.dtype, 'real floating')  # No error, but returns False.

Environment

In [5]: xarray.show_versions()

INSTALLED VERSIONS

commit: 03d3e0b
python: 3.12.3 | packaged by conda-forge | (main, Apr 15 2024, 18:35:20) [Clang 16.0.6 ]
python-bits: 64
OS: Darwin
OS-release: 23.6.0
machine: arm64
processor: arm
byteorder: little
LC_ALL: None
LANG: en_US.UTF-8
LOCALE: ('en_US', 'UTF-8')
libhdf5: 1.14.3
libnetcdf: 4.9.2

xarray: 2024.7.1.dev73+g781877cb
pandas: 2.2.2
numpy: 2.1.1
scipy: 1.13.1
netCDF4: 1.6.5
pydap: None
h5netcdf: None
h5py: None
zarr: 2.18.2
cftime: 1.6.4
nc_time_axis: None
iris: None
bottleneck: None
dask: 2024.8.2
distributed: 2024.5.2
matplotlib: 3.9.0
cartopy: None
seaborn: None
numbagg: None
fsspec: 2024.6.0
cupy: None
pint: None
sparse: None
flox: None
numpy_groupies: 0.11.1
setuptools: 70.0.0
pip: 24.0
conda: 24.7.1
pytest: 8.2.2
mypy: 1.10.0
IPython: 8.25.0

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions