Skip to content

[RISCV] Early exit if the type legalization cost is not valid for getIntrinsicInstrCost #154256

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

tclin914
Copy link
Contributor

The motivation of this change is that a crash was encountered when the code calls TLI->getTypeToPromoteTo(ISD::FSQRT, FsqrtType) with type f16/bf16, but f16/bf16 wasn't promotable if zvfhmin/zvfbfmin are not enabled. This leads to an assertion failure.

To prevent this, we now check whether the type legalization cost is valid at the beginning of the function before proceeding further.

The code to check the type legalization cost is valid is copied from llvm/include/llvm/CodeGen/BasicTTIImpl.h.

…IntrinsicInstrCost

The motivation of this change is that a crash was encountered when
the code calls `TLI->getTypeToPromoteTo(ISD::FSQRT, FsqrtType)` with
type f16/bf16, but f16/bf16 wasn't promotable if zvfhmin/zvfbfmin are not enabled.
This leads to an assertion failure.

To prevent this, we now check whether the type legalization cost is
valid at the beginning of the function before proceeding further.

The code to check the type legalization cost is valid is copied from
`llvm/include/llvm/CodeGen/BasicTTIImpl.h`.
@llvmbot llvmbot added backend:RISC-V llvm:analysis Includes value tracking, cost tables and constant folding labels Aug 19, 2025
@llvmbot
Copy link
Member

llvmbot commented Aug 19, 2025

@llvm/pr-subscribers-llvm-analysis

@llvm/pr-subscribers-backend-risc-v

Author: Jim Lin (tclin914)

Changes

The motivation of this change is that a crash was encountered when the code calls TLI->getTypeToPromoteTo(ISD::FSQRT, FsqrtType) with type f16/bf16, but f16/bf16 wasn't promotable if zvfhmin/zvfbfmin are not enabled. This leads to an assertion failure.

To prevent this, we now check whether the type legalization cost is valid at the beginning of the function before proceeding further.

The code to check the type legalization cost is valid is copied from llvm/include/llvm/CodeGen/BasicTTIImpl.h.


Full diff: https://github.com/llvm/llvm-project/pull/154256.diff

