Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[onnx] Fix onnx-to-torch lowering for flatten shape #2834

Merged
merged 5 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
auto message = llvm::formatv("unimplemented support for the given "
"dtype conversion (onnx 'type' = {0})",
dtypeIntOnnx);
llvm::errs() << message << "\n";
auto y = rewriter.notifyMatchFailure(binder.op, message);

return y;
Expand Down Expand Up @@ -1412,16 +1411,30 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.tensorResultType(resultType))
return failure();

auto operandTy = cast<Torch::ValueTensorType>(operand.getType());
llvm::SmallVector<int64_t> shape(operandTy.getSizes());
int64_t rank = shape.size();

// If axis is negative, count from the right instead of left
int64_t rank =
cast<Torch::ValueTensorType>(operand.getType()).getSizes().size();
if (axis < 0)
axis = rank + axis;

Value collapsedRight;
auto baseType = Torch::ValueTensorType::getWithLeastStaticInformation(
newling marked this conversation as resolved.
Show resolved Hide resolved
binder.op->getContext());
// We collapse in the dimensions to the right of the axis.
for (int i = axis + 1; i < rank; ++i) {
bool dynamic = shape[axis] == Torch::kUnknownSize ||
shape[i] == Torch::kUnknownSize;
if (dynamic) {
shape[axis] = Torch::kUnknownSize;
} else {
shape[axis] = shape[axis] * shape[i];
}
}

shape.resize(axis + 1, 1);

