Skip to content

[VectorCombine] New folding pattern for extract/binop/shuffle chains #145232

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
303 changes: 303 additions & 0 deletions llvm/lib/Transforms/Vectorize/VectorCombine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ class VectorCombine {
bool foldShuffleOfIntrinsics(Instruction &I);
bool foldShuffleToIdentity(Instruction &I);
bool foldShuffleFromReductions(Instruction &I);
bool foldShuffleChainsToReduce(Instruction &I);
bool foldCastFromReductions(Instruction &I);
bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
bool foldInterleaveIntrinsics(Instruction &I);
Expand Down Expand Up @@ -2988,6 +2989,305 @@ bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
return foldSelectShuffle(*Shuffle, true);
}

/// For a given chain of patterns of the following form:
///
/// ```
/// %1 = shufflevector <n x ty1> %0, <n x ty1> poison <n x ty2> mask
///
/// %2 = tail call <n x ty1> llvm.<umin/umax/smin/smax>(<n x ty1> %0, <n x
/// ty1> %1)
/// OR
/// %2 = add/mul/or/and/xor <n x ty1> %0, %1
///
/// %3 = shufflevector <n x ty1> %2, <n x ty1> poison <n x ty2> mask
/// ...
/// ...
/// %(i - 1) = tail call <n x ty1> llvm.<umin/umax/smin/smax>(<n x ty1> %(i -
/// 3), <n x ty1> %(i - 2)
/// OR
/// %(i - 1) = add/mul/or/and/xor <n x ty1> %(i - 3), %(i - 2)
///
/// %(i) = extractelement <n x ty1> %(i - 1), 0
/// ```
///
/// Where:
/// `mask` follows a partition pattern:
///
/// Ex:
/// [n = 8, p = poison]
///
/// 4 5 6 7 | p p p p
/// 2 3 | p p p p p p
/// 1 | p p p p p p p
///
/// For powers of 2, there's a consistent pattern, but for other cases
/// the parity of the current half value at each step decides the
/// next partition half (see `ExpectedParityMask` for more logical details
/// in generalising this).
///
/// Ex:
/// [n = 6]
///
/// 3 4 5 | p p p
/// 1 2 | p p p p
/// 1 | p p p p p
bool VectorCombine::foldShuffleChainsToReduce(Instruction &I) {
// Going bottom-up for the pattern.
auto *EEI = dyn_cast<ExtractElementInst>(&I);
if (!EEI)
return false;

std::queue<Value *> InstWorklist;
InstructionCost OrigCost = 0;

Value *InitEEV = nullptr;

// Common instruction operation after each shuffle op.
unsigned int CommonCallOp = 0;
Instruction::BinaryOps CommonBinOp = Instruction::BinaryOpsEnd;

bool IsFirstCallOrBinInst = true;
bool ShouldBeCallOrBinInst = true;

// This stores the last used instructions for shuffle/common op.
//
// PrevVecV[2] stores the first vector from extract element instruction,
// while PrevVecV[0] / PrevVecV[1] store the last two simultaneous
// instructions from either shuffle/common op.
SmallVector<Value *, 3> PrevVecV(3, nullptr);

Value *VecOp;
if (!match(&I, m_ExtractElt(m_Value(VecOp), m_Zero())))
return false;

auto *FVT = dyn_cast<FixedVectorType>(VecOp->getType());
if (!FVT)
return false;

int64_t VecSize = FVT->getNumElements();
if (VecSize < 2)
return false;

// Number of levels would be ~log2(n), considering we always partition
// by half for this fold pattern.
unsigned int NumLevels = Log2_64_Ceil(VecSize), VisitedCnt = 0;
int64_t ShuffleMaskHalf = 1, ExpectedParityMask = 0;

// This is how we generalise for all element sizes.
// At each step, if vector size is odd, we need non-poison
// values to cover the dominant half so we don't miss out on any element.
//
// This mask will help us retrieve this as we go from bottom to top:
//
// Mask Set -> N = N * 2 - 1
// Mask Unset -> N = N * 2
for (int Cur = VecSize, Mask = NumLevels - 1; Cur > 1;
Cur = (Cur + 1) / 2, --Mask) {
if (Cur & 1)
ExpectedParityMask |= (1ll << Mask);
}

PrevVecV[2] = VecOp;
InitEEV = EEI;

InstWorklist.push(PrevVecV[2]);

while (!InstWorklist.empty()) {
Value *V = InstWorklist.front();
InstWorklist.pop();

auto *CI = dyn_cast<Instruction>(V);
if (!CI)
return false;

if (auto *II = dyn_cast<IntrinsicInst>(CI)) {
if (!ShouldBeCallOrBinInst || !PrevVecV[2])
return false;

if (!IsFirstCallOrBinInst &&
any_of(PrevVecV, [](Value *VecV) { return VecV == nullptr; }))
return false;

// For the first found call/bin op, the vector has to come from the
// extract element op.
if (II != (IsFirstCallOrBinInst ? PrevVecV[2] : PrevVecV[0]))
return false;
IsFirstCallOrBinInst = false;

if (!CommonCallOp)
CommonCallOp = II->getIntrinsicID();
if (II->getIntrinsicID() != CommonCallOp)
return false;

switch (II->getIntrinsicID()) {
case Intrinsic::umin:
case Intrinsic::umax:
case Intrinsic::smin:
case Intrinsic::smax: {
auto *Op0 = II->getOperand(0);
auto *Op1 = II->getOperand(1);
PrevVecV[0] = Op0;
PrevVecV[1] = Op1;
break;
}
default:
return false;
}
ShouldBeCallOrBinInst ^= 1;

IntrinsicCostAttributes ICA(
CommonCallOp, II->getType(),
{PrevVecV[0]->getType(), PrevVecV[1]->getType()});
OrigCost += TTI.getIntrinsicInstrCost(ICA, CostKind);

// We may need a swap here since it can be (a, b) or (b, a)
// and accordinly change as we go up.
if (!isa<ShuffleVectorInst>(PrevVecV[1]))
std::swap(PrevVecV[0], PrevVecV[1]);
InstWorklist.push(PrevVecV[1]);
InstWorklist.push(PrevVecV[0]);
} else if (auto *BinOp = dyn_cast<BinaryOperator>(CI)) {
// Similar logic for bin ops.

if (!ShouldBeCallOrBinInst || !PrevVecV[2])
return false;

if (!IsFirstCallOrBinInst &&
any_of(PrevVecV, [](Value *VecV) { return VecV == nullptr; }))
return false;

if (BinOp != (IsFirstCallOrBinInst ? PrevVecV[2] : PrevVecV[0]))
return false;
IsFirstCallOrBinInst = false;

if (CommonBinOp == Instruction::BinaryOpsEnd)
CommonBinOp = BinOp->getOpcode();

if (BinOp->getOpcode() != CommonBinOp)
return false;

switch (CommonBinOp) {
case BinaryOperator::Add:
case BinaryOperator::Mul:
case BinaryOperator::Or:
case BinaryOperator::And:
case BinaryOperator::Xor: {
auto *Op0 = BinOp->getOperand(0);
auto *Op1 = BinOp->getOperand(1);
PrevVecV[0] = Op0;
PrevVecV[1] = Op1;
break;
}
default:
return false;
}
ShouldBeCallOrBinInst ^= 1;

OrigCost +=
TTI.getArithmeticInstrCost(CommonBinOp, BinOp->getType(), CostKind);

if (!isa<ShuffleVectorInst>(PrevVecV[1]))
std::swap(PrevVecV[0], PrevVecV[1]);
InstWorklist.push(PrevVecV[1]);
InstWorklist.push(PrevVecV[0]);
} else if (auto *SVInst = dyn_cast<ShuffleVectorInst>(CI)) {
// We shouldn't have any null values in the previous vectors,
// is so, there was a mismatch in pattern.
if (ShouldBeCallOrBinInst ||
any_of(PrevVecV, [](Value *VecV) { return VecV == nullptr; }))
return false;

if (SVInst != PrevVecV[1])
return false;

auto *ShuffleVec = SVInst->getOperand(0);
if (!ShuffleVec || ShuffleVec != PrevVecV[0])
return false;

if (!isa<PoisonValue>(SVInst->getOperand(1)))
return false;

ArrayRef<int> CurMask = SVInst->getShuffleMask();

// Subtract the parity mask when checking the condition.
for (int Mask = 0, MaskSize = CurMask.size(); Mask != MaskSize; ++Mask) {
if (Mask < ShuffleMaskHalf &&
CurMask[Mask] != ShuffleMaskHalf + Mask - (ExpectedParityMask & 1))
return false;
if (Mask >= ShuffleMaskHalf && CurMask[Mask] != -1)
return false;
}

// Update mask values.
ShuffleMaskHalf *= 2;
ShuffleMaskHalf -= (ExpectedParityMask & 1);
ExpectedParityMask >>= 1;

OrigCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc,
SVInst->getType(), SVInst->getType(),
CurMask, CostKind);

VisitedCnt += 1;
if (!ExpectedParityMask && VisitedCnt == NumLevels)
break;

ShouldBeCallOrBinInst ^= 1;
} else {
return false;
}
}

// Pattern should end with a shuffle op.
if (ShouldBeCallOrBinInst)
return false;

assert(VecSize != -1 && "Expected Match for Vector Size");

Value *FinalVecV = PrevVecV[0];
if (!InitEEV || !FinalVecV)
return false;

auto *FinalVecVTy = dyn_cast<FixedVectorType>(FinalVecV->getType());

assert(FinalVecVTy && "Expected non-null value for Vector Type");

Intrinsic::ID ReducedOp = 0;
if (CommonCallOp) {
switch (CommonCallOp) {
case Intrinsic::umin:
ReducedOp = Intrinsic::vector_reduce_umin;
break;
case Intrinsic::umax:
ReducedOp = Intrinsic::vector_reduce_umax;
break;
case Intrinsic::smin:
ReducedOp = Intrinsic::vector_reduce_smin;
break;
case Intrinsic::smax:
ReducedOp = Intrinsic::vector_reduce_smax;
break;
default:
return false;
}
} else if (CommonBinOp != Instruction::BinaryOpsEnd) {
ReducedOp = getReductionForBinop(CommonBinOp);
if (!ReducedOp)
return false;
}

IntrinsicCostAttributes ICA(ReducedOp, FinalVecVTy, {FinalVecV});
InstructionCost NewCost = TTI.getIntrinsicInstrCost(ICA, CostKind);

if (NewCost >= OrigCost)
return false;

auto *ReducedResult =
Builder.CreateIntrinsic(ReducedOp, {FinalVecV->getType()}, {FinalVecV});
replaceValue(*InitEEV, *ReducedResult);

return true;
}

/// Determine if its more efficient to fold:
/// reduce(trunc(x)) -> trunc(reduce(x)).
/// reduce(sext(x)) -> sext(reduce(x)).
Expand Down Expand Up @@ -3705,6 +4005,9 @@ bool VectorCombine::run() {
MadeChange |= foldShuffleFromReductions(I);
MadeChange |= foldCastFromReductions(I);
break;
case Instruction::ExtractElement:
MadeChange |= foldShuffleChainsToReduce(I);
break;
case Instruction::ICmp:
case Instruction::FCmp:
MadeChange |= foldExtractExtract(I);
Expand Down
Loading