-
Notifications
You must be signed in to change notification settings - Fork 14.3k
feat(linalg): add a way to pass controlFn to foldIntoPackUnpackPatterns
#143685
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
feat(linalg): add a way to pass controlFn to foldIntoPackUnpackPatterns
#143685
Conversation
…rns` (llvm#22) This PR adds a mechanism, so that downstream consumers can pass in control functions for the application of these patterns. This change shouldn't affect any consumers of this method that do not specify a controlFn. In IREE, we (will) use it to control preventing folding patterns that would inhibit fusion. See IREE issue llvm#20896 for more details.
cc @hanhanW |
@llvm/pr-subscribers-mlir-tensor @llvm/pr-subscribers-mlir-linalg Author: Ege Beysel (egebeysel) ChangesThis PR adds a mechanism, so that downstream consumers can pass in control functions for the application of these patterns. This change shouldn't affect any consumers of this method that do not specify a controlFn. The controlFn always gets the source operand of the consumer in each of the patterns as a parameter. In IREE, we (will) use it to control preventing folding patterns that would inhibit fusion. See IREE issue #20896 for more details. Full diff: https://github.com/llvm/llvm-project/pull/143685.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 2eef0a06d0eb4..2f0e57ca9f5a7 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1894,10 +1894,18 @@ void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns);
/// convert to a `linalg.dot`.
void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns);
+/// Function type which is used to control folding operations like `tensor.pad`
+/// and `tensor.extract_slice` into linalg.pack/unpack ops.
+using ControlFoldIntoPackUnpackFn = std::function<bool(OpOperand *opOperand)>;
+inline bool defaultControlFoldIntoPackUnpackFn(OpOperand *opOperand) {
+ return true;
+};
/// Populates `patterns` with patterns that fold operations like `tensor.pad`
/// and `tensor.extract_slice` into `tensor.pack` and `tensor.unpack` operations
/// respectively.
-void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns);
+void populateFoldIntoPackAndUnpackPatterns(
+ RewritePatternSet &patterns, const ControlFoldIntoPackUnpackFn &controlFn =
+ defaultControlFoldIntoPackUnpackFn);
/// Populates `patterns` with patterns that fold operations like `linalg.pack`
/// and `linalg.unpack` into `tensor.empty`.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
index 0984b6988b93b..01cebb0f8e80d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
@@ -197,7 +198,9 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
/// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
/// the pad op has zero low paddings, or if `pack` has no padding values.
struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
- using OpRewritePattern<PackOp>::OpRewritePattern;
+public:
+ FoldPadWithPackOp(MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
+ : OpRewritePattern<PackOp>(context), controlFn(std::move(controlFn)) {}
LogicalResult matchAndRewrite(PackOp packOp,
PatternRewriter &rewriter) const override {
@@ -206,6 +209,10 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad())
return failure();
+ // User controlled folding function.
+ if (!controlFn(&packOp.getSourceMutable()))
+ return failure();
+
Value constantPaddingValue = padOp.getConstantPaddingValue();
if (!constantPaddingValue)
return failure();
@@ -220,13 +227,20 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
packOp.getOuterDimsPerm());
return success();
}
+
+private:
+ ControlFoldIntoPackUnpackFn controlFn;
};
/// Fold a `unpack` -> `extract_slice` into the `unpack` since it already
/// has extract_slice semantics.
struct FoldUnpackWithExtractSliceOp
: public OpRewritePattern<tensor::ExtractSliceOp> {
- using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
+public:
+ FoldUnpackWithExtractSliceOp(MLIRContext *context,
+ ControlFoldIntoPackUnpackFn controlFn)
+ : OpRewritePattern<tensor::ExtractSliceOp>(context),
+ controlFn(std::move(controlFn)) {}
LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
PatternRewriter &rewriter) const override {
@@ -234,6 +248,10 @@ struct FoldUnpackWithExtractSliceOp
if (!unpackOp)
return failure();
+ // User controlled folding function.
+ if (!controlFn(&sliceOp.getSourceMutable()))
+ return failure();
+
if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) {
return rewriter.notifyMatchFailure(
sliceOp, "rank-reduced folding is not supported");
@@ -255,6 +273,9 @@ struct FoldUnpackWithExtractSliceOp
unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm());
return success();
}
+
+private:
+ ControlFoldIntoPackUnpackFn controlFn;
};
// Applies 'permutation' on 'inVec' and stores the result in resVec.
@@ -284,7 +305,12 @@ static bool checkAndPermute(ArrayRef<int64_t> permutation,
/// semantics.
struct FoldProducerPackWithConsumerLinalgTransposeOp
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
- using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
+
+public:
+ FoldProducerPackWithConsumerLinalgTransposeOp(
+ MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
+ : OpInterfaceRewritePattern<linalg::LinalgOp>(context),
+ controlFn(std::move(controlFn)) {}
LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
PatternRewriter &rewriter) const override {
@@ -293,6 +319,10 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
if (!packOp)
return failure();
+ // User controlled folding function.
+ if (!controlFn(&linalgOp->getOpOperand(0)))
+ return failure();
+
FailureOr<SmallVector<int64_t>> maybePerm =
getTransposeOpPermutation(linalgOp);
if (failed(maybePerm))
@@ -331,13 +361,20 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
return success();
}
+
+private:
+ ControlFoldIntoPackUnpackFn controlFn;
};
/// Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose
/// semantics.
struct FoldConsumerPackWithProducerLinalgTransposeOp
: public OpRewritePattern<PackOp> {
- using OpRewritePattern<PackOp>::OpRewritePattern;
+
+public:
+ FoldConsumerPackWithProducerLinalgTransposeOp(
+ MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
+ : OpRewritePattern<PackOp>(context), controlFn(std::move(controlFn)) {}
LogicalResult matchAndRewrite(PackOp packOp,
PatternRewriter &rewriter) const override {
@@ -345,6 +382,10 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
if (!linalgOp)
return failure();
+ // User controlled folding function.
+ if (!controlFn(&packOp.getSourceMutable()))
+ return failure();
+
FailureOr<SmallVector<int64_t>> maybePerm =
getTransposeOpPermutation(linalgOp);
if (failed(maybePerm))
@@ -375,13 +416,21 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
return success();
}
+
+private:
+ ControlFoldIntoPackUnpackFn controlFn;
};
/// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
/// transpose semantics.
struct FoldProducerUnPackWithConsumerLinalgTransposeOp
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
- using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
+
+public:
+ FoldProducerUnPackWithConsumerLinalgTransposeOp(
+ MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
+ : OpInterfaceRewritePattern<linalg::LinalgOp>(context),
+ controlFn(std::move(controlFn)) {}
LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
PatternRewriter &rewriter) const override {
@@ -390,6 +439,10 @@ struct FoldProducerUnPackWithConsumerLinalgTransposeOp
if (!unPackOp)
return failure();
+ // User controlled folding function.
+ if (!controlFn(&linalgOp->getOpOperand(0)))
+ return failure();
+
FailureOr<SmallVector<int64_t>> maybePerm =
getTransposeOpPermutation(linalgOp);
if (failed(maybePerm))
@@ -416,6 +469,9 @@ struct FoldProducerUnPackWithConsumerLinalgTransposeOp
return success();
}
+
+private:
+ ControlFoldIntoPackUnpackFn controlFn;
};
/// Fold 'transpose' -> 'unpack' into 'unpack' since 'unpack' already has
@@ -424,12 +480,21 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
: public OpRewritePattern<UnPackOp> {
using OpRewritePattern<UnPackOp>::OpRewritePattern;
+public:
+ FoldConsumerUnPackWithProducerLinalgTransposeOp(
+ MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
+ : OpRewritePattern<UnPackOp>(context), controlFn(std::move(controlFn)) {}
+
LogicalResult matchAndRewrite(UnPackOp unPackOp,
PatternRewriter &rewriter) const override {
auto linalgOp = unPackOp.getSource().getDefiningOp<linalg::LinalgOp>();
if (!linalgOp)
return failure();
+ // User controlled folding function.
+ if (!controlFn(&unPackOp.getSourceMutable()))
+ return failure();
+
FailureOr<SmallVector<int64_t>> maybePerm =
getTransposeOpPermutation(linalgOp);
if (failed(maybePerm))
@@ -474,6 +539,9 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
return success();
}
+
+private:
+ ControlFoldIntoPackUnpackFn controlFn;
};
/// tensor.empty does not define any tensor contents, so an unpadded pack
@@ -521,13 +589,14 @@ struct FoldEmptyTensorWithUnPackOp : public OpRewritePattern<UnPackOp> {
} // namespace
-void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) {
+void populateFoldIntoPackAndUnpackPatterns(
+ RewritePatternSet &patterns, const ControlFoldIntoPackUnpackFn &controlFn) {
patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp,
FoldProducerPackWithConsumerLinalgTransposeOp,
FoldConsumerPackWithProducerLinalgTransposeOp,
FoldConsumerUnPackWithProducerLinalgTransposeOp,
FoldProducerUnPackWithConsumerLinalgTransposeOp>(
- patterns.getContext());
+ patterns.getContext(), controlFn);
}
void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns) {
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is an indication that you dont want to be using pattern application. In such situations we have found it better to expose the core transformation done by the pattern as a function and then just call that function using a pass downstream. You can the have whatever control you want in that pass.
Sure, I can do that. One thing I had in mind while doing this was the compatibility with current downstream users of the pattern. I guess with that change, users would also have to adjust to it. Also, just to clarify this, in that case, I would expose every pattern inside |
friendly ping @MaheshRavishankar 😊 the above might have gotten lost I believe🙂 |
You can keep the patterns unchanged without a control function. You can add the transformation methods with the control function. The patterns can "apply always", but downstream you can change your pass to walk the function and apply the transformation as needed.
Yeah. If you need to control all of them, then yes. |
In general I don't see an issue with adding optional control functions to patterns, since it does give more control over the pattern application, although exposing the transforms as functions does give even more control (and you can always turn these transform functions into patterns anyway). That said, in this case, I do think the rewrite makes sense as a pattern application. These patterns can open up new folding opportunities, so the recursive pattern application is helpful here. Also, the patterns can be run alongside things like data layout propagation patterns, which can further expose more folding opportunities. TLDR; I think this is going to end up being better as a pattern rewrite after all anyway, so I don't see any issue with just adding a control function to the patterns. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can add a test for the control function by adding an additional flag here:
Similar to what is done in elementwise fusion:
llvm-project/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
Lines 134 to 139 in bd96918
Option<bool> fuseWithReshapeByCollapsingWithControlFn{ | |
*this, "fuse-with-reshape-by-collapsing-control", | |
llvm::cl::desc("Test controlling the linalg expand_shape -> generic " | |
"fusion patterns that " | |
"collapse the iteration space of the consumer"), | |
llvm::cl::init(false)}; |
I made the suggestion of having controlFn in the first place, but I did not think that much. After reading both your points, I'm +1 on having controlFn because of what Max said. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, just a comment about the tests.
@@ -797,4 +892,4 @@ func.func @unpack_generic_transpose_fold(%arg0: tensor<57x3x56x1x64xf32>) -> ten | |||
// CHECK-SAME: inner_tiles = [1, 64] | |||
// CHECK-SAME: into %[[OUT:.+]] : tensor<57x3x56x1x64xf32> -> tensor<3648x3x56xf32> | |||
// CHECK: return %[[UNPACK]] : tensor<3648x3x56xf32> | |||
// CHECK: } | |||
// CHECK: } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: new line at end of file.
%1 : tensor<1x1x16x4xi32> -> tensor<16x4xi32> | ||
return %transposed, %unpack : tensor<1x1x16x4xi32>, tensor<16x4xi32> | ||
} | ||
//CHECK-LABEL: func.func @linalg_transpose_linalg.unpack_fold_multi_result( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These tests can be less verbose in the checks. The main thing this is testing is whether or not the fusion happened with mutli-result, so you don't need to check all the shapes/op metadata. You could just check the operands and results of the operations to verify the producers/consumers.
into %1 : tensor<1x56x57x64xf32> -> tensor<1x57x56x2x32xf32> | ||
return %transposed, %pack : tensor<1x56x57x64xf32>, tensor<1x57x56x2x32xf32> | ||
} | ||
// CHECK-LABEL: func @linalg_transpose_linalg.pack_fold_multi_result( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same with these tests. The checks can be less verbose.
This PR adds a mechanism, so that downstream consumers can pass in control functions for the application of these patterns. This change shouldn't affect any consumers of this method that do not specify a controlFn. The controlFn always gets the source operand of the consumer in each of the patterns as a parameter.
In IREE, we (will) use it to control preventing folding patterns that would inhibit fusion. See IREE issue #20896 for more details.