Skip to content

[LV] Bundle partial reductions inside VPExpressionRecipe #147302

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: users/SamTebbs33/expression-recipe-sub
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions llvm/include/llvm/Analysis/TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,8 @@ class TargetTransformInfo {
/// Get the kind of extension that an instruction represents.
LLVM_ABI static PartialReductionExtendKind
getPartialReductionExtendKind(Instruction *I);
LLVM_ABI static PartialReductionExtendKind
getPartialReductionExtendKind(Instruction::CastOps CastOpc);

/// Construct a TTI object using a type implementing the \c Concept
/// API below.
Expand Down
19 changes: 15 additions & 4 deletions llvm/lib/Analysis/TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1001,13 +1001,24 @@ InstructionCost TargetTransformInfo::getShuffleCost(

TargetTransformInfo::PartialReductionExtendKind
TargetTransformInfo::getPartialReductionExtendKind(Instruction *I) {
if (isa<SExtInst>(I))
return PR_SignExtend;
if (isa<ZExtInst>(I))
return PR_ZeroExtend;
if (auto *Cast = dyn_cast<CastInst>(I))
return getPartialReductionExtendKind(Cast->getOpcode());
return PR_None;
}

TargetTransformInfo::PartialReductionExtendKind
TargetTransformInfo::getPartialReductionExtendKind(
Instruction::CastOps CastOpc) {
switch (CastOpc) {
case Instruction::CastOps::ZExt:
return PR_ZeroExtend;
case Instruction::CastOps::SExt:
return PR_SignExtend;
default:
return PR_None;
}
}

TTI::CastContextHint
TargetTransformInfo::getCastContextHint(const Instruction *I) {
if (!I)
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5294,7 +5294,7 @@ InstructionCost AArch64TTIImpl::getExtendedReductionCost(
EVT ResVT = TLI->getValueType(DL, ResTy);

if (Opcode == Instruction::Add && VecVT.isSimple() && ResVT.isSimple() &&
VecVT.getSizeInBits() >= 64) {
VecVT.isFixedLengthVector() && VecVT.getSizeInBits() >= 64) {
std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(VecTy);

// The legal cases are:
Expand Down
8 changes: 6 additions & 2 deletions llvm/lib/Transforms/Vectorize/VPlan.h
Original file line number Diff line number Diff line change
Expand Up @@ -2470,7 +2470,8 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {

static inline bool classof(const VPRecipeBase *R) {
return R->getVPDefID() == VPRecipeBase::VPReductionSC ||
R->getVPDefID() == VPRecipeBase::VPReductionEVLSC;
R->getVPDefID() == VPRecipeBase::VPReductionEVLSC ||
R->getVPDefID() == VPRecipeBase::VPPartialReductionSC;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this was missed before and only now is tested?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's right.

}

static inline bool classof(const VPUser *U) {
Expand Down Expand Up @@ -2532,7 +2533,10 @@ class VPPartialReductionRecipe : public VPReductionRecipe {
Opcode(Opcode), VFScaleFactor(ScaleFactor) {
[[maybe_unused]] auto *AccumulatorRecipe =
getChainOp()->getDefiningRecipe();
assert((isa<VPReductionPHIRecipe>(AccumulatorRecipe) ||
// When cloning as part of a VPExpressionRecipe the chain op could have
// replaced by a temporary VPValue, so it doesn't have a defining recipe.
assert((!AccumulatorRecipe ||
isa<VPReductionPHIRecipe>(AccumulatorRecipe) ||
isa<VPPartialReductionRecipe>(AccumulatorRecipe)) &&
"Unexpected operand order for partial reduction recipe");
}
Expand Down
42 changes: 36 additions & 6 deletions llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ bool VPRecipeBase::mayHaveSideEffects() const {
return cast<VPWidenIntrinsicRecipe>(this)->mayHaveSideEffects();
case VPBlendSC:
case VPReductionEVLSC:
case VPPartialReductionSC:
case VPReductionSC:
case VPScalarIVStepsSC:
case VPVectorPointerSC:
Expand Down Expand Up @@ -2665,11 +2666,16 @@ InstructionCost VPExpressionRecipe::computeCost(ElementCount VF,
case ExpressionTypes::ExtendedReduction: {
unsigned Opcode = RecurrenceDescriptor::getOpcode(
cast<VPReductionRecipe>(ExpressionRecipes[1])->getRecurrenceKind());
auto *ExtR = cast<VPWidenCastRecipe>(ExpressionRecipes[0]);
if (isa<VPPartialReductionRecipe>(ExpressionRecipes.back())) {
return Ctx.TTI.getPartialReductionCost(
Opcode, Ctx.Types.inferScalarType(getOperand(0)), nullptr, RedTy, VF,
TargetTransformInfo::getPartialReductionExtendKind(ExtR->getOpcode()),
TargetTransformInfo::PR_None, std::nullopt, Ctx.CostKind);
}
return Ctx.TTI.getExtendedReductionCost(
Opcode,
cast<VPWidenCastRecipe>(ExpressionRecipes.front())->getOpcode() ==
Instruction::ZExt,
RedTy, SrcVecTy, std::nullopt, Ctx.CostKind);
Opcode, ExtR->getOpcode() == Instruction::ZExt, RedTy, SrcVecTy,
std::nullopt, Ctx.CostKind);
}
case ExpressionTypes::MulAccReduction:
return Ctx.TTI.getMulAccReductionCost(false, RedTy, SrcVecTy, false,
Expand All @@ -2678,6 +2684,23 @@ InstructionCost VPExpressionRecipe::computeCost(ElementCount VF,
case ExpressionTypes::ExtNegatedMulAccReduction:
case ExpressionTypes::ExtMulAccReduction: {
bool Negated = ExpressionType == ExpressionTypes::ExtNegatedMulAccReduction;
if (isa<VPPartialReductionRecipe>(ExpressionRecipes.back())) {
auto *Ext0R = cast<VPWidenCastRecipe>(ExpressionRecipes[0]);
auto *Ext1R = cast<VPWidenCastRecipe>(ExpressionRecipes[1]);
auto *Mul = cast<VPWidenRecipe>(ExpressionRecipes[2]);
unsigned Opcode =
ExpressionType == ExpressionTypes::ExtNegatedMulAccReduction
? Instruction::Sub
: Instruction::Add;
return Ctx.TTI.getPartialReductionCost(
Opcode, Ctx.Types.inferScalarType(getOperand(0)),
Ctx.Types.inferScalarType(getOperand(1)), RedTy, VF,
TargetTransformInfo::getPartialReductionExtendKind(
Ext0R->getOpcode()),
TargetTransformInfo::getPartialReductionExtendKind(
Ext1R->getOpcode()),
Mul->getOpcode(), Ctx.CostKind);
}
Comment on lines +2687 to +2703
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No code shared here with others, might be worth having different expression types?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They do share printing and matching code. There's no real difference (in terms of bundling) between a partial reduction bundle and a normal reduction bundle, except for costing. So I don't think it would be worth adding all the extra glue code just to have another expression type. We're moving towards making partial reductions VPReductionRecipes anyway.

return Ctx.TTI.getMulAccReductionCost(
cast<VPWidenCastRecipe>(ExpressionRecipes.front())->getOpcode() ==
Instruction::ZExt,
Expand Down Expand Up @@ -2710,12 +2733,15 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent,
O << " = ";
auto *Red = cast<VPReductionRecipe>(ExpressionRecipes.back());
unsigned Opcode = RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind());
bool IsPartialReduction = isa<VPPartialReductionRecipe>(Red);

switch (ExpressionType) {
case ExpressionTypes::ExtendedReduction: {
getOperand(1)->printAsOperand(O, SlotTracker);
O << " +";
O << " reduce." << Instruction::getOpcodeName(Opcode) << " (";
O << " + ";
if (IsPartialReduction)
O << "partial.";
O << "reduce." << Instruction::getOpcodeName(Opcode) << " (";
getOperand(0)->printAsOperand(O, SlotTracker);
Red->printFlags(O);

Expand All @@ -2732,6 +2758,8 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent,
case ExpressionTypes::ExtNegatedMulAccReduction: {
getOperand(getNumOperands() - 1)->printAsOperand(O, SlotTracker);
O << " + ";
if (IsPartialReduction)
O << "partial.";
O << "reduce."
<< Instruction::getOpcodeName(
RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind()))
Expand All @@ -2758,6 +2786,8 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent,
case ExpressionTypes::ExtMulAccReduction: {
getOperand(getNumOperands() - 1)->printAsOperand(O, SlotTracker);
O << " + ";
if (IsPartialReduction)
O << "partial.";
O << "reduce."
<< Instruction::getOpcodeName(
RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind()))
Expand Down
33 changes: 22 additions & 11 deletions llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2857,15 +2857,25 @@ tryToMatchAndCreateExtendedReduction(VPReductionRecipe *Red, VPCostContext &Ctx,
VPValue *VecOp = Red->getVecOp();

// Clamp the range if using extended-reduction is profitable.
auto IsExtendedRedValidAndClampRange = [&](unsigned Opcode, bool isZExt,
Type *SrcTy) -> bool {
auto IsExtendedRedValidAndClampRange =
[&](unsigned Opcode, Instruction::CastOps ExtOpc, Type *SrcTy) -> bool {
return LoopVectorizationPlanner::getDecisionAndClampRange(
[&](ElementCount VF) {
auto *SrcVecTy = cast<VectorType>(toVectorTy(SrcTy, VF));
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
InstructionCost ExtRedCost = Ctx.TTI.getExtendedReductionCost(
Opcode, isZExt, RedTy, SrcVecTy, Red->getFastMathFlags(),
CostKind);

InstructionCost ExtRedCost;
if (isa<VPPartialReductionRecipe>(Red)) {
TargetTransformInfo::PartialReductionExtendKind ExtKind =
TargetTransformInfo::getPartialReductionExtendKind(ExtOpc);
ExtRedCost = Ctx.TTI.getPartialReductionCost(
Opcode, SrcTy, nullptr, RedTy, VF, ExtKind,
llvm::TargetTransformInfo::PR_None, std::nullopt, Ctx.CostKind);
} else {
ExtRedCost = Ctx.TTI.getExtendedReductionCost(
Opcode, ExtOpc == Instruction::CastOps::ZExt, RedTy, SrcVecTy,
Red->getFastMathFlags(), CostKind);
}
InstructionCost ExtCost =
cast<VPWidenCastRecipe>(VecOp)->computeCost(VF, Ctx);
InstructionCost RedCost = Red->computeCost(VF, Ctx);
Expand All @@ -2879,8 +2889,7 @@ tryToMatchAndCreateExtendedReduction(VPReductionRecipe *Red, VPCostContext &Ctx,
if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) &&
IsExtendedRedValidAndClampRange(
RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind()),
cast<VPWidenCastRecipe>(VecOp)->getOpcode() ==
Instruction::CastOps::ZExt,
cast<VPWidenCastRecipe>(VecOp)->getOpcode(),
Ctx.Types.inferScalarType(A)))
return new VPExpressionRecipe(cast<VPWidenCastRecipe>(VecOp), Red);

Expand All @@ -2899,6 +2908,7 @@ static VPExpressionRecipe *
tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
VPCostContext &Ctx, VFRange &Range) {
using namespace VPlanPatternMatch;
bool IsPartialReduction = isa<VPPartialReductionRecipe>(Red);

unsigned Opcode = RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind());
if (Opcode != Instruction::Add)
Expand Down Expand Up @@ -2955,12 +2965,13 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,

// Match reduce.add(mul(ext, ext)).
if (RecipeA && RecipeB &&
(RecipeA->getOpcode() == RecipeB->getOpcode() || A == B) &&
(RecipeA->getOpcode() == RecipeB->getOpcode() || IsPartialReduction) &&
match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
IsMulAccValidAndClampRange(RecipeA->getOpcode() ==
Instruction::CastOps::ZExt,
MulR, RecipeA, RecipeB, nullptr, Sub)) {
(IsPartialReduction ||
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we have to clamp the range also for partial reductions? Is this done somewhere else?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, VPRecipeBuilder::getScaledReductions clamps the range for partial reductions.

IsMulAccValidAndClampRange(RecipeA->getOpcode() ==
Instruction::CastOps::ZExt,
MulR, RecipeA, RecipeB, nullptr, Sub))) {
if (Sub)
return new VPExpressionRecipe(
RecipeA, RecipeB, MulR,
Expand Down
Loading
Loading