diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 2eef0a06d0eb4..34bef0e56e1e4 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1894,10 +1894,15 @@ void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns); /// convert to a `linalg.dot`. void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns); +/// Function type which is used to control folding operations like `tensor.pad` +/// and `tensor.extract_slice` into linalg.pack/unpack ops. +using ControlFoldIntoPackUnpackFn = std::function; /// Populates `patterns` with patterns that fold operations like `tensor.pad` /// and `tensor.extract_slice` into `tensor.pack` and `tensor.unpack` operations /// respectively. -void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns); +void populateFoldIntoPackAndUnpackPatterns( + RewritePatternSet &patterns, + const ControlFoldIntoPackUnpackFn &controlFn = nullptr); /// Populates `patterns` with patterns that fold operations like `linalg.pack` /// and `linalg.unpack` into `tensor.empty`. diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp index 0984b6988b93b..9d8f90d991720 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/Dialect/Utils/IndexingUtils.h" @@ -197,7 +198,9 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern { /// Fold a `pad` -> `pack` into `pack` if they have the same padding values and /// the pad op has zero low paddings, or if `pack` has no padding values. struct FoldPadWithPackOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +public: + FoldPadWithPackOp(MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn) + : OpRewritePattern(context), controlFn(std::move(controlFn)) {} LogicalResult matchAndRewrite(PackOp packOp, PatternRewriter &rewriter) const override { @@ -206,6 +209,10 @@ struct FoldPadWithPackOp : public OpRewritePattern { if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad()) return failure(); + // User controlled folding function. + if (controlFn && !controlFn(&packOp.getSourceMutable())) + return failure(); + Value constantPaddingValue = padOp.getConstantPaddingValue(); if (!constantPaddingValue) return failure(); @@ -220,13 +227,20 @@ struct FoldPadWithPackOp : public OpRewritePattern { packOp.getOuterDimsPerm()); return success(); } + +private: + ControlFoldIntoPackUnpackFn controlFn; }; /// Fold a `unpack` -> `extract_slice` into the `unpack` since it already /// has extract_slice semantics. struct FoldUnpackWithExtractSliceOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +public: + FoldUnpackWithExtractSliceOp(MLIRContext *context, + ControlFoldIntoPackUnpackFn controlFn) + : OpRewritePattern(context), + controlFn(std::move(controlFn)) {} LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const override { @@ -234,6 +248,10 @@ struct FoldUnpackWithExtractSliceOp if (!unpackOp) return failure(); + // User controlled folding function. + if (controlFn && !controlFn(&sliceOp.getSourceMutable())) + return failure(); + if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) { return rewriter.notifyMatchFailure( sliceOp, "rank-reduced folding is not supported"); @@ -255,6 +273,9 @@ struct FoldUnpackWithExtractSliceOp unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm()); return success(); } + +private: + ControlFoldIntoPackUnpackFn controlFn; }; // Applies 'permutation' on 'inVec' and stores the result in resVec. @@ -284,7 +305,12 @@ static bool checkAndPermute(ArrayRef permutation, /// semantics. struct FoldProducerPackWithConsumerLinalgTransposeOp : public OpInterfaceRewritePattern { - using OpInterfaceRewritePattern::OpInterfaceRewritePattern; + +public: + FoldProducerPackWithConsumerLinalgTransposeOp( + MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn) + : OpInterfaceRewritePattern(context), + controlFn(std::move(controlFn)) {} LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp, PatternRewriter &rewriter) const override { @@ -293,6 +319,10 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp if (!packOp) return failure(); + // User controlled folding function. + if (controlFn && !controlFn(&linalgOp->getOpOperand(0))) + return failure(); + FailureOr> maybePerm = getTransposeOpPermutation(linalgOp); if (failed(maybePerm)) @@ -331,13 +361,20 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp return success(); } + +private: + ControlFoldIntoPackUnpackFn controlFn; }; /// Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose /// semantics. struct FoldConsumerPackWithProducerLinalgTransposeOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + +public: + FoldConsumerPackWithProducerLinalgTransposeOp( + MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn) + : OpRewritePattern(context), controlFn(std::move(controlFn)) {} LogicalResult matchAndRewrite(PackOp packOp, PatternRewriter &rewriter) const override { @@ -345,6 +382,10 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp if (!linalgOp) return failure(); + // User controlled folding function. + if (controlFn && !controlFn(&packOp.getSourceMutable())) + return failure(); + FailureOr> maybePerm = getTransposeOpPermutation(linalgOp); if (failed(maybePerm)) @@ -375,13 +416,21 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp return success(); } + +private: + ControlFoldIntoPackUnpackFn controlFn; }; /// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has /// transpose semantics. struct FoldProducerUnPackWithConsumerLinalgTransposeOp : public OpInterfaceRewritePattern { - using OpInterfaceRewritePattern::OpInterfaceRewritePattern; + +public: + FoldProducerUnPackWithConsumerLinalgTransposeOp( + MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn) + : OpInterfaceRewritePattern(context), + controlFn(std::move(controlFn)) {} LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp, PatternRewriter &rewriter) const override { @@ -390,6 +439,10 @@ struct FoldProducerUnPackWithConsumerLinalgTransposeOp if (!unPackOp) return failure(); + // User controlled folding function. + if (controlFn && !controlFn(&linalgOp->getOpOperand(0))) + return failure(); + FailureOr> maybePerm = getTransposeOpPermutation(linalgOp); if (failed(maybePerm)) @@ -416,6 +469,9 @@ struct FoldProducerUnPackWithConsumerLinalgTransposeOp return success(); } + +private: + ControlFoldIntoPackUnpackFn controlFn; }; /// Fold 'transpose' -> 'unpack' into 'unpack' since 'unpack' already has @@ -424,12 +480,21 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; +public: + FoldConsumerUnPackWithProducerLinalgTransposeOp( + MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn) + : OpRewritePattern(context), controlFn(std::move(controlFn)) {} + LogicalResult matchAndRewrite(UnPackOp unPackOp, PatternRewriter &rewriter) const override { auto linalgOp = unPackOp.getSource().getDefiningOp(); if (!linalgOp) return failure(); + // User controlled folding function. + if (controlFn && !controlFn(&unPackOp.getSourceMutable())) + return failure(); + FailureOr> maybePerm = getTransposeOpPermutation(linalgOp); if (failed(maybePerm)) @@ -474,6 +539,9 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp return success(); } + +private: + ControlFoldIntoPackUnpackFn controlFn; }; /// tensor.empty does not define any tensor contents, so an unpadded pack @@ -521,13 +589,14 @@ struct FoldEmptyTensorWithUnPackOp : public OpRewritePattern { } // namespace -void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) { +void populateFoldIntoPackAndUnpackPatterns( + RewritePatternSet &patterns, const ControlFoldIntoPackUnpackFn &controlFn) { patterns.insert( - patterns.getContext()); + patterns.getContext(), controlFn); } void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns) { diff --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir index 84eb60248b8be..16efa73f87a2a 100644 --- a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir +++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-fold-into-pack-and-unpack %s | FileCheck %s +// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-fold-into-pack-and-unpack-control %s | FileCheck %s --check-prefix=CONTROL func.func @fold_unpack_slice(%arg0 : tensor, %arg1 : tensor, %arg2 : index, %arg3 : index) -> tensor { @@ -373,6 +374,36 @@ func.func @linalg_transpose_linalg.pack_fold(%arg0: tensor<56x57x1x64xf32>) -> t // ----- +func.func @linalg_transpose_linalg.pack_fold_multi_result(%arg0: tensor<56x57x1x64xf32>) -> (tensor<1x56x57x64xf32>, tensor<1x57x56x2x32xf32>) { + %0 = tensor.empty() : tensor<1x56x57x64xf32> + %transposed = linalg.transpose + ins(%arg0 : tensor<56x57x1x64xf32>) + outs(%0 : tensor<1x56x57x64xf32>) + permutation = [2, 0, 1, 3] + + %1 = tensor.empty() : tensor<1x57x56x2x32xf32> + %pack = linalg.pack %transposed + outer_dims_perm = [0, 2, 1, 3] + inner_dims_pos = [3] + inner_tiles = [32] + into %1 : tensor<1x56x57x64xf32> -> tensor<1x57x56x2x32xf32> + return %transposed, %pack : tensor<1x56x57x64xf32>, tensor<1x57x56x2x32xf32> +} +// CHECK-LABEL: func @linalg_transpose_linalg.pack_fold_multi_result( +// CHECK-SAME: %[[ARG0:.+]]: tensor<56x57x1x64xf32>) +// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose +// CHECK: %[[PACK:.+]] = linalg.pack %[[ARG0]] +// CHECK-SAME: outer_dims_perm = [2, 1, 0, 3] +// CHECK: return %[[TRANSPOSE]], %[[PACK]] + +// CONTROL-LABEL: func @linalg_transpose_linalg.pack_fold_multi_result( +// CONTROL: %[[TRANSPOSE:.+]] = linalg.transpose +// CONTROL: %[[PACK:.+]] = linalg.pack %[[TRANSPOSE]] +// CONTROL-SAME: outer_dims_perm = [0, 2, 1, 3] +// CONTROL: return %[[TRANSPOSE]], %[[PACK]] + +// ----- + func.func @linalg_transpose_linalg.pack_fold_with_padding(%arg0: tensor<56x57x1x55xf32>, %padding: f32) -> tensor<1x57x56x2x32xf32> { %0 = tensor.empty() : tensor<1x56x57x55xf32> %transpose = linalg.transpose @@ -550,6 +581,36 @@ func.func @linalg_transpose_linalg.unpack_fold(%arg0: tensor<1x1x4x16xi32>) -> t // ----- +func.func @linalg_transpose_linalg.unpack_fold_multi_result(%arg0: tensor<1x1x4x16xi32>) -> (tensor<1x1x16x4xi32>, tensor<16x4xi32>) { + %0 = tensor.empty() : tensor<1x1x16x4xi32> + %transposed = linalg.transpose ins(%arg0 : tensor<1x1x4x16xi32>) + outs(%0 : tensor<1x1x16x4xi32>) + permutation = [1, 0, 3, 2] + %1 = tensor.empty() : tensor<16x4xi32> + %unpack = linalg.unpack %transposed + outer_dims_perm = [0, 1] + inner_dims_pos = [0, 1] + inner_tiles = [16, 4] into + %1 : tensor<1x1x16x4xi32> -> tensor<16x4xi32> + return %transposed, %unpack : tensor<1x1x16x4xi32>, tensor<16x4xi32> +} +//CHECK-LABEL: func.func @linalg_transpose_linalg.unpack_fold_multi_result( +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x4x16xi32>) +// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose +// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[ARG0]] +// CHECK-SAME: outer_dims_perm = [1, 0] +// CHECK: return %[[TRANSPOSE]], %[[UNPACK]] +// CHECK: } + +//CONTROL-LABEL: func.func @linalg_transpose_linalg.unpack_fold_multi_result( +// CONTROL: %[[TRANSPOSE:.+]] = linalg.transpose +// CONTROL: %[[UNPACK:.+]] = linalg.unpack %[[TRANSPOSE]] +// CONTROL-SAME: outer_dims_perm = [0, 1] +// CONTROL: return %[[TRANSPOSE]], %[[UNPACK]] +// CONTROL: } + +// ----- + func.func @linalg_transpose_linalg.unpack_fold_partial_tile(%arg0: tensor<1x1x4x16xi32>) -> tensor<15x3xi32> { %0 = tensor.empty() : tensor<1x1x16x4xi32> %transposed = linalg.transpose ins(%arg0 : tensor<1x1x4x16xi32>) diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp index 046b9a65f3359..5612ed2f40d12 100644 --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -135,6 +135,11 @@ struct TestLinalgTransforms *this, "test-fold-into-pack-and-unpack", llvm::cl::desc("Test folding ops into linalg.pack and linalg.unpack"), llvm::cl::init(false)}; + Option testFoldIntoPackAndUnpackWithControlFn{ + *this, "test-fold-into-pack-and-unpack-control", + llvm::cl::desc( + "Test controlling folding ops into linalg.pack and linalg.unpack"), + llvm::cl::init(false)}; Option testSimplifyPackUnpackPatterns{ *this, "test-simplify-pack-unpack-patterns", llvm::cl::desc("Test patterns to simplify linalg.pack and linalg.unpack"), @@ -236,9 +241,11 @@ static void applyDecomposeWinogradOps(func::FuncOp funcOp) { (void)applyPatternsGreedily(funcOp, std::move(patterns)); } -static void applyFoldIntoPackAndUnpackPatterns(Operation *rootOp) { +static void applyFoldIntoPackAndUnpackPatterns( + Operation *rootOp, + linalg::ControlFoldIntoPackUnpackFn controlFn = nullptr) { RewritePatternSet patterns(rootOp->getContext()); - linalg::populateFoldIntoPackAndUnpackPatterns(patterns); + linalg::populateFoldIntoPackAndUnpackPatterns(patterns, controlFn); (void)applyPatternsGreedily(rootOp, std::move(patterns)); } @@ -279,6 +286,19 @@ void TestLinalgTransforms::runOnOperation() { Operation *rootOp = getOperation(); if (testFoldIntoPackAndUnpack) applyFoldIntoPackAndUnpackPatterns(rootOp); + if (testFoldIntoPackAndUnpackWithControlFn) { + linalg::ControlFoldIntoPackUnpackFn controlFn = [](OpOperand *opOperand) { + Operation *producer = opOperand->get().getDefiningOp(); + Operation *consumer = opOperand->getOwner(); + // If we have a pack/unpack consumer and a producer that has multiple + // uses, do not apply the folding patterns. + if (isa(consumer) && + isa(producer) && !producer->hasOneUse()) + return false; + return true; + }; + applyFoldIntoPackAndUnpackPatterns(rootOp, controlFn); + } if (testSimplifyPackUnpackPatterns) applySimplifyPackUnpackPatterns(rootOp); }