Skip to content

Conversation

rajatbajpai
Copy link
Contributor

This change introduces a new IR pass in the llc pipeline for NVPTX that transforms sequences of FMUL followed by FADD or FSUB into a single FMA instruction.

Currently, all FMA folding for NVPTX occurs at the DAGCombine stage, which is too late for any IR-level passes that might want to optimize or analyze FMAs. By moving this transformation earlier into the IR phase, we enable more opportunities for FMA folding, including across basic blocks.

Additionally, this new pass relies on the contract instruction level fast-math flag to perform these transformations, rather than depending on the -fp-contract=fast or -enable-unsafe-fp-math options passed to llc.

This change introduces a new IR pass in the llc pipeline for NVPTX that transforms
sequences of FMUL followed by FADD or FSUB into a single FMA instruction.

Currently, all FMA folding for NVPTX occurs at the DAGCombine stage, which is too
late for any IR-level passes that might want to optimize or analyze FMAs. By moving
this transformation earlier into the IR phase, we enable more opportunities for
FMA folding, including across basic blocks.

Additionally, this new pass relies on the contract instruction level fast-math flag
to perform these transformations, rather than depending on the -fp-contract=fast
or -enable-unsafe-fp-math options passed to llc.
@llvmbot
Copy link
Member

llvmbot commented Aug 21, 2025

@llvm/pr-subscribers-backend-nvptx

Author: Rajat Bajpai (rajatbajpai)

Changes

This change introduces a new IR pass in the llc pipeline for NVPTX that transforms sequences of FMUL followed by FADD or FSUB into a single FMA instruction.

Currently, all FMA folding for NVPTX occurs at the DAGCombine stage, which is too late for any IR-level passes that might want to optimize or analyze FMAs. By moving this transformation earlier into the IR phase, we enable more opportunities for FMA folding, including across basic blocks.

Additionally, this new pass relies on the contract instruction level fast-math flag to perform these transformations, rather than depending on the -fp-contract=fast or -enable-unsafe-fp-math options passed to llc.


Full diff: https://github.com/llvm/llvm-project/pull/154735.diff

6 Files Affected:

  • (modified) llvm/lib/Target/NVPTX/CMakeLists.txt (+1)
  • (modified) llvm/lib/Target/NVPTX/NVPTX.h (+6)
  • (added) llvm/lib/Target/NVPTX/NVPTXFoldFMA.cpp (+146)
  • (modified) llvm/lib/Target/NVPTX/NVPTXPassRegistry.def (+1)
  • (modified) llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp (+9)
  • (added) llvm/test/CodeGen/NVPTX/nvptx-fold-fma.ll (+228)
