Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,17 @@ def is_jax_array(x: object) -> TypeIs[jax.Array]:
is_pydata_sparse_array
"""
cls = cast(Hashable, type(x))
return _issubclass_fast(cls, "jax", "Array") or _is_jax_zero_gradient_array(x)
# We test for jax.core.Tracer here to identify jax arrays during jit tracing. From jax 0.8.2 on,
# tracers are not a subclass of jax.Array anymore. Note that tracers can also represent
# non-array values and a fully correct implementation would need to use isinstance checks. Since
# we use hash-based caching with type names as keys, we cannot use instance checks without
# losing performance here. For more information, see
# https://github.com/data-apis/array-api-compat/pull/369 and the corresponding issue.
return (
_issubclass_fast(cls, "jax", "Array")
or _issubclass_fast(cls, "jax.core", "Tracer")
Copy link
Contributor

@jakevdp jakevdp Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main reason for the change in v0.8.2 is that tracers now can represent more than just arrays, and so returning True for any tracer may lead to false positives.

The logic in Array.__instancecheck__ is what is required to accurately check in all contexts whether x is an array: https://github.com/jax-ml/jax/blob/82ae1b1cde42a5b93e00d8c3376cde627c2d83bb/jaxlib/py_array.cc#L2187-L2218

The easiest way to accomplish this would be to check isinstance(x, jax.Array) rather than recreating that logic here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That will force us to use a non-cachable operation, which is going to slow things down. But I don't think we have a choice given that the Tracer type itself no longer holds information on whether or not it's an Array.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jakevdp Can you elaborate a bit more on which kinds of non-array objects now create tracers? I.e. we use an _is_writable_cls and _is_lazy_cls. Even if tracers are not arrays, these functions could still be decidable based on the type only. Are tracers still always lazy and always immutable? I realize that these questions might be ill-defined since tracers do not represent real objects and can disappear from the final computation graph, but for our purposes that's not an issue.

Also, could you show an example of a tracer that does not wrap an array? E.g. are bools in the input now traced as bools and not as arrays? This would be very helpful for testing.

@crusaderky Current helper methods such as _is_writable_cls are designed to return None for non-array API objects. It seems we cannot make that decision based off of type information only on jax>=0.8.2. Are you fine with relaxing the None strategy and returning True for Tracers in general, or do you want to be strict here? The former still fits into our current setup, the latter must use non-cachable isinstance checks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, could you show an example of a tracer that does not wrap an array?

An example is the new hijax Box type. There are no public APIs for this (yet), but here's how you can construct it using currently-private APIs at head:

import jax
from jax._src import hijax

box = hijax.new_box()
hijax.box_set(box, (jnp.arange(4), jnp.ones((3, 3)), 2.0, None))

@jax.jit
def f(box):
  print(type(box))  # <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>
  print(box.aval)  # BoxTy()
  print(hijax.box_get(box))  # (JitTracer(int32[4]), JitTracer(float32[3,3]), JitTracer(~float32[]), None)
  # print(box.dtype)  # fails with AttributeError
  # print(box.shape)  # fails with AttributeError

f(box)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current design is that Tracer subclass reflects the type of transformation being traced (e.g. jit, vmap, grad, jaxpr, etc.) while the aval attribute can be inspected to see what kind of object is being traced.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, that's very helpful. At this point I think we need a decision by the array-api-compat team. Both versions shouldn't be hard to implement.

@crusaderky @lucascolley what are your thoughts?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the false positives seem like a concern from SciPy's side. Maybe we go with this, but add a note into the code comments about the false positives in case anyone complains in the future?

or _is_jax_zero_gradient_array(x)
)
Comment on lines +244 to +248
Copy link

Copilot AI Dec 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check for jax.core.Tracer has been added to is_jax_array, but several other helper functions in this file may also need similar updates for consistency and correctness. Specifically:

  1. _is_array_api_cls (line 302) - checks for jax.Array but not Tracer
  2. _cls_to_namespace (line 550) - checks for jax.Array but not Tracer
  3. _is_writeable_cls (line 940) - checks for jax.Array but not Tracer (JAX tracers should also be non-writeable)
  4. _is_lazy_cls (line 979) - checks for jax.Array but not Tracer (JAX tracers should also be lazy)

If is_jax_array now returns True for Tracers, these other functions should be updated to handle Tracers consistently. Otherwise, a jitted JAX array might pass is_jax_array but fail in array_namespace or behave incorrectly with is_writeable_array and is_lazy_array.

Copilot uses AI. Check for mistakes.


def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]:
Expand Down Expand Up @@ -296,6 +306,7 @@ def _is_array_api_cls(cls: type) -> bool:
or _issubclass_fast(cls, "sparse", "SparseArray")
# TODO: drop support for jax<0.4.32 which didn't have __array_namespace__
or _issubclass_fast(cls, "jax", "Array")
or _issubclass_fast(cls, "jax.core", "Tracer") # see is_jax_array for limitations
)


Expand Down Expand Up @@ -934,6 +945,7 @@ def _is_writeable_cls(cls: type) -> bool | None:
if (
_issubclass_fast(cls, "numpy", "generic")
or _issubclass_fast(cls, "jax", "Array")
or _issubclass_fast(cls, "jax.core", "Tracer") # see is_jax_array for limitations
or _issubclass_fast(cls, "sparse", "SparseArray")
):
return False
Expand Down Expand Up @@ -973,6 +985,7 @@ def _is_lazy_cls(cls: type) -> bool | None:
return False
if (
_issubclass_fast(cls, "jax", "Array")
or _issubclass_fast(cls, "jax.core", "Tracer") # see is_jax_array for limitations
or _issubclass_fast(cls, "dask.array", "Array")
or _issubclass_fast(cls, "ndonnx", "Array")
):
Expand Down
25 changes: 22 additions & 3 deletions tests/test_jax.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from numpy.testing import assert_equal
import pytest

from array_api_compat import device, to_device
from array_api_compat import (
device,
to_device,
is_jax_array,
is_lazy_array,
is_array_api_obj,
is_writeable_array,
)

try:
import jax
Expand All @@ -13,7 +20,7 @@


@pytest.mark.parametrize(
"func",
"func",
[
lambda x: jnp.zeros(1, device=device(x)),
lambda x: jnp.zeros_like(jnp.ones(1, device=device(x))),
Expand All @@ -26,7 +33,7 @@
),
),
lambda x: to_device(jnp.zeros(1), device(x)),
]
],
)
def test_device_jit(func):
# Test work around to https://github.com/jax-ml/jax/issues/26000
Expand All @@ -36,3 +43,15 @@ def test_device_jit(func):
x = jnp.ones(1)
assert_equal(func(x), jnp.asarray([0]))
assert_equal(jax.jit(func)(x), jnp.asarray([0]))


def test_inside_jit():
# Test if jax arrays are handled correctly inside jax.jit.
# Jax tracers are not a subclass of jax.Array from 0.8.2 on. We explicitly test that
# tracers are handled appropriately. For limitations, see is_jax_array() docstring.
# Reference issue: https://github.com/data-apis/array-api-compat/issues/368
x = jnp.asarray([1, 2, 3])
assert jax.jit(is_jax_array)(x)
assert jax.jit(is_array_api_obj)(x)
assert not jax.jit(is_writeable_array)(x)
assert jax.jit(is_lazy_array)(x)
Loading