Skip to content

Commit

Permalink
[MLIR][ONNX] Add TorchToOnnx Support for DepthToSpace op
Browse files Browse the repository at this point in the history
Signed-Off By: Vivek Khandelwal <[email protected]>
  • Loading branch information
vivekkhandelwal1 committed Jan 10, 2024
1 parent 4707d3b commit 208ae35
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 15 deletions.
5 changes: 4 additions & 1 deletion include/torch-mlir/Dialect/Torch/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,10 @@ Value createInitTensor(PatternRewriter &rewriter, Location loc,
// Helper to create a rank 0 tensor filled with the given `scalar`. `scalar`
// would be converted to the element type of the given `inputType`.
Value createRank0Tensor(PatternRewriter &rewriter, Location loc,
BaseTensorType inputType, Value scalar);
BaseTensorType inputType, Value scalar);

LogicalResult getTransposedType(BaseTensorType inType, int64_t dimA,
int64_t dimB, Type &transposedType);

} // namespace Torch
} // namespace torch
Expand Down
126 changes: 125 additions & 1 deletion lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,23 @@ static int64_t onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx) {
return dtypeIntTorch;
}

static LogicalResult createTorchTransposeOp(ConversionPatternRewriter &rewriter,
Location loc, Value input,
int64_t dimA, int64_t dimB,
Value &transposed) {
Type transposedType;
if (failed(getTransposedType(input.getType().cast<Torch::BaseTensorType>(),
dimA, dimB, transposedType)))
return failure();
Value cstDimA = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(dimA));
Value cstDimB = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(dimB));
transposed = rewriter.create<Torch::AtenTransposeIntOp>(
loc, transposedType, input, cstDimA, cstDimB);
return success();
}

// Simple rewrites for the default domain.
// See: https://onnx.ai/onnx/operators/
// For operators that are effectively version invariant, we register with
Expand Down Expand Up @@ -978,11 +995,118 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.op, resultType, operand, dim, resultDType);
return success();
});
patterns.onOp(
"DepthToSpace", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value input;
int64_t blockSize;
std::string mode;
if (binder.tensorOperand(input) ||
binder.s64IntegerAttr(blockSize, "blocksize") ||
binder.customOpNameStringAttr(mode, "mode", "DCR") ||
binder.tensorResultType(resultType))
return failure();
auto inputTy = input.getType().dyn_cast<Torch::BaseTensorType>();
if (!inputTy || !inputTy.hasSizes()) {
return rewriter.notifyMatchFailure(
binder.op, "Expected input type having sizes");
}
SmallVector<int64_t> inputSizes{inputTy.getSizes()};
if (inputSizes.size() != 4) {
return rewriter.notifyMatchFailure(binder.op,
"Expected input rank to be 4");
}
Value b = rewriter.create<Torch::AtenSizeIntOp>(
binder.getLoc(), input,
rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0)));
Value c = rewriter.create<Torch::AtenSizeIntOp>(
binder.getLoc(), input,
rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1)));
Value h = rewriter.create<Torch::AtenSizeIntOp>(
binder.getLoc(), input,
rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(2)));
Value w = rewriter.create<Torch::AtenSizeIntOp>(
binder.getLoc(), input,
rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(3)));
Value cstBlockSize = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(blockSize));
Value cstBlockSizeSquare = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(blockSize * blockSize));
Value cDivBlockSizeSquare = rewriter.create<Torch::AtenDivIntOp>(
binder.getLoc(), c, cstBlockSizeSquare);
cDivBlockSizeSquare = rewriter.create<Torch::AtenIntFloatOp>(
binder.getLoc(), cDivBlockSizeSquare);
Value reshapeSizesList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(input.getContext())),
llvm::SmallVector<Value>{b, cstBlockSize, cstBlockSize,
cDivBlockSizeSquare, h, w});
int64_t cDivBlockSizeSquareInt =
inputSizes[1] == Torch::kUnknownSize
? Torch::kUnknownSize
: inputSizes[1] / (blockSize * blockSize);
SmallVector<int64_t, 6> reshapeSizesInt{
inputSizes[0], blockSize, blockSize,
cDivBlockSizeSquareInt, inputSizes[2], inputSizes[3]};
Value reshapedInput = rewriter.create<Torch::AtenReshapeOp>(
binder.getLoc(),
inputTy.getWithSizesAndDtype(reshapeSizesInt,
inputTy.getOptionalDtype()),
input, reshapeSizesList);

