Skip to content

Commit ecffe33

Browse files
committed
[MLIR][SCF] Sink scf.if from scf.while before region into after region.
1 parent 5f0f758 commit ecffe33

File tree

2 files changed

+170
-1
lines changed

2 files changed

+170
-1
lines changed

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 133 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3546,6 +3546,137 @@ LogicalResult scf::WhileOp::verify() {
35463546
}
35473547

35483548
namespace {
3549+
/// Move an scf.if op that is directly before the scf.condition op in the while
3550+
/// before region, and whose condition matches the condition of the
3551+
/// scf.condition op, down into the while after region.
3552+
///
3553+
/// scf.while (..) : (...) -> ... {
3554+
/// %additional_used_values = ...
3555+
/// %cond = ...
3556+
/// ...
3557+
/// %res = scf.if %cond -> (...) {
3558+
/// use(%additional_used_values)
3559+
/// ... // then block
3560+
/// scf.yield %then_value
3561+
/// } else {
3562+
/// scf.yield %else_value
3563+
/// }
3564+
/// scf.condition(%cond) %res, ...
3565+
/// } do {
3566+
/// ^bb0(%res_arg, ...):
3567+
/// use(%res_arg)
3568+
/// ...
3569+
///
3570+
/// becomes
3571+
/// scf.while (..) : (...) -> ... {
3572+
/// %additional_used_values = ...
3573+
/// %cond = ...
3574+
/// ...
3575+
/// scf.condition(%cond) %else_value, ..., %additional_used_values
3576+
/// } do {
3577+
/// ^bb0(%res_arg ..., %additional_args): :
3578+
/// use(%additional_args)
3579+
/// ... // if then block
3580+
/// use(%then_value)
3581+
/// ...
3582+
struct WhileMoveIfDown : public OpRewritePattern<WhileOp> {
3583+
using OpRewritePattern<WhileOp>::OpRewritePattern;
3584+
3585+
LogicalResult matchAndRewrite(WhileOp op,
3586+
PatternRewriter &rewriter) const override {
3587+
auto conditionOp =
3588+
cast<scf::ConditionOp>(op.getBeforeBody()->getTerminator());
3589+
auto ifOp = dyn_cast_or_null<scf::IfOp>(conditionOp->getPrevNode());
3590+
3591+
// Check that the ifOp is directly before the conditionOp and that it
3592+
// matches the condition of the conditionOp. Also ensure that the ifOp has
3593+
// no else block with content, as that would complicate the transformation.
3594+
// TODO: support else blocks with content.
3595+
if (!ifOp || ifOp.getCondition() != conditionOp.getCondition() ||
3596+
(ifOp.elseBlock() && !ifOp.elseBlock()->without_terminator().empty()))
3597+
return failure();
3598+
3599+
assert(ifOp->use_empty() || (llvm::all_equal(ifOp->getUsers()) &&
3600+
*ifOp->user_begin() == conditionOp) &&
3601+
"ifOp has unexpected uses");
3602+
3603+
Location loc = op.getLoc();
3604+
3605+
// Replace uses of ifOp results in the conditionOp with the yielded values
3606+
// from the ifOp branches.
3607+
for (auto [idx, arg] : llvm::enumerate(conditionOp.getArgs())) {
3608+
auto it = llvm::find(ifOp->getResults(), arg);
3609+
if (it != ifOp->getResults().end()) {
3610+
size_t ifOpIdx = it.getIndex();
3611+
Value thenValue = ifOp.thenYield()->getOperand(ifOpIdx);
3612+
Value elseValue = ifOp.elseYield()->getOperand(ifOpIdx);
3613+
3614+
rewriter.replaceAllUsesWith(ifOp->getResults()[ifOpIdx], elseValue);
3615+
rewriter.replaceAllUsesWith(op.getAfterArguments()[idx], thenValue);
3616+
}
3617+
}
3618+
3619+
SmallVector<Value> additionalUsedValues;
3620+
auto isValueUsedInsideIf = [&](Value val) {
3621+
return llvm::any_of(val.getUsers(), [&](Operation *user) {
3622+
return ifOp.getThenRegion().isAncestor(user->getParentRegion());
3623+
});
3624+
};
3625+
3626+
// Collect additional used values from before region.
3627+
for (Operation *it = ifOp->getPrevNode(); it != nullptr;
3628+
it = it->getPrevNode())
3629+
llvm::copy_if(it->getResults(), std::back_inserter(additionalUsedValues),
3630+
isValueUsedInsideIf);
3631+
3632+
llvm::copy_if(op.getBeforeArguments(),
3633+
std::back_inserter(additionalUsedValues),
3634+
isValueUsedInsideIf);
3635+
3636+
// Create new whileOp with additional used values as results.
3637+
auto additionalValueTypes = llvm::map_to_vector(
3638+
additionalUsedValues, [](Value val) { return val.getType(); });
3639+
size_t additionalValueSize = additionalUsedValues.size();
3640+
SmallVector<Type> newResultTypes(op.getResultTypes());
3641+
newResultTypes.append(additionalValueTypes);
3642+
3643+
auto newWhileOp =
3644+
scf::WhileOp::create(rewriter, loc, newResultTypes, op.getInits());
3645+
3646+
newWhileOp.getBefore().takeBody(op.getBefore());
3647+
newWhileOp.getAfter().takeBody(op.getAfter());
3648+
newWhileOp.getAfter().addArguments(
3649+
additionalValueTypes, SmallVector<Location>(additionalValueSize, loc));
3650+
3651+
SmallVector<Value> conditionArgs = conditionOp.getArgs();
3652+
llvm::append_range(conditionArgs, additionalUsedValues);
3653+
3654+
// Update conditionOp inside new whileOp before region.
3655+
rewriter.setInsertionPoint(conditionOp);
3656+
rewriter.replaceOpWithNewOp<scf::ConditionOp>(
3657+
conditionOp, conditionOp.getCondition(), conditionArgs);
3658+
3659+
// Replace uses of additional used values inside the ifOp then region with
3660+
// the whileOp after region arguments.
3661+
rewriter.replaceUsesWithIf(
3662+
additionalUsedValues,
3663+
newWhileOp.getAfterArguments().take_back(additionalValueSize),
3664+
[&](OpOperand &use) {
3665+
return ifOp.getThenRegion().isAncestor(
3666+
use.getOwner()->getParentRegion());
3667+
});
3668+
3669+
// Inline ifOp then region into new whileOp after region.
3670+
rewriter.eraseOp(ifOp.thenYield());
3671+
rewriter.inlineBlockBefore(ifOp.thenBlock(), newWhileOp.getAfterBody(),
3672+
newWhileOp.getAfterBody()->begin());
3673+
rewriter.eraseOp(ifOp);
3674+
rewriter.replaceOp(op,
3675+
newWhileOp->getResults().drop_back(additionalValueSize));
3676+
return success();
3677+
}
3678+
};
3679+
35493680
/// Replace uses of the condition within the do block with true, since otherwise
35503681
/// the block would not be evaluated.
35513682
///
@@ -4258,7 +4389,8 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
42584389
results.add<RemoveLoopInvariantArgsFromBeforeBlock,
42594390
RemoveLoopInvariantValueYielded, WhileConditionTruth,
42604391
WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
4261-
WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
4392+
WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs, WhileMoveIfDown>(
4393+
context);
42624394
}
42634395

42644396
//===----------------------------------------------------------------------===//

mlir/test/Dialect/SCF/canonicalize.mlir

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -974,6 +974,43 @@ func.func @replace_if_with_cond3(%arg0 : i1, %arg2: i64) -> (i32, i64) {
974974

975975
// -----
976976

977+
// CHECK-LABEL: @while_move_if_down
978+
func.func @while_move_if_down() -> i32 {
979+
%0 = scf.while () : () -> (i32) {
980+
%additional_used_value = "test.get_some_value1" () : () -> (i32)
981+
%else_value = "test.get_some_value2" () : () -> (i32)
982+
%condition = "test.condition"() : () -> i1
983+
%res = scf.if %condition -> (i32) {
984+
"test.use1" (%additional_used_value) : (i32) -> ()
985+
%then_value = "test.get_some_value3" () : () -> (i32)
986+
scf.yield %then_value : i32
987+
} else {
988+
scf.yield %else_value : i32
989+
}
990+
scf.condition(%condition) %res : i32
991+
} do {
992+
^bb0(%res_arg: i32):
993+
"test.use2" (%res_arg) : (i32) -> ()
994+
scf.yield
995+
}
996+
return %0 : i32
997+
}
998+
// CHECK-NEXT: %[[WHILE_0:.*]]:2 = scf.while : () -> (i32, i32) {
999+
// CHECK-NEXT: %[[VAL_0:.*]] = "test.get_some_value1"() : () -> i32
1000+
// CHECK-NEXT: %[[VAL_1:.*]] = "test.get_some_value2"() : () -> i32
1001+
// CHECK-NEXT: %[[VAL_2:.*]] = "test.condition"() : () -> i1
1002+
// CHECK-NEXT: scf.condition(%[[VAL_2]]) %[[VAL_1]], %[[VAL_0]] : i32, i32
1003+
// CHECK-NEXT: } do {
1004+
// CHECK-NEXT: ^bb0(%[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32):
1005+
// CHECK-NEXT: "test.use1"(%[[VAL_4]]) : (i32) -> ()
1006+
// CHECK-NEXT: %[[VAL_5:.*]] = "test.get_some_value3"() : () -> i32
1007+
// CHECK-NEXT: "test.use2"(%[[VAL_5]]) : (i32) -> ()
1008+
// CHECK-NEXT: scf.yield
1009+
// CHECK-NEXT: }
1010+
// CHECK-NEXT: return %[[VAL_6:.*]]#0 : i32
1011+
1012+
// -----
1013+
9771014
// CHECK-LABEL: @while_cond_true
9781015
func.func @while_cond_true() -> i1 {
9791016
%0 = scf.while () : () -> i1 {

0 commit comments

Comments
 (0)