Skip to content

Commit 7aabf47

Browse files
authored
[mlir][affine] Add pass --affine-raise-from-memref (#138004)
This adds a pass that converts memref.load/store into affine.load/store. This is useful as those memref operators are ignored by passes like --affine-scalrep as they don't implement the Affine[Read/Write]OpInterface. Doing this allows you to put as much of your program in affine form before you apply affine optimization passes. This also slightly changes the implementation of affine::isValidDim. The previous implementation allowed values from the iter_args of affine loops to be used as valid dims. I think this doesn't make sense and what was meant is just the induction vars. In the real world, there is little reason to find an index in the iter_args, but I wrote that in my tests and found out it was treated as an affine dim, so corrected that. Co-authored-by: Oleksandr "Alex" Zinenko <[email protected]> Rebased from #114032.
1 parent bb2aa1a commit 7aabf47

File tree

6 files changed

+355
-6
lines changed

6 files changed

+355
-6
lines changed

mlir/include/mlir/Dialect/Affine/Passes.h

+10
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ namespace mlir {
2323
namespace func {
2424
class FuncOp;
2525
} // namespace func
26+
namespace memref {
27+
class MemRefDialect;
28+
} // namespace memref
2629

2730
namespace affine {
2831
class AffineForOp;
@@ -45,6 +48,13 @@ createSimplifyAffineStructuresPass();
4548
std::unique_ptr<OperationPass<func::FuncOp>>
4649
createAffineLoopInvariantCodeMotionPass();
4750

51+
/// Creates a pass to convert all parallel affine.for's into 1-d affine.parallel
52+
/// ops.
53+
std::unique_ptr<OperationPass<func::FuncOp>> createAffineParallelizePass();
54+
55+
/// Creates a pass that converts some memref operators to affine operators.
56+
std::unique_ptr<OperationPass<func::FuncOp>> createRaiseMemrefToAffine();
57+
4858
/// Apply normalization transformations to affine loop-like ops. If
4959
/// `promoteSingleIter` is true, single iteration loops are promoted (i.e., the
5060
/// loop is replaced by its loop body).

mlir/include/mlir/Dialect/Affine/Passes.td

+12
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,18 @@ def LoopCoalescing : Pass<"affine-loop-coalescing", "func::FuncOp"> {
396396
let dependentDialects = ["affine::AffineDialect","arith::ArithDialect"];
397397
}
398398

399+
def RaiseMemrefDialect : Pass<"affine-raise-from-memref", "func::FuncOp"> {
400+
let summary = "Turn some memref operators to affine operators where supported";
401+
let description = [{
402+
Raise memref.load and memref.store to affine.store and affine.load, inferring
403+
the affine map of those operators if needed. This allows passes like --affine-scalrep
404+
to optimize those loads and stores (forwarding them or eliminating them).
405+
They can be turned back to memref dialect ops with --lower-affine.
406+
}];
407+
let constructor = "mlir::affine::createRaiseMemrefToAffine()";
408+
let dependentDialects = ["affine::AffineDialect"];
409+
}
410+
399411
def SimplifyAffineStructures : Pass<"affine-simplify-structures", "func::FuncOp"> {
400412
let summary = "Simplify affine expressions in maps/sets and normalize "
401413
"memrefs";

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

+7-6
Original file line numberDiff line numberDiff line change
@@ -294,10 +294,12 @@ bool mlir::affine::isValidDim(Value value) {
294294
return isValidDim(value, getAffineScope(defOp));
295295

296296
// This value has to be a block argument for an op that has the
297-
// `AffineScope` trait or for an affine.for or affine.parallel.
297+
// `AffineScope` trait or an induction var of an affine.for or
298+
// affine.parallel.
299+
if (isAffineInductionVar(value))
300+
return true;
298301
auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
299-
return parentOp && (parentOp->hasTrait<OpTrait::AffineScope>() ||
300-
isa<AffineForOp, AffineParallelOp>(parentOp));
302+
return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
301303
}
302304

303305
// Value can be used as a dimension id iff it meets one of the following
@@ -316,10 +318,9 @@ bool mlir::affine::isValidDim(Value value, Region *region) {
316318

317319
auto *op = value.getDefiningOp();
318320
if (!op) {
319-
// This value has to be a block argument for an affine.for or an
321+
// This value has to be an induction var for an affine.for or an
320322
// affine.parallel.
321-
auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
322-
return isa<AffineForOp, AffineParallelOp>(parentOp);
323+
return isAffineInductionVar(value);
323324
}
324325

325326
// Affine apply operation is ok if all of its operands are ok.

mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRAffineTransforms
1313
LoopUnroll.cpp
1414
LoopUnrollAndJam.cpp
1515
PipelineDataTransfer.cpp
16+
RaiseMemrefDialect.cpp
1617
ReifyValueBounds.cpp
1718
SuperVectorize.cpp
1819
SimplifyAffineStructures.cpp
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
//===- RaiseMemrefDialect.cpp - raise memref.store and load to affine ops -===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements functionality to convert memref load and store ops to
10+
// the corresponding affine ops, inferring the affine map as needed.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Dialect/Affine/Analysis/Utils.h"
15+
#include "mlir/Dialect/Affine/Passes.h"
16+
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
17+
#include "mlir/Dialect/Affine/Utils.h"
18+
#include "mlir/Dialect/Func/IR/FuncOps.h"
19+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
20+
#include "mlir/IR/AffineExpr.h"
21+
#include "mlir/IR/Matchers.h"
22+
#include "mlir/IR/Operation.h"
23+
#include "mlir/Pass/Pass.h"
24+
#include "llvm/Support/Casting.h"
25+
#include "llvm/Support/Debug.h"
26+
27+
namespace mlir {
28+
namespace affine {
29+
#define GEN_PASS_DEF_RAISEMEMREFDIALECT
30+
#include "mlir/Dialect/Affine/Passes.h.inc"
31+
} // namespace affine
32+
} // namespace mlir
33+
34+
#define DEBUG_TYPE "raise-memref-to-affine"
35+
36+
using namespace mlir;
37+
using namespace mlir::affine;
38+
39+
namespace {
40+
41+
/// Find the index of the given value in the `dims` list,
42+
/// and append it if it was not already in the list. The
43+
/// dims list is a list of symbols or dimensions of the
44+
/// affine map. Within the results of an affine map, they
45+
/// are identified by their index, which is why we need
46+
/// this function.
47+
static std::optional<size_t>
48+
findInListOrAdd(Value value, llvm::SmallVectorImpl<Value> &dims,
49+
function_ref<bool(Value)> isValidElement) {
50+
51+
Value *loopIV = std::find(dims.begin(), dims.end(), value);
52+
if (loopIV != dims.end()) {
53+
// We found an IV that already has an index, return that index.
54+
return {std::distance(dims.begin(), loopIV)};
55+
}
56+
if (isValidElement(value)) {
57+
// This is a valid element for the dim/symbol list, push this as a
58+
// parameter.
59+
size_t idx = dims.size();
60+
dims.push_back(value);
61+
return idx;
62+
}
63+
return std::nullopt;
64+
}
65+
66+
/// Convert a value to an affine expr if possible. Adds dims and symbols
67+
/// if needed.
68+
static AffineExpr toAffineExpr(Value value,
69+
llvm::SmallVectorImpl<Value> &affineDims,
70+
llvm::SmallVectorImpl<Value> &affineSymbols) {
71+
using namespace matchers;
72+
IntegerAttr::ValueType cst;
73+
if (matchPattern(value, m_ConstantInt(&cst))) {
74+
return getAffineConstantExpr(cst.getSExtValue(), value.getContext());
75+
}
76+
77+
Operation *definingOp = value.getDefiningOp();
78+
if (llvm::isa_and_nonnull<arith::AddIOp>(definingOp) ||
79+
llvm::isa_and_nonnull<arith::MulIOp>(definingOp)) {
80+
// TODO: replace recursion with explicit stack.
81+
// For the moment this can be tolerated as we only recurse on
82+
// arith.addi and arith.muli, so there cannot be any infinite
83+
// recursion. The depth of these expressions should be in most
84+
// cases very manageable, as affine expressions should be as
85+
// simple as `a + b * c`.
86+
AffineExpr lhsE =
87+
toAffineExpr(definingOp->getOperand(0), affineDims, affineSymbols);
88+
AffineExpr rhsE =
89+
toAffineExpr(definingOp->getOperand(1), affineDims, affineSymbols);
90+
91+
if (lhsE && rhsE) {
92+
AffineExprKind kind;
93+
if (isa<arith::AddIOp>(definingOp)) {
94+
kind = mlir::AffineExprKind::Add;
95+
} else {
96+
kind = mlir::AffineExprKind::Mul;
97+
98+
if (!lhsE.isSymbolicOrConstant() && !rhsE.isSymbolicOrConstant()) {
99+
// This is not an affine expression, give up.
100+
return {};
101+
}
102+
}
103+
return getAffineBinaryOpExpr(kind, lhsE, rhsE);
104+
}
105+
return {};
106+
}
107+
108+
if (auto dimIx = findInListOrAdd(value, affineSymbols, [](Value v) {
109+
return affine::isValidSymbol(v);
110+
})) {
111+
return getAffineSymbolExpr(*dimIx, value.getContext());
112+
}
113+
114+
if (auto dimIx = findInListOrAdd(
115+
value, affineDims, [](Value v) { return affine::isValidDim(v); })) {
116+
117+
return getAffineDimExpr(*dimIx, value.getContext());
118+
}
119+
120+
return {};
121+
}
122+
123+
static LogicalResult
124+
computeAffineMapAndArgs(MLIRContext *ctx, ValueRange indices, AffineMap &map,
125+
llvm::SmallVectorImpl<Value> &mapArgs) {
126+
SmallVector<AffineExpr> results;
127+
SmallVector<Value> symbols;
128+
SmallVector<Value> dims;
129+
130+
for (Value indexExpr : indices) {
131+
AffineExpr res = toAffineExpr(indexExpr, dims, symbols);
132+
if (!res) {
133+
return failure();
134+
}
135+
results.push_back(res);
136+
}
137+
138+
map = AffineMap::get(dims.size(), symbols.size(), results, ctx);
139+
140+
dims.append(symbols);
141+
mapArgs.swap(dims);
142+
return success();
143+
}
144+
145+
struct RaiseMemrefDialect
146+
: public affine::impl::RaiseMemrefDialectBase<RaiseMemrefDialect> {
147+
148+
void runOnOperation() override {
149+
auto *ctx = &getContext();
150+
Operation *op = getOperation();
151+
IRRewriter rewriter(ctx);
152+
AffineMap map;
153+
SmallVector<Value> mapArgs;
154+
op->walk([&](Operation *op) {
155+
rewriter.setInsertionPoint(op);
156+
if (auto store = llvm::dyn_cast_or_null<memref::StoreOp>(op)) {
157+
158+
if (succeeded(computeAffineMapAndArgs(ctx, store.getIndices(), map,
159+
mapArgs))) {
160+
rewriter.replaceOpWithNewOp<AffineStoreOp>(
161+
op, store.getValueToStore(), store.getMemRef(), map, mapArgs);
162+
return;
163+
}
164+
165+
LLVM_DEBUG(llvm::dbgs()
166+
<< "[affine] Cannot raise memref op: " << op << "\n");
167+
168+
} else if (auto load = llvm::dyn_cast_or_null<memref::LoadOp>(op)) {
169+
if (succeeded(computeAffineMapAndArgs(ctx, load.getIndices(), map,
170+
mapArgs))) {
171+
rewriter.replaceOpWithNewOp<AffineLoadOp>(op, load.getMemRef(), map,
172+
mapArgs);
173+
return;
174+
}
175+
LLVM_DEBUG(llvm::dbgs()
176+
<< "[affine] Cannot raise memref op: " << op << "\n");
177+
}
178+
});
179+
}
180+
};
181+
182+
} // namespace
183+
184+
std::unique_ptr<OperationPass<func::FuncOp>>
185+
mlir::affine::createRaiseMemrefToAffine() {
186+
return std::make_unique<RaiseMemrefDialect>();
187+
}

0 commit comments

Comments
 (0)