Value transposedInput;
if (mode == "DCR") {
if (failed(createTorchTransposeOp(
rewriter, binder.getLoc(), reshapedInput,
/*dimA=*/1, /*dimB=*/3, transposedInput)))
return rewriter.notifyMatchFailure(
binder.op, "Failed to create TorchTranspose op");
if (failed(createTorchTransposeOp(
rewriter, binder.getLoc(), transposedInput,
/*dimA=*/2, /*dimB=*/4, transposedInput)))
return rewriter.notifyMatchFailure(
binder.op, "Failed to create TorchTranspose op");
} else {
// mode == "CRD"
if (failed(createTorchTransposeOp(
rewriter, binder.getLoc(), reshapedInput,
/*dimA=*/2, /*dimB=*/4, transposedInput)))
return rewriter.notifyMatchFailure(
binder.op, "Failed to create TorchTranspose op");
if (failed(createTorchTransposeOp(
rewriter, binder.getLoc(), transposedInput,
/*dimA=*/3, /*dimB=*/4, transposedInput)))
return rewriter.notifyMatchFailure(
binder.op, "Failed to create TorchTranspose op");
}
if (failed(createTorchTransposeOp(
rewriter, binder.getLoc(), transposedInput,
/*dimA=*/4, /*dimB=*/5, transposedInput)))
return rewriter.notifyMatchFailure(
binder.op, "Failed to create TorchTranspose op");

Value hMulBlockSize = rewriter.create<Torch::AtenMulIntOp>(
binder.getLoc(), h, cstBlockSize);
Value wMulBlockSize = rewriter.create<Torch::AtenMulIntOp>(
binder.getLoc(), w, cstBlockSize);
reshapeSizesList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(input.getContext())),
llvm::SmallVector<Value>{b, cDivBlockSizeSquare, hMulBlockSize,
wMulBlockSize});
rewriter.replaceOpWithNewOp<Torch::AtenReshapeOp>(
binder.op, resultType, transposedInput, reshapeSizesList);
return success();
});
patterns.onOp("Div", 14,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value lhs, rhs;
std::string direction;
if (binder.tensorOperands(lhs, rhs) ||
binder.tensorResultType(resultType))
return failure();
Expand Down
13 changes: 0 additions & 13 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2554,19 +2554,6 @@ class DecomposeAtenConvTranspose2dOp
};
} // namespace

static LogicalResult getTransposedType(BaseTensorType inType, int64_t dimA,
int64_t dimB, Type &transposedType) {
if (!inType.hasSizes())
return failure();
SmallVector<int64_t> shape(inType.getSizes());
int64_t tmp = shape[0];
shape[0] = shape[1];
shape[1] = tmp;
transposedType = inType.getWithSizesAndDtype(llvm::ArrayRef(shape),
inType.getOptionalDtype());
return success();
}

// The convolution backward op is decomposed as follows:
// inputH, inputW = input.shape[2:]
// output_padding_ = [
Expand Down
13 changes: 13 additions & 0 deletions lib/Dialect/Torch/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -500,3 +500,16 @@ Value Torch::createRank0Tensor(PatternRewriter &rewriter, Location loc,
ValueRange{});
return createInitTensor(rewriter, loc, rank0TensorTy, scalar, dimList);
}

LogicalResult Torch::getTransposedType(BaseTensorType inType, int64_t dimA,
int64_t dimB, Type &transposedType) {
if (!inType.hasSizes())
return failure();
SmallVector<int64_t> shape(inType.getSizes());
int64_t tmp = shape[dimA];
shape[dimA] = shape[dimB];
shape[dimB] = tmp;
transposedType = inType.getWithSizesAndDtype(llvm::ArrayRef(shape),
inType.getOptionalDtype());
return success();
}
68 changes: 68 additions & 0 deletions test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -857,3 +857,71 @@ func.func @test_elu_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3
%0 = torch.operator "onnx.Elu"(%arg0) {torch.onnx.alpha = 2.000000e+00 : f32} : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32>
return %0 : !torch.vtensor<[3],f32>
}

