Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch] Fix unsqueezed output shape in canonicalization of AtenUnflattenIntOp #3730

Merged
merged 1 commit into from
Sep 24, 2024

Conversation

vinayakdsci
Copy link
Contributor

Fixes iree-org/iree#18562.

During canonicalization pass on AtenUnflattenIntOp, if the second dim was statically equal to one, we would create an AtenAddIntOp to add one to the dimension obtained from op.getDim(). This, when passed into Torch::unsqueezeTensor(), would make it get interpreted as non-constant, which would lead to MLIR failing an assertion when UnsqueezeOp would later get lowered into ExpandShapeOp, as the output of the UnsqueezeOp would consist of only dynamic dims.

This patch fixes this behavior, by extracting the integer value from the dim if it was constant, and then emitting a ConstantIntOp from (dim+1). This creates an output with static shape.

Copy link
Collaborator

@vivekkhandelwal1 vivekkhandelwal1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@zjgarvey zjgarvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent. Thanks, Vinayak.

@zjgarvey zjgarvey merged commit 6773288 into llvm:main Sep 24, 2024
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Crash: T = mlir::Value]: Assertion `isa<T>(*this) && "Invalid accessor called"' failed.
3 participants