From 1a2ce713a9db1004d286c25843fd7e7b3e2888ed Mon Sep 17 00:00:00 2001 From: gdehame Date: Mon, 10 Feb 2025 13:10:37 +0100 Subject: [PATCH] Lowering to linalg for AtenCol2ImOp Added a lowering to linalg for the torch.aten.col2im operation. Added a unit test to verify the lowering. --- lib/Conversion/TorchToLinalg/DataMovement.cpp | 268 ++++++++++++++++++ .../TorchToLinalg/datamovement.mlir | 42 +++ 2 files changed, 310 insertions(+) diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index b8c20bc73f65..9fd4b1153961 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -24,6 +24,7 @@ #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" #include "llvm/ADT/APInt.h" #include @@ -2735,6 +2736,271 @@ SmallVector ConvertSparseOperatorOp::legalizedNames = { "torch.aten.to_dense", "torch.aten.to_sparse", "torch.aten.to_csr", "torch.aten.to_csc", "torch.aten.to_bsr", "torch.aten.to_bsc", }; + +class ConvertAtenCol2ImOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + // Rewriting method. + LogicalResult + matchAndRewrite(AtenCol2imOp col2imOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Retrieve the hyperparameters + Value input = col2imOp.getSelf(); + if (!(col2imOp.getOutputSize().getDefiningOp() && + isa( + col2imOp.getOutputSize().getDefiningOp()))) + return failure(); + + Torch::PrimListConstructOp outputSizes = cast( + col2imOp.getOutputSize().getDefiningOp()); + if (!(outputSizes.getNumOperands() == 2 && + outputSizes->getOperand(0).getDefiningOp() && + outputSizes->getOperand(1).getDefiningOp() && + isa( + outputSizes->getOperand(0).getDefiningOp()))) + return failure(); + if (!isa(outputSizes->getOperand(1).getDefiningOp())) + return failure(); + int height = + cast(outputSizes->getOperand(0).getDefiningOp()) + .getValue(); + int width = + cast(outputSizes->getOperand(1).getDefiningOp()) + .getValue(); + if (!(col2imOp.getPadding().getDefiningOp() && + isa( + col2imOp.getPadding().getDefiningOp()))) + return failure(); + + Torch::PrimListConstructOp paddings = + cast(col2imOp.getPadding().getDefiningOp()); + + if (!(paddings.getNumOperands() == 2 && + paddings->getOperand(0).getDefiningOp() && + paddings->getOperand(1).getDefiningOp() && + isa(paddings->getOperand(0).getDefiningOp()))) + return failure(); + + if (!isa(paddings->getOperand(1).getDefiningOp())) + return failure(); + int horizontalPadding = + cast(paddings->getOperand(1).getDefiningOp()) + .getValue(); + int verticalPadding = + cast(paddings->getOperand(0).getDefiningOp()) + .getValue(); + int paddedWidth = width + 2 * horizontalPadding; + int paddedHeight = height + 2 * verticalPadding; + if (!(col2imOp.getKernelSize().getDefiningOp() && + isa( + col2imOp.getKernelSize().getDefiningOp()))) + return failure(); + Torch::PrimListConstructOp kerSizes = cast( + col2imOp.getKernelSize().getDefiningOp()); + if (!(kerSizes.getNumOperands() == 2 && + kerSizes->getOperand(0).getDefiningOp() && + kerSizes->getOperand(1).getDefiningOp() && + isa(kerSizes->getOperand(0).getDefiningOp()))) + return failure(); + if (!isa(kerSizes->getOperand(1).getDefiningOp())) + return failure(); + int kernelWidth = + cast(kerSizes->getOperand(0).getDefiningOp()) + .getValue(); + int kernelHeight = + cast(kerSizes->getOperand(1).getDefiningOp()) + .getValue(); + if (!(col2imOp.getDilation().getDefiningOp() && + isa( + col2imOp.getDilation().getDefiningOp()))) + return failure(); + Torch::PrimListConstructOp dilations = cast( + col2imOp.getDilation().getDefiningOp()); + + if (!(dilations.getNumOperands() == 2 && + dilations->getOperand(0).getDefiningOp() && + dilations->getOperand(1).getDefiningOp() && + isa(dilations->getOperand(0).getDefiningOp()))) + return failure(); + if (!isa(dilations->getOperand(1).getDefiningOp())) + return failure(); + int verticalDilation = + cast(dilations->getOperand(0).getDefiningOp()) + .getValue(); + int horizontalDilation = + cast(dilations->getOperand(1).getDefiningOp()) + .getValue(); + if (!(col2imOp.getStride().getDefiningOp() && + isa( + col2imOp.getStride().getDefiningOp()))) + return failure(); + Torch::PrimListConstructOp strides = + cast(col2imOp.getStride().getDefiningOp()); + + if (!(strides.getNumOperands() == 2 && + strides->getOperand(0).getDefiningOp() && + strides->getOperand(1).getDefiningOp() && + isa(strides->getOperand(0).getDefiningOp()))) + return failure(); + + if (!isa(strides->getOperand(1).getDefiningOp())) + return failure(); + + int verticalStride = + cast(strides->getOperand(0).getDefiningOp()) + .getValue(); + int horizontalStride = + cast(strides->getOperand(1).getDefiningOp()) + .getValue(); + + // Create intermediate buffers + TensorType outputType = + cast(col2imOp.getType()).toBuiltinTensor(); + Type elementType = outputType.getElementType(); + Value outputBuffer = rewriter.create( + col2imOp->getLoc(), + ArrayRef{outputType.getDimSize(0), outputType.getDimSize(1), + height, width}, + elementType); + Value paddedOutput = rewriter.create( + col2imOp->getLoc(), + ArrayRef{outputType.getDimSize(0), outputType.getDimSize(1), + paddedHeight, paddedWidth}, + elementType); + // Create the linalg loop interators + SmallVector iteratorTypes( + 6, utils::IteratorType::reduction); + iteratorTypes[0] = utils::IteratorType::parallel; + iteratorTypes[1] = utils::IteratorType::parallel; + + SmallVector indexingMaps; + AffineExpr batch = rewriter.getAffineDimExpr(0); + AffineExpr chan = rewriter.getAffineDimExpr(1); + AffineExpr line = rewriter.getAffineDimExpr(2); + AffineExpr col = rewriter.getAffineDimExpr(3); + AffineExpr kerLineIndex = rewriter.getAffineDimExpr(4); + AffineExpr kerColIndex = rewriter.getAffineDimExpr(5); + indexingMaps.push_back(AffineMap::get( + 6, 0, + ArrayRef{ + batch, + kerLineIndex * kernelWidth + kerColIndex + + chan * kernelWidth * kernelHeight, + col + line * (1 + (paddedWidth - 1 - + (kernelWidth - 1) * horizontalDilation) / + horizontalStride)}, + rewriter.getContext())); + // We create 2 additional irrelevent indexing maps and inputs (kernel, + // upperBounds) so that the operation is able to find the upper bounds of + // each loop. Otherwise we get the following error: "'linalg.generic' op + // expected the shape-to-loops map to be non-null" + indexingMaps.push_back( + AffineMap::get(6, 0, ArrayRef{kerLineIndex, kerColIndex}, + rewriter.getContext())); + indexingMaps.push_back(AffineMap::get(6, 0, ArrayRef{line, col}, + rewriter.getContext())); + indexingMaps.push_back(AffineMap::get( + 6, 0, + ArrayRef{ + batch, chan, + line * verticalStride + kerLineIndex * verticalDilation, + col * horizontalStride + kerColIndex * horizontalDilation}, + rewriter.getContext())); + // The body of the linalg.generic op + auto body = [&](OpBuilder &b, Location loc, ValueRange args) { + Value acc = + (elementType.isInteger()) + ? b.create(loc, args[0], args[3]).getResult() + : (isa(elementType) + ? b.create(loc, args[0], args[3]) + .getResult() + : b.create(loc, args[0], args[3]) + .getResult()); + b.create(loc, acc); + }; + input = rewriter.create( + col2imOp->getLoc(), + cast(input.getType()).toBuiltinTensor(), input); + + // Create the "irrelevent" inputs + Value kernel = rewriter.create( + col2imOp->getLoc(), ArrayRef{kernelWidth, kernelHeight}, + elementType); + Value upperBounds = rewriter.create( + col2imOp->getLoc(), + ArrayRef{ + 1 + (paddedHeight - 1 - (kernelHeight - 1) * verticalDilation) / + verticalStride, + 1 + ((paddedWidth - 1 - (kernelWidth - 1) * horizontalDilation)) / + horizontalStride}, + elementType); + assert(((isa(elementType) && + (cast(elementType).getElementType().isInteger() || + isa( + cast(elementType).getElementType()))) || + isa(elementType) || elementType.isInteger()) && + "Not implemented yet\n"); + + TypedAttr init0 = + elementType.isInteger() + ? rewriter.getIntegerAttr(elementType, 0) + : (isa(elementType) + ? rewriter.getFloatAttr(elementType, 0.0) + : (cast(elementType) + .getElementType() + .isInteger() + ? TypedAttr(rewriter.getIntegerAttr( + cast(elementType).getElementType(), + 0)) + : rewriter.getFloatAttr( + cast(elementType).getElementType(), + 0))); + Value fill0 = + isa(elementType) + ? rewriter.createOrFold( + col2imOp->getLoc(), elementType, + rewriter.getArrayAttr(ArrayRef{init0, init0})) + : rewriter.createOrFold(col2imOp->getLoc(), + elementType, init0); + + paddedOutput = + rewriter + .create(col2imOp->getLoc(), ValueRange(fill0), + ValueRange(paddedOutput)) + ->getResult(0); + paddedOutput = + rewriter + .create( + col2imOp->getLoc(), paddedOutput.getType(), + ValueRange{input, kernel, upperBounds}, + ValueRange(paddedOutput), indexingMaps, iteratorTypes, body) + ->getResult(0); + + // Remove the padding + OpFoldResult one = rewriter.getI32IntegerAttr(1); + OpFoldResult zero = rewriter.getI32IntegerAttr(0); + OpFoldResult vpad = rewriter.getI32IntegerAttr(verticalPadding); + OpFoldResult hpad = rewriter.getI32IntegerAttr(horizontalPadding); + OpFoldResult vdim = rewriter.getI32IntegerAttr(height); + OpFoldResult hdim = rewriter.getI32IntegerAttr(width); + OpFoldResult batchSize = + rewriter.getI32IntegerAttr(outputType.getDimSize(0)); + OpFoldResult nChannels = + rewriter.getI32IntegerAttr(outputType.getDimSize(1)); + outputBuffer = rewriter.create( + col2imOp->getLoc(), paddedOutput, + ArrayRef{Range{zero, batchSize, one}, + Range{zero, nChannels, one}, Range{vpad, vdim, one}, + Range{hpad, hdim, one}}); + rewriter.setInsertionPoint(col2imOp); + TorchConversion::FromBuiltinTensorOp newOp = + rewriter.create( + col2imOp->getLoc(), col2imOp.getType(), outputBuffer); + rewriter.replaceOp(col2imOp, newOp); + return success(); + } +}; } // namespace void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality( @@ -2800,6 +3066,8 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); // Rewrite all special sparse conversions hidden as operators. target.addDynamicallyLegalOp([&](Torch::OperatorOp op) { return !ConvertSparseOperatorOp::isSparsePrimitive(op.getNameAttr()); diff --git a/test/Conversion/TorchToLinalg/datamovement.mlir b/test/Conversion/TorchToLinalg/datamovement.mlir index dd5e5c553d31..a09bdb596637 100644 --- a/test/Conversion/TorchToLinalg/datamovement.mlir +++ b/test/Conversion/TorchToLinalg/datamovement.mlir @@ -32,3 +32,45 @@ func.func @torch.aten.permute$rank0(%arg0: !torch.vtensor<[],f32>) -> !torch.vte %1 = torch.aten.permute %arg0, %0 : !torch.vtensor<[],f32>, !torch.list -> !torch.vtensor<[],f32> return %1 : !torch.vtensor<[],f32> } + +// ----- + +// CHECK: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4 * 2 + d5 + d1 * 4, d3 + d2 * 16)> +// CHECK: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> +// CHECK: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3)> +// CHECK: #[[MAP3:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 * 2 + d4, d3 * 2 + d5)> +// CHECK-LABEL: func.func @torch.aten.col2im( +// CHECK-SAME: %[[VAL_ARG0:.*]]: !torch.vtensor<[1,12,128],f32>) -> !torch.vtensor<[1,3,14,30],f32> { +// CHECK: %[[VAL_CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<1x3x16x32xf32> +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_ARG0]] : !torch.vtensor<[1,12,128],f32> -> tensor<1x12x128xf32> +// CHECK: %[[VAL_2:.*]] = tensor.empty() : tensor<2x2xf32> +// CHECK: %[[VAL_3:.*]] = tensor.empty() : tensor<8x16xf32> +// CHECK: %[[VAL_4:.*]] = linalg.fill ins(%[[VAL_CST:.*]] : f32) outs(%[[VAL_0]] : tensor<1x3x16x32xf32>) -> tensor<1x3x16x32xf32> +// CHECK: %[[VAL_5:.*]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]], #[[MAP3]]], iterator_types = ["parallel", "parallel", "reduction", "reduction", "reduction", "reduction"]} ins(%[[VAL_1]], %[[VAL_2]], %[[VAL_3]] : tensor<1x12x128xf32>, tensor<2x2xf32>, tensor<8x16xf32>) outs(%[[VAL_4]] : tensor<1x3x16x32xf32>) { +// CHECK: ^bb0(%[[VAL_IN0:.*]]: f32, %[[VAL_IN1:.*]]: f32, %[[VAL_IN2:.*]]: f32, %[[VAL_OUT:.*]]: f32): +// CHECK: %[[VAL_7:.*]] = arith.addf %[[VAL_IN0]], %[[VAL_OUT]] : f32 +// CHECK: linalg.yield %[[VAL_7]] : f32 +// CHECK: } -> tensor<1x3x16x32xf32> +// CHECK: %[[VAL_SLICE:.*]] = tensor.extract_slice %[[VAL_5]][0, 0, 1, 1] [1, 3, 14, 30] [1, 1, 1, 1] : tensor<1x3x16x32xf32> to tensor<1x3x14x30xf32> +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_SLICE]] : tensor<1x3x14x30xf32> -> !torch.vtensor<[1,3,14,30],f32> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[1,3,14,30],f32> +func.func @torch.aten.col2im(%arg0: !torch.vtensor<[1,12,128],f32>) -> !torch.vtensor<[1,3,14,30],f32> { + %int14 = torch.constant.int 14 + %int30 = torch.constant.int 30 + %0 = torch.prim.ListConstruct %int14, %int30 : (!torch.int, !torch.int) -> !torch.list + %int2 = torch.constant.int 2 + %int2_0 = torch.constant.int 2 + %1 = torch.prim.ListConstruct %int2, %int2_0 : (!torch.int, !torch.int) -> !torch.list + %int1 = torch.constant.int 1 + %int1_1 = torch.constant.int 1 + %2 = torch.prim.ListConstruct %int1, %int1_1 : (!torch.int, !torch.int) -> !torch.list + %int1_2 = torch.constant.int 1 + %int1_3 = torch.constant.int 1 + %3 = torch.prim.ListConstruct %int1_2, %int1_3 : (!torch.int, !torch.int) -> !torch.list + %int2_4 = torch.constant.int 2 + %int2_5 = torch.constant.int 2 + %4 = torch.prim.ListConstruct %int2_4, %int2_5 : (!torch.int, !torch.int) -> !torch.list + %5 = torch.aten.col2im %arg0, %0, %1, %2, %3, %4 : !torch.vtensor<[1,12,128],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.list -> !torch.vtensor<[1,3,14,30],f32> + return %5 : !torch.vtensor<[1,3,14,30],f32> +}