Skip to content

Preserve static shape after AdvancedSubtensor that only shuffles elements #1532

@jessegrabowski

Description

@jessegrabowski

Description

I'm thinking about graphs like this:

import pytensor.tensor as pt

x = pt.tensor('x', shape=(3,)) 
x[[2, 0, 1]].type.shape # (None, )

For each index input, it should be possible to check in the make_node of AdvancedSubtensor and check if:

  1. x has known static shape at the dimension being indexed, and if so;
  2. check if index is a constant, and if so;
  3. check if index it is only permuting element (len(index) == x.type.shape and np.sort(index) == np.arange(index.shape[0]))

Then we know the shape of x will not change.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions