Skip to content

[SLP] Check for extracts, being replaced by original scalars, for user nodes #149572

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 152 additions & 13 deletions llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2008,6 +2008,7 @@ class BoUpSLP {
void deleteTree() {
VectorizableTree.clear();
ScalarToTreeEntries.clear();
PostponedNodesWithNonVecUsers.clear();
OperandsToTreeEntry.clear();
ScalarsInSplitNodes.clear();
MustGather.clear();
Expand Down Expand Up @@ -4016,6 +4017,9 @@ class BoUpSLP {
/// Returns true if any scalar in the list is a copyable element.
bool hasCopyableElements() const { return !CopyableElements.empty(); }

/// Returns the state of the operations.
const InstructionsState &getOperations() const { return S; }

/// When ReuseReorderShuffleIndices is empty it just returns position of \p
/// V within vector of Scalars. Otherwise, try to remap on its reuse index.
unsigned findLaneForValue(Value *V) const {
Expand Down Expand Up @@ -4419,6 +4423,13 @@ class BoUpSLP {
OrdersType &CurrentOrder,
SmallVectorImpl<Value *> &PointerOps);

/// Checks if it is profitable to vectorize the specified list of the
/// instructions if not all users are vectorized.
bool isProfitableToVectorizeWithNonVecUsers(const InstructionsState &S,
const EdgeInfo &UserTreeIdx,
ArrayRef<Value *> VL,
ArrayRef<int> Mask);

/// Maps a specific scalar to its tree entry(ies).
SmallDenseMap<Value *, SmallVector<TreeEntry *>> ScalarToTreeEntries;

Expand All @@ -4429,6 +4440,9 @@ class BoUpSLP {
/// Scalars, used in split vectorize nodes.
SmallDenseMap<Value *, SmallVector<TreeEntry *>> ScalarsInSplitNodes;

/// List of tree nodes indices, which have non-vectorized users.
SmallSet<unsigned, 4> PostponedNodesWithNonVecUsers;

/// Maps a value to the proposed vectorizable size.
SmallDenseMap<Value *, unsigned> InstrElementSize;

Expand Down Expand Up @@ -9151,6 +9165,83 @@ getVectorCallCosts(CallInst *CI, FixedVectorType *VecTy,
return {IntrinsicCost, LibCost};
}

bool BoUpSLP::isProfitableToVectorizeWithNonVecUsers(
const InstructionsState &S, const EdgeInfo &UserTreeIdx,
ArrayRef<Value *> VL, ArrayRef<int> Mask) {
assert(S && "Expected valid instructions state.");
// Loads, extracts and geps are immediately scalarizable, so no need to check.
if (S.getOpcode() == Instruction::Load ||
S.getOpcode() == Instruction::ExtractElement ||
S.getOpcode() == Instruction::GetElementPtr)
return true;
// Check only vectorized users, others scalarized (potentially, at least)
// already.
if (!UserTreeIdx.UserTE || UserTreeIdx.UserTE->isGather() ||
UserTreeIdx.UserTE->State == TreeEntry::SplitVectorize)
return true;
// PHI nodes may have cyclic deps, so cannot check here.
if (UserTreeIdx.UserTE->getOpcode() == Instruction::PHI)
return true;
// Do not check root reduction nodes, they do not have non-vectorized users.
if (UserIgnoreList && UserTreeIdx.UserTE->Idx == 0)
return true;
constexpr TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
ArrayRef<Value *> UserVL = UserTreeIdx.UserTE->Scalars;
Type *UserScalarTy = getValueType(UserVL.front());
if (!isValidElementType(UserScalarTy))
return true;
Type *ScalarTy = getValueType(VL.front());
if (!isValidElementType(ScalarTy))
return true;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment is a bit misleading, those "Scalars" can be any vector instruction if using REVEC.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check is exactly for the revec case

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that, but the comment makes it seem that VL is a list of llvm.vector.extract. Maybe change it to something like If the instructions are already of vector type, assume that the cost of extracting one is cheaper than keeping the original instruction in IR.

// Ignore subvectors extracts for revectorized nodes.
if (UserScalarTy->isVectorTy())
return true;
auto *UserVecTy =
getWidenedType(UserScalarTy, UserTreeIdx.UserTE->getVectorFactor());
APInt DemandedElts = APInt::getZero(UserTreeIdx.UserTE->getVectorFactor());
// Check the external uses and check, if vector node + extracts is not
// profitable for the vectorization.
InstructionCost UserScalarsCost = 0;
for (Value *V : UserVL) {
auto *I = dyn_cast<Instruction>(V);
if (!I)
continue;
if (areAllUsersVectorized(I, UserIgnoreList))
continue;
DemandedElts.setBit(UserTreeIdx.UserTE->findLaneForValue(V));
UserScalarsCost += TTI->getInstructionCost(I, CostKind);
}
// No non-vectorized users - success.
if (DemandedElts.isZero())
return true;

auto AreExtractsCheaperThanScalars = [&]() {
// If extracts are cheaper than the original scalars - success.
InstructionCost ExtractCost = ::getScalarizationOverhead(
*TTI, UserScalarTy, UserVecTy, DemandedElts,
/*Insert=*/false, /*Extract=*/true, CostKind);
if (ExtractCost <= UserScalarsCost)
return true;
SmallPtrSet<Value *, 4> CheckedExtracts;
InstructionCost NodeCost =
getEntryCost(UserTreeIdx.UserTE, {}, CheckedExtracts);
// The node is profitable for vectorization - success.
if (ExtractCost <= NodeCost)
return true;
auto *VecTy = getWidenedType(ScalarTy, VL.size());
InstructionCost ScalarsCost = ::getScalarizationOverhead(
*TTI, ScalarTy, VecTy, APInt::getAllOnes(VL.size()),
/*Insert=*/true, /*Extract=*/false, CostKind);
if (!Mask.empty())
ScalarsCost +=
getShuffleCost(*TTI, TTI::SK_PermuteSingleSrc, VecTy, Mask, CostKind);
return ExtractCost < UserScalarsCost + ScalarsCost;
};

// User extracts are cheaper than user scalars + immediate scalars - success.
return AreExtractsCheaperThanScalars();
}

BoUpSLP::TreeEntry::EntryState BoUpSLP::getScalarsVectorizationState(
const InstructionsState &S, ArrayRef<Value *> VL,
bool IsScatterVectorizeUserTE, OrdersType &CurrentOrder,
Expand Down Expand Up @@ -10700,6 +10791,15 @@ void BoUpSLP::buildTreeRec(ArrayRef<Value *> VLRef, unsigned Depth,
return;
}

// Postpone vectorization, if the node is not profitable because of the
// external uses.
if (!isProfitableToVectorizeWithNonVecUsers(S, UserTreeIdx, VL,
ReuseShuffleIndices)) {
PostponedNodesWithNonVecUsers.insert(VectorizableTree.size());
newGatherTreeEntry(VL, S, UserTreeIdx, ReuseShuffleIndices);
return;
}

Instruction *VL0 = S.getMainOp();
BasicBlock *BB = VL0->getParent();
auto &BSRef = BlocksSchedules[BB];
Expand Down Expand Up @@ -12085,6 +12185,27 @@ void BoUpSLP::transformNodes() {
ArrayRef<Value *> VL = E.Scalars;
const unsigned Sz = getVectorElementSize(VL.front());
unsigned MinVF = getMinVF(2 * Sz);
const EdgeInfo &EI = E.UserTreeIndex;
// Try to vectorized postponed scalars, if external uses are vectorized.
if (PostponedNodesWithNonVecUsers.contains(E.Idx) &&
isProfitableToVectorizeWithNonVecUsers(
E.getOperations(), EI, E.Scalars, E.ReuseShuffleIndices)) {
assert(E.hasState() && "Expected to have state");
unsigned PrevSize = VectorizableTree.size();
[[maybe_unused]] unsigned PrevEntriesSize =
LoadEntriesToVectorize.size();
buildTreeRec(VL, 0, EdgeInfo(&E, UINT_MAX));
if (PrevSize + 1 == VectorizableTree.size() &&
VectorizableTree[PrevSize]->isGather()) {
VectorizableTree.pop_back();
assert(PrevEntriesSize == LoadEntriesToVectorize.size() &&
"LoadEntriesToVectorize expected to remain the same");
} else {
E.CombinedEntriesWithIndices.emplace_back(PrevSize, 0);
continue;
}
}

// Do not try partial vectorization for small nodes (<= 2), nodes with the
// same opcode and same parent block or all constants.
if (VL.size() <= 2 || LoadEntriesToVectorize.contains(Idx) ||
Expand Down Expand Up @@ -13245,7 +13366,9 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
});
InVectors.front() = V;
}
if (!SubVectors.empty()) {
if (!SubVectors.empty() &&
(SubVectors.size() > 1 || SubVectors.front().second != 0 ||
SubVectors.front().first->getVectorFactor() != CommonMask.size())) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe simplify the !SubVectors.empty() away? We later check SubVectors.size() > 1 anyway.

Also, for conditions that aren't immediately readable, either write a comment, or give a name to that check. E.g.

bool WhatThisIsChecking = SubVectors.size() > 1 || SubVectors.front().second != 0 ||
    SubVectors.front().first->getVectorFactor() != CommonMask.size();
if (WhatThisIsChecking) {
  ...
}

const PointerUnion<Value *, const TreeEntry *> &Vec = InVectors.front();
if (InVectors.size() == 2)
Cost += createShuffle(Vec, InVectors.back(), CommonMask);
Expand Down Expand Up @@ -13356,6 +13479,16 @@ TTI::CastContextHint BoUpSLP::getCastContextHint(const TreeEntry &TE) const {
InstructionCost
BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
SmallPtrSetImpl<Value *> &CheckedExtracts) {
// No need to count the cost for combined entries, they are combined and
// just skip their cost.
if (E->State == TreeEntry::CombinedVectorize) {
LLVM_DEBUG(
dbgs() << "SLP: Skipping cost for combined node that starts with "
<< E->Scalars[0] << ".\n";
E->dump());
return 0;
}

ArrayRef<Value *> VL = E->Scalars;

Type *ScalarTy = getValueType(VL[0]);
Expand Down Expand Up @@ -13767,7 +13900,8 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
case Instruction::Trunc:
case Instruction::FPTrunc:
case Instruction::BitCast: {
auto SrcIt = MinBWs.find(getOperandEntry(E, 0));
auto SrcIt =
MinBWs.empty() ? MinBWs.end() : MinBWs.find(getOperandEntry(E, 0));
Type *SrcScalarTy = VL0->getOperand(0)->getType();
auto *SrcVecTy = getWidenedType(SrcScalarTy, VL.size());
unsigned Opcode = ShuffleOrOp;
Expand Down Expand Up @@ -14214,7 +14348,8 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
Type *SrcSclTy = E->getMainOp()->getOperand(0)->getType();
auto *SrcTy = getWidenedType(SrcSclTy, VL.size());
if (SrcSclTy->isIntegerTy() && ScalarTy->isIntegerTy()) {
auto SrcIt = MinBWs.find(getOperandEntry(E, 0));
auto SrcIt = MinBWs.empty() ? MinBWs.end()
: MinBWs.find(getOperandEntry(E, 0));
unsigned BWSz = DL->getTypeSizeInBits(ScalarTy);
unsigned SrcBWSz =
DL->getTypeSizeInBits(E->getMainOp()->getOperand(0)->getType());
Expand Down Expand Up @@ -15019,15 +15154,6 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals,
SmallPtrSet<Value *, 4> CheckedExtracts;
for (unsigned I = 0, E = VectorizableTree.size(); I < E; ++I) {
TreeEntry &TE = *VectorizableTree[I];
// No need to count the cost for combined entries, they are combined and
// just skip their cost.
if (TE.State == TreeEntry::CombinedVectorize) {
LLVM_DEBUG(
dbgs() << "SLP: Skipping cost for combined node that starts with "
<< *TE.Scalars[0] << ".\n";
TE.dump(); dbgs() << "SLP: Current total cost = " << Cost << "\n");
continue;
}
if (TE.hasState() &&
(TE.isGather() || TE.State == TreeEntry::SplitVectorize)) {
if (const TreeEntry *E =
Expand Down Expand Up @@ -15242,6 +15368,8 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals,
auto *Inst = cast<Instruction>(EU.Scalar);
InstructionCost ScalarCost = TTI->getInstructionCost(Inst, CostKind);
auto OperandIsScalar = [&](Value *V) {
if (!isa<Instruction>(V))
return true;
if (!isVectorized(V)) {
// Some extractelements might be not vectorized, but
// transformed into shuffle and removed from the function,
Expand Down Expand Up @@ -17338,7 +17466,18 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
});
InVectors.front() = Vec;
}
if (!SubVectors.empty()) {
if (SubVectors.size() == 1 && SubVectors.front().second == 0 &&
SubVectors.front().first->getVectorFactor() == CommonMask.size()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Same here: would be nice to give a name to that condition.

Value *Vec = SubVectors.front().first->VectorizedValue;
if (Vec->getType()->isIntOrIntVectorTy())
Vec = castToScalarTyElem(
Vec, any_of(SubVectors.front().first->Scalars, [&](Value *V) {
if (isa<PoisonValue>(V))
return false;
return !isKnownNonNegative(V, SimplifyQuery(*R.DL));
}));
transformMaskAfterShuffle(CommonMask, CommonMask);
} else if (!SubVectors.empty()) {
Value *Vec = InVectors.front();
if (InVectors.size() == 2) {
Vec = createShuffle(Vec, InVectors.back(), CommonMask);
Expand Down
Loading
Loading