Skip to content

[VectorCombine] Add initial nodes to the Worklist in VectorCombine #149047

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 116 additions & 80 deletions llvm/lib/Transforms/Vectorize/VectorCombine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,8 @@ class VectorCombine {
const Instruction &I,
ExtractElementInst *&ConvertToShuffle,
unsigned PreferredExtractIndex);
void foldExtExtCmp(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
Instruction &I);
void foldExtExtBinop(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
Instruction &I);
Value *foldExtExtCmp(Value *V0, Value *V1, Value *ExtIndex, Instruction &I);
Value *foldExtExtBinop(Value *V0, Value *V1, Value *ExtIndex, Instruction &I);
bool foldExtractExtract(Instruction &I);
bool foldInsExtFNeg(Instruction &I);
bool foldInsExtBinop(Instruction &I);
Expand Down Expand Up @@ -138,7 +136,7 @@ class VectorCombine {
bool foldInterleaveIntrinsics(Instruction &I);
bool shrinkType(Instruction &I);

void replaceValue(Value &Old, Value &New) {
void replaceValue(Instruction &Old, Value &New, bool Erase = true) {
LLVM_DEBUG(dbgs() << "VC: Replacing: " << Old << '\n');
LLVM_DEBUG(dbgs() << " With: " << New << '\n');
Old.replaceAllUsesWith(&New);
Expand All @@ -147,7 +145,11 @@ class VectorCombine {
Worklist.pushUsersToWorkList(*NewI);
Worklist.pushValue(NewI);
}
Worklist.pushValue(&Old);
if (Erase && isInstructionTriviallyDead(&Old)) {
eraseInstruction(Old);
} else {
Worklist.push(&Old);
}
}

void eraseInstruction(Instruction &I) {
Expand All @@ -158,11 +160,23 @@ class VectorCombine {

// Push remaining users of the operands and then the operand itself - allows
// further folds that were hindered by OneUse limits.
for (Value *Op : Ops)
if (auto *OpI = dyn_cast<Instruction>(Op)) {
Worklist.pushUsersToWorkList(*OpI);
Worklist.pushValue(OpI);
SmallPtrSet<Value *, 4> Visited;
for (Value *Op : Ops) {
if (Visited.insert(Op).second) {
if (auto *OpI = dyn_cast<Instruction>(Op)) {
if (RecursivelyDeleteTriviallyDeadInstructions(
OpI, nullptr, nullptr, [this](Value *V) {
if (auto I = dyn_cast<Instruction>(V)) {
LLVM_DEBUG(dbgs() << "VC: Erased: " << *I << '\n');
Worklist.remove(I);
}
}))
continue;
Worklist.pushUsersToWorkList(*OpI);
Worklist.pushValue(OpI);
}
}
}
}
};
} // namespace
Expand Down Expand Up @@ -546,9 +560,8 @@ static Value *createShiftShuffle(Value *Vec, unsigned OldIndex,
/// the source vector (shift the scalar element) to a NewIndex for extraction.
/// Return null if the input can be constant folded, so that we are not creating
/// unnecessary instructions.
static ExtractElementInst *translateExtract(ExtractElementInst *ExtElt,
unsigned NewIndex,
IRBuilderBase &Builder) {
static Value *translateExtract(ExtractElementInst *ExtElt, unsigned NewIndex,
IRBuilderBase &Builder) {
// Shufflevectors can only be created for fixed-width vectors.
Value *X = ExtElt->getVectorOperand();
if (!isa<FixedVectorType>(X->getType()))
Expand All @@ -563,52 +576,41 @@ static ExtractElementInst *translateExtract(ExtractElementInst *ExtElt,

Value *Shuf = createShiftShuffle(X, cast<ConstantInt>(C)->getZExtValue(),
NewIndex, Builder);
return dyn_cast<ExtractElementInst>(
Builder.CreateExtractElement(Shuf, NewIndex));
return Shuf;
}

/// Try to reduce extract element costs by converting scalar compares to vector
/// compares followed by extract.
/// cmp (ext0 V0, C), (ext1 V1, C)
void VectorCombine::foldExtExtCmp(ExtractElementInst *Ext0,
ExtractElementInst *Ext1, Instruction &I) {
Value *VectorCombine::foldExtExtCmp(Value *V0, Value *V1, Value *ExtIndex,
Instruction &I) {
assert(isa<CmpInst>(&I) && "Expected a compare");
assert(cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue() ==
cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue() &&
"Expected matching constant extract indexes");

// cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C
++NumVecCmp;
CmpInst::Predicate Pred = cast<CmpInst>(&I)->getPredicate();
Value *V0 = Ext0->getVectorOperand(), *V1 = Ext1->getVectorOperand();
Value *VecCmp = Builder.CreateCmp(Pred, V0, V1);
Value *NewExt = Builder.CreateExtractElement(VecCmp, Ext0->getIndexOperand());
replaceValue(I, *NewExt);
return Builder.CreateExtractElement(VecCmp, ExtIndex, "foldExtExtCmp");
}

/// Try to reduce extract element costs by converting scalar binops to vector
/// binops followed by extract.
/// bo (ext0 V0, C), (ext1 V1, C)
void VectorCombine::foldExtExtBinop(ExtractElementInst *Ext0,
ExtractElementInst *Ext1, Instruction &I) {
Value *VectorCombine::foldExtExtBinop(Value *V0, Value *V1, Value *ExtIndex,
Instruction &I) {
assert(isa<BinaryOperator>(&I) && "Expected a binary operator");
assert(cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue() ==
cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue() &&
"Expected matching constant extract indexes");

// bo (extelt V0, C), (extelt V1, C) --> extelt (bo V0, V1), C
++NumVecBO;
Value *V0 = Ext0->getVectorOperand(), *V1 = Ext1->getVectorOperand();
Value *VecBO =
Builder.CreateBinOp(cast<BinaryOperator>(&I)->getOpcode(), V0, V1);
Value *VecBO = Builder.CreateBinOp(cast<BinaryOperator>(&I)->getOpcode(), V0,
V1, "foldExtExtBinop");

// All IR flags are safe to back-propagate because any potential poison
// created in unused vector elements is discarded by the extract.
if (auto *VecBOInst = dyn_cast<Instruction>(VecBO))
VecBOInst->copyIRFlags(&I);

Value *NewExt = Builder.CreateExtractElement(VecBO, Ext0->getIndexOperand());
replaceValue(I, *NewExt);
return Builder.CreateExtractElement(VecBO, ExtIndex, "foldExtExtBinop");
}

/// Match an instruction with extracted vector operands.
Expand Down Expand Up @@ -647,25 +649,29 @@ bool VectorCombine::foldExtractExtract(Instruction &I) {
if (isExtractExtractCheap(Ext0, Ext1, I, ExtractToChange, InsertIndex))
return false;

Value *ExtOp0 = Ext0->getVectorOperand();
Value *ExtOp1 = Ext1->getVectorOperand();

if (ExtractToChange) {
unsigned CheapExtractIdx = ExtractToChange == Ext0 ? C1 : C0;
ExtractElementInst *NewExtract =
Value *NewExtOp =
translateExtract(ExtractToChange, CheapExtractIdx, Builder);
if (!NewExtract)
if (!NewExtOp)
return false;
if (ExtractToChange == Ext0)
Ext0 = NewExtract;
ExtOp0 = NewExtOp;
else
Ext1 = NewExtract;
ExtOp1 = NewExtOp;
}

if (Pred != CmpInst::BAD_ICMP_PREDICATE)
foldExtExtCmp(Ext0, Ext1, I);
else
foldExtExtBinop(Ext0, Ext1, I);

Value *ExtIndex = ExtractToChange == Ext0 ? Ext1->getIndexOperand()
: Ext0->getIndexOperand();
Value *NewExt = Pred != CmpInst::BAD_ICMP_PREDICATE
? foldExtExtCmp(ExtOp0, ExtOp1, ExtIndex, I)
: foldExtExtBinop(ExtOp0, ExtOp1, ExtIndex, I);
Worklist.push(Ext0);
Worklist.push(Ext1);
replaceValue(I, *NewExt);
return true;
}

Expand Down Expand Up @@ -1772,7 +1778,7 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
LI->getAlign(), VecTy->getElementType(), Idx, *DL);
NewLoad->setAlignment(ScalarOpAlignment);

replaceValue(*EI, *NewLoad);
replaceValue(*EI, *NewLoad, false);
}

FailureGuard.release();
Expand Down Expand Up @@ -2856,7 +2862,7 @@ bool VectorCombine::foldShuffleToIdentity(Instruction &I) {
if (!IL.first)
return true;
Value *V = IL.first->get();
if (auto *I = dyn_cast<Instruction>(V); I && !I->hasOneUse())
if (auto *I = dyn_cast<Instruction>(V); I && !I->hasOneUser())
return false;
if (V->getValueID() != FrontV->getValueID())
return false;
Expand Down Expand Up @@ -3058,7 +3064,7 @@ bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
Shuffle->getOperand(0), Shuffle->getOperand(1), ConcatMask);
LLVM_DEBUG(dbgs() << "Created new shuffle: " << *NewShuffle << "\n");
replaceValue(*Shuffle, *NewShuffle);
MadeChanges = true;
return true;
}

// See if we can re-use foldSelectShuffle, getting it to reduce the size of
Expand Down Expand Up @@ -3438,7 +3444,7 @@ bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) {
for (int S = 0, E = ReconstructMasks.size(); S != E; S++) {
Builder.SetInsertPoint(Shuffles[S]);
Value *NSV = Builder.CreateShuffleVector(NOp0, NOp1, ReconstructMasks[S]);
replaceValue(*Shuffles[S], *NSV);
replaceValue(*Shuffles[S], *NSV, false);
}

Worklist.pushValue(NSV0A);
Expand Down Expand Up @@ -3703,8 +3709,7 @@ bool VectorCombine::run() {

LLVM_DEBUG(dbgs() << "\n\nVECTORCOMBINE on " << F.getName() << "\n");

bool MadeChange = false;
auto FoldInst = [this, &MadeChange](Instruction &I) {
auto FoldInst = [this](Instruction &I) {
Builder.SetInsertPoint(&I);
bool IsVectorType = isa<VectorType>(I.getType());
bool IsFixedVectorType = isa<FixedVectorType>(I.getType());
Expand All @@ -3719,10 +3724,12 @@ bool VectorCombine::run() {
if (IsFixedVectorType) {
switch (Opcode) {
case Instruction::InsertElement:
MadeChange |= vectorizeLoadInsert(I);
if (vectorizeLoadInsert(I))
return true;
break;
case Instruction::ShuffleVector:
MadeChange |= widenSubvectorLoad(I);
if (widenSubvectorLoad(I))
return true;
break;
default:
break;
Expand All @@ -3732,19 +3739,25 @@ bool VectorCombine::run() {
// This transform works with scalable and fixed vectors
// TODO: Identify and allow other scalable transforms
if (IsVectorType) {
MadeChange |= scalarizeOpOrCmp(I);
MadeChange |= scalarizeLoadExtract(I);
MadeChange |= scalarizeExtExtract(I);
MadeChange |= scalarizeVPIntrinsic(I);
MadeChange |= foldInterleaveIntrinsics(I);
if (scalarizeOpOrCmp(I))
return true;
if (scalarizeLoadExtract(I))
return true;
if (scalarizeExtExtract(I))
return true;
if (scalarizeVPIntrinsic(I))
return true;
if (foldInterleaveIntrinsics(I))
return true;
}

if (Opcode == Instruction::Store)
MadeChange |= foldSingleElementStore(I);
if (foldSingleElementStore(I))
return true;

// If this is an early pipeline invocation of this pass, we are done.
if (TryEarlyFoldsOnly)
return;
return false;

// Otherwise, try folds that improve codegen but may interfere with
// early IR canonicalizations.
Expand All @@ -3753,56 +3766,79 @@ bool VectorCombine::run() {
if (IsFixedVectorType) {
switch (Opcode) {
case Instruction::InsertElement:
MadeChange |= foldInsExtFNeg(I);
MadeChange |= foldInsExtBinop(I);
MadeChange |= foldInsExtVectorToShuffle(I);
if (foldInsExtFNeg(I))
return true;
if (foldInsExtBinop(I))
return true;
if (foldInsExtVectorToShuffle(I))
return true;
break;
case Instruction::ShuffleVector:
MadeChange |= foldPermuteOfBinops(I);
MadeChange |= foldShuffleOfBinops(I);
MadeChange |= foldShuffleOfSelects(I);
MadeChange |= foldShuffleOfCastops(I);
MadeChange |= foldShuffleOfShuffles(I);
MadeChange |= foldShuffleOfIntrinsics(I);
MadeChange |= foldSelectShuffle(I);
MadeChange |= foldShuffleToIdentity(I);
if (foldPermuteOfBinops(I))
return true;
if (foldShuffleOfBinops(I))
return true;
if (foldShuffleOfSelects(I))
return true;
if (foldShuffleOfCastops(I))
return true;
if (foldShuffleOfShuffles(I))
return true;
if (foldShuffleOfIntrinsics(I))
return true;
if (foldSelectShuffle(I))
return true;
if (foldShuffleToIdentity(I))
return true;
break;
case Instruction::BitCast:
MadeChange |= foldBitcastShuffle(I);
if (foldBitcastShuffle(I))
return true;
break;
case Instruction::And:
case Instruction::Or:
case Instruction::Xor:
MadeChange |= foldBitOpOfBitcasts(I);
if (foldBitOpOfBitcasts(I))
return true;
break;
default:
MadeChange |= shrinkType(I);
if (shrinkType(I))
return true;
break;
}
} else {
switch (Opcode) {
case Instruction::Call:
MadeChange |= foldShuffleFromReductions(I);
MadeChange |= foldCastFromReductions(I);
if (foldShuffleFromReductions(I))
return true;
if (foldCastFromReductions(I))
return true;
break;
case Instruction::ICmp:
case Instruction::FCmp:
MadeChange |= foldExtractExtract(I);
if (foldExtractExtract(I))
return true;
break;
case Instruction::Or:
MadeChange |= foldConcatOfBoolMasks(I);
if (foldConcatOfBoolMasks(I))
return true;
[[fallthrough]];
default:
if (Instruction::isBinaryOp(Opcode)) {
MadeChange |= foldExtractExtract(I);
MadeChange |= foldExtractedCmps(I);
MadeChange |= foldBinopOfReductions(I);
if (foldExtractExtract(I))
return true;
if (foldExtractedCmps(I))
return true;
if (foldBinopOfReductions(I))
return true;
}
break;
}
}
return false;
};

bool MadeChange = false;
for (BasicBlock &BB : F) {
// Ignore unreachable basic blocks.
if (!DT.isReachableFromEntry(&BB))
Expand All @@ -3811,7 +3847,7 @@ bool VectorCombine::run() {
for (Instruction &I : make_early_inc_range(BB)) {
if (I.isDebugOrPseudoInst())
continue;
FoldInst(I);
MadeChange |= FoldInst(I);
}
}

Expand All @@ -3825,7 +3861,7 @@ bool VectorCombine::run() {
continue;
}

FoldInst(*I);
MadeChange |= FoldInst(*I);
}

return MadeChange;
Expand Down
Loading
Loading