26
26
27
27
using namespace llvm ;
28
28
29
+ static bool tryFoldBinaryFMul (BinaryOperator *BI, Value *MulOperand,
30
+ Value *OtherOperand, bool IsFirstOperand,
31
+ bool IsFSub) {
32
+ auto *FMul = dyn_cast<BinaryOperator>(MulOperand);
33
+ if (!FMul || FMul->getOpcode () != Instruction::FMul || !FMul->hasOneUse () ||
34
+ !FMul->hasAllowContract ())
35
+ return false ;
36
+
37
+ LLVM_DEBUG ({
38
+ const char *OpName = IsFSub ? " FSub" : " FAdd" ;
39
+ dbgs () << " Found " << OpName << " with FMul (single use) as "
40
+ << (IsFirstOperand ? " first" : " second" ) << " operand: " << *BI
41
+ << " \n " ;
42
+ });
43
+
44
+ Value *MulOp0 = FMul->getOperand (0 );
45
+ Value *MulOp1 = FMul->getOperand (1 );
46
+ IRBuilder<> Builder (BI);
47
+ Value *FMA = nullptr ;
48
+
49
+ if (!IsFSub) {
50
+ // fadd(fmul(a, b), c) => fma(a, b, c)
51
+ // fadd(c, fmul(a, b)) => fma(a, b, c)
52
+ FMA = Builder.CreateIntrinsic (Intrinsic::fma, {BI->getType ()},
53
+ {MulOp0, MulOp1, OtherOperand});
54
+ } else {
55
+ if (IsFirstOperand) {
56
+ // fsub(fmul(a, b), c) => fma(a, b, fneg(c))
57
+ Value *NegOtherOp =
58
+ Builder.CreateFNegFMF (OtherOperand, BI->getFastMathFlags ());
59
+ FMA = Builder.CreateIntrinsic (Intrinsic::fma, {BI->getType ()},
60
+ {MulOp0, MulOp1, NegOtherOp});
61
+ } else {
62
+ // fsub(a, fmul(b, c)) => fma(fneg(b), c, a)
63
+ Value *NegMulOp0 =
64
+ Builder.CreateFNegFMF (MulOp0, FMul->getFastMathFlags ());
65
+ FMA = Builder.CreateIntrinsic (Intrinsic::fma, {BI->getType ()},
66
+ {NegMulOp0, MulOp1, OtherOperand});
67
+ }
68
+ }
69
+
70
+ // Combine fast-math flags from the original instructions
71
+ auto *FMAInst = cast<Instruction>(FMA);
72
+ FastMathFlags BinaryFMF = BI->getFastMathFlags ();
73
+ FastMathFlags FMulFMF = FMul->getFastMathFlags ();
74
+ FastMathFlags NewFMF = FastMathFlags::intersectRewrite (BinaryFMF, FMulFMF) |
75
+ FastMathFlags::unionValue (BinaryFMF, FMulFMF);
76
+ FMAInst->setFastMathFlags (NewFMF);
77
+
78
+ LLVM_DEBUG ({
79
+ const char *OpName = IsFSub ? " FSub" : " FAdd" ;
80
+ dbgs () << " Replacing " << OpName << " with FMA: " << *FMA << " \n " ;
81
+ });
82
+ BI->replaceAllUsesWith (FMA);
83
+ BI->eraseFromParent ();
84
+ FMul->eraseFromParent ();
85
+ return true ;
86
+ }
87
+
29
88
static bool foldFMA (Function &F) {
30
89
bool Changed = false ;
31
90
SmallVector<BinaryOperator *, 16 > FAddFSubInsts;
@@ -50,66 +109,6 @@ static bool foldFMA(Function &F) {
50
109
}
51
110
}
52
111
53
- auto tryFoldBinaryFMul = [](BinaryOperator *BI, Value *MulOperand,
54
- Value *OtherOperand, bool IsFirstOperand,
55
- bool IsFSub) -> bool {
56
- auto *FMul = dyn_cast<BinaryOperator>(MulOperand);
57
- if (!FMul || FMul->getOpcode () != Instruction::FMul || !FMul->hasOneUse () ||
58
- !FMul->hasAllowContract ())
59
- return false ;
60
-
61
- LLVM_DEBUG ({
62
- const char *OpName = IsFSub ? " FSub" : " FAdd" ;
63
- dbgs () << " Found " << OpName << " with FMul (single use) as "
64
- << (IsFirstOperand ? " first" : " second" ) << " operand: " << *BI
65
- << " \n " ;
66
- });
67
-
68
- Value *MulOp0 = FMul->getOperand (0 );
69
- Value *MulOp1 = FMul->getOperand (1 );
70
- IRBuilder<> Builder (BI);
71
- Value *FMA = nullptr ;
72
-
73
- if (!IsFSub) {
74
- // fadd(fmul(a, b), c) => fma(a, b, c)
75
- // fadd(c, fmul(a, b)) => fma(a, b, c)
76
- FMA = Builder.CreateIntrinsic (Intrinsic::fma, {BI->getType ()},
77
- {MulOp0, MulOp1, OtherOperand});
78
- } else {
79
- if (IsFirstOperand) {
80
- // fsub(fmul(a, b), c) => fma(a, b, fneg(c))
81
- Value *NegOtherOp = Builder.CreateFNeg (OtherOperand);
82
- cast<Instruction>(NegOtherOp)->setFastMathFlags (BI->getFastMathFlags ());
83
- FMA = Builder.CreateIntrinsic (Intrinsic::fma, {BI->getType ()},
84
- {MulOp0, MulOp1, NegOtherOp});
85
- } else {
86
- // fsub(a, fmul(b, c)) => fma(fneg(b), c, a)
87
- Value *NegMulOp0 = Builder.CreateFNeg (MulOp0);
88
- cast<Instruction>(NegMulOp0)->setFastMathFlags (
89
- FMul->getFastMathFlags ());
90
- FMA = Builder.CreateIntrinsic (Intrinsic::fma, {BI->getType ()},
91
- {NegMulOp0, MulOp1, OtherOperand});
92
- }
93
- }
94
-
95
- // Combine fast-math flags from the original instructions
96
- auto *FMAInst = cast<Instruction>(FMA);
97
- FastMathFlags BinaryFMF = BI->getFastMathFlags ();
98
- FastMathFlags FMulFMF = FMul->getFastMathFlags ();
99
- FastMathFlags NewFMF = FastMathFlags::intersectRewrite (BinaryFMF, FMulFMF) |
100
- FastMathFlags::unionValue (BinaryFMF, FMulFMF);
101
- FMAInst->setFastMathFlags (NewFMF);
102
-
103
- LLVM_DEBUG ({
104
- const char *OpName = IsFSub ? " FSub" : " FAdd" ;
105
- dbgs () << " Replacing " << OpName << " with FMA: " << *FMA << " \n " ;
106
- });
107
- BI->replaceAllUsesWith (FMA);
108
- BI->eraseFromParent ();
109
- FMul->eraseFromParent ();
110
- return true ;
111
- };
112
-
113
112
for (auto *BI : FAddFSubInsts) {
114
113
Value *Op0 = BI->getOperand (0 );
115
114
Value *Op1 = BI->getOperand (1 );
@@ -142,5 +141,10 @@ FunctionPass *llvm::createNVPTXFoldFMAPass() { return new NVPTXFoldFMA(); }
142
141
143
142
PreservedAnalyses NVPTXFoldFMAPass::run (Function &F,
144
143
FunctionAnalysisManager &) {
145
- return foldFMA (F) ? PreservedAnalyses::none () : PreservedAnalyses::all ();
144
+ if (!foldFMA (F))
145
+ return PreservedAnalyses::all ();
146
+
147
+ PreservedAnalyses PA;
148
+ PA.preserveSet <CFGAnalyses>();
149
+ return PA;
146
150
}
0 commit comments