Skip to content

Commit

Permalink
[onnx] Fix onnx.Shape to include start and end processing (#3580)
Browse files Browse the repository at this point in the history
`onnx.Shape` can select only a subset of indices using attributes. Add
support for these attributes.

---------

Co-authored-by: zjgarvey <[email protected]>
  • Loading branch information
rsuderman and zjgarvey authored Aug 5, 2024
1 parent 839fe90 commit b1a2322
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 14 deletions.
53 changes: 42 additions & 11 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1615,17 +1615,48 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
return success();
});

patterns.onOp("Shape", 9,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value operand;
if (binder.tensorOperand(operand) ||
binder.tensorResultType(resultType))
return failure();
rewriter.replaceOpWithNewOp<Torch::Aten_ShapeAsTensorOp>(
binder.op, resultType, operand);
return success();
});
patterns.onOp(
"Shape", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value operand;
int64_t start, end;
if (binder.tensorOperand(operand) ||
binder.tensorResultType(resultType) ||
binder.s64IntegerAttr(start, "start", 0) ||
binder.s64IntegerAttr(end, "end", -1))
return failure();

auto inputType = dyn_cast<Torch::ValueTensorType>(operand.getType());
int64_t inputRank = inputType.getSizes().size();

auto shapeType = Torch::ValueTensorType::get(
binder.op->getContext(), SmallVector<int64_t>{inputRank},
resultType.getOptionalDtype());

Value shape = rewriter.create<Torch::Aten_ShapeAsTensorOp>(
binder.getLoc(), shapeType, operand);

if (start == 0 && end == -1) {
rewriter.replaceOp(binder.op, shape);
return success();
}

Value sv = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(start));

Value ev = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(end));

Value step = rewriter.create<Torch::ConstantIntOp>(binder.getLoc(), 1);

Value dim = rewriter.create<Torch::ConstantIntOp>(binder.getLoc(), 0);

shape = rewriter.create<Torch::AtenSliceTensorOp>(
binder.getLoc(), resultType, shape, dim, sv, ev, step);

rewriter.replaceOp(binder.op, shape);
return success();
});

patterns.onOp("Sinh", 9,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Expand Down
21 changes: 18 additions & 3 deletions test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2715,6 +2715,21 @@ func.func @test_sequence_map_extract_shapes(%arg0: !torch.list<vtensor<[?,?,?],f
return %0 : !torch.list<vtensor<[3],si64>>
}

// -----

// CHECK-LABEL: func.func @test_shape_start_1_end_negative_1
func.func @test_shape_start_1_end_negative_1(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[1],si64> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 21 : si64} {
// CHECK: %[[SHAPE:.+]] = torch.aten._shape_as_tensor %arg0
// CHECK: %[[INT1_0:.+]] = torch.constant.int 1
// CHECK: %[[INT2_0:.+]] = torch.constant.int -1
// CHECK: %[[INT1_1:.+]] = torch.constant.int 1
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
// CHECK: %[[SLICE:.+]] = torch.aten.slice.Tensor %[[SHAPE]], %[[INT0_0]], %[[INT1_0]], %[[INT2_0]], %[[INT1_1]]
%0 = torch.operator "onnx.Shape"(%arg0) {torch.onnx.end = -1 : si64, torch.onnx.start = 1 : si64} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[1],si64>
return %0 : !torch.vtensor<[1],si64>
}


// -----

// CHECK-LABEL: func.func @test_upsample_nearest
Expand Down Expand Up @@ -3133,7 +3148,7 @@ func.func @test_scatternd_min(%arg0: !torch.vtensor<[4,4,4],f32>, %arg1: !torch.
return %0 : !torch.vtensor<[4,4,4],f32>
}

// ----
// -----

// CHECK-LABEL: func.func @test_split_to_sequence_1
func.func @test_split_to_sequence_1(%arg0: !torch.vtensor<[3,6],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.list<vtensor<[3,6],f32>> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
Expand All @@ -3151,7 +3166,7 @@ func.func @test_split_to_sequence_1(%arg0: !torch.vtensor<[3,6],f32>, %arg1: !to
return %1 : !torch.list<vtensor<[3,6],f32>>
}

// ----
// -----

// CHECK-LABEL: func.func @test_split_to_sequence_2
func.func @test_split_to_sequence_2(%arg0: !torch.vtensor<[2,6],f32>, %arg1: !torch.vtensor<[],si64>) -> !torch.list<vtensor<[1,6],f32>> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
Expand All @@ -3169,7 +3184,7 @@ func.func @test_split_to_sequence_2(%arg0: !torch.vtensor<[2,6],f32>, %arg1: !to
return %1 : !torch.list<vtensor<[1,6],f32>>
}

// ----
// -----

// CHECK-LABEL: func.func @test_split_to_sequence_with_list(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,6],f32>,
Expand Down

0 comments on commit b1a2322

Please sign in to comment.