diff --git a/llvm/lib/Target/NVPTX/CMakeLists.txt b/llvm/lib/Target/NVPTX/CMakeLists.txt
index 693f0d0b35edc..1264a5e2e9f32 100644
--- a/llvm/lib/Target/NVPTX/CMakeLists.txt
+++ b/llvm/lib/Target/NVPTX/CMakeLists.txt
@@ -17,6 +17,7 @@ set(NVPTXCodeGen_sources
   NVPTXAssignValidGlobalNames.cpp
   NVPTXAtomicLower.cpp
   NVPTXCtorDtorLowering.cpp
+  NVPTXFoldFMA.cpp
   NVPTXForwardParams.cpp
   NVPTXFrameLowering.cpp
   NVPTXGenericToNVVM.cpp
diff --git a/llvm/lib/Target/NVPTX/NVPTX.h b/llvm/lib/Target/NVPTX/NVPTX.h
index 77a0e03d4075a..e84fa42319b34 100644
--- a/llvm/lib/Target/NVPTX/NVPTX.h
+++ b/llvm/lib/Target/NVPTX/NVPTX.h
@@ -52,6 +52,7 @@ FunctionPass *createNVPTXLowerAllocaPass();
 FunctionPass *createNVPTXLowerUnreachablePass(bool TrapUnreachable,
                                               bool NoTrapAfterNoreturn);
 FunctionPass *createNVPTXTagInvariantLoadsPass();
+FunctionPass *createNVPTXFoldFMAPass();
 MachineFunctionPass *createNVPTXPeephole();
 MachineFunctionPass *createNVPTXProxyRegErasurePass();
 MachineFunctionPass *createNVPTXForwardParamsPass();
@@ -76,12 +77,17 @@ void initializeNVPTXAAWrapperPassPass(PassRegistry &);
 void initializeNVPTXExternalAAWrapperPass(PassRegistry &);
 void initializeNVPTXPeepholePass(PassRegistry &);
 void initializeNVPTXTagInvariantLoadLegacyPassPass(PassRegistry &);
+void initializeNVPTXFoldFMAPass(PassRegistry &);
 void initializeNVPTXPrologEpilogPassPass(PassRegistry &);
 
 struct NVVMIntrRangePass : PassInfoMixin<NVVMIntrRangePass> {
   PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
 };
 
+struct NVPTXFoldFMAPass : PassInfoMixin<NVPTXFoldFMAPass> {
+  PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
+};
+
 struct NVVMReflectPass : PassInfoMixin<NVVMReflectPass> {
   NVVMReflectPass() : SmVersion(0) {}
   NVVMReflectPass(unsigned SmVersion) : SmVersion(SmVersion) {}
diff --git a/llvm/lib/Target/NVPTX/NVPTXFoldFMA.cpp b/llvm/lib/Target/NVPTX/NVPTXFoldFMA.cpp
new file mode 100644
index 0000000000000..b844b880559ac
--- /dev/null
+++ b/llvm/lib/Target/NVPTX/NVPTXFoldFMA.cpp
@@ -0,0 +1,146 @@
+//===------ NVPTXFoldFMA.cpp - Fold FMA --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements FMA folding for float/double type for NVPTX. It folds
+// following patterns:
+// 1. fadd(fmul(a, b), c) => fma(a, b, c)
+// 2. fadd(c, fmul(a, b)) => fma(a, b, c)
+// 3. fadd(fmul(a, b), fmul(c, d)) => fma(a, b, fmul(c, d))
+// 4. fsub(fmul(a, b), c) => fma(a, b, fneg(c))
+// 5. fsub(a, fmul(b, c)) => fma(fneg(b), c, a)
+// 6. fsub(fmul(a, b), fmul(c, d)) => fma(a, b, fneg(fmul(c, d)))
+//===----------------------------------------------------------------------===//
+
+#include "NVPTXUtilities.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Intrinsics.h"
+
+#define DEBUG_TYPE "nvptx-fold-fma"
+
+using namespace llvm;
+
+static bool foldFMA(Function &F) {
+  bool Changed = false;
+  SmallVector<BinaryOperator *, 16> FAddFSubInsts;
+
+  // Collect all float/double FAdd/FSub instructions with allow-contract
+  for (auto &I : instructions(F)) {
+    if (auto *BI = dyn_cast<BinaryOperator>(&I)) {
+      // Only FAdd and FSub are supported.
+      if (BI->getOpcode() != Instruction::FAdd &&
+          BI->getOpcode() != Instruction::FSub)
+        continue;
+
+      // At minimum, the instruction should have allow-contract.
+      if (!BI->hasAllowContract())
+        continue;
+
+      // Only float and double are supported.
+      if (!BI->getType()->isFloatTy() && !BI->getType()->isDoubleTy())
+        continue;
+
+      FAddFSubInsts.push_back(BI);
+    }
+  }
+
+  auto tryFoldBinaryFMul = [](BinaryOperator *BI, Value *MulOperand,
+                              Value *OtherOperand, bool IsFirstOperand,
+                              bool IsFSub) -> bool {
+    auto *FMul = dyn_cast<BinaryOperator>(MulOperand);
+    if (!FMul || FMul->getOpcode() != Instruction::FMul || !FMul->hasOneUse() ||
+        !FMul->hasAllowContract())
+      return false;
+
+    LLVM_DEBUG({
+      const char *OpName = IsFSub ? "FSub" : "FAdd";
+      dbgs() << "Found " << OpName << " with FMul (single use) as "
+             << (IsFirstOperand ? "first" : "second") << " operand: " << *BI
+             << "\n";
+    });
+
+    Value *MulOp0 = FMul->getOperand(0);
+    Value *MulOp1 = FMul->getOperand(1);
+    IRBuilder<> Builder(BI);
+    Value *FMA = nullptr;
+
+    if (!IsFSub) {
+      // fadd(fmul(a, b), c) => fma(a, b, c)
+      // fadd(c, fmul(a, b)) => fma(a, b, c)
+      FMA = Builder.CreateIntrinsic(Intrinsic::fma, {BI->getType()},
+                                    {MulOp0, MulOp1, OtherOperand});
+    } else {
+      if (IsFirstOperand) {
+        // fsub(fmul(a, b), c) => fma(a, b, fneg(c))
+        Value *NegOtherOp = Builder.CreateFNeg(OtherOperand);
+        cast<Instruction>(NegOtherOp)->setFastMathFlags(BI->getFastMathFlags());
+        FMA = Builder.CreateIntrinsic(Intrinsic::fma, {BI->getType()},
+                                      {MulOp0, MulOp1, NegOtherOp});
+      } else {
+        // fsub(a, fmul(b, c)) => fma(fneg(b), c, a)
+        Value *NegMulOp0 = Builder.CreateFNeg(MulOp0);
+        cast<Instruction>(NegMulOp0)->setFastMathFlags(
+            FMul->getFastMathFlags());
+        FMA = Builder.CreateIntrinsic(Intrinsic::fma, {BI->getType()},
+                                      {NegMulOp0, MulOp1, OtherOperand});
+      }
+    }
+
+    // Combine fast-math flags from the original instructions
+    auto *FMAInst = cast<Instruction>(FMA);
+    FastMathFlags BinaryFMF = BI->getFastMathFlags();
+    FastMathFlags FMulFMF = FMul->getFastMathFlags();
+    FastMathFlags NewFMF = FastMathFlags::intersectRewrite(BinaryFMF, FMulFMF) |
+                           FastMathFlags::unionValue(BinaryFMF, FMulFMF);
+    FMAInst->setFastMathFlags(NewFMF);
+
+    LLVM_DEBUG({
+      const char *OpName = IsFSub ? "FSub" : "FAdd";
+      dbgs() << "Replacing " << OpName << " with FMA: " << *FMA << "\n";
+    });
+    BI->replaceAllUsesWith(FMA);
+    BI->eraseFromParent();
+    FMul->eraseFromParent();
+    return true;
+  };
+
+  for (auto *BI : FAddFSubInsts) {
+    Value *Op0 = BI->getOperand(0);
+    Value *Op1 = BI->getOperand(1);
+    bool IsFSub = BI->getOpcode() == Instruction::FSub;
+
+    if (tryFoldBinaryFMul(BI, Op0, Op1, true /*IsFirstOperand*/, IsFSub) ||
+        tryFoldBinaryFMul(BI, Op1, Op0, false /*IsFirstOperand*/, IsFSub))
+      Changed = true;
+  }
+
+  return Changed;
+}
+
+namespace {
+
+struct NVPTXFoldFMA : public FunctionPass {
+  static char ID;
+  NVPTXFoldFMA() : FunctionPass(ID) {}
+  bool runOnFunction(Function &F) override;
+};
+
+} // namespace
+
+char NVPTXFoldFMA::ID = 0;
+INITIALIZE_PASS(NVPTXFoldFMA, "nvptx-fold-fma", "NVPTX Fold FMA", false, false)
+
+bool NVPTXFoldFMA::runOnFunction(Function &F) { return foldFMA(F); }
+
+FunctionPass *llvm::createNVPTXFoldFMAPass() { return new NVPTXFoldFMA(); }
+
+PreservedAnalyses NVPTXFoldFMAPass::run(Function &F,
+                                        FunctionAnalysisManager &) {
+  return foldFMA(F) ? PreservedAnalyses::none() : PreservedAnalyses::all();
+}
diff --git a/llvm/lib/Target/NVPTX/NVPTXPassRegistry.def b/llvm/lib/Target/NVPTX/NVPTXPassRegistry.def
index ee37c9826012c..176d334321a80 100644
--- a/llvm/lib/Target/NVPTX/NVPTXPassRegistry.def
+++ b/llvm/lib/Target/NVPTX/NVPTXPassRegistry.def
@@ -40,4 +40,5 @@ FUNCTION_PASS("nvvm-intr-range", NVVMIntrRangePass())
 FUNCTION_PASS("nvptx-copy-byval-args", NVPTXCopyByValArgsPass())
 FUNCTION_PASS("nvptx-lower-args", NVPTXLowerArgsPass(*this))
 FUNCTION_PASS("nvptx-tag-invariant-loads", NVPTXTagInvariantLoadsPass())
+FUNCTION_PASS("nvptx-fold-fma", NVPTXFoldFMAPass())
 #undef FUNCTION_PASS
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
index 0603994606d71..ad0493229ecf8 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
@@ -51,6 +51,12 @@ static cl::opt<bool>
                                cl::desc("Disable load/store vectorizer"),
                                cl::init(false), cl::Hidden);
 
+// FoldFMA is a new pass; this option will lets us turn it off in case we
+// encounter some issues.
+static cl::opt<bool> DisableFoldFMA("disable-nvptx-fold-fma",
+                                    cl::desc("Disable NVPTX Fold FMA"),
+                                    cl::init(false), cl::Hidden);
+
 // TODO: Remove this flag when we are confident with no regressions.
 static cl::opt<bool> DisableRequireStructuredCFG(
     "disable-nvptx-require-structured-cfg",
@@ -115,6 +121,7 @@ extern "C" LLVM_ABI LLVM_EXTERNAL_VISIBILITY void LLVMInitializeNVPTXTarget() {
   initializeNVPTXExternalAAWrapperPass(PR);
   initializeNVPTXPeepholePass(PR);
   initializeNVPTXTagInvariantLoadLegacyPassPass(PR);
+  initializeNVPTXFoldFMAPass(PR);
   initializeNVPTXPrologEpilogPassPass(PR);
 }
 
@@ -397,6 +404,8 @@ void NVPTXPassConfig::addIRPasses() {
       addPass(createLoadStoreVectorizerPass());
     addPass(createSROAPass());
     addPass(createNVPTXTagInvariantLoadsPass());
+    if (!DisableFoldFMA)
+      addPass(createNVPTXFoldFMAPass());
   }
 
   if (ST.hasPTXASUnreachableBug()) {
diff --git a/llvm/test/CodeGen/NVPTX/nvptx-fold-fma.ll b/llvm/test/CodeGen/NVPTX/nvptx-fold-fma.ll
new file mode 100644
index 0000000000000..ef01e9f044acf
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/nvptx-fold-fma.ll
@@ -0,0 +1,228 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -passes=nvptx-fold-fma -S | FileCheck %s
+
+target triple = "nvptx64-nvidia-cuda"
+
+; fsub(fmul(a, b), c) => fma(a, b, fneg(c))
+define float @test_fsub_fmul_c(float %a, float %b, float %c) {
+; CHECK-LABEL: define float @test_fsub_fmul_c(
+; CHECK-SAME: float [[A:%.*]], float [[B:%.*]], float [[C:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = fneg contract float [[C]]
+; CHECK-NEXT:    [[TMP2:%.*]] = call contract float @llvm.fma.f32(float [[A]], float [[B]], float [[TMP1]])
+; CHECK-NEXT:    ret float [[TMP2]]
+;
+  %mul = fmul contract float %a, %b
+  %sub = fsub contract float %mul, %c
+  ret float %sub
+}
+
+
+; fsub(c, fmul(a, b)) => fma(-a, b, c)
+define float @test_fsub_c_fmul(float %a, float %b, float %c) {
+; CHECK-LABEL: define float @test_fsub_c_fmul(
+; CHECK-SAME: float [[A:%.*]], float [[B:%.*]], float [[C:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = fneg contract float [[A]]
+; CHECK-NEXT:    [[TMP2:%.*]] = call contract float @llvm.fma.f32(float [[TMP1]], float [[B]], float [[C]])
+; CHECK-NEXT:    ret float [[TMP2]]
+;
+  %mul = fmul contract float %a, %b
+  %sub = fsub contract float %c, %mul
+  ret float %sub
+}
+
+
+; fsub(fmul(a, b), fmul(c, d)) => fma(a, b, fneg(fmul(c, d)))
+define float @test_fsub_fmul_fmul(float %a, float %b, float %c, float %d) {
+; CHECK-LABEL: define float @test_fsub_fmul_fmul(
+; CHECK-SAME: float [[A:%.*]], float [[B:%.*]], float [[C:%.*]], float [[D:%.*]]) {
+; CHECK-NEXT:    [[MUL2:%.*]] = fmul contract float [[C]], [[D]]
+; CHECK-NEXT:    [[TMP1:%.*]] = fneg contract float [[MUL2]]
+; CHECK-NEXT:    [[TMP2:%.*]] = call contract float @llvm.fma.f32(float [[A]], float [[B]], float [[TMP1]])
+; CHECK-NEXT:    ret float [[TMP2]]
+;
+  %mul1 = fmul contract float %a, %b
+  %mul2 = fmul contract float %c, %d
+  %sub = fsub contract float %mul1, %mul2
+  ret float %sub
+}
+
+
+; fsub(fmul(a, b), c) => fma(a, b, fneg(c)) where fsub and fmul are in different BBs
+define float @test_fsub_fmul_different_BB(float %a, float %b, float %c, i32 %n) {
+; CHECK-LABEL: define float @test_fsub_fmul_different_BB(
+; CHECK-SAME: float [[A:%.*]], float [[B:%.*]], float [[C:%.*]], i32 [[N:%.*]]) {
+; CHECK-NEXT:  [[INIT:.*]]:
+; CHECK-NEXT:    [[CMP_ITER:%.*]] = icmp sgt i32 [[N]], 10
+; CHECK-NEXT:    br i1 [[CMP_ITER]], label %[[ITERATION:.*]], label %[[EXIT:.*]]
+; CHECK:       [[ITERATION]]:
+; CHECK-NEXT:    [[I:%.*]] = phi i32 [ 0, %[[INIT]] ], [ [[I_NEXT:%.*]], %[[ITERATION]] ]
+; CHECK-NEXT:    [[ACC:%.*]] = phi float [ [[C]], %[[INIT]] ], [ [[ACC_NEXT:%.*]], %[[ITERATION]] ]
+; CHECK-NEXT:    [[I_NEXT]] = add i32 [[I]], 1
+; CHECK-NEXT:    [[ACC_NEXT]] = fadd contract float [[ACC]], 1.000000e+00
+; CHECK-NEXT:    [[CMP_LOOP:%.*]] = icmp slt i32 [[I_NEXT]], [[N]]
+; CHECK-NEXT:    br i1 [[CMP_LOOP]], label %[[ITERATION]], label %[[EXIT]]
+; CHECK:       [[EXIT]]:
+; CHECK-NEXT:    [[C_PHI:%.*]] = phi float [ [[C]], %[[INIT]] ], [ [[ACC_NEXT]], %[[ITERATION]] ]
+; CHECK-NEXT:    [[TMP0:%.*]] = fneg contract float [[C_PHI]]
+; CHECK-NEXT:    [[TMP1:%.*]] = call contract float @llvm.fma.f32(float [[A]], float [[B]], float [[TMP0]])
+; CHECK-NEXT:    ret float [[TMP1]]
+;
+init:
+  %mul = fmul contract float %a, %b
+  %cmp_iter = icmp sgt i32 %n, 10
+  br i1 %cmp_iter, label %iteration, label %exit
+
+iteration:
+  %i = phi i32 [ 0, %init ], [ %i_next, %iteration ]
+  %acc = phi float [ %c, %init ], [ %acc_next, %iteration ]
+  %i_next = add i32 %i, 1
+  %acc_next = fadd contract float %acc, 1.0
+  %cmp_loop = icmp slt i32 %i_next, %n
+  br i1 %cmp_loop, label %iteration, label %exit
+
+exit:
+  %c_phi = phi float [ %c, %init ], [ %acc_next, %iteration ]
+  %sub = fsub contract float %mul, %c_phi
+  ret float %sub
+}
+
+
+; fadd(fmul(a, b), c) => fma(a, b, c)
+define float @test_fadd_fmul_c(float %a, float %b, float %c) {
+; CHECK-LABEL: define float @test_fadd_fmul_c(
+; CHECK-SAME: float [[A:%.*]], float [[B:%.*]], float [[C:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = call contract float @llvm.fma.f32(float [[A]], float [[B]], float [[C]])
+; CHECK-NEXT:    ret float [[TMP1]]
+;
+  %mul = fmul contract float %a, %b
+  %add = fadd contract float %mul, %c
+  ret float %add
+}
+
+
+; fadd(c, fmul(a, b)) => fma(a, b, c)
+define float @test_fadd_c_fmul(float %a, float %b, float %c) {
+; CHECK-LABEL: define float @test_fadd_c_fmul(
+; CHECK-SAME: float [[A:%.*]], float [[B:%.*]], float [[C:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = call contract float @llvm.fma.f32(float [[A]], float [[B]], float [[C]])
+; CHECK-NEXT:    ret float [[TMP1]]
+;
+  %mul = fmul contract float %a, %b
+  %add = fadd contract float %c, %mul
+  ret float %add
+}
+
+
+; fadd(fmul(a, b), fmul(c, d)) => fma(a, b, fmul(c, d))
+define float @test_fadd_fmul_fmul(float %a, float %b, float %c, float %d) {
+; CHECK-LABEL: define float @test_fadd_fmul_fmul(
+; CHECK-SAME: float [[A:%.*]], float [[B:%.*]], float [[C:%.*]], float [[D:%.*]]) {
+; CHECK-NEXT:    [[MUL2:%.*]] = fmul contract float [[C]], [[D]]
+; CHECK-NEXT:    [[TMP1:%.*]] = call contract float @llvm.fma.f32(float [[A]], float [[B]], float [[MUL2]])
+; CHECK-NEXT:    ret float [[TMP1]]
+;
+  %mul1 = fmul contract float %a, %b
+  %mul2 = fmul contract float %c, %d
+  %add = fadd contract float %mul1, %mul2
+  ret float %add
+}
+
+
+; fadd(fmul(a, b), c) => fma(a, b, c) where fadd and fmul are in different BBs
+define float @test_fadd_fmul_different_BB(float %a, float %b, float %c, i32 %n) {
+; CHECK-LABEL: define float @test_fadd_fmul_different_BB(
+; CHECK-SAME: float [[A:%.*]], float [[B:%.*]], float [[C:%.*]], i32 [[N:%.*]]) {
+; CHECK-NEXT:  [[INIT:.*]]:
+; CHECK-NEXT:    [[CMP_ITER:%.*]] = icmp sgt i32 [[N]], 10
+; CHECK-NEXT:    br i1 [[CMP_ITER]], label %[[ITERATION:.*]], label %[[EXIT:.*]]
+; CHECK:       [[ITERATION]]:
+; CHECK-NEXT:    [[I:%.*]] = phi i32 [ 0, %[[INIT]] ], [ [[I_NEXT:%.*]], %[[ITERATION]] ]
+; CHECK-NEXT:    [[ACC:%.*]] = phi float [ [[C]], %[[INIT]] ], [ [[ACC_NEXT:%.*]], %[[ITERATION]] ]
+; CHECK-NEXT:    [[I_NEXT]] = add i32 [[I]], 1
+; CHECK-NEXT:    [[ACC_NEXT]] = fadd contract float [[ACC]], 1.000000e+00
+; CHECK-NEXT:    [[CMP_LOOP:%.*]] = icmp slt i32 [[I_NEXT]], [[N]]
+; CHECK-NEXT:    br i1 [[CMP_LOOP]], label %[[ITERATION]], label %[[EXIT]]
+; CHECK:       [[EXIT]]:
+; CHECK-NEXT:    [[C_PHI:%.*]] = phi float [ [[C]], %[[INIT]] ], [ [[ACC_NEXT]], %[[ITERATION]] ]
+; CHECK-NEXT:    [[TMP0:%.*]] = call contract float @llvm.fma.f32(float [[A]], float [[B]], float [[C_PHI]])
+; CHECK-NEXT:    ret float [[TMP0]]
+;
+init:
+  %mul = fmul contract float %a, %b
+  %cmp_iter = icmp sgt i32 %n, 10
+  br i1 %cmp_iter, label %iteration, label %exit
+
+iteration:
+  %i = phi i32 [ 0, %init ], [ %i_next, %iteration ]
+  %acc = phi float [ %c, %init ], [ %acc_next, %iteration ]
+  %i_next = add i32 %i, 1
+  %acc_next = fadd contract float %acc, 1.0
+  %cmp_loop = icmp slt i32 %i_next, %n
+  br i1 %cmp_loop, label %iteration, label %exit
+
+exit:
+  %c_phi = phi float [ %c, %init ], [ %acc_next, %iteration ]
+  %add = fadd contract float %mul, %c_phi
+  ret float %add
+}
+
+
+; These scenarios shouldn't work.
+; fadd(fpext(fmul(a, b)), c) => fma(fpext(a), fpext(b), c)
+define double @test_fadd_fpext_fmul_c(float %a, float %b, double %c) {
+; CHECK-LABEL: define double @test_fadd_fpext_fmul_c(
+; CHECK-SAME: float [[A:%.*]], float [[B:%.*]], double [[C:%.*]]) {
+; CHECK-NEXT:    [[MUL:%.*]] = fmul contract float [[A]], [[B]]
+; CHECK-NEXT:    [[EXT:%.*]] = fpext float [[MUL]] to double
+; CHECK-NEXT:    [[ADD:%.*]] = fadd contract double [[EXT]], [[C]]
+; CHECK-NEXT:    ret double [[ADD]]
+;
+  %mul = fmul contract float %a, %b
+  %ext = fpext float %mul to double
+  %add = fadd contract double %ext, %c
+  ret double %add
+}
+
+
+; fadd(c, fpext(fmul(a, b))) => fma(fpext(a), fpext(b), c)
+define double @test_fadd_c_fpext_fmul(float %a, float %b, double %c) {
+; CHECK-LABEL: define double @test_fadd_c_fpext_fmul(
+; CHECK-SAME: float [[A:%.*]], float [[B:%.*]], double [[C:%.*]]) {
+; CHECK-NEXT:    [[MUL:%.*]] = fmul contract float [[A]], [[B]]
+; CHECK-NEXT:    [[EXT:%.*]] = fpext float [[MUL]] to double
+; CHECK-NEXT:    [[ADD:%.*]] = fadd contract double [[C]], [[EXT]]
+; CHECK-NEXT:    ret double [[ADD]]
+;
+  %mul = fmul contract float %a, %b
+  %ext = fpext float %mul to double
+  %add = fadd contract double %c, %ext
+  ret double %add
+}
+
+
+; Double precision tests
+; fsub(fmul(a, b), c) => fma(a, b, fneg(c))
+define double @test_fsub_fmul_c_double(double %a, double %b, double %c) {
+; CHECK-LABEL: define double @test_fsub_fmul_c_double(
+; CHECK-SAME: double [[A:%.*]], double [[B:%.*]], double [[C:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = fneg contract double [[C]]
+; CHECK-NEXT:    [[TMP2:%.*]] = call contract double @llvm.fma.f64(double [[A]], double [[B]], double [[TMP1]])
+; CHECK-NEXT:    ret double [[TMP2]]
+;
+  %mul = fmul contract double %a, %b
+  %sub = fsub contract double %mul, %c
+  ret double %sub
+}
+
+
+; fadd(fmul(a, b), c) => fma(a, b, c)
+define double @test_fadd_fmul_c_double(double %a, double %b, double %c) {
+; CHECK-LABEL: define double @test_fadd_fmul_c_double(
+; CHECK-SAME: double [[A:%.*]], double [[B:%.*]], double [[C:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = call contract double @llvm.fma.f64(double [[A]], double [[B]], double [[C]])
+; CHECK-NEXT:    ret double [[TMP1]]
+;
+  %mul = fmul contract double %a, %b
+  %add = fadd contract double %mul, %c
+  ret double %add
+}

@AlexMaclean
Copy link
Member

@nikic / @Artem-B Do either of you know if there is any existing support for LLVM IR level FMA-fusion? This seems broadly useful and maybe even like something that should be occurring in InstCombine (though I assume there is a good reason we don't do it there already)?

@nikic
Copy link
Contributor

nikic commented Aug 21, 2025

@arsenm probably has more context on this.

@justinfargnoli
Copy link
Contributor

which is too late for any IR-level passes that might want to optimize or analyze FMAs

Which passes are these? If there aren't any, do we plan on adding passes that do care about this?

@rajatbajpai
Copy link
Contributor Author

If there aren't any, do we plan on adding passes that do care about this?

Yes, we plan to add one.

@arsenm
Copy link
Contributor

arsenm commented Aug 22, 2025

There's a CodeGenPrepare hook that enables cross block FMA formation, e.g. #121465

@arsenm
Copy link
Contributor

arsenm commented Aug 22, 2025

that might want to optimize or analyze FMAs.

Fusing the FMA doesn't really give you new information. You could perform equivalent analysis on the separate operations

@rajatbajpai
Copy link
Contributor Author

Fusing the FMA doesn't really give you new information. You could perform equivalent analysis on the separate operations

We're aiming to vectorize the fma.f32 instructions into a fma.f32x2. To enable this, we plan to fold FMAs during the IR phase, prior to ISel. CUDA FMA Instructions.

While the bandwidth of two scalar FMAs is equivalent to that of a vectorized FMA, vectorization can benefit workloads that are bottlenecked by instruction issue rates. We plan to add this transformation as a separate, opt-in optimization pass in the llc pipeline.

@justinfargnoli justinfargnoli requested a review from Artem-B August 22, 2025 18:14
@rajatbajpai
Copy link
Contributor Author

@Artem-B Please let me know if you need any additional details or clarification to proceed with the review of these code changes. Thanks!

Copy link
Member

@AlexMaclean AlexMaclean left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks reasonable to me.

While I'm not convinced this is needed for <2 x float> vectorization, I think it can be simpler to just fold to an intrinsic in the IR as opposed to trying to keep a multi-instruction idiom recognizable and intact through other transformations. I know we've moved this direction with integer min/max and I think it makes sense here as well.

Comment on lines 46 to 47
if (!BI->getType()->isFloatTy() && !BI->getType()->isDoubleTy())
continue;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not half and bfloat?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No specific reasons—first, I wanted to support the float and double. This pass could be extended in the future to handle half and bfloat types as well.

Comment on lines 99 to 100
FastMathFlags NewFMF = FastMathFlags::intersectRewrite(BinaryFMF, FMulFMF) |
FastMathFlags::unionValue(BinaryFMF, FMulFMF);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you provide some justification for this? Unless you're copying it from somewhere else maybe some alive2 proofs would be good.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is based on my experience with fast-math flags handling in this PR #106492 (comment)

1. Moved lambda function into a static function.
2. Preserving CFG analysis.
3. Using CreateFNegFMF instead of CreateFNeg api.
@rajatbajpai rajatbajpai force-pushed the dev/rbajpai/nvptx-fma-vectorizer branch from 49f0c03 to f1eff5c Compare September 1, 2025 10:56
Copy link
Member

@Artem-B Artem-B left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some sort of late IR-level pass with peephole optimizations/transformations to aid lowering would be useful in general, not just for the FMA. Perhaps we should rename it to NVPTXIRPeepholePass, similar to the MF pass NVPTXPeepholePass

@arsenm
Copy link
Contributor

arsenm commented Sep 3, 2025

We're aiming to vectorize the fma.f32 instructions into a fma.f32x2. To enable this, we plan to fold FMAs during the IR phase, prior to ISel. CUDA FMA Instructions.

This is SLP vectorizer's job, and the other vectorizer.s

IMO there's no reason this should be a new pass

1. Removed extra arguments passed to tryFoldBinaryFMul.
2. Removed temporary storage to collect the binary instructions.
3. Made guarding condition little easier to read.
4. Added one more test scenario.
@rajatbajpai
Copy link
Contributor Author

This is SLP vectorizer's job, and the other vectorizer.s

IMO there's no reason this should be a new pass

I understand the intention but I have three concerns:

  1. Currently, the SLPVectorizer is invoked within the opt pipeline. Introducing it in llc solely for FMA vectorisation may not be justified, especially if its scope remains limited to that functionality.
  2. The NVPTXIRPeephole pass will synthesize FMAs during llc. For the SLPVectorizer to recognize and act on these FMAs, it must be scheduled after this pass. Otherwise, it won't have visibility into the transformed instructions.
  3. As far as I understand, SLPVectorizer performs vectorization that benefits multiple backends. However, it's unclear which other targets would gain from FMA-specific vectorization. Even for NVPTX, the benefits of this transformation are not universally guaranteed.

@rajatbajpai
Copy link
Contributor Author

Gentle ping for review.

@rajatbajpai
Copy link
Contributor Author

ping @Artem-B and @arsenm for review.

@arsenm
Copy link
Contributor

arsenm commented Sep 22, 2025

  1. Currently, the SLPVectorizer is invoked within the opt pipeline. Introducing it in llc solely for FMA vectorisation may not be justified, especially if its scope remains limited to that functionality.

FMA vectorization isn't special. It's just another one of many vectorizable operations

  1. The NVPTXIRPeephole pass will synthesize FMAs during llc. For the SLPVectorizer to recognize and act on these FMAs, it must be scheduled after this pass. Otherwise, it won't have visibility into the transformed instructions.

llc is just a testing utility, it doesn't mean anything on its own. The placement of the vectorizer in late middle end or early codegen is fairly arbitrary.

  1. As far as I understand, SLPVectorizer performs vectorization that benefits multiple backends. However, it's unclear which other targets would gain from FMA-specific vectorization. Even for NVPTX, the benefits of this transformation are not universally guaranteed.

This isn't a unique property, and that's what the cost model is for. The solution to cost model questions isn't reimplement a new vectorizer for every operation x every backend

@rajatbajpai
Copy link
Contributor Author

Thanks for your suggestions @arsenm.

llc is just a testing utility, it doesn't mean anything on its own. The placement of the vectorizer in late middle end or early codegen is fairly arbitrary.

Yes, I'm aware. I do have one concern with this suggestion. If we move the SLP to early codegen, the expectation would be that FMA synthesis occurs before SLP; otherwise, we risk missing vectorisation opportunities.
As far as I know, FMA synthesis currently happens in the DAGCombiner across all backends. So, is the idea here to move FMA folding from DAGCombiner into an IR-level pass for all backends—similar to what this new pass in the PR is doing for NVPTX?

FMA vectorization isn't special. It's just another one of many vectorizable operations

Out of curiosity, are there any other backends besides NVPTX that currently support FMA vectorisation instructions?

Copy link
Member

Artem-B commented Oct 3, 2025

are there any other backends besides NVPTX that currently support FMA vectorisation instructions?

AVX2 on intel has FMA3 instruction:

; FMA3-NEXT: vfmaddsub213pd {{.*#+}} xmm0 = (xmm1 * xmm0) +/- xmm2

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants