Description
Using jax.make_jaxpr with return_shape=True on a functions that return a HiType fail.
Here is an example:
def f():
return make_tup(1, 2)
jaxpr, shape = jax.make_jaxpr(f, return_shape=True)() # <- crash here
Logs:
if return_shape:
> out = [ShapeDtypeStruct(o.shape, o.dtype) for o in jaxpr.out_avals]
^^^^^^^
E AttributeError: 'TupTy' object has no attribute 'shape'
System info (python version, jaxlib version, accelerator, etc.)
Running on jax main:
jax: 0.8.3.dev20260107+e5c0b719b
jaxlib: 0.8.2
numpy: 2.4.0
python: 3.12.2 | packaged by conda-forge | (main, Feb 16 2024, 20:50:58) [GCC 12.3.0]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='XPS-15-9520-a-picard', release='6.8.0-90-generic', version='#91-Ubuntu SMP PREEMPT_DYNAMIC Tue Nov 18 14:14:30 UTC 2025', machine='x86_64')
Description
Using jax.make_jaxpr with
return_shape=Trueon a functions that return a HiType fail.Here is an example:
Logs:
System info (python version, jaxlib version, accelerator, etc.)
Running on jax main: