@@ -3546,6 +3546,137 @@ LogicalResult scf::WhileOp::verify() {
35463546}
35473547
35483548namespace  {
3549+ // / Move an scf.if op that is directly before the scf.condition op in the while
3550+ // / before region, and whose condition matches the condition of the
3551+ // / scf.condition op, down into the while after region.
3552+ // /
3553+ // / scf.while (..) : (...) -> ... {
3554+ // /  %additional_used_values = ...
3555+ // /  %cond = ...
3556+ // /  ...
3557+ // /  %res = scf.if %cond -> (...) {
3558+ // /    use(%additional_used_values)
3559+ // /    ... // then block
3560+ // /    scf.yield %then_value
3561+ // /  } else {
3562+ // /    scf.yield %else_value
3563+ // /  }
3564+ // /  scf.condition(%cond) %res, ...
3565+ // / } do {
3566+ // / ^bb0(%res_arg, ...):
3567+ // /    use(%res_arg)
3568+ // /    ...
3569+ // /
3570+ // / becomes
3571+ // / scf.while (..) : (...) -> ... {
3572+ // /  %additional_used_values = ...
3573+ // /  %cond = ...
3574+ // /  ...
3575+ // /  scf.condition(%cond) %else_value, ..., %additional_used_values
3576+ // / } do {
3577+ // / ^bb0(%res_arg ..., %additional_args): :
3578+ // /    use(%additional_args)
3579+ // /    ... // if then block
3580+ // /    use(%then_value)
3581+ // /    ...
3582+ struct  WhileMoveIfDown  : public  OpRewritePattern <WhileOp> {
3583+   using  OpRewritePattern<WhileOp>::OpRewritePattern;
3584+ 
3585+   LogicalResult matchAndRewrite (WhileOp op,
3586+                                 PatternRewriter &rewriter) const  override  {
3587+     auto  conditionOp =
3588+         cast<scf::ConditionOp>(op.getBeforeBody ()->getTerminator ());
3589+     auto  ifOp = dyn_cast_or_null<scf::IfOp>(conditionOp->getPrevNode ());
3590+ 
3591+     //  Check that the ifOp is directly before the conditionOp and that it
3592+     //  matches the condition of the conditionOp. Also ensure that the ifOp has
3593+     //  no else block with content, as that would complicate the transformation.
3594+     //  TODO: support else blocks with content.
3595+     if  (!ifOp || ifOp.getCondition () != conditionOp.getCondition () ||
3596+         (ifOp.elseBlock () && !ifOp.elseBlock ()->without_terminator ().empty ()))
3597+       return  failure ();
3598+ 
3599+     assert (ifOp->use_empty () || (llvm::all_equal (ifOp->getUsers ()) &&
3600+                                  *ifOp->user_begin () == conditionOp) &&
3601+                                     " ifOp has unexpected uses" 
3602+ 
3603+     Location loc = op.getLoc ();
3604+ 
3605+     //  Replace uses of ifOp results in the conditionOp with the yielded values
3606+     //  from the ifOp branches.
3607+     for  (auto  [idx, arg] : llvm::enumerate (conditionOp.getArgs ())) {
3608+       auto  it = llvm::find (ifOp->getResults (), arg);
3609+       if  (it != ifOp->getResults ().end ()) {
3610+         size_t  ifOpIdx = it.getIndex ();
3611+         Value thenValue = ifOp.thenYield ()->getOperand (ifOpIdx);
3612+         Value elseValue = ifOp.elseYield ()->getOperand (ifOpIdx);
3613+ 
3614+         rewriter.replaceAllUsesWith (ifOp->getResults ()[ifOpIdx], elseValue);
3615+         rewriter.replaceAllUsesWith (op.getAfterArguments ()[idx], thenValue);
3616+       }
3617+     }
3618+ 
3619+     SmallVector<Value> additionalUsedValues;
3620+     auto  isValueUsedInsideIf = [&](Value val) {
3621+       return  llvm::any_of (val.getUsers (), [&](Operation *user) {
3622+         return  ifOp.getThenRegion ().isAncestor (user->getParentRegion ());
3623+       });
3624+     };
3625+ 
3626+     //  Collect additional used values from before region.
3627+     for  (Operation *it = ifOp->getPrevNode (); it != nullptr ;
3628+          it = it->getPrevNode ())
3629+       llvm::copy_if (it->getResults (), std::back_inserter (additionalUsedValues),
3630+                     isValueUsedInsideIf);
3631+ 
3632+     llvm::copy_if (op.getBeforeArguments (),
3633+                   std::back_inserter (additionalUsedValues),
3634+                   isValueUsedInsideIf);
3635+ 
3636+     //  Create new whileOp with additional used values as results.
3637+     auto  additionalValueTypes = llvm::map_to_vector (
3638+         additionalUsedValues, [](Value val) { return  val.getType (); });
3639+     size_t  additionalValueSize = additionalUsedValues.size ();
3640+     SmallVector<Type> newResultTypes (op.getResultTypes ());
3641+     newResultTypes.append (additionalValueTypes);
3642+ 
3643+     auto  newWhileOp =
3644+         scf::WhileOp::create (rewriter, loc, newResultTypes, op.getInits ());
3645+ 
3646+     newWhileOp.getBefore ().takeBody (op.getBefore ());
3647+     newWhileOp.getAfter ().takeBody (op.getAfter ());
3648+     newWhileOp.getAfter ().addArguments (
3649+         additionalValueTypes, SmallVector<Location>(additionalValueSize, loc));
3650+ 
3651+     SmallVector<Value> conditionArgs = conditionOp.getArgs ();
3652+     llvm::append_range (conditionArgs, additionalUsedValues);
3653+ 
3654+     //  Update conditionOp inside new whileOp before region.
3655+     rewriter.setInsertionPoint (conditionOp);
3656+     rewriter.replaceOpWithNewOp <scf::ConditionOp>(
3657+         conditionOp, conditionOp.getCondition (), conditionArgs);
3658+ 
3659+     //  Replace uses of additional used values inside the ifOp then region with
3660+     //  the whileOp after region arguments.
3661+     rewriter.replaceUsesWithIf (
3662+         additionalUsedValues,
3663+         newWhileOp.getAfterArguments ().take_back (additionalValueSize),
3664+         [&](OpOperand &use) {
3665+           return  ifOp.getThenRegion ().isAncestor (
3666+               use.getOwner ()->getParentRegion ());
3667+         });
3668+ 
3669+     //  Inline ifOp then region into new whileOp after region.
3670+     rewriter.eraseOp (ifOp.thenYield ());
3671+     rewriter.inlineBlockBefore (ifOp.thenBlock (), newWhileOp.getAfterBody (),
3672+                                newWhileOp.getAfterBody ()->begin ());
3673+     rewriter.eraseOp (ifOp);
3674+     rewriter.replaceOp (op,
3675+                        newWhileOp->getResults ().drop_back (additionalValueSize));
3676+     return  success ();
3677+   }
3678+ };
3679+ 
35493680// / Replace uses of the condition within the do block with true, since otherwise
35503681// / the block would not be evaluated.
35513682// /
@@ -4258,7 +4389,8 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
42584389  results.add <RemoveLoopInvariantArgsFromBeforeBlock,
42594390              RemoveLoopInvariantValueYielded, WhileConditionTruth,
42604391              WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
4261-               WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
4392+               WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs, WhileMoveIfDown>(
4393+       context);
42624394}
42634395
42644396// ===----------------------------------------------------------------------===//
0 commit comments