Skip to content

Using jax.make_jaxpr with return_shape=True on a functions that return a HiType fail #34193

@armandpicard

Description

@armandpicard

Description

Using jax.make_jaxpr with return_shape=True on a functions that return a HiType fail.
Here is an example:

  def f():
    return make_tup(1, 2)

  jaxpr, shape = jax.make_jaxpr(f, return_shape=True)()  # <- crash here

Logs:

      if return_shape:
>       out = [ShapeDtypeStruct(o.shape, o.dtype) for o in jaxpr.out_avals]
                                ^^^^^^^
E       AttributeError: 'TupTy' object has no attribute 'shape'

System info (python version, jaxlib version, accelerator, etc.)

Running on jax main:

jax:    0.8.3.dev20260107+e5c0b719b
jaxlib: 0.8.2
numpy:  2.4.0
python: 3.12.2 | packaged by conda-forge | (main, Feb 16 2024, 20:50:58) [GCC 12.3.0]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='XPS-15-9520-a-picard', release='6.8.0-90-generic', version='#91-Ubuntu SMP PREEMPT_DYNAMIC Tue Nov 18 14:14:30 UTC 2025', machine='x86_64')

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions