diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp index ec1044aaa42ac..0218339ee321a 100644 --- a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp @@ -15,10 +15,132 @@ #include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/IR/Dominance.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/RegionUtils.h" using namespace mlir; namespace { +/// Move a scf.if op that is directly before the scf.condition op in the while +/// before region, and whose condition matches the condition of the +/// scf.condition op, down into the while after region. +/// +/// scf.while (..) : (...) -> ... { +/// %additional_used_values = ... +/// %cond = ... +/// ... +/// %res = scf.if %cond -> (...) { +/// use(%additional_used_values) +/// ... // then block +/// scf.yield %then_value +/// } else { +/// scf.yield %else_value +/// } +/// scf.condition(%cond) %res, ... +/// } do { +/// ^bb0(%res_arg, ...): +/// use(%res_arg) +/// ... +/// +/// becomes +/// scf.while (..) : (...) -> ... { +/// %additional_used_values = ... +/// %cond = ... +/// ... +/// scf.condition(%cond) %else_value, ..., %additional_used_values +/// } do { +/// ^bb0(%res_arg ..., %additional_args): : +/// use(%additional_args) +/// ... // if then block +/// use(%then_value) +/// ... +struct WhileMoveIfDown : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::WhileOp op, + PatternRewriter &rewriter) const override { + auto conditionOp = + cast(op.getBeforeBody()->getTerminator()); + auto ifOp = dyn_cast_or_null(conditionOp->getPrevNode()); + + // Check that the ifOp is directly before the conditionOp and that it + // matches the condition of the conditionOp. Also ensure that the ifOp has + // no else block with content, as that would complicate the transformation. + // TODO: support else blocks with content. + if (!ifOp || ifOp.getCondition() != conditionOp.getCondition() || + (ifOp.elseBlock() && !ifOp.elseBlock()->without_terminator().empty())) + return failure(); + + assert(ifOp->use_empty() || (llvm::all_equal(ifOp->getUsers()) && + *ifOp->user_begin() == conditionOp) && + "ifOp has unexpected uses"); + + Location loc = op.getLoc(); + + // Replace uses of ifOp results in the conditionOp with the yielded values + // from the ifOp branches. + for (auto [idx, arg] : llvm::enumerate(conditionOp.getArgs())) { + auto it = llvm::find(ifOp->getResults(), arg); + if (it != ifOp->getResults().end()) { + size_t ifOpIdx = it.getIndex(); + Value thenValue = ifOp.thenYield()->getOperand(ifOpIdx); + Value elseValue = ifOp.elseYield()->getOperand(ifOpIdx); + + rewriter.replaceAllUsesWith(ifOp->getResults()[ifOpIdx], elseValue); + rewriter.replaceAllUsesWith(op.getAfterArguments()[idx], thenValue); + } + } + + // Collect additional used values from before region. + SetVector additionalUsedValues; + visitUsedValuesDefinedAbove(ifOp.getThenRegion(), [&](OpOperand *operand) { + if (op.getBefore().isAncestor(operand->get().getParentRegion())) + additionalUsedValues.insert(operand->get()); + }); + + // Create new whileOp with additional used values as results. + auto additionalValueTypes = llvm::map_to_vector( + additionalUsedValues, [](Value val) { return val.getType(); }); + size_t additionalValueSize = additionalUsedValues.size(); + SmallVector newResultTypes(op.getResultTypes()); + newResultTypes.append(additionalValueTypes); + + auto newWhileOp = + scf::WhileOp::create(rewriter, loc, newResultTypes, op.getInits()); + + newWhileOp.getBefore().takeBody(op.getBefore()); + newWhileOp.getAfter().takeBody(op.getAfter()); + newWhileOp.getAfter().addArguments( + additionalValueTypes, SmallVector(additionalValueSize, loc)); + + SmallVector conditionArgs = conditionOp.getArgs(); + llvm::append_range(conditionArgs, additionalUsedValues); + + // Update conditionOp inside new whileOp before region. + rewriter.setInsertionPoint(conditionOp); + rewriter.replaceOpWithNewOp( + conditionOp, conditionOp.getCondition(), conditionArgs); + + // Replace uses of additional used values inside the ifOp then region with + // the whileOp after region arguments. + rewriter.replaceUsesWithIf( + additionalUsedValues.takeVector(), + newWhileOp.getAfterArguments().take_back(additionalValueSize), + [&](OpOperand &use) { + return ifOp.getThenRegion().isAncestor( + use.getOwner()->getParentRegion()); + }); + + // Inline ifOp then region into new whileOp after region. + rewriter.eraseOp(ifOp.thenYield()); + rewriter.inlineBlockBefore(ifOp.thenBlock(), newWhileOp.getAfterBody(), + newWhileOp.getAfterBody()->begin()); + rewriter.eraseOp(ifOp); + rewriter.replaceOp(op, + newWhileOp->getResults().drop_back(additionalValueSize)); + return success(); + } +}; + struct UpliftWhileOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -267,5 +389,6 @@ FailureOr mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter, } void mlir::scf::populateUpliftWhileToForPatterns(RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + scf::WhileOp::getCanonicalizationPatterns(patterns, patterns.getContext()); } diff --git a/mlir/test/Dialect/SCF/uplift-while.mlir b/mlir/test/Dialect/SCF/uplift-while.mlir index cbe2ce5076ad2..736112824c515 100644 --- a/mlir/test/Dialect/SCF/uplift-while.mlir +++ b/mlir/test/Dialect/SCF/uplift-while.mlir @@ -185,3 +185,34 @@ func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> (i32, f32) // CHECK: %[[T2:.*]] = "test.test2"(%[[ARG2]]) : (f32) -> f32 // CHECK: scf.yield %[[T1]], %[[T2]] : i32, f32 // CHECK: return %[[RES]]#0, %[[RES]]#1 : i32, f32 + +// ----- + +func.func @uplift_while(%low: index, %upper: index, %val : i32) -> i32 { + %c1 = arith.constant 1 : index + %1:2 = scf.while (%iv = %low, %iter = %val) : (index, i32) -> (index, i32) { + %2 = arith.cmpi slt, %iv, %upper : index + %3:2 = scf.if %2 -> (index, i32) { + %4 = "test.test"(%iter) : (i32) -> i32 + %5 = arith.addi %iv, %c1 : index + scf.yield %5, %4 : index, i32 + } else { + scf.yield %iv, %iter : index, i32 + } + scf.condition(%2) %3#0, %3#1 : index, i32 + } do { + ^bb0(%arg0: index, %arg1: i32): + scf.yield %arg0, %arg1 : index, i32 + } + return %1#1 : i32 +} + +// CHECK-LABEL: func.func @uplift_while( +// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: i32) -> i32 { +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 1 : index +// CHECK: %[[FOR_0:.*]] = scf.for %[[VAL_0:.*]] = %[[ARG0]] to %[[ARG1]] step %[[CONSTANT_0]] iter_args(%[[VAL_1:.*]] = %[[ARG2]]) -> (i32) { +// CHECK: %[[VAL_2:.*]] = "test.test"(%[[VAL_1]]) : (i32) -> i32 +// CHECK: scf.yield %[[VAL_2]] : i32 +// CHECK: } +// CHECK: return %[[FOR_0]] : i32 +// CHECK: }