Skip to content

Commit f1eff5c

Browse files
committed
Addressed review comments
1. Moved lambda function into a static function. 2. Preserving CFG analysis. 3. Using CreateFNegFMF instead of CreateFNeg api.
1 parent 7549860 commit f1eff5c

File tree

1 file changed

+65
-61
lines changed

1 file changed

+65
-61
lines changed

llvm/lib/Target/NVPTX/NVPTXFoldFMA.cpp

Lines changed: 65 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,65 @@
2626

2727
using namespace llvm;
2828

29+
static bool tryFoldBinaryFMul(BinaryOperator *BI, Value *MulOperand,
30+
Value *OtherOperand, bool IsFirstOperand,
31+
bool IsFSub) {
32+
auto *FMul = dyn_cast<BinaryOperator>(MulOperand);
33+
if (!FMul || FMul->getOpcode() != Instruction::FMul || !FMul->hasOneUse() ||
34+
!FMul->hasAllowContract())
35+
return false;
36+
37+
LLVM_DEBUG({
38+
const char *OpName = IsFSub ? "FSub" : "FAdd";
39+
dbgs() << "Found " << OpName << " with FMul (single use) as "
40+
<< (IsFirstOperand ? "first" : "second") << " operand: " << *BI
41+
<< "\n";
42+
});
43+
44+
Value *MulOp0 = FMul->getOperand(0);
45+
Value *MulOp1 = FMul->getOperand(1);
46+
IRBuilder<> Builder(BI);
47+
Value *FMA = nullptr;
48+
49+
if (!IsFSub) {
50+
// fadd(fmul(a, b), c) => fma(a, b, c)
51+
// fadd(c, fmul(a, b)) => fma(a, b, c)
52+
FMA = Builder.CreateIntrinsic(Intrinsic::fma, {BI->getType()},
53+
{MulOp0, MulOp1, OtherOperand});
54+
} else {
55+
if (IsFirstOperand) {
56+
// fsub(fmul(a, b), c) => fma(a, b, fneg(c))
57+
Value *NegOtherOp =
58+
Builder.CreateFNegFMF(OtherOperand, BI->getFastMathFlags());
59+
FMA = Builder.CreateIntrinsic(Intrinsic::fma, {BI->getType()},
60+
{MulOp0, MulOp1, NegOtherOp});
61+
} else {
62+
// fsub(a, fmul(b, c)) => fma(fneg(b), c, a)
63+
Value *NegMulOp0 =
64+
Builder.CreateFNegFMF(MulOp0, FMul->getFastMathFlags());
65+
FMA = Builder.CreateIntrinsic(Intrinsic::fma, {BI->getType()},
66+
{NegMulOp0, MulOp1, OtherOperand});
67+
}
68+
}
69+
70+
// Combine fast-math flags from the original instructions
71+
auto *FMAInst = cast<Instruction>(FMA);
72+
FastMathFlags BinaryFMF = BI->getFastMathFlags();
73+
FastMathFlags FMulFMF = FMul->getFastMathFlags();
74+
FastMathFlags NewFMF = FastMathFlags::intersectRewrite(BinaryFMF, FMulFMF) |
75+
FastMathFlags::unionValue(BinaryFMF, FMulFMF);
76+
FMAInst->setFastMathFlags(NewFMF);
77+
78+
LLVM_DEBUG({
79+
const char *OpName = IsFSub ? "FSub" : "FAdd";
80+
dbgs() << "Replacing " << OpName << " with FMA: " << *FMA << "\n";
81+
});
82+
BI->replaceAllUsesWith(FMA);
83+
BI->eraseFromParent();
84+
FMul->eraseFromParent();
85+
return true;
86+
}
87+
2988
static bool foldFMA(Function &F) {
3089
bool Changed = false;
3190
SmallVector<BinaryOperator *, 16> FAddFSubInsts;
@@ -50,66 +109,6 @@ static bool foldFMA(Function &F) {
50109
}
51110
}
52111

53-
auto tryFoldBinaryFMul = [](BinaryOperator *BI, Value *MulOperand,
54-
Value *OtherOperand, bool IsFirstOperand,
55-
bool IsFSub) -> bool {
56-
auto *FMul = dyn_cast<BinaryOperator>(MulOperand);
57-
if (!FMul || FMul->getOpcode() != Instruction::FMul || !FMul->hasOneUse() ||
58-
!FMul->hasAllowContract())
59-
return false;
60-
61-
LLVM_DEBUG({
62-
const char *OpName = IsFSub ? "FSub" : "FAdd";
63-
dbgs() << "Found " << OpName << " with FMul (single use) as "
64-
<< (IsFirstOperand ? "first" : "second") << " operand: " << *BI
65-
<< "\n";
66-
});
67-
68-
Value *MulOp0 = FMul->getOperand(0);
69-
Value *MulOp1 = FMul->getOperand(1);
70-
IRBuilder<> Builder(BI);
71-
Value *FMA = nullptr;
72-
73-
if (!IsFSub) {
74-
// fadd(fmul(a, b), c) => fma(a, b, c)
75-
// fadd(c, fmul(a, b)) => fma(a, b, c)
76-
FMA = Builder.CreateIntrinsic(Intrinsic::fma, {BI->getType()},
77-
{MulOp0, MulOp1, OtherOperand});
78-
} else {
79-
if (IsFirstOperand) {
80-
// fsub(fmul(a, b), c) => fma(a, b, fneg(c))
81-
Value *NegOtherOp = Builder.CreateFNeg(OtherOperand);
82-
cast<Instruction>(NegOtherOp)->setFastMathFlags(BI->getFastMathFlags());
83-
FMA = Builder.CreateIntrinsic(Intrinsic::fma, {BI->getType()},
84-
{MulOp0, MulOp1, NegOtherOp});
85-
} else {
86-
// fsub(a, fmul(b, c)) => fma(fneg(b), c, a)
87-
Value *NegMulOp0 = Builder.CreateFNeg(MulOp0);
88-
cast<Instruction>(NegMulOp0)->setFastMathFlags(
89-
FMul->getFastMathFlags());
90-
FMA = Builder.CreateIntrinsic(Intrinsic::fma, {BI->getType()},
91-
{NegMulOp0, MulOp1, OtherOperand});
92-
}
93-
}
94-
95-
// Combine fast-math flags from the original instructions
96-
auto *FMAInst = cast<Instruction>(FMA);
97-
FastMathFlags BinaryFMF = BI->getFastMathFlags();
98-
FastMathFlags FMulFMF = FMul->getFastMathFlags();
99-
FastMathFlags NewFMF = FastMathFlags::intersectRewrite(BinaryFMF, FMulFMF) |
100-
FastMathFlags::unionValue(BinaryFMF, FMulFMF);
101-
FMAInst->setFastMathFlags(NewFMF);
102-
103-
LLVM_DEBUG({
104-
const char *OpName = IsFSub ? "FSub" : "FAdd";
105-
dbgs() << "Replacing " << OpName << " with FMA: " << *FMA << "\n";
106-
});
107-
BI->replaceAllUsesWith(FMA);
108-
BI->eraseFromParent();
109-
FMul->eraseFromParent();
110-
return true;
111-
};
112-
113112
for (auto *BI : FAddFSubInsts) {
114113
Value *Op0 = BI->getOperand(0);
115114
Value *Op1 = BI->getOperand(1);
@@ -142,5 +141,10 @@ FunctionPass *llvm::createNVPTXFoldFMAPass() { return new NVPTXFoldFMA(); }
142141

143142
PreservedAnalyses NVPTXFoldFMAPass::run(Function &F,
144143
FunctionAnalysisManager &) {
145-
return foldFMA(F) ? PreservedAnalyses::none() : PreservedAnalyses::all();
144+
if (!foldFMA(F))
145+
return PreservedAnalyses::all();
146+
147+
PreservedAnalyses PA;
148+
PA.preserveSet<CFGAnalyses>();
149+
return PA;
146150
}

0 commit comments

Comments
 (0)