// CHECK-LABEL: @test_depthtospace_example
func.func @test_depthtospace_example(%arg0: !torch.vtensor<[1,8,2,3],f32>) -> !torch.vtensor<[1,2,4,6],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[C0:.*]] = torch.constant.int 0
// CHECK: %[[SIZE:.*]] = torch.aten.size.int %arg0, %[[C0]] : !torch.vtensor<[1,8,2,3],f32>, !torch.int -> !torch.int
// CHECK: %[[C1:.*]] = torch.constant.int 1
// CHECK: %[[SIZE_0:.*]] = torch.aten.size.int %arg0, %[[C1]] : !torch.vtensor<[1,8,2,3],f32>, !torch.int -> !torch.int
// CHECK: %[[C2:.*]] = torch.constant.int 2
// CHECK: %[[SIZE_1:.*]] = torch.aten.size.int %arg0, %[[C2]] : !torch.vtensor<[1,8,2,3],f32>, !torch.int -> !torch.int
// CHECK: %[[C3:.*]] = torch.constant.int 3
// CHECK: %[[SIZE_2:.*]] = torch.aten.size.int %arg0, %[[C3]] : !torch.vtensor<[1,8,2,3],f32>, !torch.int -> !torch.int
// CHECK: %[[C2_0:.*]] = torch.constant.int 2
// CHECK: %[[C4:.*]] = torch.constant.int 4
// CHECK: %[[DIV:.*]] = torch.aten.div.int %[[SIZE_0]], %[[C4]] : !torch.int, !torch.int -> !torch.float
// CHECK: %[[INT:.*]] = torch.aten.Int.float %[[DIV]] : !torch.float -> !torch.int
// CHECK: %[[RESHAPE_LIST:.*]] = torch.prim.ListConstruct %[[SIZE]], %[[C2_0]], %[[C2_0]], %[[INT]], %[[SIZE_1]], %[[SIZE_2]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[RESHAPE:.*]] = torch.aten.reshape %arg0, %[[RESHAPE_LIST]] : !torch.vtensor<[1,8,2,3],f32>, !torch.list<int> -> !torch.vtensor<[1,2,2,2,2,3],f32>
// CHECK: %[[C1_0:.*]] = torch.constant.int 1
// CHECK: %[[C3_0:.*]] = torch.constant.int 3
// CHECK: %[[TRANSPOSE:.*]] = torch.aten.transpose.int %[[RESHAPE]], %[[C1_0]], %[[C3_0]] : !torch.vtensor<[1,2,2,2,2,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,2,2,2,2,3],f32>
// CHECK: %[[C2_1:.*]] = torch.constant.int 2
// CHECK: %[[C4_0:.*]] = torch.constant.int 4
// CHECK: %[[TRANSPOSE_1:.*]] = torch.aten.transpose.int %[[TRANSPOSE]], %[[C2_1]], %[[C4_0]] : !torch.vtensor<[1,2,2,2,2,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,2,2,2,2,3],f32>
// CHECK: %[[C4_1:.*]] = torch.constant.int 4
// CHECK: %[[C5:.*]] = torch.constant.int 5
// CHECK: %[[TRANSPOSE_2:.*]] = torch.aten.transpose.int %[[TRANSPOSE_1]], %[[C4_1]], %[[C5]] : !torch.vtensor<[1,2,2,2,2,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,2,2,2,3,2],f32>
// CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[SIZE_1]], %[[C2_0]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[MUL_0:.*]] = torch.aten.mul.int %[[SIZE_2]], %[[C2_0]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[RESHAPE_LIST_0:.*]] = torch.prim.ListConstruct %[[SIZE]], %5, %[[MUL]], %[[MUL_0]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[RESULT:.*]] = torch.aten.reshape %[[TRANSPOSE_2]], %[[RESHAPE_LIST_0]] : !torch.vtensor<[1,2,2,2,3,2],f32>, !torch.list<int> -> !torch.vtensor<[1,2,4,6],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[1,2,4,6],f32
%0 = torch.operator "onnx.DepthToSpace"(%arg0) {torch.onnx.blocksize = 2 : si64, torch.onnx.mode = "DCR"} : (!torch.vtensor<[1,8,2,3],f32>) -> !torch.vtensor<[1,2,4,6],f32>
return %0 : !torch.vtensor<[1,2,4,6],f32>
}

// CHECK-LABEL: @test_depthtospace_crd_mode_example
func.func @test_depthtospace_crd_mode_example(%arg0: !torch.vtensor<[1,8,2,3],f32>) -> !torch.vtensor<[1,2,4,6],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[C0:.*]] = torch.constant.int 0
// CHECK: %[[SIZE:.*]] = torch.aten.size.int %arg0, %[[C0]] : !torch.vtensor<[1,8,2,3],f32>, !torch.int -> !torch.int
// CHECK: %[[C1:.*]] = torch.constant.int 1
// CHECK: %[[SIZE_0:.*]] = torch.aten.size.int %arg0, %[[C1]] : !torch.vtensor<[1,8,2,3],f32>, !torch.int -> !torch.int
// CHECK: %[[C2:.*]] = torch.constant.int 2
// CHECK: %[[SIZE_1:.*]] = torch.aten.size.int %arg0, %[[C2]] : !torch.vtensor<[1,8,2,3],f32>, !torch.int -> !torch.int
// CHECK: %[[C3:.*]] = torch.constant.int 3
// CHECK: %[[SIZE_2:.*]] = torch.aten.size.int %arg0, %[[C3]] : !torch.vtensor<[1,8,2,3],f32>, !torch.int -> !torch.int
// CHECK: %[[C2_0:.*]] = torch.constant.int 2
// CHECK: %[[C4:.*]] = torch.constant.int 4
// CHECK: %[[DIV:.*]] = torch.aten.div.int %[[SIZE_0]], %[[C4]] : !torch.int, !torch.int -> !torch.float
// CHECK: %[[INT:.*]] = torch.aten.Int.float %[[DIV]] : !torch.float -> !torch.int
// CHECK: %[[RESHAPE_LIST:.*]] = torch.prim.ListConstruct %[[SIZE]], %[[C2_0]], %[[C2_0]], %[[INT]], %[[SIZE_1]], %[[SIZE_2]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[RESHAPE:.*]] = torch.aten.reshape %arg0, %[[RESHAPE_LIST]] : !torch.vtensor<[1,8,2,3],f32>, !torch.list<int> -> !torch.vtensor<[1,2,2,2,2,3],f32>
// CHECK: %[[C2_1:.*]] = torch.constant.int 2
// CHECK: %[[C4_0:.*]] = torch.constant.int 4
// CHECK: %[[TRANSPOSE:.*]] = torch.aten.transpose.int %[[RESHAPE]], %[[C2_1]], %[[C4_0]] : !torch.vtensor<[1,2,2,2,2,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,2,2,2,2,3],f32>
// CHECK: %[[C3_0:.*]] = torch.constant.int 3
// CHECK: %[[C4_1:.*]] = torch.constant.int 4
// CHECK: %[[TRANSPOSE_1:.*]] = torch.aten.transpose.int %[[TRANSPOSE]], %[[C3_0]], %[[C4_1]] : !torch.vtensor<[1,2,2,2,2,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,2,2,2,2,3],f32>
// CHECK: %[[C4_1:.*]] = torch.constant.int 4
// CHECK: %[[C5:.*]] = torch.constant.int 5
// CHECK: %[[TRANSPOSE_2:.*]] = torch.aten.transpose.int %[[TRANSPOSE_1]], %[[C4_1]], %[[C5]] : !torch.vtensor<[1,2,2,2,2,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,2,2,2,3,2],f32>
// CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[SIZE_1]], %[[C2_0]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[MUL_0:.*]] = torch.aten.mul.int %[[SIZE_2]], %[[C2_0]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[RESHAPE_LIST_0:.*]] = torch.prim.ListConstruct %[[SIZE]], %5, %[[MUL]], %[[MUL_0]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[RESULT:.*]] = torch.aten.reshape %[[TRANSPOSE_2]], %[[RESHAPE_LIST_0]] : !torch.vtensor<[1,2,2,2,3,2],f32>, !torch.list<int> -> !torch.vtensor<[1,2,4,6],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[1,2,4,6],f32
%0 = torch.operator "onnx.DepthToSpace"(%arg0) {torch.onnx.blocksize = 2 : si64, torch.onnx.mode = "CRD"} : (!torch.vtensor<[1,8,2,3],f32>) -> !torch.vtensor<[1,2,4,6],f32>
return %0 : !torch.vtensor<[1,2,4,6],f32>
}

0 comments on commit 208ae35

Please sign in to comment.