Skip to content

feat(linalg): add a way to pass controlFn to foldIntoPackUnpackPatterns #143685

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1894,10 +1894,15 @@ void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns);
/// convert to a `linalg.dot`.
void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns);

/// Function type which is used to control folding operations like `tensor.pad`
/// and `tensor.extract_slice` into linalg.pack/unpack ops.
using ControlFoldIntoPackUnpackFn = std::function<bool(OpOperand *opOperand)>;
/// Populates `patterns` with patterns that fold operations like `tensor.pad`
/// and `tensor.extract_slice` into `tensor.pack` and `tensor.unpack` operations
/// respectively.
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns);
void populateFoldIntoPackAndUnpackPatterns(
RewritePatternSet &patterns,
const ControlFoldIntoPackUnpackFn &controlFn = nullptr);

/// Populates `patterns` with patterns that fold operations like `linalg.pack`
/// and `linalg.unpack` into `tensor.empty`.
Expand Down
83 changes: 76 additions & 7 deletions mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
Expand Down Expand Up @@ -197,7 +198,9 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
/// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
/// the pad op has zero low paddings, or if `pack` has no padding values.
struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
using OpRewritePattern<PackOp>::OpRewritePattern;
public:
FoldPadWithPackOp(MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
: OpRewritePattern<PackOp>(context), controlFn(std::move(controlFn)) {}

LogicalResult matchAndRewrite(PackOp packOp,
PatternRewriter &rewriter) const override {
Expand All @@ -206,6 +209,10 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad())
return failure();

// User controlled folding function.
if (controlFn && !controlFn(&packOp.getSourceMutable()))
return failure();

Value constantPaddingValue = padOp.getConstantPaddingValue();
if (!constantPaddingValue)
return failure();
Expand All @@ -220,20 +227,31 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
packOp.getOuterDimsPerm());
return success();
}

private:
ControlFoldIntoPackUnpackFn controlFn;
};

/// Fold a `unpack` -> `extract_slice` into the `unpack` since it already
/// has extract_slice semantics.
struct FoldUnpackWithExtractSliceOp
: public OpRewritePattern<tensor::ExtractSliceOp> {
using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
public:
FoldUnpackWithExtractSliceOp(MLIRContext *context,
ControlFoldIntoPackUnpackFn controlFn)
: OpRewritePattern<tensor::ExtractSliceOp>(context),
controlFn(std::move(controlFn)) {}

LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
PatternRewriter &rewriter) const override {
auto unpackOp = sliceOp.getSource().getDefiningOp<UnPackOp>();
if (!unpackOp)
return failure();

// User controlled folding function.
if (controlFn && !controlFn(&sliceOp.getSourceMutable()))
return failure();

if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) {
return rewriter.notifyMatchFailure(
sliceOp, "rank-reduced folding is not supported");
Expand All @@ -255,6 +273,9 @@ struct FoldUnpackWithExtractSliceOp
unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm());
return success();
}

private:
ControlFoldIntoPackUnpackFn controlFn;
};

// Applies 'permutation' on 'inVec' and stores the result in resVec.
Expand Down Expand Up @@ -284,7 +305,12 @@ static bool checkAndPermute(ArrayRef<int64_t> permutation,
/// semantics.
struct FoldProducerPackWithConsumerLinalgTransposeOp
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;

public:
FoldProducerPackWithConsumerLinalgTransposeOp(
MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
: OpInterfaceRewritePattern<linalg::LinalgOp>(context),
controlFn(std::move(controlFn)) {}

LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
PatternRewriter &rewriter) const override {
Expand All @@ -293,6 +319,10 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
if (!packOp)
return failure();

// User controlled folding function.
if (controlFn && !controlFn(&linalgOp->getOpOperand(0)))
return failure();

FailureOr<SmallVector<int64_t>> maybePerm =
getTransposeOpPermutation(linalgOp);
if (failed(maybePerm))
Expand Down Expand Up @@ -331,20 +361,31 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp

return success();
}

private:
ControlFoldIntoPackUnpackFn controlFn;
};

/// Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose
/// semantics.
struct FoldConsumerPackWithProducerLinalgTransposeOp
: public OpRewritePattern<PackOp> {
using OpRewritePattern<PackOp>::OpRewritePattern;

public:
FoldConsumerPackWithProducerLinalgTransposeOp(
MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
: OpRewritePattern<PackOp>(context), controlFn(std::move(controlFn)) {}

LogicalResult matchAndRewrite(PackOp packOp,
PatternRewriter &rewriter) const override {
auto linalgOp = packOp.getSource().getDefiningOp<linalg::LinalgOp>();
if (!linalgOp)
return failure();

// User controlled folding function.
if (controlFn && !controlFn(&packOp.getSourceMutable()))
return failure();

FailureOr<SmallVector<int64_t>> maybePerm =
getTransposeOpPermutation(linalgOp);
if (failed(maybePerm))
Expand Down Expand Up @@ -375,13 +416,21 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp

return success();
}

private:
ControlFoldIntoPackUnpackFn controlFn;
};

/// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
/// transpose semantics.
struct FoldProducerUnPackWithConsumerLinalgTransposeOp
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;

public:
FoldProducerUnPackWithConsumerLinalgTransposeOp(
MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
: OpInterfaceRewritePattern<linalg::LinalgOp>(context),
controlFn(std::move(controlFn)) {}

LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
PatternRewriter &rewriter) const override {
Expand All @@ -390,6 +439,10 @@ struct FoldProducerUnPackWithConsumerLinalgTransposeOp
if (!unPackOp)
return failure();

// User controlled folding function.
if (controlFn && !controlFn(&linalgOp->getOpOperand(0)))
return failure();

FailureOr<SmallVector<int64_t>> maybePerm =
getTransposeOpPermutation(linalgOp);
if (failed(maybePerm))
Expand All @@ -416,6 +469,9 @@ struct FoldProducerUnPackWithConsumerLinalgTransposeOp

return success();
}

private:
ControlFoldIntoPackUnpackFn controlFn;
};

/// Fold 'transpose' -> 'unpack' into 'unpack' since 'unpack' already has
Expand All @@ -424,12 +480,21 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
: public OpRewritePattern<UnPackOp> {
using OpRewritePattern<UnPackOp>::OpRewritePattern;

public:
FoldConsumerUnPackWithProducerLinalgTransposeOp(
MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
: OpRewritePattern<UnPackOp>(context), controlFn(std::move(controlFn)) {}

LogicalResult matchAndRewrite(UnPackOp unPackOp,
PatternRewriter &rewriter) const override {
auto linalgOp = unPackOp.getSource().getDefiningOp<linalg::LinalgOp>();
if (!linalgOp)
return failure();

// User controlled folding function.
if (controlFn && !controlFn(&unPackOp.getSourceMutable()))
return failure();

FailureOr<SmallVector<int64_t>> maybePerm =
getTransposeOpPermutation(linalgOp);
if (failed(maybePerm))
Expand Down Expand Up @@ -474,6 +539,9 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp

return success();
}

private:
ControlFoldIntoPackUnpackFn controlFn;
};

/// tensor.empty does not define any tensor contents, so an unpadded pack
Expand Down Expand Up @@ -521,13 +589,14 @@ struct FoldEmptyTensorWithUnPackOp : public OpRewritePattern<UnPackOp> {

} // namespace

void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) {
void populateFoldIntoPackAndUnpackPatterns(
RewritePatternSet &patterns, const ControlFoldIntoPackUnpackFn &controlFn) {
patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp,
FoldProducerPackWithConsumerLinalgTransposeOp,
FoldConsumerPackWithProducerLinalgTransposeOp,
FoldConsumerUnPackWithProducerLinalgTransposeOp,
FoldProducerUnPackWithConsumerLinalgTransposeOp>(
patterns.getContext());
patterns.getContext(), controlFn);
}