2 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp (+12-25)
  • (modified) llvm/test/Analysis/CostModel/RISCV/fp-sqrt-pow.ll (+2)
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index af78b3cc2c7ff..9b41d8cad10fe 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -1245,12 +1245,17 @@ InstructionCost
 RISCVTTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
                                     TTI::TargetCostKind CostKind) const {
   auto *RetTy = ICA.getReturnType();
+  auto *STy = dyn_cast<StructType>(RetTy);
+  Type *LegalizeTy = STy ? STy->getContainedType(0) : RetTy;
+  auto LT = getTypeLegalizationCost(LegalizeTy);
+  if (!LT.first.isValid())
+    return InstructionCost::getInvalid();
+
   switch (ICA.getID()) {
   case Intrinsic::lrint:
   case Intrinsic::llrint:
   case Intrinsic::lround:
   case Intrinsic::llround: {
-    auto LT = getTypeLegalizationCost(RetTy);
     Type *SrcTy = ICA.getArgTypes().front();
     auto SrcLT = getTypeLegalizationCost(SrcTy);
     if (ST->hasVInstructions() && LT.second.isVector()) {
@@ -1258,16 +1263,12 @@ RISCVTTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
       unsigned SrcEltSz = DL.getTypeSizeInBits(SrcTy->getScalarType());
       unsigned DstEltSz = DL.getTypeSizeInBits(RetTy->getScalarType());
       if (LT.second.getVectorElementType() == MVT::bf16) {
-        if (!ST->hasVInstructionsBF16Minimal())
-          return InstructionCost::getInvalid();
         if (DstEltSz == 32)
           Ops = {RISCV::VFWCVTBF16_F_F_V, RISCV::VFCVT_X_F_V};
         else
           Ops = {RISCV::VFWCVTBF16_F_F_V, RISCV::VFWCVT_X_F_V};
       } else if (LT.second.getVectorElementType() == MVT::f16 &&
                  !ST->hasVInstructionsF16()) {
-        if (!ST->hasVInstructionsF16Minimal())
-          return InstructionCost::getInvalid();
         if (DstEltSz == 32)
           Ops = {RISCV::VFWCVT_F_F_V, RISCV::VFCVT_X_F_V};
         else
@@ -1297,7 +1298,6 @@ RISCVTTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
   case Intrinsic::round:
   case Intrinsic::roundeven: {
     // These all use the same code.
-    auto LT = getTypeLegalizationCost(RetTy);
     if (!LT.second.isVector() && TLI->isOperationCustom(ISD::FCEIL, LT.second))
       return LT.first * 8;
     break;
@@ -1306,7 +1306,6 @@ RISCVTTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
   case Intrinsic::umax:
   case Intrinsic::smin:
   case Intrinsic::smax: {
-    auto LT = getTypeLegalizationCost(RetTy);
     if (LT.second.isScalarInteger() && ST->hasStdExtZbb())
       return LT.first;
 
@@ -1334,7 +1333,6 @@ RISCVTTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
   case Intrinsic::ssub_sat:
   case Intrinsic::uadd_sat:
   case Intrinsic::usub_sat: {
-    auto LT = getTypeLegalizationCost(RetTy);
     if (ST->hasVInstructions() && LT.second.isVector()) {
       unsigned Op;
       switch (ICA.getID()) {
@@ -1358,14 +1356,12 @@ RISCVTTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
   case Intrinsic::fma:
   case Intrinsic::fmuladd: {
     // TODO: handle promotion with f16/bf16 with zvfhmin/zvfbfmin
-    auto LT = getTypeLegalizationCost(RetTy);
     if (ST->hasVInstructions() && LT.second.isVector())
       return LT.first *
              getRISCVInstructionCost(RISCV::VFMADD_VV, LT.second, CostKind);
     break;
   }
   case Intrinsic::fabs: {
-    auto LT = getTypeLegalizationCost(RetTy);
     if (ST->hasVInstructions() && LT.second.isVector()) {
       // lui a0, 8
       // addi a0, a0, -1
@@ -1385,7 +1381,6 @@ RISCVTTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
     break;
   }
   case Intrinsic::sqrt: {
-    auto LT = getTypeLegalizationCost(RetTy);
     if (ST->hasVInstructions() && LT.second.isVector()) {
       SmallVector<unsigned, 4> ConvOp;
       SmallVector<unsigned, 2> FsqrtOp;
@@ -1430,7 +1425,6 @@ RISCVTTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
   case Intrinsic::cttz:
   case Intrinsic::ctlz:
   case Intrinsic::ctpop: {
-    auto LT = getTypeLegalizationCost(RetTy);
     if (ST->hasVInstructions() && ST->hasStdExtZvbb() && LT.second.isVector()) {
       unsigned Op;
       switch (ICA.getID()) {
@@ -1449,7 +1443,6 @@ RISCVTTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
     break;
   }
   case Intrinsic::abs: {
-    auto LT = getTypeLegalizationCost(RetTy);
     if (ST->hasVInstructions() && LT.second.isVector()) {
       // vrsub.vi v10, v8, 0
       // vmax.vv v8, v8, v10
@@ -1476,7 +1469,6 @@ RISCVTTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
   }
   // TODO: add more intrinsic
   case Intrinsic::stepvector: {
-    auto LT = getTypeLegalizationCost(RetTy);
     // Legalisation of illegal types involves an `index' instruction plus
     // (LT.first - 1) vector adds.
     if (ST->hasVInstructions())
@@ -1506,7 +1498,6 @@ RISCVTTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
     return Cost;
   }
   case Intrinsic::experimental_vp_splat: {
-    auto LT = getTypeLegalizationCost(RetTy);
     // TODO: Lower i1 experimental_vp_splat
     if (!ST->hasVInstructions() || LT.second.getScalarType() == MVT::i1)
       return InstructionCost::getInvalid();
@@ -1530,11 +1521,10 @@ RISCVTTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
     Type *SrcTy = ICA.getArgTypes()[0];
 
     auto SrcLT = getTypeLegalizationCost(SrcTy);
-    auto DstLT = getTypeLegalizationCost(RetTy);
     if (!SrcTy->isVectorTy())
       break;
 
-    if (!SrcLT.first.isValid() || !DstLT.first.isValid())
+    if (!SrcLT.first.isValid())
       return InstructionCost::getInvalid();
 
     Cost +=
@@ -1553,14 +1543,11 @@ RISCVTTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
   }
   }
 
-  if (ST->hasVInstructions() && RetTy->isVectorTy()) {
-    if (auto LT = getTypeLegalizationCost(RetTy);
-        LT.second.isVector()) {
-      MVT EltTy = LT.second.getVectorElementType();
-      if (const auto *Entry = CostTableLookup(VectorIntrinsicCostTable,
-                                              ICA.getID(), EltTy))
-        return LT.first * Entry->Cost;
-    }
+  if (ST->hasVInstructions() && LT.second.isVector()) {
+    MVT EltTy = LT.second.getVectorElementType();
+    if (const auto *Entry = CostTableLookup(VectorIntrinsicCostTable,
+                                            ICA.getID(), EltTy))
+      return LT.first * Entry->Cost;
   }
 
   return BaseT::getIntrinsicInstrCost(ICA, CostKind);
diff --git a/llvm/test/Analysis/CostModel/RISCV/fp-sqrt-pow.ll b/llvm/test/Analysis/CostModel/RISCV/fp-sqrt-pow.ll
index 32ad44f7dda7b..60f21690a950e 100644
--- a/llvm/test/Analysis/CostModel/RISCV/fp-sqrt-pow.ll
+++ b/llvm/test/Analysis/CostModel/RISCV/fp-sqrt-pow.ll
@@ -1,6 +1,8 @@
 ; NOTE: Assertions have been autogenerated by utils/update_analyze_test_checks.py
 ; RUN: opt < %s -passes="print<cost-model>" 2>&1 -disable-output -S -mtriple=riscv64 -mattr=+v,+f,+d,+zvfh,+zvfbfmin | FileCheck %s --check-prefixes=CHECK,ZVFH
 ; RUN: opt < %s -passes="print<cost-model>" 2>&1 -disable-output -S -mtriple=riscv64 -mattr=+v,+f,+d,+zvfhmin,+zvfbfmin | FileCheck %s --check-prefixes=CHECK,ZVFHMIN
+; Check that we don't crash querying costs when zvfhmin/zvfbfmin are not enabled.
+; RUN: opt -passes="print<cost-model>" 2>&1 -disable-output -mtriple=riscv64 -mattr=+v,+f,+d
 
 define void @sqrt() {
 ; CHECK-LABEL: 'sqrt'

Copy link
Contributor

@artagnon artagnon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The motivation of this change is that a crash was encountered when the code calls TLI->getTypeToPromoteTo(ISD::FSQRT, FsqrtType) with type f16/bf16, but f16/bf16 wasn't promotable if zvfhmin/zvfbfmin are not enabled. This leads to an assertion failure.

Isn't the right fix something like the following?

        if (!ST->hasVInstructionsF16Minimal())
          return InstructionCost::getInvalid();
       if (!ST->hasVInstructionsBF16Minimal())
          return InstructionCost::getInvalid();

@tclin914
Copy link
Contributor Author

The motivation of this change is that a crash was encountered when the code calls TLI->getTypeToPromoteTo(ISD::FSQRT, FsqrtType) with type f16/bf16, but f16/bf16 wasn't promotable if zvfhmin/zvfbfmin are not enabled. This leads to an assertion failure.

Isn't the right fix something like the following?

        if (!ST->hasVInstructionsF16Minimal())
          return InstructionCost::getInvalid();
       if (!ST->hasVInstructionsBF16Minimal())
          return InstructionCost::getInvalid();

Yeah, that is another fix we can check if zvfhmin/zvfbfmin are enabled before calling TLI->getTypeToPromoteTo(ISD::FSQRT, FsqrtType) to avoid assertion failure. The changes in this PR is also kind of refactor for all cases that can early exit if the type is invalid.

@@ -1358,14 +1356,12 @@ RISCVTTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
case Intrinsic::fma:
case Intrinsic::fmuladd: {
// TODO: handle promotion with f16/bf16 with zvfhmin/zvfbfmin
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Outdated comment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It isn't outdated comment. We still need to handle promotion cost with f16/bf16 if it is only with zvfhmin/zvfbfmin.

Comment on lines +1248 to +1249
auto *STy = dyn_cast<StructType>(RetTy);
Type *LegalizeTy = STy ? STy->getContainedType(0) : RetTy;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this necessary? Is it correct to get the 0th contained type?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The RetTy might be {i8, i1} when the intrinsic is llvm.sadd.with.overflow, so we need to use this code to handle this case properly.

Comment on lines +4 to +5
; Check that we don't crash querying costs when zvfhmin/zvfbfmin are not enabled.
; RUN: opt < %s -passes="print<cost-model>" 2>&1 -disable-output -mtriple=riscv64 -mattr=+v,+f,+d
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add this line to some other intrinsic test-files as well?

Copy link
Contributor Author

@tclin914 tclin914 Aug 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added this line to other intrinsic test files which are related to f16/bf16 types. Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:RISC-V llvm:analysis Includes value tracking, cost tables and constant folding
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants