Skip to content

Commit b95ba89

Browse files
committed
JEP 28661: the __jax_array__ protocol
1 parent 7ba83d8 commit b95ba89

File tree

2 files changed

+215
-0
lines changed

2 files changed

+215
-0
lines changed
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
# JEP 28661: Supporting the `__jax_array__` protocol
2+
3+
[@jakevdp](http://github.com/jakevdp), *May 2025*
4+
5+
An occasional user request is for the ability to define custom array-like objects that
6+
work with jax APIs. JAX currently has a partial implementation of a mechanism that does
7+
this via a `__jax_array__` method defined on the custom object. This was never intended
8+
to be a load-bearing public API (see the discussion at {jax-issue}`#4725`), but has
9+
become essential to packages like Keras and flax, which explicitly document the ability
10+
to use their custom array objects with jax functions. This JEP proposes a design for
11+
full, documented support of the `__jax_array__` protocol.
12+
13+
## Levels of array extensibility
14+
Requests for extensibility of JAX arrays come in a few flavors:
15+
16+
### Level 1 Extensibility: polymorphic inputs
17+
What I’ll call "Level 1" extensibility is the desire that JAX APIs accept polymorphic inputs.
18+
That is, a user desires behavior like this:
19+
20+
```python
21+
class CustomArray:
22+
data: numpy.ndarray
23+
...
24+
25+
x = CustomArray(np.arange(5))
26+
result = jnp.sin(x) # Converts `x` to JAX array and returns a JAX array
27+
```
28+
29+
Under this extensibility model, JAX functions would accept CustomArray objects as inputs,
30+
implicitly converting them to `jax.Array` objects for the sake of computation.
31+
This is similar to the functionality offered by NumPy via the `__array__` method, and in
32+
JAX (in many but not all cases) via the `__jax_array__` method.
33+
34+
This is the mode of extensibility that has been requested by the maintainers of `flax.nnx`
35+
and others. The current implementation is also used by JAX internally for the case of
36+
symbolic dimensions.
37+
38+
### Level 2 extensibility: polymorphic outputs
39+
What I’ll call "Level 2" extensibility is the desire that JAX APIs should not only accept
40+
polymorphic inputs, but also wrap outputs to match the class of the input.
41+
That is, a user desires behavior like this:
42+
43+
```python
44+
class CustomArray:
45+
data: numpy.ndarray
46+
...
47+
48+
x = CustomArray(np.arange(5))
49+
result = jnp.sin(x) # returns a new CustomArray
50+
```
51+
52+
Under this extensibility model, JAX functions would not only accept custom objects
53+
as inputs, but have some protocol to determine how to correctly re-wrap outputs with
54+
the same class. In NumPy, this sort of functionality is offered in varying degrees by
55+
the special `__array_ufunc__`, `__array_wrap__`, and `__array_function__` protocols,
56+
which allow user-defined objects to customize how NumPy API functions operate on
57+
arbitrary inputs and map input types to outputs.
58+
JAX does not currently have any equivalent to these interfaces in NumPy.
59+
60+
This is the mode of extensibility that has been requested by the maintainers of `keras`,
61+
among others.
62+
63+
### Level 3 extensibility: subclassing `Array`
64+
65+
What I’ll call "Level 3" extensibility is the desire that the JAX array object itself
66+
could be subclassable. NumPy provides some APIs that allow this
67+
(see [Subclassing ndarray](https://numpy.org/devdocs/user/basics.subclassing.html)) but
68+
this sort of approach would take some extra thought in JAX due to the need for
69+
representing array objects abstractly via tracing.
70+
71+
This mode of extensibility has occasionally been requested by users who want to add
72+
special metadata to JAX arrays, such as units of measurement.
73+
74+
## Synopsis
75+
76+
For the sake of this proposal, we will stick with the simplest, level 1 extensibility
77+
model. The proposed interface is the one currently non-uniformly supported by a number
78+
of JAX APIs, the `__jax_array__` method. Its usage looks something like this:
79+
80+
```python
81+
import jax
82+
import jax.numpy as jnp
83+
import numpy as np
84+
85+
class CustomArray:
86+
data: np.ndarray
87+
88+
def __init__(self, data: np.ndarray):
89+
self.data = data
90+
91+
def __jax_array__(self) -> jax.Array:
92+
return jnp.asarray(self.data)
93+
94+
arr = CustomArray(np.arange(5))
95+
result = jnp.multiply(arr, 2)
96+
print(repr(result))
97+
# Array([0, 2, 4, 6, 8], dtype=int32)
98+
```
99+
100+
We may revisit other extensibility levels in the future.
101+
102+
## Design challenges
103+
104+
JAX presents some interesting design challenges related to this kind of extensibility,
105+
which have not been fully explored previously. We’ll discuss them in turn here:
106+
107+
### Priority of `__jax_array__` vs. PyTree flattening
108+
JAX already has a supported mechanism for registering custom objects, namely pytree
109+
registration (see [Extending pytrees](https://docs.jax.dev/en/latest/pytrees.html#extending-pytrees)).
110+
If we also support __jax_array__, which one should take precedence?
111+
112+
To put this more concretely, what should be the result of this code?
113+
114+
```python
115+
@jax.jit
116+
def f(x):
117+
print("is JAX array:", isinstance(x, jax.Array))
118+
119+
f(CustomArray(...))
120+
```
121+
122+
If we choose to prioritize `__jax_array__` at the JIT boundary, then the output of this
123+
function would be:
124+
```
125+
is JAX array: True
126+
```
127+
That is, at the JIT boundary, the `CustomArray` object would be converted into a
128+
`__jax_array__`, and its shape and dtype would be used to construct a standard JAX
129+
tracer for the function.
130+
131+
If we choose to prioritize pytree flattening at the JIT boundary, then the output of
132+
this function would be:
133+
```
134+
type(x)=CustomArray
135+
```
136+
That is, at the JIT boundary, the `CustomArray` object is flattened, and then unflattened
137+
before being passed to the JIT-compiled function for tracing. If `CustomArray` has been
138+
registered as a pytree, it will generally contain traced arrays as its attributes, and
139+
when x is passed to any JAX API that supports `__jax_array__`, these traced attributes
140+
will be converted to a single traced array according to the logic specified in the method.
141+
142+
There are deeper consequences here for how other transformations like vmap and grad work
143+
when encountering custom objects: for example, if we prioritize pytree flattening, vmap
144+
would operate over the dimensions of the flattened contents of the custom object, while
145+
if we prioritize `__jax_array__`, vmap would operate over the converted array dimensions.
146+
147+
This also has consequences when it comes to JIT invariance: consider a function like this:
148+
```python
149+
def f(x):
150+
if isinstance(x, CustomArray):
151+
return x.custom_method()
152+
else:
153+
# do something else
154+
...
155+
156+
result1 = f(x)
157+
result2 = jax.jit(f)(x)
158+
```
159+
If `jit` consumes `x` via pytree flattening, the results should agree for a well-specified
160+
flattening rule. If `jit` consumes `x` via `__jax_array__`, the results will differ because
161+
`x` is no longer a CustomArray within the JIT-compiled version of the function.
162+
163+
#### Synopsis
164+
As of JAX v0.6.0, transformations prioritize `__jax_array__` when it is available. This status
165+
quo can lead to confusion around lack of JIT invariance, and the current implementation in practice
166+
leads to subtle bugs in the case of automatic differentiation, where the forward and backward pass
167+
do not treat inputs consistently.
168+
169+
Because the pytree extensibility mechanism already exists for the case of customizing
170+
transformations, it seems most straightforward if transformations act only via this
171+
mechanism: that is, **we propose to remove `__jax_array__` parsing during abstractification.**
172+
This approach will preserve object identity through transformations, and give the user the
173+
most possible flexibility. If the user wants to opt-in to array conversion semantics, that
174+
is always possible by explicitly casting their input via jnp.asarray, which will trigger the
175+
`__jax_array__` protocol.
176+
177+
### Which APIs should support `__jax_array__`?
178+
JAX has a number of different levels of API, from the level of explicit primitive binding
179+
(e.g. `jax.lax.add_p.bind(x, y)`) to the `jax.lax` APIs (e.g. `jax.lax.add(x, y)`) to the
180+
`jax.numpy` APIs (e.g. `jax.numpy.add(x, y)`). Which of these API categories should handle
181+
implicit conversion via `__jax_array__`?
182+
183+
In order to limit the scope of the change and the required testing, I propose that `__jax_array__`
184+
only be explicitly supported in `jax.numpy` APIs: after all, it is inspired by the` __array__`
185+
protocol which is supported by the NumPy package. We could always expand this in the future to
186+
`jax.lax` APIs if needed.
187+
188+
This is in line with the current state of the package, where `__jax_array__` handling is mainly
189+
within the input validation utilities used by `jax.numpy` APIs.
190+
191+
## Implementation
192+
With these design choices in mind, we plan to implement this as follows:
193+
194+
- **Adding runtime support to `jax.numpy`**: This is likely the easiest part, as most
195+
`jax.numpy` functions use a common internal utility (`ensure_arraylike`) to validate
196+
inputs and convert them to array. This utility already supports `__jax_array__`, and
197+
so most jax.numpy APIs are already compliant.
198+
- **Adding test coverage**: To ensure compliance across the APIs, we should add a new
199+
test scaffold that calls every `jax.numpy` API with custom inputs and validates correct
200+
behavior.
201+
- **Deprecating `__jax_array__` during abstractification**: Currently JAX's abstractification
202+
pass, used in `jit` and other transformations, does parse the `__jax_array__` protocol,
203+
and this is not the behavior we want long-term. We need to deprecate this behavior, and
204+
ensure that downstream packages that rely on it can move toward pytree registration or
205+
explicit array conversion where necessary.
206+
- **Adding type annotations**: the type interface for jax.numpy functions is in
207+
`jax/numpy/__init__.pyi`, and we’ll need to change each input type from `ArrayLike` to
208+
`ArrayLike | SupportsJAXArray`, where the latter is a protocol with a `__jax_array__`
209+
method. We cannot add this directly to the `ArrayLike` definition, because `ArrayLike`
210+
is used in contexts where `__jax_array__` should not be supported.
211+
- **Documentation**: once the above support is added, we should add a documentation section
212+
on array extensibility that outlines exactly what to expect regarding the `__jax_array__`
213+
protocol, with examples of how it can be used in conjunction with pytree registration
214+
in order to effectively work with user-defined types.

docs/jep/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ Then create a pull request that adds a file named
5252
17111: Efficient transposition of `shard_map` (and other maps) <17111-shmap-transpose>
5353
18137: Scope of JAX NumPy & SciPy Wrappers <18137-numpy-scipy-scope>
5454
25516: Effort-based versioning <25516-effver>
55+
28661: Supporting the `__jax_array__` protocol <28661-jax-array-protocol>
5556

5657

5758
Several early JEPs were converted in hindsight from other documentation,

0 commit comments

Comments
 (0)