Skip to content

Unclear precedence between NumPy and JAX arrays in arithmetic #9952

Open
@shoyer

Description

@shoyer

What happened?

Consider the following case of an xarray.DataArray wrapping a single element JAX array:

import jax
import xarray
import numpy as np

da = xarray.DataArray(jax.numpy.ones(1))

This object is wrapping a jax.Array, with operations implemented via the Array API (yay!), as one can check by inspect da.data.

da * 1 and 1 * da are both JAX arrays. So is da * np.array(1.0).

Unfortunately, np.array(1.0) * da is not -- it's a base NumPy array.

This feels quite inconsistent. Ideally JAX would take precedence in all these cases, even though the Python Array API rules technically do not prescribe an order of precendece between different array types.

What did you expect to happen?

No response

Minimal Complete Verifiable Example

No response

MVCE confirmation

  • Minimal example — the example is as focused as reasonably possible to demonstrate the underlying issue in xarray.
  • Complete example — the example is self-contained, including all data and the text of any traceback.
  • Verifiable example — the example copy & pastes into an IPython prompt or Binder notebook, returning the result.
  • New issue — a search of GitHub Issues suggests this is not a duplicate.
  • Recent environment — the issue occurs with the latest version of xarray and its dependencies.

Relevant log output

No response

Anything else we need to know?

No response

Environment

xarray = 2025.1.1 jax = 0.4.38

Metadata

Metadata

Assignees

No one assigned

    Labels

    array API standardSupport for the Python array API standardbugtopic-arraysrelated to flexible array support

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions