Skip to content

Commit 7549860

Browse files
committed
[NVPTX] Add IR pass for FMA transformation in NVPTX
This change introduces a new IR pass in the llc pipeline for NVPTX that transforms sequences of FMUL followed by FADD or FSUB into a single FMA instruction. Currently, all FMA folding for NVPTX occurs at the DAGCombine stage, which is too late for any IR-level passes that might want to optimize or analyze FMAs. By moving this transformation earlier into the IR phase, we enable more opportunities for FMA folding, including across basic blocks. Additionally, this new pass relies on the contract instruction level fast-math flag to perform these transformations, rather than depending on the -fp-contract=fast or -enable-unsafe-fp-math options passed to llc.
1 parent 5928619 commit 7549860

File tree

6 files changed

+391
-0
lines changed

6 files changed

+391
-0
lines changed

llvm/lib/Target/NVPTX/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ set(NVPTXCodeGen_sources
1717
NVPTXAssignValidGlobalNames.cpp
1818
NVPTXAtomicLower.cpp
1919
NVPTXCtorDtorLowering.cpp
20+
NVPTXFoldFMA.cpp
2021
NVPTXForwardParams.cpp
2122
NVPTXFrameLowering.cpp
2223
NVPTXGenericToNVVM.cpp

llvm/lib/Target/NVPTX/NVPTX.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ FunctionPass *createNVPTXLowerAllocaPass();
5252
FunctionPass *createNVPTXLowerUnreachablePass(bool TrapUnreachable,
5353
bool NoTrapAfterNoreturn);
5454
FunctionPass *createNVPTXTagInvariantLoadsPass();
55+
FunctionPass *createNVPTXFoldFMAPass();
5556
MachineFunctionPass *createNVPTXPeephole();
5657
MachineFunctionPass *createNVPTXProxyRegErasurePass();
5758
MachineFunctionPass *createNVPTXForwardParamsPass();
@@ -76,12 +77,17 @@ void initializeNVPTXAAWrapperPassPass(PassRegistry &);
7677
void initializeNVPTXExternalAAWrapperPass(PassRegistry &);
7778
void initializeNVPTXPeepholePass(PassRegistry &);
7879
void initializeNVPTXTagInvariantLoadLegacyPassPass(PassRegistry &);
80+
void initializeNVPTXFoldFMAPass(PassRegistry &);
7981
void initializeNVPTXPrologEpilogPassPass(PassRegistry &);
8082

8183
struct NVVMIntrRangePass : PassInfoMixin<NVVMIntrRangePass> {
8284
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
8385
};
8486

87+
struct NVPTXFoldFMAPass : PassInfoMixin<NVPTXFoldFMAPass> {
88+
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
89+
};
90+
8591
struct NVVMReflectPass : PassInfoMixin<NVVMReflectPass> {
8692
NVVMReflectPass() : SmVersion(0) {}
8793
NVVMReflectPass(unsigned SmVersion) : SmVersion(SmVersion) {}
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
//===------ NVPTXFoldFMA.cpp - Fold FMA --------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements FMA folding for float/double type for NVPTX. It folds
10+
// following patterns:
11+
// 1. fadd(fmul(a, b), c) => fma(a, b, c)
12+
// 2. fadd(c, fmul(a, b)) => fma(a, b, c)
13+
// 3. fadd(fmul(a, b), fmul(c, d)) => fma(a, b, fmul(c, d))
14+
// 4. fsub(fmul(a, b), c) => fma(a, b, fneg(c))
15+
// 5. fsub(a, fmul(b, c)) => fma(fneg(b), c, a)
16+
// 6. fsub(fmul(a, b), fmul(c, d)) => fma(a, b, fneg(fmul(c, d)))
17+
//===----------------------------------------------------------------------===//
18+
19+
#include "NVPTXUtilities.h"
20+
#include "llvm/IR/IRBuilder.h"
21+
#include "llvm/IR/InstIterator.h"
22+
#include "llvm/IR/Instructions.h"
23+
#include "llvm/IR/Intrinsics.h"
24+
25+
#define DEBUG_TYPE "nvptx-fold-fma"
26+
27+
using namespace llvm;
28+
29+
static bool foldFMA(Function &F) {
30+
bool Changed = false;
31+
SmallVector<BinaryOperator *, 16> FAddFSubInsts;
32+
33+
// Collect all float/double FAdd/FSub instructions with allow-contract
34+
for (auto &I : instructions(F)) {
35+
if (auto *BI = dyn_cast<BinaryOperator>(&I)) {
36+
// Only FAdd and FSub are supported.
37+
if (BI->getOpcode() != Instruction::FAdd &&
38+
BI->getOpcode() != Instruction::FSub)
39+
continue;
40+
41+
// At minimum, the instruction should have allow-contract.
42+
if (!BI->hasAllowContract())
43+
continue;
44+
45+
// Only float and double are supported.
46+
if (!BI->getType()->isFloatTy() && !BI->getType()->isDoubleTy())
47+
continue;
48+
49+
FAddFSubInsts.push_back(BI);
50+
}
51+
}
52+
53+
auto tryFoldBinaryFMul = [](BinaryOperator *BI, Value *MulOperand,
54+
Value *OtherOperand, bool IsFirstOperand,
55+
bool IsFSub) -> bool {
56+
auto *FMul = dyn_cast<BinaryOperator>(MulOperand);
57+
if (!FMul || FMul->getOpcode() != Instruction::FMul || !FMul->hasOneUse() ||
58+
!FMul->hasAllowContract())
59+
return false;
60+
61+
LLVM_DEBUG({
62+
const char *OpName = IsFSub ? "FSub" : "FAdd";
63+
dbgs() << "Found " << OpName << " with FMul (single use) as "
64+
<< (IsFirstOperand ? "first" : "second") << " operand: " << *BI
65+
<< "\n";
66+
});
67+
68+
Value *MulOp0 = FMul->getOperand(0);
69+
Value *MulOp1 = FMul->getOperand(1);
70+
IRBuilder<> Builder(BI);
71+
Value *FMA = nullptr;
72+
73+
if (!IsFSub) {
74+
// fadd(fmul(a, b), c) => fma(a, b, c)
75+
// fadd(c, fmul(a, b)) => fma(a, b, c)
76+
FMA = Builder.CreateIntrinsic(Intrinsic::fma, {BI->getType()},
77+
{MulOp0, MulOp1, OtherOperand});
78+
} else {
79+
if (IsFirstOperand) {
80+
// fsub(fmul(a, b), c) => fma(a, b, fneg(c))
81+
Value *NegOtherOp = Builder.CreateFNeg(OtherOperand);
82+
cast<Instruction>(NegOtherOp)->setFastMathFlags(BI->getFastMathFlags());
83+
FMA = Builder.CreateIntrinsic(Intrinsic::fma, {BI->getType()},
84+
{MulOp0, MulOp1, NegOtherOp});
85+
} else {
86+
// fsub(a, fmul(b, c)) => fma(fneg(b), c, a)
87+
Value *NegMulOp0 = Builder.CreateFNeg(MulOp0);
88+
cast<Instruction>(NegMulOp0)->setFastMathFlags(
89+
FMul->getFastMathFlags());
90+
FMA = Builder.CreateIntrinsic(Intrinsic::fma, {BI->getType()},
91+
{NegMulOp0, MulOp1, OtherOperand});
92+
}
93+
}
94+
95+
// Combine fast-math flags from the original instructions
96+
auto *FMAInst = cast<Instruction>(FMA);
97+
FastMathFlags BinaryFMF = BI->getFastMathFlags();
98+
FastMathFlags FMulFMF = FMul->getFastMathFlags();
99+
FastMathFlags NewFMF = FastMathFlags::intersectRewrite(BinaryFMF, FMulFMF) |
100+
FastMathFlags::unionValue(BinaryFMF, FMulFMF);
101+
FMAInst->setFastMathFlags(NewFMF);
102+
103+
LLVM_DEBUG({
104+
const char *OpName = IsFSub ? "FSub" : "FAdd";
105+
dbgs() << "Replacing " << OpName << " with FMA: " << *FMA << "\n";
106+
});
107+
BI->replaceAllUsesWith(FMA);
108+
BI->eraseFromParent();
109+
FMul->eraseFromParent();
110+
return true;
111+
};
112+
113+
for (auto *BI : FAddFSubInsts) {
114+
Value *Op0 = BI->getOperand(0);
115+
Value *Op1 = BI->getOperand(1);
116+
bool IsFSub = BI->getOpcode() == Instruction::FSub;
117+
118+
if (tryFoldBinaryFMul(BI, Op0, Op1, true /*IsFirstOperand*/, IsFSub) ||
119+
tryFoldBinaryFMul(BI, Op1, Op0, false /*IsFirstOperand*/, IsFSub))
120+
Changed = true;
121+
}
122+
123+
return Changed;
124+
}
125+
126+
namespace {
127+
128+
struct NVPTXFoldFMA : public FunctionPass {
129+
static char ID;
130+
NVPTXFoldFMA() : FunctionPass(ID) {}
131+
bool runOnFunction(Function &F) override;
132+
};
133+
134+
} // namespace
135+
136+
char NVPTXFoldFMA::ID = 0;
137+
INITIALIZE_PASS(NVPTXFoldFMA, "nvptx-fold-fma", "NVPTX Fold FMA", false, false)
138+
139+
bool NVPTXFoldFMA::runOnFunction(Function &F) { return foldFMA(F); }
140+
141+
FunctionPass *llvm::createNVPTXFoldFMAPass() { return new NVPTXFoldFMA(); }
142+
143+
PreservedAnalyses NVPTXFoldFMAPass::run(Function &F,
144+
FunctionAnalysisManager &) {
145+
return foldFMA(F) ? PreservedAnalyses::none() : PreservedAnalyses::all();
146+
}

llvm/lib/Target/NVPTX/NVPTXPassRegistry.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,5 @@ FUNCTION_PASS("nvvm-intr-range", NVVMIntrRangePass())
4040
FUNCTION_PASS("nvptx-copy-byval-args", NVPTXCopyByValArgsPass())
4141
FUNCTION_PASS("nvptx-lower-args", NVPTXLowerArgsPass(*this))
4242
FUNCTION_PASS("nvptx-tag-invariant-loads", NVPTXTagInvariantLoadsPass())
43+
FUNCTION_PASS("nvptx-fold-fma", NVPTXFoldFMAPass())
4344
#undef FUNCTION_PASS

llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,12 @@ static cl::opt<bool>
5151
cl::desc("Disable load/store vectorizer"),
5252
cl::init(false), cl::Hidden);
5353

54+
// FoldFMA is a new pass; this option will lets us turn it off in case we
55+
// encounter some issues.
56+
static cl::opt<bool> DisableFoldFMA("disable-nvptx-fold-fma",
57+
cl::desc("Disable NVPTX Fold FMA"),
58+
cl::init(false), cl::Hidden);
59+
5460
// TODO: Remove this flag when we are confident with no regressions.
5561
static cl::opt<bool> DisableRequireStructuredCFG(
5662
"disable-nvptx-require-structured-cfg",
@@ -115,6 +121,7 @@ extern "C" LLVM_ABI LLVM_EXTERNAL_VISIBILITY void LLVMInitializeNVPTXTarget() {
115121
initializeNVPTXExternalAAWrapperPass(PR);
116122
initializeNVPTXPeepholePass(PR);
117123
initializeNVPTXTagInvariantLoadLegacyPassPass(PR);
124+
initializeNVPTXFoldFMAPass(PR);
118125
initializeNVPTXPrologEpilogPassPass(PR);
119126
}
120127

@@ -397,6 +404,8 @@ void NVPTXPassConfig::addIRPasses() {
397404
addPass(createLoadStoreVectorizerPass());
398405
addPass(createSROAPass());
399406
addPass(createNVPTXTagInvariantLoadsPass());
407+
if (!DisableFoldFMA)
408+
addPass(createNVPTXFoldFMAPass());
400409
}
401410

402411
if (ST.hasPTXASUnreachableBug()) {

0 commit comments

Comments
 (0)