Skip to content

Commit 5520ab3

Browse files
authored
[VPlan] Add ComputeAnyOfResult VPInstruction (NFC) (#141932)
Add a dedicated opcode for any-of reduction, similar to #132689 and #132690. The patch also explictly adds the start value to not require RecurrenceDescriptor during execute. It also allows freezing the start value to make it poison-safe. PR: #141932
1 parent 3a8b488 commit 5520ab3

File tree

6 files changed

+79
-26
lines changed

6 files changed

+79
-26
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7209,15 +7209,25 @@ static void addRuntimeUnrollDisableMetaData(Loop *L) {
72097209
}
72107210
}
72117211

7212-
// If \p R is a ComputeReductionResult when vectorizing the epilog loop,
7213-
// fix the reduction's scalar PHI node by adding the incoming value from the
7214-
// main vector loop.
7212+
static Value *getStartValueFromReductionResult(VPInstruction *RdxResult) {
7213+
using namespace VPlanPatternMatch;
7214+
assert(RdxResult->getOpcode() == VPInstruction::ComputeFindLastIVResult &&
7215+
"RdxResult must be ComputeFindLastIVResult");
7216+
VPValue *StartVPV = RdxResult->getOperand(1);
7217+
match(StartVPV, m_Freeze(m_VPValue(StartVPV)));
7218+
return StartVPV->getLiveInIRValue();
7219+
}
7220+
7221+
// If \p R is a Compute{Reduction,AnyOf,FindLastIV}Result when vectorizing the
7222+
// epilog loop, fix the reduction's scalar PHI node by adding the incoming value
7223+
// from the main vector loop.
72157224
static void fixReductionScalarResumeWhenVectorizingEpilog(
72167225
VPRecipeBase *R, VPTransformState &State, BasicBlock *LoopMiddleBlock,
72177226
BasicBlock *BypassBlock) {
72187227
auto *EpiRedResult = dyn_cast<VPInstruction>(R);
72197228
if (!EpiRedResult ||
7220-
(EpiRedResult->getOpcode() != VPInstruction::ComputeReductionResult &&
7229+
(EpiRedResult->getOpcode() != VPInstruction::ComputeAnyOfResult &&
7230+
EpiRedResult->getOpcode() != VPInstruction::ComputeReductionResult &&
72217231
EpiRedResult->getOpcode() != VPInstruction::ComputeFindLastIVResult))
72227232
return;
72237233

@@ -7229,15 +7239,18 @@ static void fixReductionScalarResumeWhenVectorizingEpilog(
72297239
EpiRedHeaderPhi->getStartValue()->getUnderlyingValue();
72307240
if (RecurrenceDescriptor::isAnyOfRecurrenceKind(
72317241
RdxDesc.getRecurrenceKind())) {
7242+
Value *StartV = EpiRedResult->getOperand(1)->getLiveInIRValue();
7243+
(void)StartV;
72327244
auto *Cmp = cast<ICmpInst>(MainResumeValue);
72337245
assert(Cmp->getPredicate() == CmpInst::ICMP_NE &&
72347246
"AnyOf expected to start with ICMP_NE");
7235-
assert(Cmp->getOperand(1) == RdxDesc.getRecurrenceStartValue() &&
7247+
assert(Cmp->getOperand(1) == StartV &&
72367248
"AnyOf expected to start by comparing main resume value to original "
72377249
"start value");
72387250
MainResumeValue = Cmp->getOperand(0);
72397251
} else if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(
72407252
RdxDesc.getRecurrenceKind())) {
7253+
Value *StartV = getStartValueFromReductionResult(EpiRedResult);
72417254
using namespace llvm::PatternMatch;
72427255
Value *Cmp, *OrigResumeV, *CmpOp;
72437256
bool IsExpectedPattern =
@@ -7246,10 +7259,7 @@ static void fixReductionScalarResumeWhenVectorizingEpilog(
72467259
m_Value(OrigResumeV))) &&
72477260
(match(Cmp, m_SpecificICmp(ICmpInst::ICMP_EQ, m_Specific(OrigResumeV),
72487261
m_Value(CmpOp))) &&
7249-
(match(CmpOp,
7250-
m_Freeze(m_Specific(RdxDesc.getRecurrenceStartValue()))) ||
7251-
(CmpOp == RdxDesc.getRecurrenceStartValue() &&
7252-
isGuaranteedNotToBeUndefOrPoison(CmpOp))));
7262+
((CmpOp == StartV && isGuaranteedNotToBeUndefOrPoison(CmpOp))));
72537263
assert(IsExpectedPattern && "Unexpected reduction resume pattern");
72547264
(void)IsExpectedPattern;
72557265
MainResumeValue = OrigResumeV;
@@ -9184,6 +9194,8 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
91849194
OrigExitingVPV->replaceUsesWithIf(NewExitingVPV, [](VPUser &U, unsigned) {
91859195
return isa<VPInstruction>(&U) &&
91869196
(cast<VPInstruction>(&U)->getOpcode() ==
9197+
VPInstruction::ComputeAnyOfResult ||
9198+
cast<VPInstruction>(&U)->getOpcode() ==
91879199
VPInstruction::ComputeReductionResult ||
91889200
cast<VPInstruction>(&U)->getOpcode() ==
91899201
VPInstruction::ComputeFindLastIVResult);
@@ -9236,6 +9248,12 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
92369248
FinalReductionResult =
92379249
Builder.createNaryOp(VPInstruction::ComputeFindLastIVResult,
92389250
{PhiR, Start, NewExitingVPV}, ExitDL);
9251+
} else if (RecurrenceDescriptor::isAnyOfRecurrenceKind(
9252+
RdxDesc.getRecurrenceKind())) {
9253+
VPValue *Start = PhiR->getStartValue();
9254+
FinalReductionResult =
9255+
Builder.createNaryOp(VPInstruction::ComputeAnyOfResult,
9256+
{PhiR, Start, NewExitingVPV}, ExitDL);
92399257
} else {
92409258
VPIRFlags Flags = RecurrenceDescriptor::isFloatingPointRecurrenceKind(
92419259
RdxDesc.getRecurrenceKind())
@@ -9764,23 +9782,37 @@ preparePlanForEpilogueVectorLoop(VPlan &Plan, Loop *L,
97649782
Value *ResumeV = nullptr;
97659783
// TODO: Move setting of resume values to prepareToExecute.
97669784
if (auto *ReductionPhi = dyn_cast<VPReductionPHIRecipe>(&R)) {
9785+
auto *RdxResult =
9786+
cast<VPInstruction>(*find_if(ReductionPhi->users(), [](VPUser *U) {
9787+
auto *VPI = dyn_cast<VPInstruction>(U);
9788+
return VPI &&
9789+
(VPI->getOpcode() == VPInstruction::ComputeAnyOfResult ||
9790+
VPI->getOpcode() == VPInstruction::ComputeReductionResult ||
9791+
VPI->getOpcode() == VPInstruction::ComputeFindLastIVResult);
9792+
}));
97679793
ResumeV = cast<PHINode>(ReductionPhi->getUnderlyingInstr())
97689794
->getIncomingValueForBlock(L->getLoopPreheader());
97699795
const RecurrenceDescriptor &RdxDesc =
97709796
ReductionPhi->getRecurrenceDescriptor();
97719797
RecurKind RK = RdxDesc.getRecurrenceKind();
97729798
if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK)) {
9799+
Value *StartV = RdxResult->getOperand(1)->getLiveInIRValue();
9800+
assert(RdxDesc.getRecurrenceStartValue() == StartV &&
9801+
"start value from ComputeAnyOfResult must match");
9802+
97739803
// VPReductionPHIRecipes for AnyOf reductions expect a boolean as
97749804
// start value; compare the final value from the main vector loop
97759805
// to the start value.
97769806
BasicBlock *PBB = cast<Instruction>(ResumeV)->getParent();
97779807
IRBuilder<> Builder(PBB, PBB->getFirstNonPHIIt());
9778-
ResumeV =
9779-
Builder.CreateICmpNE(ResumeV, RdxDesc.getRecurrenceStartValue());
9808+
ResumeV = Builder.CreateICmpNE(ResumeV, StartV);
97809809
} else if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK)) {
9781-
ToFrozen[RdxDesc.getRecurrenceStartValue()] =
9782-
cast<PHINode>(ResumeV)->getIncomingValueForBlock(
9783-
EPI.MainLoopIterationCountCheck);
9810+
Value *StartV = getStartValueFromReductionResult(RdxResult);
9811+
assert(RdxDesc.getRecurrenceStartValue() == StartV &&
9812+
"start value from ComputeFindLastIVResult must match");
9813+
9814+
ToFrozen[StartV] = cast<PHINode>(ResumeV)->getIncomingValueForBlock(
9815+
EPI.MainLoopIterationCountCheck);
97849816

97859817
// VPReductionPHIRecipe for FindLastIV reductions requires an adjustment
97869818
// to the resume value. The resume value is adjusted to the sentinel
@@ -9790,8 +9822,7 @@ preparePlanForEpilogueVectorLoop(VPlan &Plan, Loop *L,
97909822
// variable.
97919823
BasicBlock *ResumeBB = cast<Instruction>(ResumeV)->getParent();
97929824
IRBuilder<> Builder(ResumeBB, ResumeBB->getFirstNonPHIIt());
9793-
Value *Cmp = Builder.CreateICmpEQ(
9794-
ResumeV, ToFrozen[RdxDesc.getRecurrenceStartValue()]);
9825+
Value *Cmp = Builder.CreateICmpEQ(ResumeV, ToFrozen[StartV]);
97959826
ResumeV =
97969827
Builder.CreateSelect(Cmp, RdxDesc.getSentinelValue(), ResumeV);
97979828
}

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -907,6 +907,7 @@ class VPInstruction : public VPRecipeWithIRFlags,
907907
BranchOnCount,
908908
BranchOnCond,
909909
Broadcast,
910+
ComputeAnyOfResult,
910911
ComputeFindLastIVResult,
911912
ComputeReductionResult,
912913
// Extracts the last lane from its operand if it is a vector, or the last

llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
8989
inferScalarType(R->getOperand(1)) &&
9090
"different types inferred for different operands");
9191
return IntegerType::get(Ctx, 1);
92+
case VPInstruction::ComputeAnyOfResult:
9293
case VPInstruction::ComputeFindLastIVResult:
9394
case VPInstruction::ComputeReductionResult: {
9495
auto *PhiR = cast<VPReductionPHIRecipe>(R->getOperand(0));

llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,12 @@ m_VPInstruction(const Op0_t &Op0, const Op1_t &Op1, const Op2_t &Op2) {
318318
{Op0, Op1, Op2});
319319
}
320320

321+
template <typename Op0_t>
322+
inline UnaryVPInstruction_match<Op0_t, Instruction::Freeze>
323+
m_Freeze(const Op0_t &Op0) {
324+
return m_VPInstruction<Instruction::Freeze>(Op0);
325+
}
326+
321327
template <typename Op0_t>
322328
inline UnaryVPInstruction_match<Op0_t, VPInstruction::Not>
323329
m_Not(const Op0_t &Op0) {

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,20 @@ Value *VPInstruction::generate(VPTransformState &State) {
604604
return Builder.CreateVectorSplat(
605605
State.VF, State.get(getOperand(0), /*IsScalar*/ true), "broadcast");
606606
}
607+
case VPInstruction::ComputeAnyOfResult: {
608+
// FIXME: The cross-recipe dependency on VPReductionPHIRecipe is temporary
609+
// and will be removed by breaking up the recipe further.
610+
auto *PhiR = cast<VPReductionPHIRecipe>(getOperand(0));
611+
auto *OrigPhi = cast<PHINode>(PhiR->getUnderlyingValue());
612+
Value *ReducedPartRdx = State.get(getOperand(2));
613+
for (unsigned Idx = 3; Idx < getNumOperands(); ++Idx)
614+
ReducedPartRdx = Builder.CreateBinOp(
615+
(Instruction::BinaryOps)RecurrenceDescriptor::getOpcode(
616+
RecurKind::AnyOf),
617+
State.get(getOperand(Idx)), ReducedPartRdx, "bin.rdx");
618+
return createAnyOfReduction(Builder, ReducedPartRdx,
619+
State.get(getOperand(1), VPLane(0)), OrigPhi);
620+
}
607621
case VPInstruction::ComputeFindLastIVResult: {
608622
// FIXME: The cross-recipe dependency on VPReductionPHIRecipe is temporary
609623
// and will be removed by breaking up the recipe further.
@@ -681,18 +695,11 @@ Value *VPInstruction::generate(VPTransformState &State) {
681695

682696
// Create the reduction after the loop. Note that inloop reductions create
683697
// the target reduction in the loop using a Reduction recipe.
684-
if ((State.VF.isVector() ||
685-
RecurrenceDescriptor::isAnyOfRecurrenceKind(RK)) &&
686-
!PhiR->isInLoop()) {
698+
if (State.VF.isVector() && !PhiR->isInLoop()) {
687699
// TODO: Support in-order reductions based on the recurrence descriptor.
688700
// All ops in the reduction inherit fast-math-flags from the recurrence
689701
// descriptor.
690-
if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK))
691-
ReducedPartRdx =
692-
createAnyOfReduction(Builder, ReducedPartRdx,
693-
RdxDesc.getRecurrenceStartValue(), OrigPhi);
694-
else
695-
ReducedPartRdx = createSimpleReduction(Builder, ReducedPartRdx, RK);
702+
ReducedPartRdx = createSimpleReduction(Builder, ReducedPartRdx, RK);
696703

697704
// If the reduction can be performed in a smaller type, we need to extend
698705
// the reduction to the wider type before we branch to the original loop.
@@ -830,6 +837,7 @@ bool VPInstruction::isVectorToScalar() const {
830837
getOpcode() == VPInstruction::ExtractPenultimateElement ||
831838
getOpcode() == Instruction::ExtractElement ||
832839
getOpcode() == VPInstruction::FirstActiveLane ||
840+
getOpcode() == VPInstruction::ComputeAnyOfResult ||
833841
getOpcode() == VPInstruction::ComputeFindLastIVResult ||
834842
getOpcode() == VPInstruction::ComputeReductionResult ||
835843
getOpcode() == VPInstruction::AnyOf;
@@ -925,6 +933,7 @@ bool VPInstruction::onlyFirstLaneUsed(const VPValue *Op) const {
925933
return true;
926934
case VPInstruction::PtrAdd:
927935
return Op == getOperand(0) || vputils::onlyFirstLaneUsed(this);
936+
case VPInstruction::ComputeAnyOfResult:
928937
case VPInstruction::ComputeFindLastIVResult:
929938
return Op == getOperand(1);
930939
};
@@ -1005,6 +1014,9 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
10051014
case VPInstruction::ExtractPenultimateElement:
10061015
O << "extract-penultimate-element";
10071016
break;
1017+
case VPInstruction::ComputeAnyOfResult:
1018+
O << "compute-anyof-result";
1019+
break;
10081020
case VPInstruction::ComputeFindLastIVResult:
10091021
O << "compute-find-last-iv-result";
10101022
break;

llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,9 @@ void UnrollState::unrollBlock(VPBlockBase *VPB) {
327327
// Add all VPValues for all parts to ComputeReductionResult which combines
328328
// the parts to compute the final reduction value.
329329
VPValue *Op1;
330-
if (match(&R, m_VPInstruction<VPInstruction::ComputeReductionResult>(
330+
if (match(&R, m_VPInstruction<VPInstruction::ComputeAnyOfResult>(
331+
m_VPValue(), m_VPValue(), m_VPValue(Op1))) ||
332+
match(&R, m_VPInstruction<VPInstruction::ComputeReductionResult>(
331333
m_VPValue(), m_VPValue(Op1))) ||
332334
match(&R, m_VPInstruction<VPInstruction::ComputeFindLastIVResult>(
333335
m_VPValue(), m_VPValue(), m_VPValue(Op1)))) {

0 commit comments

Comments
 (0)