void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns) {
Expand Down
61 changes: 61 additions & 0 deletions mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-fold-into-pack-and-unpack %s | FileCheck %s
// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-fold-into-pack-and-unpack-control %s | FileCheck %s --check-prefix=CONTROL

func.func @fold_unpack_slice(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
%arg2 : index, %arg3 : index) -> tensor<?x?xf32> {
Expand Down Expand Up @@ -373,6 +374,36 @@ func.func @linalg_transpose_linalg.pack_fold(%arg0: tensor<56x57x1x64xf32>) -> t

// -----

func.func @linalg_transpose_linalg.pack_fold_multi_result(%arg0: tensor<56x57x1x64xf32>) -> (tensor<1x56x57x64xf32>, tensor<1x57x56x2x32xf32>) {
%0 = tensor.empty() : tensor<1x56x57x64xf32>
%transposed = linalg.transpose
ins(%arg0 : tensor<56x57x1x64xf32>)
outs(%0 : tensor<1x56x57x64xf32>)
permutation = [2, 0, 1, 3]

%1 = tensor.empty() : tensor<1x57x56x2x32xf32>
%pack = linalg.pack %transposed
outer_dims_perm = [0, 2, 1, 3]
inner_dims_pos = [3]
inner_tiles = [32]
into %1 : tensor<1x56x57x64xf32> -> tensor<1x57x56x2x32xf32>
return %transposed, %pack : tensor<1x56x57x64xf32>, tensor<1x57x56x2x32xf32>
}
// CHECK-LABEL: func @linalg_transpose_linalg.pack_fold_multi_result(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same with these tests. The checks can be less verbose.

// CHECK-SAME: %[[ARG0:.+]]: tensor<56x57x1x64xf32>)
// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose
// CHECK: %[[PACK:.+]] = linalg.pack %[[ARG0]]
// CHECK-SAME: outer_dims_perm = [2, 1, 0, 3]
// CHECK: return %[[TRANSPOSE]], %[[PACK]]

// CONTROL-LABEL: func @linalg_transpose_linalg.pack_fold_multi_result(
// CONTROL: %[[TRANSPOSE:.+]] = linalg.transpose
// CONTROL: %[[PACK:.+]] = linalg.pack %[[TRANSPOSE]]
// CONTROL-SAME: outer_dims_perm = [0, 2, 1, 3]
// CONTROL: return %[[TRANSPOSE]], %[[PACK]]

// -----

func.func @linalg_transpose_linalg.pack_fold_with_padding(%arg0: tensor<56x57x1x55xf32>, %padding: f32) -> tensor<1x57x56x2x32xf32> {
%0 = tensor.empty() : tensor<1x56x57x55xf32>
%transpose = linalg.transpose
Expand Down Expand Up @@ -550,6 +581,36 @@ func.func @linalg_transpose_linalg.unpack_fold(%arg0: tensor<1x1x4x16xi32>) -> t

// -----

func.func @linalg_transpose_linalg.unpack_fold_multi_result(%arg0: tensor<1x1x4x16xi32>) -> (tensor<1x1x16x4xi32>, tensor<16x4xi32>) {
%0 = tensor.empty() : tensor<1x1x16x4xi32>
%transposed = linalg.transpose ins(%arg0 : tensor<1x1x4x16xi32>)
outs(%0 : tensor<1x1x16x4xi32>)
permutation = [1, 0, 3, 2]
%1 = tensor.empty() : tensor<16x4xi32>
%unpack = linalg.unpack %transposed
outer_dims_perm = [0, 1]
inner_dims_pos = [0, 1]
inner_tiles = [16, 4] into
%1 : tensor<1x1x16x4xi32> -> tensor<16x4xi32>
return %transposed, %unpack : tensor<1x1x16x4xi32>, tensor<16x4xi32>
}
//CHECK-LABEL: func.func @linalg_transpose_linalg.unpack_fold_multi_result(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests can be less verbose in the checks. The main thing this is testing is whether or not the fusion happened with mutli-result, so you don't need to check all the shapes/op metadata. You could just check the operands and results of the operations to verify the producers/consumers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made them less verbose, thanks! I definitely got a little caught up while writing those 😂

// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x4x16xi32>)
// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[ARG0]]
// CHECK-SAME: outer_dims_perm = [1, 0]
// CHECK: return %[[TRANSPOSE]], %[[UNPACK]]
// CHECK: }

//CONTROL-LABEL: func.func @linalg_transpose_linalg.unpack_fold_multi_result(
// CONTROL: %[[TRANSPOSE:.+]] = linalg.transpose
// CONTROL: %[[UNPACK:.+]] = linalg.unpack %[[TRANSPOSE]]
// CONTROL-SAME: outer_dims_perm = [0, 1]
// CONTROL: return %[[TRANSPOSE]], %[[UNPACK]]
// CONTROL: }

// -----

func.func @linalg_transpose_linalg.unpack_fold_partial_tile(%arg0: tensor<1x1x4x16xi32>) -> tensor<15x3xi32> {
%0 = tensor.empty() : tensor<1x1x16x4xi32>
%transposed = linalg.transpose ins(%arg0 : tensor<1x1x4x16xi32>)
Expand Down
24 changes: 22 additions & 2 deletions mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,11 @@ struct TestLinalgTransforms
*this, "test-fold-into-pack-and-unpack",
llvm::cl::desc("Test folding ops into linalg.pack and linalg.unpack"),
llvm::cl::init(false)};
Option<bool> testFoldIntoPackAndUnpackWithControlFn{
*this, "test-fold-into-pack-and-unpack-control",
llvm::cl::desc(
"Test controlling folding ops into linalg.pack and linalg.unpack"),
llvm::cl::init(false)};
Option<bool> testSimplifyPackUnpackPatterns{
*this, "test-simplify-pack-unpack-patterns",
llvm::cl::desc("Test patterns to simplify linalg.pack and linalg.unpack"),
Expand Down Expand Up @@ -236,9 +241,11 @@ static void applyDecomposeWinogradOps(func::FuncOp funcOp) {
(void)applyPatternsGreedily(funcOp, std::move(patterns));
}

static void applyFoldIntoPackAndUnpackPatterns(Operation *rootOp) {
static void applyFoldIntoPackAndUnpackPatterns(
Operation *rootOp,
linalg::ControlFoldIntoPackUnpackFn controlFn = nullptr) {
RewritePatternSet patterns(rootOp->getContext());
linalg::populateFoldIntoPackAndUnpackPatterns(patterns);
linalg::populateFoldIntoPackAndUnpackPatterns(patterns, controlFn);
(void)applyPatternsGreedily(rootOp, std::move(patterns));
}

Expand Down Expand Up @@ -279,6 +286,19 @@ void TestLinalgTransforms::runOnOperation() {
Operation *rootOp = getOperation();
if (testFoldIntoPackAndUnpack)
applyFoldIntoPackAndUnpackPatterns(rootOp);
if (testFoldIntoPackAndUnpackWithControlFn) {
linalg::ControlFoldIntoPackUnpackFn controlFn = [](OpOperand *opOperand) {
Operation *producer = opOperand->get().getDefiningOp();
Operation *consumer = opOperand->getOwner();
// If we have a pack/unpack consumer and a producer that has multiple
// uses, do not apply the folding patterns.
if (isa<linalg::PackOp, linalg::UnPackOp>(consumer) &&
isa<TilingInterface>(producer) && !producer->hasOneUse())
return false;
return true;
};
applyFoldIntoPackAndUnpackPatterns(rootOp, controlFn);
}
if (testSimplifyPackUnpackPatterns)
applySimplifyPackUnpackPatterns(rootOp);
}
Expand Down
Loading