auto baseType = rewriter.getType<Torch::ValueTensorType>(
shape, operandTy.getDtype());
Value collapsedRight;
if (axis >= rank) {
// If the right range is empty, add a dim of size 1 to the
// right side of the shape:
Expand Down
42 changes: 21 additions & 21 deletions test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1286,23 +1286,23 @@ func.func @ints_constant() -> !torch.vtensor<[2], si64> attributes {torch.onnx_m
func.func @test_flatten_4d_axis_2(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[6,20],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 2
// CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,3,20],f32>
// CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0
// CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 1
// CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[6,20],f32>
// CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor<[2,3,20],f32>, !torch.int, !torch.int -> !torch.vtensor<[6,20],f32>
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 2 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[6,20],f32>
return %0 : !torch.vtensor<[6,20],f32>
}

// -----

// CHECK-LABEL: @test_flatten_4d_axis_0
// // CHECK-LABEL: @test_flatten_4d_axis_0
func.func @test_flatten_4d_axis_0(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[1,120],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0
// CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[120],f32>
// CHECK-DAG: %[[LEFT_INDEX:.*]] = torch.constant.int 0
// CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor, !torch.int -> !torch.vtensor<[1,120],f32>
// CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor<[120],f32>, !torch.int -> !torch.vtensor<[1,120],f32>
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[1,120],f32>
return %0 : !torch.vtensor<[1,120],f32>
}
Expand All @@ -1312,10 +1312,10 @@ func.func @test_flatten_4d_axis_0(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torc
// CHECK-LABEL: @test_flatten_4d_axis_4
func.func @test_flatten_4d_axis_4(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[120,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[RIGHT_INDEX:.*]] = torch.constant.int 4
// CHECK-DAG: %[[CR:.*]] = torch.aten.unsqueeze %arg0, %[[RIGHT_INDEX]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int -> !torch.vtensor
// CHECK-DAG: %[[CR:.*]] = torch.aten.unsqueeze %arg0, %[[RIGHT_INDEX]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int -> !torch.vtensor<[2,3,4,5,1],f32>
// CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0
// CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 3
// CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[120,1],f32>
// CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor<[2,3,4,5,1],f32>, !torch.int, !torch.int -> !torch.vtensor<[120,1],f32>
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 4 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[120,1],f32>
return %0 : !torch.vtensor<[120,1],f32>
}
Expand All @@ -1326,10 +1326,10 @@ func.func @test_flatten_4d_axis_4(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torc
func.func @test_flatten_4d_axis_negative_2(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[6,20],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 2
// CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,3,20],f32>
// CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0
// CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 1
// CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[6,20],f32>
// CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor<[2,3,20],f32>, !torch.int, !torch.int -> !torch.vtensor<[6,20],f32>
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = -2 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[6,20],f32>
return %0 : !torch.vtensor<[6,20],f32>
}
Expand All @@ -1340,10 +1340,10 @@ func.func @test_flatten_4d_axis_negative_2(%arg0: !torch.vtensor<[2,3,4,5],f32>)
func.func @test_flatten_4d_axis_negative_1(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[24,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 3
// CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,3,4,5],f32>
// CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0
// CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 2
// CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[24,5],f32>
// CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[24,5],f32>
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[24,5],f32>
return %0 : !torch.vtensor<[24,5],f32>
}
Expand All @@ -1354,9 +1354,9 @@ func.func @test_flatten_4d_axis_negative_1(%arg0: !torch.vtensor<[2,3,4,5],f32>)
func.func @test_flatten_4d_axis_negative_4(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[1,120],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0
// CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[120],f32>
// CHECK-DAG: %[[LEFT_INDEX:.*]] = torch.constant.int 0
// CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor, !torch.int -> !torch.vtensor<[1,120],f32>
// CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor<[120],f32>, !torch.int -> !torch.vtensor<[1,120],f32>
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = -4 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[1,120],f32>
return %0 : !torch.vtensor<[1,120],f32>
}
Expand All @@ -1367,10 +1367,10 @@ func.func @test_flatten_4d_axis_negative_4(%arg0: !torch.vtensor<[2,3,4,5],f32>)
func.func @test_flatten_2d_axis_1(%arg0: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 1
// CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 1
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3],f32>, !torch.int, !torch.int -> !torch.vtensor
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,3],f32>
// CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0
// CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 0
// CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[2,3],f32>
// CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor<[2,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,3],f32>
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32>
return %0 : !torch.vtensor<[2,3],f32>
}
Expand All @@ -1381,9 +1381,9 @@ func.func @test_flatten_2d_axis_1(%arg0: !torch.vtensor<[2,3],f32>) -> !torch.vt
func.func @test_flatten_1d_axis_0(%arg0: !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0
// CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 0
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2],f32>, !torch.int, !torch.int -> !torch.vtensor
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2],f32>, !torch.int, !torch.int -> !torch.vtensor<[2],f32>
// CHECK-DAG: %[[LEFT_INDEX:.*]] = torch.constant.int 0
// CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor, !torch.int -> !torch.vtensor<[1,2],f32>
// CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor<[2],f32>, !torch.int -> !torch.vtensor<[1,2],f32>
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2],f32>
return %0 : !torch.vtensor<[1,2],f32>
}
Expand All @@ -1394,9 +1394,9 @@ func.func @test_flatten_1d_axis_0(%arg0: !torch.vtensor<[2],f32>) -> !torch.vten
func.func @test_flatten_1d_axis_negative_1(%arg0: !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0
// CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 0
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2],f32>, !torch.int, !torch.int -> !torch.vtensor
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2],f32>, !torch.int, !torch.int -> !torch.vtensor<[2],f32>
// CHECK-DAG: %[[LEFT_INDEX:.*]] = torch.constant.int 0
// CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor, !torch.int -> !torch.vtensor<[1,2],f32>
// CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor<[2],f32>, !torch.int -> !torch.vtensor<[1,2],f32>
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2],f32>
return %0 : !torch.vtensor<[1,2],f32>
}
Expand All @@ -1406,10 +1406,10 @@ func.func @test_flatten_1d_axis_negative_1(%arg0: !torch.vtensor<[2],f32>) -> !t
// COM: CHECK-LABEL: @test_flatten_1d_axis_1
func.func @test_flatten_1d_axis_1(%arg0: !torch.vtensor<[2],f32>) -> !torch.vtensor<[2,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[RIGHT_INDEX:.*]] = torch.constant.int 1
// CHECK-DAG: %[[CR:.*]] = torch.aten.unsqueeze %arg0, %[[RIGHT_INDEX]] : !torch.vtensor<[2],f32>, !torch.int -> !torch.vtensor
// CHECK-DAG: %[[CR:.*]] = torch.aten.unsqueeze %arg0, %[[RIGHT_INDEX]] : !torch.vtensor<[2],f32>, !torch.int -> !torch.vtensor<[2,1],f32>
// CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0
// CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 0
// CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[2,1],f32>
// CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor<[2,1],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,1],f32>
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[2],f32>) -> !torch.vtensor<[2,1],f32>
return %0 : !torch.vtensor<[2,1],f32>
}
Expand Down
Loading