|
| 1 | +# JEP 28661: Supporting the `__jax_array__` protocol |
| 2 | + |
| 3 | +[@jakevdp](http://github.com/jakevdp), *May 2025* |
| 4 | + |
| 5 | +An occasional user request is for the ability to define custom array-like objects that |
| 6 | +work with jax APIs. JAX currently has a partial implementation of a mechanism that does |
| 7 | +this via a `__jax_array__` method defined on the custom object. This was never intended |
| 8 | +to be a load-bearing public API (see the discussion at {jax-issue}`#4725`), but has |
| 9 | +become essential to packages like Keras and flax, which explicitly document the ability |
| 10 | +to use their custom array objects with jax functions. This JEP proposes a design for |
| 11 | +full, documented support of the `__jax_array__` protocol. |
| 12 | + |
| 13 | +## Levels of array extensibility |
| 14 | +Requests for extensibility of JAX arrays come in a few flavors: |
| 15 | + |
| 16 | +### Level 1 Extensibility: polymorphic inputs |
| 17 | +What I’ll call "Level 1" extensibility is the desire that JAX APIs accept polymorphic inputs. |
| 18 | +That is, a user desires behavior like this: |
| 19 | + |
| 20 | +```python |
| 21 | +class CustomArray: |
| 22 | + data: numpy.ndarray |
| 23 | + ... |
| 24 | + |
| 25 | +x = CustomArray(np.arange(5)) |
| 26 | +result = jnp.sin(x) # Converts `x` to JAX array and returns a JAX array |
| 27 | +``` |
| 28 | + |
| 29 | +Under this extensibility model, JAX functions would accept CustomArray objects as inputs, |
| 30 | +implicitly converting them to `jax.Array` objects for the sake of computation. |
| 31 | +This is similar to the functionality offered by NumPy via the `__array__` method, and in |
| 32 | +JAX (in many but not all cases) via the `__jax_array__` method. |
| 33 | + |
| 34 | +This is the mode of extensibility that has been requested by the maintainers of `flax.nnx` |
| 35 | +and others. The current implementation is also used by JAX internally for the case of |
| 36 | +symbolic dimensions. |
| 37 | + |
| 38 | +### Level 2 extensibility: polymorphic outputs |
| 39 | +What I’ll call "Level 2" extensibility is the desire that JAX APIs should not only accept |
| 40 | +polymorphic inputs, but also wrap outputs to match the class of the input. |
| 41 | +That is, a user desires behavior like this: |
| 42 | + |
| 43 | +```python |
| 44 | +class CustomArray: |
| 45 | + data: numpy.ndarray |
| 46 | + ... |
| 47 | + |
| 48 | +x = CustomArray(np.arange(5)) |
| 49 | +result = jnp.sin(x) # returns a new CustomArray |
| 50 | +``` |
| 51 | + |
| 52 | +Under this extensibility model, JAX functions would not only accept custom objects |
| 53 | +as inputs, but have some protocol to determine how to correctly re-wrap outputs with |
| 54 | +the same class. In NumPy, this sort of functionality is offered in varying degrees by |
| 55 | +the special `__array_ufunc__`, `__array_wrap__`, and `__array_function__` protocols, |
| 56 | +which allow user-defined objects to customize how NumPy API functions operate on |
| 57 | +arbitrary inputs and map input types to outputs. |
| 58 | +JAX does not currently have any equivalent to these interfaces in NumPy. |
| 59 | + |
| 60 | +This is the mode of extensibility that has been requested by the maintainers of `keras`, |
| 61 | +among others. |
| 62 | + |
| 63 | +### Level 3 extensibility: subclassing `Array` |
| 64 | + |
| 65 | +What I’ll call "Level 3" extensibility is the desire that the JAX array object itself |
| 66 | +could be subclassable. NumPy provides some APIs that allow this |
| 67 | +(see [Subclassing ndarray](https://numpy.org/devdocs/user/basics.subclassing.html)) but |
| 68 | +this sort of approach would take some extra thought in JAX due to the need for |
| 69 | +representing array objects abstractly via tracing. |
| 70 | + |
| 71 | +This mode of extensibility has occasionally been requested by users who want to add |
| 72 | +special metadata to JAX arrays, such as units of measurement. |
| 73 | + |
| 74 | +## Synopsis |
| 75 | + |
| 76 | +For the sake of this proposal, we will stick with the simplest, level 1 extensibility |
| 77 | +model. The proposed interface is the one currently non-uniformly supported by a number |
| 78 | +of JAX APIs, the `__jax_array__` method. Its usage looks something like this: |
| 79 | + |
| 80 | +```python |
| 81 | +import jax |
| 82 | +import jax.numpy as jnp |
| 83 | +import numpy as np |
| 84 | + |
| 85 | +class CustomArray: |
| 86 | + data: np.ndarray |
| 87 | + |
| 88 | + def __init__(self, data: np.ndarray): |
| 89 | + self.data = data |
| 90 | + |
| 91 | + def __jax_array__(self) -> jax.Array: |
| 92 | + return jnp.asarray(self.data) |
| 93 | + |
| 94 | +arr = CustomArray(np.arange(5)) |
| 95 | +result = jnp.multiply(arr, 2) |
| 96 | +print(repr(result)) |
| 97 | +# Array([0, 2, 4, 6, 8], dtype=int32) |
| 98 | +``` |
| 99 | + |
| 100 | +We may revisit other extensibility levels in the future. |
| 101 | + |
| 102 | +## Design challenges |
| 103 | + |
| 104 | +JAX presents some interesting design challenges related to this kind of extensibility, |
| 105 | +which have not been fully explored previously. We’ll discuss them in turn here: |
| 106 | + |
| 107 | +### Priority of `__jax_array__` vs. PyTree flattening |
| 108 | +JAX already has a supported mechanism for registering custom objects, namely pytree |
| 109 | +registration (see [Extending pytrees](https://docs.jax.dev/en/latest/pytrees.html#extending-pytrees)). |
| 110 | +If we also support __jax_array__, which one should take precedence? |
| 111 | + |
| 112 | +To put this more concretely, what should be the result of this code? |
| 113 | + |
| 114 | +```python |
| 115 | +@jax.jit |
| 116 | +def f(x): |
| 117 | + print("is JAX array:", isinstance(x, jax.Array)) |
| 118 | + |
| 119 | +f(CustomArray(...)) |
| 120 | +``` |
| 121 | + |
| 122 | +If we choose to prioritize `__jax_array__` at the JIT boundary, then the output of this |
| 123 | +function would be: |
| 124 | +``` |
| 125 | +is JAX array: True |
| 126 | +``` |
| 127 | +That is, at the JIT boundary, the `CustomArray` object would be converted into a |
| 128 | +`__jax_array__`, and its shape and dtype would be used to construct a standard JAX |
| 129 | +tracer for the function. |
| 130 | + |
| 131 | +If we choose to prioritize pytree flattening at the JIT boundary, then the output of |
| 132 | +this function would be: |
| 133 | +``` |
| 134 | +type(x)=CustomArray |
| 135 | +``` |
| 136 | +That is, at the JIT boundary, the `CustomArray` object is flattened, and then unflattened |
| 137 | +before being passed to the JIT-compiled function for tracing. If `CustomArray` has been |
| 138 | +registered as a pytree, it will generally contain traced arrays as its attributes, and |
| 139 | +when x is passed to any JAX API that supports `__jax_array__`, these traced attributes |
| 140 | +will be converted to a single traced array according to the logic specified in the method. |
| 141 | + |
| 142 | +There are deeper consequences here for how other transformations like vmap and grad work |
| 143 | +when encountering custom objects: for example, if we prioritize pytree flattening, vmap |
| 144 | +would operate over the dimensions of the flattened contents of the custom object, while |
| 145 | +if we prioritize `__jax_array__`, vmap would operate over the converted array dimensions. |
| 146 | + |
| 147 | +This also has consequences when it comes to JIT invariance: consider a function like this: |
| 148 | +```python |
| 149 | +def f(x): |
| 150 | + if isinstance(x, CustomArray): |
| 151 | + return x.custom_method() |
| 152 | + else: |
| 153 | + # do something else |
| 154 | + ... |
| 155 | + |
| 156 | +result1 = f(x) |
| 157 | +result2 = jax.jit(f)(x) |
| 158 | +``` |
| 159 | +If `jit` consumes `x` via pytree flattening, the results should agree for a well-specified |
| 160 | +flattening rule. If `jit` consumes `x` via `__jax_array__`, the results will differ because |
| 161 | +`x` is no longer a CustomArray within the JIT-compiled version of the function. |
| 162 | + |
| 163 | +#### Synopsis |
| 164 | +As of JAX v0.6.0, transformations prioritize `__jax_array__` when it is available. This status |
| 165 | +quo can lead to confusion around lack of JIT invariance, and the current implementation in practice |
| 166 | +leads to subtle bugs in the case of automatic differentiation, where the forward and backward pass |
| 167 | +do not treat inputs consistently. |
| 168 | + |
| 169 | +Because the pytree extensibility mechanism already exists for the case of customizing |
| 170 | +transformations, it seems most straightforward if transformations act only via this |
| 171 | +mechanism: that is, **we propose to remove `__jax_array__` parsing during abstractification.** |
| 172 | +This approach will preserve object identity through transformations, and give the user the |
| 173 | +most possible flexibility. If the user wants to opt-in to array conversion semantics, that |
| 174 | +is always possible by explicitly casting their input via jnp.asarray, which will trigger the |
| 175 | +`__jax_array__` protocol. |
| 176 | + |
| 177 | +### Which APIs should support `__jax_array__`? |
| 178 | +JAX has a number of different levels of API, from the level of explicit primitive binding |
| 179 | +(e.g. `jax.lax.add_p.bind(x, y)`) to the `jax.lax` APIs (e.g. `jax.lax.add(x, y)`) to the |
| 180 | +`jax.numpy` APIs (e.g. `jax.numpy.add(x, y)`). Which of these API categories should handle |
| 181 | +implicit conversion via `__jax_array__`? |
| 182 | + |
| 183 | +In order to limit the scope of the change and the required testing, I propose that `__jax_array__` |
| 184 | +only be explicitly supported in `jax.numpy` APIs: after all, it is inspired by the` __array__` |
| 185 | +protocol which is supported by the NumPy package. We could always expand this in the future to |
| 186 | +`jax.lax` APIs if needed. |
| 187 | + |
| 188 | +This is in line with the current state of the package, where `__jax_array__` handling is mainly |
| 189 | +within the input validation utilities used by `jax.numpy` APIs. |
| 190 | + |
| 191 | +## Implementation |
| 192 | +With these design choices in mind, we plan to implement this as follows: |
| 193 | + |
| 194 | +- **Adding runtime support to `jax.numpy`**: This is likely the easiest part, as most |
| 195 | + `jax.numpy` functions use a common internal utility (`ensure_arraylike`) to validate |
| 196 | + inputs and convert them to array. This utility already supports `__jax_array__`, and |
| 197 | + so most jax.numpy APIs are already compliant. |
| 198 | +- **Adding test coverage**: To ensure compliance across the APIs, we should add a new |
| 199 | + test scaffold that calls every `jax.numpy` API with custom inputs and validates correct |
| 200 | + behavior. |
| 201 | +- **Deprecating `__jax_array__` during abstractification**: Currently JAX's abstractification |
| 202 | + pass, used in `jit` and other transformations, does parse the `__jax_array__` protocol, |
| 203 | + and this is not the behavior we want long-term. We need to deprecate this behavior, and |
| 204 | + ensure that downstream packages that rely on it can move toward pytree registration or |
| 205 | + explicit array conversion where necessary. |
| 206 | +- **Adding type annotations**: the type interface for jax.numpy functions is in |
| 207 | + `jax/numpy/__init__.pyi`, and we’ll need to change each input type from `ArrayLike` to |
| 208 | + `ArrayLike | SupportsJAXArray`, where the latter is a protocol with a `__jax_array__` |
| 209 | + method. We cannot add this directly to the `ArrayLike` definition, because `ArrayLike` |
| 210 | + is used in contexts where `__jax_array__` should not be supported. |
| 211 | +- **Documentation**: once the above support is added, we should add a documentation section |
| 212 | + on array extensibility that outlines exactly what to expect regarding the `__jax_array__` |
| 213 | + protocol, with examples of how it can be used in conjunction with pytree registration |
| 214 | + in order to effectively work with user-defined types. |
0 commit comments