Skip to content

feat(linalg): add a way to pass controlFn to foldIntoPackUnpackPatterns #143685

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

egebeysel
Copy link
Contributor

@egebeysel egebeysel commented Jun 11, 2025

This PR adds a mechanism, so that downstream consumers can pass in control functions for the application of these patterns. This change shouldn't affect any consumers of this method that do not specify a controlFn. The controlFn always gets the source operand of the consumer in each of the patterns as a parameter.

In IREE, we (will) use it to control preventing folding patterns that would inhibit fusion. See IREE issue #20896 for more details.

…rns` (llvm#22)

This PR adds a mechanism, so that downstream consumers can pass in control functions for the application of these patterns. This change shouldn't affect any consumers of this method that do not specify a controlFn.

In IREE, we (will) use it to control preventing folding patterns that would inhibit fusion. See IREE issue llvm#20896 for more details.
@egebeysel
Copy link
Contributor Author

cc @hanhanW

@llvmbot
Copy link
Member

llvmbot commented Jun 11, 2025

@llvm/pr-subscribers-mlir-tensor
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Ege Beysel (egebeysel)

Changes

This PR adds a mechanism, so that downstream consumers can pass in control functions for the application of these patterns. This change shouldn't affect any consumers of this method that do not specify a controlFn. The controlFn always gets the source operand of the consumer in each of the patterns as a parameter.

In IREE, we (will) use it to control preventing folding patterns that would inhibit fusion. See IREE issue #20896 for more details.


Full diff: https://github.com/llvm/llvm-project/pull/143685.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+9-1)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp (+76-7)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 2eef0a06d0eb4..2f0e57ca9f5a7 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1894,10 +1894,18 @@ void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns);
 /// convert to a `linalg.dot`.
 void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns);
 
+/// Function type which is used to control folding operations like `tensor.pad`
+/// and `tensor.extract_slice` into linalg.pack/unpack ops.
+using ControlFoldIntoPackUnpackFn = std::function<bool(OpOperand *opOperand)>;
+inline bool defaultControlFoldIntoPackUnpackFn(OpOperand *opOperand) {
+  return true;
+};
 /// Populates `patterns` with patterns that fold operations like `tensor.pad`
 /// and `tensor.extract_slice` into `tensor.pack` and `tensor.unpack` operations
 /// respectively.
-void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns);
+void populateFoldIntoPackAndUnpackPatterns(
+    RewritePatternSet &patterns, const ControlFoldIntoPackUnpackFn &controlFn =
+                                     defaultControlFoldIntoPackUnpackFn);
 
 /// Populates `patterns` with patterns that fold operations like `linalg.pack`
 /// and `linalg.unpack` into `tensor.empty`.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
index 0984b6988b93b..01cebb0f8e80d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
@@ -197,7 +198,9 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
 /// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
 /// the pad op has zero low paddings, or if `pack` has no padding values.
 struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
-  using OpRewritePattern<PackOp>::OpRewritePattern;
+public:
+  FoldPadWithPackOp(MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
+      : OpRewritePattern<PackOp>(context), controlFn(std::move(controlFn)) {}
 
   LogicalResult matchAndRewrite(PackOp packOp,
                                 PatternRewriter &rewriter) const override {
@@ -206,6 +209,10 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
     if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad())
       return failure();
 
+    // User controlled folding function.
+    if (!controlFn(&packOp.getSourceMutable()))
+      return failure();
+
     Value constantPaddingValue = padOp.getConstantPaddingValue();
     if (!constantPaddingValue)
       return failure();
@@ -220,13 +227,20 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
         packOp.getOuterDimsPerm());
     return success();
   }
+
+private:
+  ControlFoldIntoPackUnpackFn controlFn;
 };
 
 /// Fold a `unpack` -> `extract_slice` into the `unpack` since it already
 /// has extract_slice semantics.
 struct FoldUnpackWithExtractSliceOp
     : public OpRewritePattern<tensor::ExtractSliceOp> {
-  using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
+public:
+  FoldUnpackWithExtractSliceOp(MLIRContext *context,
+                               ControlFoldIntoPackUnpackFn controlFn)
+      : OpRewritePattern<tensor::ExtractSliceOp>(context),
+        controlFn(std::move(controlFn)) {}
 
   LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
                                 PatternRewriter &rewriter) const override {
@@ -234,6 +248,10 @@ struct FoldUnpackWithExtractSliceOp
     if (!unpackOp)
       return failure();
 
+    // User controlled folding function.
+    if (!controlFn(&sliceOp.getSourceMutable()))
+      return failure();
+
     if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) {
       return rewriter.notifyMatchFailure(
           sliceOp, "rank-reduced folding is not supported");
@@ -255,6 +273,9 @@ struct FoldUnpackWithExtractSliceOp
         unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm());
     return success();
   }
+
+private:
+  ControlFoldIntoPackUnpackFn controlFn;
 };
 
 // Applies 'permutation' on 'inVec' and stores the result in resVec.
@@ -284,7 +305,12 @@ static bool checkAndPermute(ArrayRef<int64_t> permutation,
 /// semantics.
 struct FoldProducerPackWithConsumerLinalgTransposeOp
     : public OpInterfaceRewritePattern<linalg::LinalgOp> {
-  using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
+
+public:
+  FoldProducerPackWithConsumerLinalgTransposeOp(
+      MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
+      : OpInterfaceRewritePattern<linalg::LinalgOp>(context),
+        controlFn(std::move(controlFn)) {}
 
   LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
                                 PatternRewriter &rewriter) const override {
@@ -293,6 +319,10 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
     if (!packOp)
       return failure();
 
+    // User controlled folding function.
+    if (!controlFn(&linalgOp->getOpOperand(0)))
+      return failure();
+
     FailureOr<SmallVector<int64_t>> maybePerm =
         getTransposeOpPermutation(linalgOp);
     if (failed(maybePerm))
@@ -331,13 +361,20 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
 
     return success();
   }
+
+private:
+  ControlFoldIntoPackUnpackFn controlFn;
 };
 
 /// Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose
 /// semantics.
 struct FoldConsumerPackWithProducerLinalgTransposeOp
     : public OpRewritePattern<PackOp> {
-  using OpRewritePattern<PackOp>::OpRewritePattern;
+
+public:
+  FoldConsumerPackWithProducerLinalgTransposeOp(
+      MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
+      : OpRewritePattern<PackOp>(context), controlFn(std::move(controlFn)) {}
 
   LogicalResult matchAndRewrite(PackOp packOp,
                                 PatternRewriter &rewriter) const override {
@@ -345,6 +382,10 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
     if (!linalgOp)
       return failure();
 
+    // User controlled folding function.
+    if (!controlFn(&packOp.getSourceMutable()))
+      return failure();
+
     FailureOr<SmallVector<int64_t>> maybePerm =
         getTransposeOpPermutation(linalgOp);
     if (failed(maybePerm))
@@ -375,13 +416,21 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
 
     return success();
   }
+
+private:
+  ControlFoldIntoPackUnpackFn controlFn;
 };
 
 /// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
 /// transpose semantics.
 struct FoldProducerUnPackWithConsumerLinalgTransposeOp
     : public OpInterfaceRewritePattern<linalg::LinalgOp> {
-  using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
+
+public:
+  FoldProducerUnPackWithConsumerLinalgTransposeOp(
+      MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
+      : OpInterfaceRewritePattern<linalg::LinalgOp>(context),
+        controlFn(std::move(controlFn)) {}
 
   LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
                                 PatternRewriter &rewriter) const override {
@@ -390,6 +439,10 @@ struct FoldProducerUnPackWithConsumerLinalgTransposeOp
     if (!unPackOp)
       return failure();
 
+    // User controlled folding function.
+    if (!controlFn(&linalgOp->getOpOperand(0)))
+      return failure();
+
     FailureOr<SmallVector<int64_t>> maybePerm =
         getTransposeOpPermutation(linalgOp);
     if (failed(maybePerm))
@@ -416,6 +469,9 @@ struct FoldProducerUnPackWithConsumerLinalgTransposeOp
 
     return success();
   }
+
+private:
+  ControlFoldIntoPackUnpackFn controlFn;
 };
 
 /// Fold 'transpose' -> 'unpack' into 'unpack' since 'unpack' already has
@@ -424,12 +480,21 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
     : public OpRewritePattern<UnPackOp> {
   using OpRewritePattern<UnPackOp>::OpRewritePattern;
 
+public:
+  FoldConsumerUnPackWithProducerLinalgTransposeOp(
+      MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
+      : OpRewritePattern<UnPackOp>(context), controlFn(std::move(controlFn)) {}
+
   LogicalResult matchAndRewrite(UnPackOp unPackOp,
                                 PatternRewriter &rewriter) const override {
     auto linalgOp = unPackOp.getSource().getDefiningOp<linalg::LinalgOp>();
     if (!linalgOp)
       return failure();
 
+    // User controlled folding function.
+    if (!controlFn(&unPackOp.getSourceMutable()))
+      return failure();
+
     FailureOr<SmallVector<int64_t>> maybePerm =
         getTransposeOpPermutation(linalgOp);
     if (failed(maybePerm))
@@ -474,6 +539,9 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
 
     return success();
   }
+
+private:
+  ControlFoldIntoPackUnpackFn controlFn;
 };
 
 /// tensor.empty does not define any tensor contents, so an unpadded pack
@@ -521,13 +589,14 @@ struct FoldEmptyTensorWithUnPackOp : public OpRewritePattern<UnPackOp> {
 
 } // namespace
 
-void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) {
+void populateFoldIntoPackAndUnpackPatterns(
+    RewritePatternSet &patterns, const ControlFoldIntoPackUnpackFn &controlFn) {
   patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp,
                   FoldProducerPackWithConsumerLinalgTransposeOp,
                   FoldConsumerPackWithProducerLinalgTransposeOp,
                   FoldConsumerUnPackWithProducerLinalgTransposeOp,
                   FoldProducerUnPackWithConsumerLinalgTransposeOp>(
-      patterns.getContext());
+      patterns.getContext(), controlFn);
 }
 
 void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns) {

@hanhanW hanhanW requested a review from Max191 June 11, 2025 14:42
Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is an indication that you dont want to be using pattern application. In such situations we have found it better to expose the core transformation done by the pattern as a function and then just call that function using a pass downstream. You can the have whatever control you want in that pass.

@egebeysel
Copy link
Contributor Author

I think this is an indication that you dont want to be using pattern application. In such situations we have found it better to expose the core transformation done by the pattern as a function and then just call that function using a pass downstream. You can the have whatever control you want in that pass.

Sure, I can do that. One thing I had in mind while doing this was the compatibility with current downstream users of the pattern. I guess with that change, users would also have to adjust to it.

Also, just to clarify this, in that case, I would expose every pattern inside populateFoldIntoPackUnpackPatterns as functions, right?

@egebeysel
Copy link
Contributor Author

egebeysel commented Jun 23, 2025

friendly ping @MaheshRavishankar 😊 the above might have gotten lost I believe🙂

@MaheshRavishankar
Copy link
Contributor

I think this is an indication that you dont want to be using pattern application. In such situations we have found it better to expose the core transformation done by the pattern as a function and then just call that function using a pass downstream. You can the have whatever control you want in that pass.

Sure, I can do that. One thing I had in mind while doing this was the compatibility with current downstream users of the pattern. I guess with that change, users would also have to adjust to it.

You can keep the patterns unchanged without a control function. You can add the transformation methods with the control function. The patterns can "apply always", but downstream you can change your pass to walk the function and apply the transformation as needed.

Also, just to clarify this, in that case, I would expose every pattern inside populateFoldIntoPackUnpackPatterns as functions, right?

Yeah. If you need to control all of them, then yes.

@MaheshRavishankar
Copy link
Contributor

but before I send you down that rabbit hole, @Max191 or @hanhanW what do you guys think?

@Max191
Copy link
Contributor

Max191 commented Jun 24, 2025

but before I send you down that rabbit hole, @Max191 or @hanhanW what do you guys think?

In general I don't see an issue with adding optional control functions to patterns, since it does give more control over the pattern application, although exposing the transforms as functions does give even more control (and you can always turn these transform functions into patterns anyway). That said, in this case, I do think the rewrite makes sense as a pattern application. These patterns can open up new folding opportunities, so the recursive pattern application is helpful here. Also, the patterns can be run alongside things like data layout propagation patterns, which can further expose more folding opportunities.

TLDR; I think this is going to end up being better as a pattern rewrite after all anyway, so I don't see any issue with just adding a control function to the patterns.

Copy link
Contributor

@Max191 Max191 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can add a test for the control function by adding an additional flag here:

Similar to what is done in elementwise fusion:

Option<bool> fuseWithReshapeByCollapsingWithControlFn{
*this, "fuse-with-reshape-by-collapsing-control",
llvm::cl::desc("Test controlling the linalg expand_shape -> generic "
"fusion patterns that "
"collapse the iteration space of the consumer"),
llvm::cl::init(false)};

@hanhanW
Copy link
Contributor

hanhanW commented Jun 24, 2025

but before I send you down that rabbit hole, @Max191 or @hanhanW what do you guys think?

In general I don't see an issue with adding optional control functions to patterns, since it does give more control over the pattern application, although exposing the transforms as functions does give even more control (and you can always turn these transform functions into patterns anyway). That said, in this case, I do think the rewrite makes sense as a pattern application. These patterns can open up new folding opportunities, so the recursive pattern application is helpful here. Also, the patterns can be run alongside things like data layout propagation patterns, which can further expose more folding opportunities.

TLDR; I think this is going to end up being better as a pattern rewrite after all anyway, so I don't see any issue with just adding a control function to the patterns.

I made the suggestion of having controlFn in the first place, but I did not think that much. After reading both your points, I'm +1 on having controlFn because of what Max said.

@MaheshRavishankar MaheshRavishankar dismissed their stale review June 24, 2025 23:33

Defering to Hanhan/Max's review

Copy link
Contributor

@Max191 Max191 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, just a comment about the tests.

@@ -797,4 +892,4 @@ func.func @unpack_generic_transpose_fold(%arg0: tensor<57x3x56x1x64xf32>) -> ten
// CHECK-SAME: inner_tiles = [1, 64]
// CHECK-SAME: into %[[OUT:.+]] : tensor<57x3x56x1x64xf32> -> tensor<3648x3x56xf32>
// CHECK: return %[[UNPACK]] : tensor<3648x3x56xf32>
// CHECK: }
// CHECK: }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: new line at end of file.

%1 : tensor<1x1x16x4xi32> -> tensor<16x4xi32>
return %transposed, %unpack : tensor<1x1x16x4xi32>, tensor<16x4xi32>
}
//CHECK-LABEL: func.func @linalg_transpose_linalg.unpack_fold_multi_result(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests can be less verbose in the checks. The main thing this is testing is whether or not the fusion happened with mutli-result, so you don't need to check all the shapes/op metadata. You could just check the operands and results of the operations to verify the producers/consumers.

into %1 : tensor<1x56x57x64xf32> -> tensor<1x57x56x2x32xf32>
return %transposed, %pack : tensor<1x56x57x64xf32>, tensor<1x57x56x2x32xf32>
}
// CHECK-LABEL: func @linalg_transpose_linalg.pack_fold_multi_result(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same with these tests. The checks can be less verbose.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants