Open
Description
What happened?
Consider the following case of an xarray.DataArray wrapping a single element JAX array:
import jax
import xarray
import numpy as np
da = xarray.DataArray(jax.numpy.ones(1))
This object is wrapping a jax.Array
, with operations implemented via the Array API (yay!), as one can check by inspect da.data
.
da * 1
and 1 * da
are both JAX arrays. So is da * np.array(1.0)
.
Unfortunately, np.array(1.0) * da
is not -- it's a base NumPy array.
This feels quite inconsistent. Ideally JAX would take precedence in all these cases, even though the Python Array API rules technically do not prescribe an order of precendece between different array types.
What did you expect to happen?
No response
Minimal Complete Verifiable Example
No response
MVCE confirmation
- Minimal example — the example is as focused as reasonably possible to demonstrate the underlying issue in xarray.
- Complete example — the example is self-contained, including all data and the text of any traceback.
- Verifiable example — the example copy & pastes into an IPython prompt or Binder notebook, returning the result.
- New issue — a search of GitHub Issues suggests this is not a duplicate.
- Recent environment — the issue occurs with the latest version of xarray and its dependencies.
Relevant log output
No response
Anything else we need to know?
No response
Environment
xarray = 2025.1.1
jax = 0.4.38