Skip to content

Commit

Permalink
[torch] Unpacking sometimes misses shape inference (#3609)
Browse files Browse the repository at this point in the history
It is possible that the unpacked tensor does not match the same inferred
shapes. This is pretty common when ingesting form the `onnx` frontend.
  • Loading branch information
rsuderman authored Aug 8, 2024
1 parent f91f816 commit fd98476
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
13 changes: 12 additions & 1 deletion lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3290,7 +3290,18 @@ void PrimListUnpackOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
if (op->getNumResults() != listConstruct.getElements().size())
return failure();

rewriter.replaceOp(op, listConstruct.getElements());
SmallVector<Value> unpacked;
for (int i = 0, s = op->getNumResults(); i < s; ++i) {
auto element = listConstruct.getElements()[i];
if (element.getType() != op->getResult(i).getType()) {
element = rewriter.create<TensorStaticInfoCastOp>(
op.getLoc(), op->getResult(i).getType(), element);
}

unpacked.push_back(element);
}

rewriter.replaceOp(op, unpacked);
return success();
});
}
Expand Down
12 changes: 12 additions & 0 deletions test/Dialect/Torch/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1890,6 +1890,18 @@ func.func @prim.ListUnpack$fold_list(%arg0: !torch.vtensor<[2,3],f32>, %arg1: !t
return %1#0, %1#1 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>
}

// CHECK-LABEL: func.func @prim.ListUnpack$fold_list_cast(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,3],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[2,3],f32>) -> (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) {
// CHECK: %[[CAST0:.+]] = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[2,3],f32> to !torch.vtensor<[?,?],f32>
// CHECK: %[[CAST1:.+]] = torch.tensor_static_info_cast %arg1 : !torch.vtensor<[2,3],f32> to !torch.vtensor<[?,?],f32>
// CHECK: return %[[CAST0]], %[[CAST1]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>
func.func @prim.ListUnpack$fold_list_cast(%arg0: !torch.vtensor<[2,3],f32>, %arg1: !torch.vtensor<[2,3],f32>) -> (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) {
%0 = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>) -> !torch.list<vtensor>
%1:2 = torch.prim.ListUnpack %0 : !torch.list<vtensor> -> !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>
return %1#0, %1#1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>
}

// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
// CHECK: %[[CST:.*]] = torch.vtensor.literal(dense<3> : tensor<si64>) : !torch.vtensor<[],si64>
// CHECK: return %[[CST]] : !torch.vtensor<[],si64>
Expand Down

0 comments on commit fd98476

Please sign in to comment.