From 754986095db086f9701e141cf27a3ccc7581871b Mon Sep 17 00:00:00 2001 From: rbajpai Date: Thu, 31 Jul 2025 11:31:29 +0530 Subject: [PATCH 1/3] [NVPTX] Add IR pass for FMA transformation in NVPTX This change introduces a new IR pass in the llc pipeline for NVPTX that transforms sequences of FMUL followed by FADD or FSUB into a single FMA instruction. Currently, all FMA folding for NVPTX occurs at the DAGCombine stage, which is too late for any IR-level passes that might want to optimize or analyze FMAs. By moving this transformation earlier into the IR phase, we enable more opportunities for FMA folding, including across basic blocks. Additionally, this new pass relies on the contract instruction level fast-math flag to perform these transformations, rather than depending on the -fp-contract=fast or -enable-unsafe-fp-math options passed to llc. --- llvm/lib/Target/NVPTX/CMakeLists.txt | 1 + llvm/lib/Target/NVPTX/NVPTX.h | 6 + llvm/lib/Target/NVPTX/NVPTXFoldFMA.cpp | 146 ++++++++++++ llvm/lib/Target/NVPTX/NVPTXPassRegistry.def | 1 + llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp | 9 + llvm/test/CodeGen/NVPTX/nvptx-fold-fma.ll | 228 +++++++++++++++++++ 6 files changed, 391 insertions(+) create mode 100644 llvm/lib/Target/NVPTX/NVPTXFoldFMA.cpp create mode 100644 llvm/test/CodeGen/NVPTX/nvptx-fold-fma.ll diff --git a/llvm/lib/Target/NVPTX/CMakeLists.txt b/llvm/lib/Target/NVPTX/CMakeLists.txt index 693f0d0b35edc..1264a5e2e9f32 100644 --- a/llvm/lib/Target/NVPTX/CMakeLists.txt +++ b/llvm/lib/Target/NVPTX/CMakeLists.txt @@ -17,6 +17,7 @@ set(NVPTXCodeGen_sources NVPTXAssignValidGlobalNames.cpp NVPTXAtomicLower.cpp NVPTXCtorDtorLowering.cpp + NVPTXFoldFMA.cpp NVPTXForwardParams.cpp NVPTXFrameLowering.cpp NVPTXGenericToNVVM.cpp diff --git a/llvm/lib/Target/NVPTX/NVPTX.h b/llvm/lib/Target/NVPTX/NVPTX.h index 77a0e03d4075a..e84fa42319b34 100644 --- a/llvm/lib/Target/NVPTX/NVPTX.h +++ b/llvm/lib/Target/NVPTX/NVPTX.h @@ -52,6 +52,7 @@ FunctionPass *createNVPTXLowerAllocaPass(); FunctionPass *createNVPTXLowerUnreachablePass(bool TrapUnreachable, bool NoTrapAfterNoreturn); FunctionPass *createNVPTXTagInvariantLoadsPass(); +FunctionPass *createNVPTXFoldFMAPass(); MachineFunctionPass *createNVPTXPeephole(); MachineFunctionPass *createNVPTXProxyRegErasurePass(); MachineFunctionPass *createNVPTXForwardParamsPass(); @@ -76,12 +77,17 @@ void initializeNVPTXAAWrapperPassPass(PassRegistry &); void initializeNVPTXExternalAAWrapperPass(PassRegistry &); void initializeNVPTXPeepholePass(PassRegistry &); void initializeNVPTXTagInvariantLoadLegacyPassPass(PassRegistry &); +void initializeNVPTXFoldFMAPass(PassRegistry &); void initializeNVPTXPrologEpilogPassPass(PassRegistry &); struct NVVMIntrRangePass : PassInfoMixin { PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); }; +struct NVPTXFoldFMAPass : PassInfoMixin { + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); +}; + struct NVVMReflectPass : PassInfoMixin { NVVMReflectPass() : SmVersion(0) {} NVVMReflectPass(unsigned SmVersion) : SmVersion(SmVersion) {} diff --git a/llvm/lib/Target/NVPTX/NVPTXFoldFMA.cpp b/llvm/lib/Target/NVPTX/NVPTXFoldFMA.cpp new file mode 100644 index 0000000000000..b844b880559ac --- /dev/null +++ b/llvm/lib/Target/NVPTX/NVPTXFoldFMA.cpp @@ -0,0 +1,146 @@ +//===------ NVPTXFoldFMA.cpp - Fold FMA --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements FMA folding for float/double type for NVPTX. It folds +// following patterns: +// 1. fadd(fmul(a, b), c) => fma(a, b, c) +// 2. fadd(c, fmul(a, b)) => fma(a, b, c) +// 3. fadd(fmul(a, b), fmul(c, d)) => fma(a, b, fmul(c, d)) +// 4. fsub(fmul(a, b), c) => fma(a, b, fneg(c)) +// 5. fsub(a, fmul(b, c)) => fma(fneg(b), c, a) +// 6. fsub(fmul(a, b), fmul(c, d)) => fma(a, b, fneg(fmul(c, d))) +//===----------------------------------------------------------------------===// + +#include "NVPTXUtilities.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Intrinsics.h" + +#define DEBUG_TYPE "nvptx-fold-fma" + +using namespace llvm; + +static bool foldFMA(Function &F) { + bool Changed = false; + SmallVector FAddFSubInsts; + + // Collect all float/double FAdd/FSub instructions with allow-contract + for (auto &I : instructions(F)) { + if (auto *BI = dyn_cast(&I)) { + // Only FAdd and FSub are supported. + if (BI->getOpcode() != Instruction::FAdd && + BI->getOpcode() != Instruction::FSub) + continue; + + // At minimum, the instruction should have allow-contract. + if (!BI->hasAllowContract()) + continue; + + // Only float and double are supported. + if (!BI->getType()->isFloatTy() && !BI->getType()->isDoubleTy()) + continue; + + FAddFSubInsts.push_back(BI); + } + } + + auto tryFoldBinaryFMul = [](BinaryOperator *BI, Value *MulOperand, + Value *OtherOperand, bool IsFirstOperand, + bool IsFSub) -> bool { + auto *FMul = dyn_cast(MulOperand); + if (!FMul || FMul->getOpcode() != Instruction::FMul || !FMul->hasOneUse() || + !FMul->hasAllowContract()) + return false; + + LLVM_DEBUG({ + const char *OpName = IsFSub ? "FSub" : "FAdd"; + dbgs() << "Found " << OpName << " with FMul (single use) as " + << (IsFirstOperand ? "first" : "second") << " operand: " << *BI + << "\n"; + }); + + Value *MulOp0 = FMul->getOperand(0); + Value *MulOp1 = FMul->getOperand(1); + IRBuilder<> Builder(BI); + Value *FMA = nullptr; + + if (!IsFSub) { + // fadd(fmul(a, b), c) => fma(a, b, c) + // fadd(c, fmul(a, b)) => fma(a, b, c) + FMA = Builder.CreateIntrinsic(Intrinsic::fma, {BI->getType()}, + {MulOp0, MulOp1, OtherOperand}); + } else { + if (IsFirstOperand) { + // fsub(fmul(a, b), c) => fma(a, b, fneg(c)) + Value *NegOtherOp = Builder.CreateFNeg(OtherOperand); + cast(NegOtherOp)->setFastMathFlags(BI->getFastMathFlags()); + FMA = Builder.CreateIntrinsic(Intrinsic::fma, {BI->getType()}, + {MulOp0, MulOp1, NegOtherOp}); + } else { + // fsub(a, fmul(b, c)) => fma(fneg(b), c, a) + Value *NegMulOp0 = Builder.CreateFNeg(MulOp0); + cast(NegMulOp0)->setFastMathFlags( + FMul->getFastMathFlags()); + FMA = Builder.CreateIntrinsic(Intrinsic::fma, {BI->getType()}, + {NegMulOp0, MulOp1, OtherOperand}); + } + } + + // Combine fast-math flags from the original instructions + auto *FMAInst = cast(FMA); + FastMathFlags BinaryFMF = BI->getFastMathFlags(); + FastMathFlags FMulFMF = FMul->getFastMathFlags(); + FastMathFlags NewFMF = FastMathFlags::intersectRewrite(BinaryFMF, FMulFMF) | + FastMathFlags::unionValue(BinaryFMF, FMulFMF); + FMAInst->setFastMathFlags(NewFMF); + + LLVM_DEBUG({ + const char *OpName = IsFSub ? "FSub" : "FAdd"; + dbgs() << "Replacing " << OpName << " with FMA: " << *FMA << "\n"; + }); + BI->replaceAllUsesWith(FMA); + BI->eraseFromParent(); + FMul->eraseFromParent(); + return true; + }; + + for (auto *BI : FAddFSubInsts) { + Value *Op0 = BI->getOperand(0); + Value *Op1 = BI->getOperand(1); + bool IsFSub = BI->getOpcode() == Instruction::FSub; + + if (tryFoldBinaryFMul(BI, Op0, Op1, true /*IsFirstOperand*/, IsFSub) || + tryFoldBinaryFMul(BI, Op1, Op0, false /*IsFirstOperand*/, IsFSub)) + Changed = true; + } + + return Changed; +} + +namespace { + +struct NVPTXFoldFMA : public FunctionPass { + static char ID; + NVPTXFoldFMA() : FunctionPass(ID) {} + bool runOnFunction(Function &F) override; +}; + +} // namespace + +char NVPTXFoldFMA::ID = 0; +INITIALIZE_PASS(NVPTXFoldFMA, "nvptx-fold-fma", "NVPTX Fold FMA", false, false) + +bool NVPTXFoldFMA::runOnFunction(Function &F) { return foldFMA(F); } + +FunctionPass *llvm::createNVPTXFoldFMAPass() { return new NVPTXFoldFMA(); } + +PreservedAnalyses NVPTXFoldFMAPass::run(Function &F, + FunctionAnalysisManager &) { + return foldFMA(F) ? PreservedAnalyses::none() : PreservedAnalyses::all(); +} diff --git a/llvm/lib/Target/NVPTX/NVPTXPassRegistry.def b/llvm/lib/Target/NVPTX/NVPTXPassRegistry.def index ee37c9826012c..176d334321a80 100644 --- a/llvm/lib/Target/NVPTX/NVPTXPassRegistry.def +++ b/llvm/lib/Target/NVPTX/NVPTXPassRegistry.def @@ -40,4 +40,5 @@ FUNCTION_PASS("nvvm-intr-range", NVVMIntrRangePass()) FUNCTION_PASS("nvptx-copy-byval-args", NVPTXCopyByValArgsPass()) FUNCTION_PASS("nvptx-lower-args", NVPTXLowerArgsPass(*this)) FUNCTION_PASS("nvptx-tag-invariant-loads", NVPTXTagInvariantLoadsPass()) +FUNCTION_PASS("nvptx-fold-fma", NVPTXFoldFMAPass()) #undef FUNCTION_PASS diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp index 0603994606d71..ad0493229ecf8 100644 --- a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp @@ -51,6 +51,12 @@ static cl::opt cl::desc("Disable load/store vectorizer"), cl::init(false), cl::Hidden); +// FoldFMA is a new pass; this option will lets us turn it off in case we +// encounter some issues. +static cl::opt DisableFoldFMA("disable-nvptx-fold-fma", + cl::desc("Disable NVPTX Fold FMA"), + cl::init(false), cl::Hidden); + // TODO: Remove this flag when we are confident with no regressions. static cl::opt DisableRequireStructuredCFG( "disable-nvptx-require-structured-cfg", @@ -115,6 +121,7 @@ extern "C" LLVM_ABI LLVM_EXTERNAL_VISIBILITY void LLVMInitializeNVPTXTarget() { initializeNVPTXExternalAAWrapperPass(PR); initializeNVPTXPeepholePass(PR); initializeNVPTXTagInvariantLoadLegacyPassPass(PR); + initializeNVPTXFoldFMAPass(PR); initializeNVPTXPrologEpilogPassPass(PR); } @@ -397,6 +404,8 @@ void NVPTXPassConfig::addIRPasses() { addPass(createLoadStoreVectorizerPass()); addPass(createSROAPass()); addPass(createNVPTXTagInvariantLoadsPass()); + if (!DisableFoldFMA) + addPass(createNVPTXFoldFMAPass()); } if (ST.hasPTXASUnreachableBug()) { diff --git a/llvm/test/CodeGen/NVPTX/nvptx-fold-fma.ll b/llvm/test/CodeGen/NVPTX/nvptx-fold-fma.ll new file mode 100644 index 0000000000000..ef01e9f044acf --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/nvptx-fold-fma.ll @@ -0,0 +1,228 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +; RUN: opt < %s -passes=nvptx-fold-fma -S | FileCheck %s + +target triple = "nvptx64-nvidia-cuda" + +; fsub(fmul(a, b), c) => fma(a, b, fneg(c)) +define float @test_fsub_fmul_c(float %a, float %b, float %c) { +; CHECK-LABEL: define float @test_fsub_fmul_c( +; CHECK-SAME: float [[A:%.*]], float [[B:%.*]], float [[C:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = fneg contract float [[C]] +; CHECK-NEXT: [[TMP2:%.*]] = call contract float @llvm.fma.f32(float [[A]], float [[B]], float [[TMP1]]) +; CHECK-NEXT: ret float [[TMP2]] +; + %mul = fmul contract float %a, %b + %sub = fsub contract float %mul, %c + ret float %sub +} + + +; fsub(c, fmul(a, b)) => fma(-a, b, c) +define float @test_fsub_c_fmul(float %a, float %b, float %c) { +; CHECK-LABEL: define float @test_fsub_c_fmul( +; CHECK-SAME: float [[A:%.*]], float [[B:%.*]], float [[C:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = fneg contract float [[A]] +; CHECK-NEXT: [[TMP2:%.*]] = call contract float @llvm.fma.f32(float [[TMP1]], float [[B]], float [[C]]) +; CHECK-NEXT: ret float [[TMP2]] +; + %mul = fmul contract float %a, %b + %sub = fsub contract float %c, %mul + ret float %sub +} + + +; fsub(fmul(a, b), fmul(c, d)) => fma(a, b, fneg(fmul(c, d))) +define float @test_fsub_fmul_fmul(float %a, float %b, float %c, float %d) { +; CHECK-LABEL: define float @test_fsub_fmul_fmul( +; CHECK-SAME: float [[A:%.*]], float [[B:%.*]], float [[C:%.*]], float [[D:%.*]]) { +; CHECK-NEXT: [[MUL2:%.*]] = fmul contract float [[C]], [[D]] +; CHECK-NEXT: [[TMP1:%.*]] = fneg contract float [[MUL2]] +; CHECK-NEXT: [[TMP2:%.*]] = call contract float @llvm.fma.f32(float [[A]], float [[B]], float [[TMP1]]) +; CHECK-NEXT: ret float [[TMP2]] +; + %mul1 = fmul contract float %a, %b + %mul2 = fmul contract float %c, %d + %sub = fsub contract float %mul1, %mul2 + ret float %sub +} + + +; fsub(fmul(a, b), c) => fma(a, b, fneg(c)) where fsub and fmul are in different BBs +define float @test_fsub_fmul_different_BB(float %a, float %b, float %c, i32 %n) { +; CHECK-LABEL: define float @test_fsub_fmul_different_BB( +; CHECK-SAME: float [[A:%.*]], float [[B:%.*]], float [[C:%.*]], i32 [[N:%.*]]) { +; CHECK-NEXT: [[INIT:.*]]: +; CHECK-NEXT: [[CMP_ITER:%.*]] = icmp sgt i32 [[N]], 10 +; CHECK-NEXT: br i1 [[CMP_ITER]], label %[[ITERATION:.*]], label %[[EXIT:.*]] +; CHECK: [[ITERATION]]: +; CHECK-NEXT: [[I:%.*]] = phi i32 [ 0, %[[INIT]] ], [ [[I_NEXT:%.*]], %[[ITERATION]] ] +; CHECK-NEXT: [[ACC:%.*]] = phi float [ [[C]], %[[INIT]] ], [ [[ACC_NEXT:%.*]], %[[ITERATION]] ] +; CHECK-NEXT: [[I_NEXT]] = add i32 [[I]], 1 +; CHECK-NEXT: [[ACC_NEXT]] = fadd contract float [[ACC]], 1.000000e+00 +; CHECK-NEXT: [[CMP_LOOP:%.*]] = icmp slt i32 [[I_NEXT]], [[N]] +; CHECK-NEXT: br i1 [[CMP_LOOP]], label %[[ITERATION]], label %[[EXIT]] +; CHECK: [[EXIT]]: +; CHECK-NEXT: [[C_PHI:%.*]] = phi float [ [[C]], %[[INIT]] ], [ [[ACC_NEXT]], %[[ITERATION]] ] +; CHECK-NEXT: [[TMP0:%.*]] = fneg contract float [[C_PHI]] +; CHECK-NEXT: [[TMP1:%.*]] = call contract float @llvm.fma.f32(float [[A]], float [[B]], float [[TMP0]]) +; CHECK-NEXT: ret float [[TMP1]] +; +init: + %mul = fmul contract float %a, %b + %cmp_iter = icmp sgt i32 %n, 10 + br i1 %cmp_iter, label %iteration, label %exit + +iteration: + %i = phi i32 [ 0, %init ], [ %i_next, %iteration ] + %acc = phi float [ %c, %init ], [ %acc_next, %iteration ] + %i_next = add i32 %i, 1 + %acc_next = fadd contract float %acc, 1.0 + %cmp_loop = icmp slt i32 %i_next, %n + br i1 %cmp_loop, label %iteration, label %exit + +exit: + %c_phi = phi float [ %c, %init ], [ %acc_next, %iteration ] + %sub = fsub contract float %mul, %c_phi + ret float %sub +} + + +; fadd(fmul(a, b), c) => fma(a, b, c) +define float @test_fadd_fmul_c(float %a, float %b, float %c) { +; CHECK-LABEL: define float @test_fadd_fmul_c( +; CHECK-SAME: float [[A:%.*]], float [[B:%.*]], float [[C:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = call contract float @llvm.fma.f32(float [[A]], float [[B]], float [[C]]) +; CHECK-NEXT: ret float [[TMP1]] +; + %mul = fmul contract float %a, %b + %add = fadd contract float %mul, %c + ret float %add +} + + +; fadd(c, fmul(a, b)) => fma(a, b, c) +define float @test_fadd_c_fmul(float %a, float %b, float %c) { +; CHECK-LABEL: define float @test_fadd_c_fmul( +; CHECK-SAME: float [[A:%.*]], float [[B:%.*]], float [[C:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = call contract float @llvm.fma.f32(float [[A]], float [[B]], float [[C]]) +; CHECK-NEXT: ret float [[TMP1]] +; + %mul = fmul contract float %a, %b + %add = fadd contract float %c, %mul + ret float %add +} + + +; fadd(fmul(a, b), fmul(c, d)) => fma(a, b, fmul(c, d)) +define float @test_fadd_fmul_fmul(float %a, float %b, float %c, float %d) { +; CHECK-LABEL: define float @test_fadd_fmul_fmul( +; CHECK-SAME: float [[A:%.*]], float [[B:%.*]], float [[C:%.*]], float [[D:%.*]]) { +; CHECK-NEXT: [[MUL2:%.*]] = fmul contract float [[C]], [[D]] +; CHECK-NEXT: [[TMP1:%.*]] = call contract float @llvm.fma.f32(float [[A]], float [[B]], float [[MUL2]]) +; CHECK-NEXT: ret float [[TMP1]] +; + %mul1 = fmul contract float %a, %b + %mul2 = fmul contract float %c, %d + %add = fadd contract float %mul1, %mul2 + ret float %add +} + + +; fadd(fmul(a, b), c) => fma(a, b, c) where fadd and fmul are in different BBs +define float @test_fadd_fmul_different_BB(float %a, float %b, float %c, i32 %n) { +; CHECK-LABEL: define float @test_fadd_fmul_different_BB( +; CHECK-SAME: float [[A:%.*]], float [[B:%.*]], float [[C:%.*]], i32 [[N:%.*]]) { +; CHECK-NEXT: [[INIT:.*]]: +; CHECK-NEXT: [[CMP_ITER:%.*]] = icmp sgt i32 [[N]], 10 +; CHECK-NEXT: br i1 [[CMP_ITER]], label %[[ITERATION:.*]], label %[[EXIT:.*]] +; CHECK: [[ITERATION]]: +; CHECK-NEXT: [[I:%.*]] = phi i32 [ 0, %[[INIT]] ], [ [[I_NEXT:%.*]], %[[ITERATION]] ] +; CHECK-NEXT: [[ACC:%.*]] = phi float [ [[C]], %[[INIT]] ], [ [[ACC_NEXT:%.*]], %[[ITERATION]] ] +; CHECK-NEXT: [[I_NEXT]] = add i32 [[I]], 1 +; CHECK-NEXT: [[ACC_NEXT]] = fadd contract float [[ACC]], 1.000000e+00 +; CHECK-NEXT: [[CMP_LOOP:%.*]] = icmp slt i32 [[I_NEXT]], [[N]] +; CHECK-NEXT: br i1 [[CMP_LOOP]], label %[[ITERATION]], label %[[EXIT]] +; CHECK: [[EXIT]]: +; CHECK-NEXT: [[C_PHI:%.*]] = phi float [ [[C]], %[[INIT]] ], [ [[ACC_NEXT]], %[[ITERATION]] ] +; CHECK-NEXT: [[TMP0:%.*]] = call contract float @llvm.fma.f32(float [[A]], float [[B]], float [[C_PHI]]) +; CHECK-NEXT: ret float [[TMP0]] +; +init: + %mul = fmul contract float %a, %b + %cmp_iter = icmp sgt i32 %n, 10 + br i1 %cmp_iter, label %iteration, label %exit + +iteration: + %i = phi i32 [ 0, %init ], [ %i_next, %iteration ] + %acc = phi float [ %c, %init ], [ %acc_next, %iteration ] + %i_next = add i32 %i, 1 + %acc_next = fadd contract float %acc, 1.0 + %cmp_loop = icmp slt i32 %i_next, %n + br i1 %cmp_loop, label %iteration, label %exit + +exit: + %c_phi = phi float [ %c, %init ], [ %acc_next, %iteration ] + %add = fadd contract float %mul, %c_phi + ret float %add +} + + +; These scenarios shouldn't work. +; fadd(fpext(fmul(a, b)), c) => fma(fpext(a), fpext(b), c) +define double @test_fadd_fpext_fmul_c(float %a, float %b, double %c) { +; CHECK-LABEL: define double @test_fadd_fpext_fmul_c( +; CHECK-SAME: float [[A:%.*]], float [[B:%.*]], double [[C:%.*]]) { +; CHECK-NEXT: [[MUL:%.*]] = fmul contract float [[A]], [[B]] +; CHECK-NEXT: [[EXT:%.*]] = fpext float [[MUL]] to double +; CHECK-NEXT: [[ADD:%.*]] = fadd contract double [[EXT]], [[C]] +; CHECK-NEXT: ret double [[ADD]] +; + %mul = fmul contract float %a, %b + %ext = fpext float %mul to double + %add = fadd contract double %ext, %c + ret double %add +} + + +; fadd(c, fpext(fmul(a, b))) => fma(fpext(a), fpext(b), c) +define double @test_fadd_c_fpext_fmul(float %a, float %b, double %c) { +; CHECK-LABEL: define double @test_fadd_c_fpext_fmul( +; CHECK-SAME: float [[A:%.*]], float [[B:%.*]], double [[C:%.*]]) { +; CHECK-NEXT: [[MUL:%.*]] = fmul contract float [[A]], [[B]] +; CHECK-NEXT: [[EXT:%.*]] = fpext float [[MUL]] to double +; CHECK-NEXT: [[ADD:%.*]] = fadd contract double [[C]], [[EXT]] +; CHECK-NEXT: ret double [[ADD]] +; + %mul = fmul contract float %a, %b + %ext = fpext float %mul to double + %add = fadd contract double %c, %ext + ret double %add +} + + +; Double precision tests +; fsub(fmul(a, b), c) => fma(a, b, fneg(c)) +define double @test_fsub_fmul_c_double(double %a, double %b, double %c) { +; CHECK-LABEL: define double @test_fsub_fmul_c_double( +; CHECK-SAME: double [[A:%.*]], double [[B:%.*]], double [[C:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = fneg contract double [[C]] +; CHECK-NEXT: [[TMP2:%.*]] = call contract double @llvm.fma.f64(double [[A]], double [[B]], double [[TMP1]]) +; CHECK-NEXT: ret double [[TMP2]] +; + %mul = fmul contract double %a, %b + %sub = fsub contract double %mul, %c + ret double %sub +} + + +; fadd(fmul(a, b), c) => fma(a, b, c) +define double @test_fadd_fmul_c_double(double %a, double %b, double %c) { +; CHECK-LABEL: define double @test_fadd_fmul_c_double( +; CHECK-SAME: double [[A:%.*]], double [[B:%.*]], double [[C:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = call contract double @llvm.fma.f64(double [[A]], double [[B]], double [[C]]) +; CHECK-NEXT: ret double [[TMP1]] +; + %mul = fmul contract double %a, %b + %add = fadd contract double %mul, %c + ret double %add +} From f1eff5c5db2107cb1dc8eb477e9965307be2ea15 Mon Sep 17 00:00:00 2001 From: rbajpai Date: Mon, 1 Sep 2025 16:20:15 +0530 Subject: [PATCH 2/3] Addressed review comments 1. Moved lambda function into a static function. 2. Preserving CFG analysis. 3. Using CreateFNegFMF instead of CreateFNeg api. --- llvm/lib/Target/NVPTX/NVPTXFoldFMA.cpp | 126 +++++++++++++------------ 1 file changed, 65 insertions(+), 61 deletions(-) diff --git a/llvm/lib/Target/NVPTX/NVPTXFoldFMA.cpp b/llvm/lib/Target/NVPTX/NVPTXFoldFMA.cpp index b844b880559ac..41e82a80c3f9e 100644 --- a/llvm/lib/Target/NVPTX/NVPTXFoldFMA.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXFoldFMA.cpp @@ -26,6 +26,65 @@ using namespace llvm; +static bool tryFoldBinaryFMul(BinaryOperator *BI, Value *MulOperand, + Value *OtherOperand, bool IsFirstOperand, + bool IsFSub) { + auto *FMul = dyn_cast(MulOperand); + if (!FMul || FMul->getOpcode() != Instruction::FMul || !FMul->hasOneUse() || + !FMul->hasAllowContract()) + return false; + + LLVM_DEBUG({ + const char *OpName = IsFSub ? "FSub" : "FAdd"; + dbgs() << "Found " << OpName << " with FMul (single use) as " + << (IsFirstOperand ? "first" : "second") << " operand: " << *BI + << "\n"; + }); + + Value *MulOp0 = FMul->getOperand(0); + Value *MulOp1 = FMul->getOperand(1); + IRBuilder<> Builder(BI); + Value *FMA = nullptr; + + if (!IsFSub) { + // fadd(fmul(a, b), c) => fma(a, b, c) + // fadd(c, fmul(a, b)) => fma(a, b, c) + FMA = Builder.CreateIntrinsic(Intrinsic::fma, {BI->getType()}, + {MulOp0, MulOp1, OtherOperand}); + } else { + if (IsFirstOperand) { + // fsub(fmul(a, b), c) => fma(a, b, fneg(c)) + Value *NegOtherOp = + Builder.CreateFNegFMF(OtherOperand, BI->getFastMathFlags()); + FMA = Builder.CreateIntrinsic(Intrinsic::fma, {BI->getType()}, + {MulOp0, MulOp1, NegOtherOp}); + } else { + // fsub(a, fmul(b, c)) => fma(fneg(b), c, a) + Value *NegMulOp0 = + Builder.CreateFNegFMF(MulOp0, FMul->getFastMathFlags()); + FMA = Builder.CreateIntrinsic(Intrinsic::fma, {BI->getType()}, + {NegMulOp0, MulOp1, OtherOperand}); + } + } + + // Combine fast-math flags from the original instructions + auto *FMAInst = cast(FMA); + FastMathFlags BinaryFMF = BI->getFastMathFlags(); + FastMathFlags FMulFMF = FMul->getFastMathFlags(); + FastMathFlags NewFMF = FastMathFlags::intersectRewrite(BinaryFMF, FMulFMF) | + FastMathFlags::unionValue(BinaryFMF, FMulFMF); + FMAInst->setFastMathFlags(NewFMF); + + LLVM_DEBUG({ + const char *OpName = IsFSub ? "FSub" : "FAdd"; + dbgs() << "Replacing " << OpName << " with FMA: " << *FMA << "\n"; + }); + BI->replaceAllUsesWith(FMA); + BI->eraseFromParent(); + FMul->eraseFromParent(); + return true; +} + static bool foldFMA(Function &F) { bool Changed = false; SmallVector FAddFSubInsts; @@ -50,66 +109,6 @@ static bool foldFMA(Function &F) { } } - auto tryFoldBinaryFMul = [](BinaryOperator *BI, Value *MulOperand, - Value *OtherOperand, bool IsFirstOperand, - bool IsFSub) -> bool { - auto *FMul = dyn_cast(MulOperand); - if (!FMul || FMul->getOpcode() != Instruction::FMul || !FMul->hasOneUse() || - !FMul->hasAllowContract()) - return false; - - LLVM_DEBUG({ - const char *OpName = IsFSub ? "FSub" : "FAdd"; - dbgs() << "Found " << OpName << " with FMul (single use) as " - << (IsFirstOperand ? "first" : "second") << " operand: " << *BI - << "\n"; - }); - - Value *MulOp0 = FMul->getOperand(0); - Value *MulOp1 = FMul->getOperand(1); - IRBuilder<> Builder(BI); - Value *FMA = nullptr; - - if (!IsFSub) { - // fadd(fmul(a, b), c) => fma(a, b, c) - // fadd(c, fmul(a, b)) => fma(a, b, c) - FMA = Builder.CreateIntrinsic(Intrinsic::fma, {BI->getType()}, - {MulOp0, MulOp1, OtherOperand}); - } else { - if (IsFirstOperand) { - // fsub(fmul(a, b), c) => fma(a, b, fneg(c)) - Value *NegOtherOp = Builder.CreateFNeg(OtherOperand); - cast(NegOtherOp)->setFastMathFlags(BI->getFastMathFlags()); - FMA = Builder.CreateIntrinsic(Intrinsic::fma, {BI->getType()}, - {MulOp0, MulOp1, NegOtherOp}); - } else { - // fsub(a, fmul(b, c)) => fma(fneg(b), c, a) - Value *NegMulOp0 = Builder.CreateFNeg(MulOp0); - cast(NegMulOp0)->setFastMathFlags( - FMul->getFastMathFlags()); - FMA = Builder.CreateIntrinsic(Intrinsic::fma, {BI->getType()}, - {NegMulOp0, MulOp1, OtherOperand}); - } - } - - // Combine fast-math flags from the original instructions - auto *FMAInst = cast(FMA); - FastMathFlags BinaryFMF = BI->getFastMathFlags(); - FastMathFlags FMulFMF = FMul->getFastMathFlags(); - FastMathFlags NewFMF = FastMathFlags::intersectRewrite(BinaryFMF, FMulFMF) | - FastMathFlags::unionValue(BinaryFMF, FMulFMF); - FMAInst->setFastMathFlags(NewFMF); - - LLVM_DEBUG({ - const char *OpName = IsFSub ? "FSub" : "FAdd"; - dbgs() << "Replacing " << OpName << " with FMA: " << *FMA << "\n"; - }); - BI->replaceAllUsesWith(FMA); - BI->eraseFromParent(); - FMul->eraseFromParent(); - return true; - }; - for (auto *BI : FAddFSubInsts) { Value *Op0 = BI->getOperand(0); Value *Op1 = BI->getOperand(1); @@ -142,5 +141,10 @@ FunctionPass *llvm::createNVPTXFoldFMAPass() { return new NVPTXFoldFMA(); } PreservedAnalyses NVPTXFoldFMAPass::run(Function &F, FunctionAnalysisManager &) { - return foldFMA(F) ? PreservedAnalyses::none() : PreservedAnalyses::all(); + if (!foldFMA(F)) + return PreservedAnalyses::all(); + + PreservedAnalyses PA; + PA.preserveSet(); + return PA; } From 49adcdff1f59958d400b8da6686f2d692ea86c4d Mon Sep 17 00:00:00 2001 From: rbajpai Date: Fri, 5 Sep 2025 14:35:43 +0530 Subject: [PATCH 3/3] Addressed review comments 1. Removed extra arguments passed to tryFoldBinaryFMul. 2. Removed temporary storage to collect the binary instructions. 3. Made guarding condition little easier to read. 4. Added one more test scenario. --- llvm/lib/Target/NVPTX/CMakeLists.txt | 2 +- llvm/lib/Target/NVPTX/NVPTX.h | 6 +- .../{NVPTXFoldFMA.cpp => NVPTXIRPeephole.cpp} | 73 +++++++++++-------- llvm/lib/Target/NVPTX/NVPTXPassRegistry.def | 2 +- llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp | 17 +++-- llvm/test/CodeGen/NVPTX/nvptx-fold-fma.ll | 21 +++++- 6 files changed, 76 insertions(+), 45 deletions(-) rename llvm/lib/Target/NVPTX/{NVPTXFoldFMA.cpp => NVPTXIRPeephole.cpp} (69%) diff --git a/llvm/lib/Target/NVPTX/CMakeLists.txt b/llvm/lib/Target/NVPTX/CMakeLists.txt index 1264a5e2e9f32..dfc4c12348895 100644 --- a/llvm/lib/Target/NVPTX/CMakeLists.txt +++ b/llvm/lib/Target/NVPTX/CMakeLists.txt @@ -17,7 +17,7 @@ set(NVPTXCodeGen_sources NVPTXAssignValidGlobalNames.cpp NVPTXAtomicLower.cpp NVPTXCtorDtorLowering.cpp - NVPTXFoldFMA.cpp + NVPTXIRPeephole.cpp NVPTXForwardParams.cpp NVPTXFrameLowering.cpp NVPTXGenericToNVVM.cpp diff --git a/llvm/lib/Target/NVPTX/NVPTX.h b/llvm/lib/Target/NVPTX/NVPTX.h index e84fa42319b34..e10331b1bd560 100644 --- a/llvm/lib/Target/NVPTX/NVPTX.h +++ b/llvm/lib/Target/NVPTX/NVPTX.h @@ -52,7 +52,7 @@ FunctionPass *createNVPTXLowerAllocaPass(); FunctionPass *createNVPTXLowerUnreachablePass(bool TrapUnreachable, bool NoTrapAfterNoreturn); FunctionPass *createNVPTXTagInvariantLoadsPass(); -FunctionPass *createNVPTXFoldFMAPass(); +FunctionPass *createNVPTXIRPeepholePass(); MachineFunctionPass *createNVPTXPeephole(); MachineFunctionPass *createNVPTXProxyRegErasurePass(); MachineFunctionPass *createNVPTXForwardParamsPass(); @@ -77,14 +77,14 @@ void initializeNVPTXAAWrapperPassPass(PassRegistry &); void initializeNVPTXExternalAAWrapperPass(PassRegistry &); void initializeNVPTXPeepholePass(PassRegistry &); void initializeNVPTXTagInvariantLoadLegacyPassPass(PassRegistry &); -void initializeNVPTXFoldFMAPass(PassRegistry &); +void initializeNVPTXIRPeepholePass(PassRegistry &); void initializeNVPTXPrologEpilogPassPass(PassRegistry &); struct NVVMIntrRangePass : PassInfoMixin { PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); }; -struct NVPTXFoldFMAPass : PassInfoMixin { +struct NVPTXIRPeepholePass : PassInfoMixin { PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); }; diff --git a/llvm/lib/Target/NVPTX/NVPTXFoldFMA.cpp b/llvm/lib/Target/NVPTX/NVPTXIRPeephole.cpp similarity index 69% rename from llvm/lib/Target/NVPTX/NVPTXFoldFMA.cpp rename to llvm/lib/Target/NVPTX/NVPTXIRPeephole.cpp index 41e82a80c3f9e..50535434d22f0 100644 --- a/llvm/lib/Target/NVPTX/NVPTXFoldFMA.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXIRPeephole.cpp @@ -1,4 +1,4 @@ -//===------ NVPTXFoldFMA.cpp - Fold FMA --------------===// +//===------ NVPTXIRPeephole.cpp - NVPTX IR Peephole --------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -22,18 +22,37 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" -#define DEBUG_TYPE "nvptx-fold-fma" +#define DEBUG_TYPE "nvptx-ir-peephole" using namespace llvm; -static bool tryFoldBinaryFMul(BinaryOperator *BI, Value *MulOperand, - Value *OtherOperand, bool IsFirstOperand, - bool IsFSub) { - auto *FMul = dyn_cast(MulOperand); - if (!FMul || FMul->getOpcode() != Instruction::FMul || !FMul->hasOneUse() || - !FMul->hasAllowContract()) +static bool tryFoldBinaryFMul(BinaryOperator *BI) { + Value *Op0 = BI->getOperand(0); + Value *Op1 = BI->getOperand(1); + + auto *FMul0 = dyn_cast(Op0); + auto *FMul1 = dyn_cast(Op1); + + BinaryOperator *FMul = nullptr; + Value *OtherOperand = nullptr; + bool IsFirstOperand = false; + + // Either Op0 or Op1 should be a valid FMul + if (FMul0 && FMul0->getOpcode() == Instruction::FMul && FMul0->hasOneUse() && + FMul0->hasAllowContract()) { + FMul = FMul0; + OtherOperand = Op1; + IsFirstOperand = true; + } else if (FMul1 && FMul1->getOpcode() == Instruction::FMul && + FMul1->hasOneUse() && FMul1->hasAllowContract()) { + FMul = FMul1; + OtherOperand = Op0; + IsFirstOperand = false; + } else { return false; + } + bool IsFSub = BI->getOpcode() == Instruction::FSub; LLVM_DEBUG({ const char *OpName = IsFSub ? "FSub" : "FAdd"; dbgs() << "Found " << OpName << " with FMul (single use) as " @@ -87,10 +106,9 @@ static bool tryFoldBinaryFMul(BinaryOperator *BI, Value *MulOperand, static bool foldFMA(Function &F) { bool Changed = false; - SmallVector FAddFSubInsts; - // Collect all float/double FAdd/FSub instructions with allow-contract - for (auto &I : instructions(F)) { + // Iterate and process float/double FAdd/FSub instructions with allow-contract + for (auto &I : llvm::make_early_inc_range(instructions(F))) { if (auto *BI = dyn_cast(&I)) { // Only FAdd and FSub are supported. if (BI->getOpcode() != Instruction::FAdd && @@ -105,42 +123,35 @@ static bool foldFMA(Function &F) { if (!BI->getType()->isFloatTy() && !BI->getType()->isDoubleTy()) continue; - FAddFSubInsts.push_back(BI); + if (tryFoldBinaryFMul(BI)) + Changed = true; } } - - for (auto *BI : FAddFSubInsts) { - Value *Op0 = BI->getOperand(0); - Value *Op1 = BI->getOperand(1); - bool IsFSub = BI->getOpcode() == Instruction::FSub; - - if (tryFoldBinaryFMul(BI, Op0, Op1, true /*IsFirstOperand*/, IsFSub) || - tryFoldBinaryFMul(BI, Op1, Op0, false /*IsFirstOperand*/, IsFSub)) - Changed = true; - } - return Changed; } namespace { -struct NVPTXFoldFMA : public FunctionPass { +struct NVPTXIRPeephole : public FunctionPass { static char ID; - NVPTXFoldFMA() : FunctionPass(ID) {} + NVPTXIRPeephole() : FunctionPass(ID) {} bool runOnFunction(Function &F) override; }; } // namespace -char NVPTXFoldFMA::ID = 0; -INITIALIZE_PASS(NVPTXFoldFMA, "nvptx-fold-fma", "NVPTX Fold FMA", false, false) +char NVPTXIRPeephole::ID = 0; +INITIALIZE_PASS(NVPTXIRPeephole, "nvptx-ir-peephole", "NVPTX IR Peephole", + false, false) -bool NVPTXFoldFMA::runOnFunction(Function &F) { return foldFMA(F); } +bool NVPTXIRPeephole::runOnFunction(Function &F) { return foldFMA(F); } -FunctionPass *llvm::createNVPTXFoldFMAPass() { return new NVPTXFoldFMA(); } +FunctionPass *llvm::createNVPTXIRPeepholePass() { + return new NVPTXIRPeephole(); +} -PreservedAnalyses NVPTXFoldFMAPass::run(Function &F, - FunctionAnalysisManager &) { +PreservedAnalyses NVPTXIRPeepholePass::run(Function &F, + FunctionAnalysisManager &) { if (!foldFMA(F)) return PreservedAnalyses::all(); diff --git a/llvm/lib/Target/NVPTX/NVPTXPassRegistry.def b/llvm/lib/Target/NVPTX/NVPTXPassRegistry.def index 176d334321a80..7d645bff7110f 100644 --- a/llvm/lib/Target/NVPTX/NVPTXPassRegistry.def +++ b/llvm/lib/Target/NVPTX/NVPTXPassRegistry.def @@ -40,5 +40,5 @@ FUNCTION_PASS("nvvm-intr-range", NVVMIntrRangePass()) FUNCTION_PASS("nvptx-copy-byval-args", NVPTXCopyByValArgsPass()) FUNCTION_PASS("nvptx-lower-args", NVPTXLowerArgsPass(*this)) FUNCTION_PASS("nvptx-tag-invariant-loads", NVPTXTagInvariantLoadsPass()) -FUNCTION_PASS("nvptx-fold-fma", NVPTXFoldFMAPass()) +FUNCTION_PASS("nvptx-ir-peephole", NVPTXIRPeepholePass()) #undef FUNCTION_PASS diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp index ad0493229ecf8..0b587145f61ec 100644 --- a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp @@ -51,11 +51,12 @@ static cl::opt cl::desc("Disable load/store vectorizer"), cl::init(false), cl::Hidden); -// FoldFMA is a new pass; this option will lets us turn it off in case we -// encounter some issues. -static cl::opt DisableFoldFMA("disable-nvptx-fold-fma", - cl::desc("Disable NVPTX Fold FMA"), - cl::init(false), cl::Hidden); +// NVPTX IR Peephole is a new pass; this option will lets us turn it off in case +// we encounter some issues. +static cl::opt + DisableNVPTXIRPeephole("disable-nvptx-ir-peephole", + cl::desc("Disable NVPTX IR Peephole"), + cl::init(false), cl::Hidden); // TODO: Remove this flag when we are confident with no regressions. static cl::opt DisableRequireStructuredCFG( @@ -121,7 +122,7 @@ extern "C" LLVM_ABI LLVM_EXTERNAL_VISIBILITY void LLVMInitializeNVPTXTarget() { initializeNVPTXExternalAAWrapperPass(PR); initializeNVPTXPeepholePass(PR); initializeNVPTXTagInvariantLoadLegacyPassPass(PR); - initializeNVPTXFoldFMAPass(PR); + initializeNVPTXIRPeepholePass(PR); initializeNVPTXPrologEpilogPassPass(PR); } @@ -404,8 +405,8 @@ void NVPTXPassConfig::addIRPasses() { addPass(createLoadStoreVectorizerPass()); addPass(createSROAPass()); addPass(createNVPTXTagInvariantLoadsPass()); - if (!DisableFoldFMA) - addPass(createNVPTXFoldFMAPass()); + if (!DisableNVPTXIRPeephole) + addPass(createNVPTXIRPeepholePass()); } if (ST.hasPTXASUnreachableBug()) { diff --git a/llvm/test/CodeGen/NVPTX/nvptx-fold-fma.ll b/llvm/test/CodeGen/NVPTX/nvptx-fold-fma.ll index ef01e9f044acf..6d9ad8d3ad436 100644 --- a/llvm/test/CodeGen/NVPTX/nvptx-fold-fma.ll +++ b/llvm/test/CodeGen/NVPTX/nvptx-fold-fma.ll @@ -1,5 +1,5 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 -; RUN: opt < %s -passes=nvptx-fold-fma -S | FileCheck %s +; RUN: opt < %s -passes=nvptx-ir-peephole -S | FileCheck %s target triple = "nvptx64-nvidia-cuda" @@ -47,6 +47,25 @@ define float @test_fsub_fmul_fmul(float %a, float %b, float %c, float %d) { } +; fsub(fmul(a, b), fmul(c, d)) => fma(fneg(c), d, fmul(a, b))) +; fmul(a, b) has multiple uses. +define float @test_fsub_fmul_fmul_multiple_use(float %a, float %b, float %c, float %d) { +; CHECK-LABEL: define float @test_fsub_fmul_fmul_multiple_use( +; CHECK-SAME: float [[A:%.*]], float [[B:%.*]], float [[C:%.*]], float [[D:%.*]]) { +; CHECK-NEXT: [[MUL1:%.*]] = fmul contract float [[A]], [[B]] +; CHECK-NEXT: [[TMP1:%.*]] = fneg contract float [[C]] +; CHECK-NEXT: [[TMP2:%.*]] = call contract float @llvm.fma.f32(float [[TMP1]], float [[D]], float [[MUL1]]) +; CHECK-NEXT: [[ADD:%.*]] = fadd float [[TMP2]], [[MUL1]] +; CHECK-NEXT: ret float [[ADD]] +; + %mul1 = fmul contract float %a, %b + %mul2 = fmul contract float %c, %d + %sub = fsub contract float %mul1, %mul2 + %add = fadd float %sub, %mul1 + ret float %add +} + + ; fsub(fmul(a, b), c) => fma(a, b, fneg(c)) where fsub and fmul are in different BBs define float @test_fsub_fmul_different_BB(float %a, float %b, float %c, i32 %n) { ; CHECK-LABEL: define float @test_fsub_fmul_different_BB(