Skip to content

[DirectX][NFC] Refactoring DirectX backend to not use llvm::to_underlying in switch cases. #151032

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

joaosaffran
Copy link
Contributor

As request in #150676, DirectX has some switch statements that use to_underlying to handle the cases, this prevents compiler warnings when the statements are not fully covered.

Closes: #150676

@joaosaffran joaosaffran requested review from bogner and llvm-beanz July 28, 2025 20:28
@llvmbot llvmbot added mc Machine (object) code backend:DirectX HLSL HLSL Language Support objectyaml labels Jul 28, 2025
@joaosaffran joaosaffran requested a review from inbelic July 28, 2025 20:28
@llvmbot
Copy link
Member

llvmbot commented Jul 28, 2025

@llvm/pr-subscribers-backend-directx
@llvm/pr-subscribers-mc

@llvm/pr-subscribers-hlsl

Author: None (joaosaffran)

Changes

As request in #150676, DirectX has some switch statements that use to_underlying to handle the cases, this prevents compiler warnings when the statements are not fully covered.

Closes: #150676


Full diff: https://github.com/llvm/llvm-project/pull/151032.diff

6 Files Affected:

  • (modified) llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp (+12-5)
  • (modified) llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp (+3-4)
  • (modified) llvm/lib/MC/DXContainerRootSignature.cpp (+23-12)
  • (modified) llvm/lib/ObjectYAML/DXContainerEmitter.cpp (+17-11)
  • (modified) llvm/lib/ObjectYAML/DXContainerYAML.cpp (+12-6)
  • (modified) llvm/lib/Target/DirectX/DXILRootSignature.cpp (+8-6)
diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
index 53f59349ae029..742b2c55bcce9 100644
--- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
+++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Frontend/HLSL/RootSignatureMetadata.h"
+#include "llvm/BinaryFormat/DXContainer.h"
 #include "llvm/Frontend/HLSL/RootSignatureValidations.h"
 #include "llvm/IR/DiagnosticInfo.h"
 #include "llvm/IR/IRBuilder.h"
@@ -559,11 +560,17 @@ bool MetadataParser::validateRootSignature(
     assert(dxbc::isValidParameterType(Info.Header.ParameterType) &&
            "Invalid value for ParameterType");
 
-    switch (Info.Header.ParameterType) {
+    dxbc::RootParameterType PT =
+        static_cast<dxbc::RootParameterType>(Info.Header.ParameterType);
 
-    case llvm::to_underlying(dxbc::RootParameterType::CBV):
-    case llvm::to_underlying(dxbc::RootParameterType::UAV):
-    case llvm::to_underlying(dxbc::RootParameterType::SRV): {
+    switch (PT) {
+    case dxbc::RootParameterType::Constants32Bit:
+      // ToDo: Add proper validation.
+      continue;
+
+    case dxbc::RootParameterType::CBV:
+    case dxbc::RootParameterType::UAV:
+    case dxbc::RootParameterType::SRV: {
       const dxbc::RTS0::v2::RootDescriptor &Descriptor =
           RSD.ParametersContainer.getRootDescriptor(Info.Location);
       if (!llvm::hlsl::rootsig::verifyRegisterValue(Descriptor.ShaderRegister))
@@ -580,7 +587,7 @@ bool MetadataParser::validateRootSignature(
       }
       break;
     }
-    case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
+    case dxbc::RootParameterType::DescriptorTable: {
       const mcdxbc::DescriptorTable &Table =
           RSD.ParametersContainer.getDescriptorTable(Info.Location);
       for (const dxbc::RTS0::v2::DescriptorRange &Range : Table) {
diff --git a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
index f11c7d2033bfb..e55d3ec15c19d 100644
--- a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
+++ b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
@@ -53,10 +53,9 @@ bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal) {
 
 bool verifyRangeType(uint32_t Type) {
   switch (Type) {
-  case llvm::to_underlying(dxbc::DescriptorRangeType::CBV):
-  case llvm::to_underlying(dxbc::DescriptorRangeType::SRV):
-  case llvm::to_underlying(dxbc::DescriptorRangeType::UAV):
-  case llvm::to_underlying(dxbc::DescriptorRangeType::Sampler):
+#define DESCRIPTOR_RANGE(Num, Val)                                             \
+  case llvm::to_underlying(dxbc::DescriptorRangeType::Val):
+#include "llvm/BinaryFormat/DXContainerConstants.def"
     return true;
   };
 
diff --git a/llvm/lib/MC/DXContainerRootSignature.cpp b/llvm/lib/MC/DXContainerRootSignature.cpp
index 482280b5ef289..c94a39f80eeb2 100644
--- a/llvm/lib/MC/DXContainerRootSignature.cpp
+++ b/llvm/lib/MC/DXContainerRootSignature.cpp
@@ -8,6 +8,7 @@
 
 #include "llvm/MC/DXContainerRootSignature.h"
 #include "llvm/ADT/SmallString.h"
+#include "llvm/BinaryFormat/DXContainer.h"
 #include "llvm/Support/EndianStream.h"
 
 using namespace llvm;
@@ -35,20 +36,26 @@ size_t RootSignatureDesc::getSize() const {
       StaticSamplers.size() * sizeof(dxbc::RTS0::v1::StaticSampler);
 
   for (const RootParameterInfo &I : ParametersContainer) {
-    switch (I.Header.ParameterType) {
-    case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit):
+    if (!dxbc::isValidParameterType(I.Header.ParameterType))
+      continue;
+
+    dxbc::RootParameterType PT =
+        static_cast<dxbc::RootParameterType>(I.Header.ParameterType);
+
+    switch (PT) {
+    case dxbc::RootParameterType::Constants32Bit:
       Size += sizeof(dxbc::RTS0::v1::RootConstants);
       break;
-    case llvm::to_underlying(dxbc::RootParameterType::CBV):
-    case llvm::to_underlying(dxbc::RootParameterType::SRV):
-    case llvm::to_underlying(dxbc::RootParameterType::UAV):
+    case dxbc::RootParameterType::CBV:
+    case dxbc::RootParameterType::SRV:
+    case dxbc::RootParameterType::UAV:
       if (Version == 1)
         Size += sizeof(dxbc::RTS0::v1::RootDescriptor);
       else
         Size += sizeof(dxbc::RTS0::v2::RootDescriptor);
 
       break;
-    case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable):
+    case dxbc::RootParameterType::DescriptorTable:
       const DescriptorTable &Table =
           ParametersContainer.getDescriptorTable(I.Location);
 
@@ -97,8 +104,12 @@ void RootSignatureDesc::write(raw_ostream &OS) const {
   for (size_t I = 0; I < NumParameters; ++I) {
     rewriteOffsetToCurrentByte(BOS, ParamsOffsets[I]);
     const auto &[Type, Loc] = ParametersContainer.getTypeAndLocForParameter(I);
-    switch (Type) {
-    case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): {
+    if (!dxbc::isValidParameterType(Type))
+      continue;
+    dxbc::RootParameterType PT = static_cast<dxbc::RootParameterType>(Type);
+
+    switch (PT) {
+    case dxbc::RootParameterType::Constants32Bit: {
       const dxbc::RTS0::v1::RootConstants &Constants =
           ParametersContainer.getConstant(Loc);
       support::endian::write(BOS, Constants.ShaderRegister,
@@ -109,9 +120,9 @@ void RootSignatureDesc::write(raw_ostream &OS) const {
                              llvm::endianness::little);
       break;
     }
-    case llvm::to_underlying(dxbc::RootParameterType::CBV):
-    case llvm::to_underlying(dxbc::RootParameterType::SRV):
-    case llvm::to_underlying(dxbc::RootParameterType::UAV): {
+    case dxbc::RootParameterType::CBV:
+    case dxbc::RootParameterType::SRV:
+    case dxbc::RootParameterType::UAV: {
       const dxbc::RTS0::v2::RootDescriptor &Descriptor =
           ParametersContainer.getRootDescriptor(Loc);
 
@@ -123,7 +134,7 @@ void RootSignatureDesc::write(raw_ostream &OS) const {
         support::endian::write(BOS, Descriptor.Flags, llvm::endianness::little);
       break;
     }
-    case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
+    case dxbc::RootParameterType::DescriptorTable: {
       const DescriptorTable &Table =
           ParametersContainer.getDescriptorTable(Loc);
       support::endian::write(BOS, (uint32_t)Table.Ranges.size(),
diff --git a/llvm/lib/ObjectYAML/DXContainerEmitter.cpp b/llvm/lib/ObjectYAML/DXContainerEmitter.cpp
index 043b575a43b11..35e1c3e3b5953 100644
--- a/llvm/lib/ObjectYAML/DXContainerEmitter.cpp
+++ b/llvm/lib/ObjectYAML/DXContainerEmitter.cpp
@@ -278,8 +278,19 @@ void DXContainerWriter::writeParts(raw_ostream &OS) {
         dxbc::RTS0::v1::RootParameterHeader Header{L.Header.Type, L.Header.Visibility,
                                          L.Header.Offset};
 
-        switch (L.Header.Type) {
-        case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): {
+        if (!dxbc::isValidParameterType(L.Header.Type)) {
+          // Handling invalid parameter type edge case. We intentionally let
+          // obj2yaml/yaml2obj parse and emit invalid dxcontainer data, in order
+          // for that to be used as a testing tool more effectively.
+          RS.ParametersContainer.addInvalidParameter(Header);
+          continue;
+        }
+
+        dxbc::RootParameterType ParameterType =
+            static_cast<dxbc::RootParameterType>(L.Header.Type);
+
+        switch (ParameterType) {
+        case dxbc::RootParameterType::Constants32Bit: {
           const DXContainerYAML::RootConstantsYaml &ConstantYaml =
               P.RootSignature->Parameters.getOrInsertConstants(L);
           dxbc::RTS0::v1::RootConstants Constants;
@@ -289,9 +300,9 @@ void DXContainerWriter::writeParts(raw_ostream &OS) {
           RS.ParametersContainer.addParameter(Header, Constants);
           break;
         }
-        case llvm::to_underlying(dxbc::RootParameterType::CBV):
-        case llvm::to_underlying(dxbc::RootParameterType::SRV):
-        case llvm::to_underlying(dxbc::RootParameterType::UAV): {
+        case dxbc::RootParameterType::CBV:
+        case dxbc::RootParameterType::SRV:
+        case dxbc::RootParameterType::UAV: {
           const DXContainerYAML::RootDescriptorYaml &DescriptorYaml =
               P.RootSignature->Parameters.getOrInsertDescriptor(L);
 
@@ -303,7 +314,7 @@ void DXContainerWriter::writeParts(raw_ostream &OS) {
           RS.ParametersContainer.addParameter(Header, Descriptor);
           break;
         }
-        case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
+        case dxbc::RootParameterType::DescriptorTable: {
           const DXContainerYAML::DescriptorTableYaml &TableYaml =
               P.RootSignature->Parameters.getOrInsertTable(L);
           mcdxbc::DescriptorTable Table;
@@ -323,11 +334,6 @@ void DXContainerWriter::writeParts(raw_ostream &OS) {
           RS.ParametersContainer.addParameter(Header, Table);
           break;
         }
-        default:
-          // Handling invalid parameter type edge case. We intentionally let
-          // obj2yaml/yaml2obj parse and emit invalid dxcontainer data, in order
-          // for that to be used as a testing tool more effectively.
-          RS.ParametersContainer.addInvalidParameter(Header);
         }
       }
 
diff --git a/llvm/lib/ObjectYAML/DXContainerYAML.cpp b/llvm/lib/ObjectYAML/DXContainerYAML.cpp
index 263f7bdf37bca..03d5725c21491 100644
--- a/llvm/lib/ObjectYAML/DXContainerYAML.cpp
+++ b/llvm/lib/ObjectYAML/DXContainerYAML.cpp
@@ -424,22 +424,28 @@ void MappingContextTraits<DXContainerYAML::RootParameterLocationYaml,
   IO.mapRequired("ParameterType", L.Header.Type);
   IO.mapRequired("ShaderVisibility", L.Header.Visibility);
 
-  switch (L.Header.Type) {
-  case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): {
+  if (!dxbc::isValidParameterType(L.Header.Type))
+    return;
+  dxbc::RootParameterType PT =
+      static_cast<dxbc::RootParameterType>(L.Header.Type);
+
+  // We allow ParameterType to be invalid here.
+  switch (PT) {
+  case dxbc::RootParameterType::Constants32Bit: {
     DXContainerYAML::RootConstantsYaml &Constants =
         S.Parameters.getOrInsertConstants(L);
     IO.mapRequired("Constants", Constants);
     break;
   }
-  case llvm::to_underlying(dxbc::RootParameterType::CBV):
-  case llvm::to_underlying(dxbc::RootParameterType::SRV):
-  case llvm::to_underlying(dxbc::RootParameterType::UAV): {
+  case dxbc::RootParameterType::CBV:
+  case dxbc::RootParameterType::SRV:
+  case dxbc::RootParameterType::UAV: {
     DXContainerYAML::RootDescriptorYaml &Descriptor =
         S.Parameters.getOrInsertDescriptor(L);
     IO.mapRequired("Descriptor", Descriptor);
     break;
   }
-  case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
+  case dxbc::RootParameterType::DescriptorTable: {
     DXContainerYAML::DescriptorTableYaml &Table =
         S.Parameters.getOrInsertTable(L);
     IO.mapRequired("Table", Table);
diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.cpp b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
index ebdfcaa566b51..04c7c77953a5b 100644
--- a/llvm/lib/Target/DirectX/DXILRootSignature.cpp
+++ b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
@@ -175,8 +175,10 @@ PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M,
       OS << "- Parameter Type: " << Type << "\n"
          << "  Shader Visibility: " << Header.ShaderVisibility << "\n";
 
-      switch (Type) {
-      case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): {
+      assert(dxbc::isValidParameterType(Type) && "Invalid Parameter Type");
+      dxbc::RootParameterType PT = static_cast<dxbc::RootParameterType>(Type);
+      switch (PT) {
+      case dxbc::RootParameterType::Constants32Bit: {
         const dxbc::RTS0::v1::RootConstants &Constants =
             RS.ParametersContainer.getConstant(Loc);
         OS << "  Register Space: " << Constants.RegisterSpace << "\n"
@@ -184,9 +186,9 @@ PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M,
            << "  Num 32 Bit Values: " << Constants.Num32BitValues << "\n";
         break;
       }
-      case llvm::to_underlying(dxbc::RootParameterType::CBV):
-      case llvm::to_underlying(dxbc::RootParameterType::UAV):
-      case llvm::to_underlying(dxbc::RootParameterType::SRV): {
+      case dxbc::RootParameterType::CBV:
+      case dxbc::RootParameterType::UAV:
+      case dxbc::RootParameterType::SRV: {
         const dxbc::RTS0::v2::RootDescriptor &Descriptor =
             RS.ParametersContainer.getRootDescriptor(Loc);
         OS << "  Register Space: " << Descriptor.RegisterSpace << "\n"
@@ -195,7 +197,7 @@ PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M,
           OS << "  Flags: " << Descriptor.Flags << "\n";
         break;
       }
-      case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
+      case dxbc::RootParameterType::DescriptorTable: {
         const mcdxbc::DescriptorTable &Table =
             RS.ParametersContainer.getDescriptorTable(Loc);
         OS << "  NumRanges: " << Table.Ranges.size() << "\n";

@llvmbot
Copy link
Member

llvmbot commented Jul 28, 2025

@llvm/pr-subscribers-objectyaml

Author: None (joaosaffran)

Changes

As request in #150676, DirectX has some switch statements that use to_underlying to handle the cases, this prevents compiler warnings when the statements are not fully covered.

Closes: #150676


Full diff: https://github.com/llvm/llvm-project/pull/151032.diff

6 Files Affected:

  • (modified) llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp (+12-5)
  • (modified) llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp (+3-4)
  • (modified) llvm/lib/MC/DXContainerRootSignature.cpp (+23-12)
  • (modified) llvm/lib/ObjectYAML/DXContainerEmitter.cpp (+17-11)
  • (modified) llvm/lib/ObjectYAML/DXContainerYAML.cpp (+12-6)
  • (modified) llvm/lib/Target/DirectX/DXILRootSignature.cpp (+8-6)
diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
index 53f59349ae029..742b2c55bcce9 100644
--- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
+++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Frontend/HLSL/RootSignatureMetadata.h"
+#include "llvm/BinaryFormat/DXContainer.h"
 #include "llvm/Frontend/HLSL/RootSignatureValidations.h"
 #include "llvm/IR/DiagnosticInfo.h"
 #include "llvm/IR/IRBuilder.h"
@@ -559,11 +560,17 @@ bool MetadataParser::validateRootSignature(
     assert(dxbc::isValidParameterType(Info.Header.ParameterType) &&
            "Invalid value for ParameterType");
 
-    switch (Info.Header.ParameterType) {
+    dxbc::RootParameterType PT =
+        static_cast<dxbc::RootParameterType>(Info.Header.ParameterType);
 
-    case llvm::to_underlying(dxbc::RootParameterType::CBV):
-    case llvm::to_underlying(dxbc::RootParameterType::UAV):
-    case llvm::to_underlying(dxbc::RootParameterType::SRV): {
+    switch (PT) {
+    case dxbc::RootParameterType::Constants32Bit:
+      // ToDo: Add proper validation.
+      continue;
+
+    case dxbc::RootParameterType::CBV:
+    case dxbc::RootParameterType::UAV:
+    case dxbc::RootParameterType::SRV: {
       const dxbc::RTS0::v2::RootDescriptor &Descriptor =
           RSD.ParametersContainer.getRootDescriptor(Info.Location);
       if (!llvm::hlsl::rootsig::verifyRegisterValue(Descriptor.ShaderRegister))
@@ -580,7 +587,7 @@ bool MetadataParser::validateRootSignature(
       }
       break;
     }
-    case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
+    case dxbc::RootParameterType::DescriptorTable: {
       const mcdxbc::DescriptorTable &Table =
           RSD.ParametersContainer.getDescriptorTable(Info.Location);
       for (const dxbc::RTS0::v2::DescriptorRange &Range : Table) {
diff --git a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
index f11c7d2033bfb..e55d3ec15c19d 100644
--- a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
+++ b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
@@ -53,10 +53,9 @@ bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal) {
 
 bool verifyRangeType(uint32_t Type) {
   switch (Type) {
-  case llvm::to_underlying(dxbc::DescriptorRangeType::CBV):
-  case llvm::to_underlying(dxbc::DescriptorRangeType::SRV):
-  case llvm::to_underlying(dxbc::DescriptorRangeType::UAV):
-  case llvm::to_underlying(dxbc::DescriptorRangeType::Sampler):
+#define DESCRIPTOR_RANGE(Num, Val)                                             \
+  case llvm::to_underlying(dxbc::DescriptorRangeType::Val):
+#include "llvm/BinaryFormat/DXContainerConstants.def"
     return true;
   };
 
diff --git a/llvm/lib/MC/DXContainerRootSignature.cpp b/llvm/lib/MC/DXContainerRootSignature.cpp
index 482280b5ef289..c94a39f80eeb2 100644
--- a/llvm/lib/MC/DXContainerRootSignature.cpp
+++ b/llvm/lib/MC/DXContainerRootSignature.cpp
@@ -8,6 +8,7 @@
 
 #include "llvm/MC/DXContainerRootSignature.h"
 #include "llvm/ADT/SmallString.h"
+#include "llvm/BinaryFormat/DXContainer.h"
 #include "llvm/Support/EndianStream.h"
 
 using namespace llvm;
@@ -35,20 +36,26 @@ size_t RootSignatureDesc::getSize() const {
       StaticSamplers.size() * sizeof(dxbc::RTS0::v1::StaticSampler);
 
   for (const RootParameterInfo &I : ParametersContainer) {
-    switch (I.Header.ParameterType) {
-    case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit):
+    if (!dxbc::isValidParameterType(I.Header.ParameterType))
+      continue;
+
+    dxbc::RootParameterType PT =
+        static_cast<dxbc::RootParameterType>(I.Header.ParameterType);
+
+    switch (PT) {
+    case dxbc::RootParameterType::Constants32Bit:
       Size += sizeof(dxbc::RTS0::v1::RootConstants);
       break;
-    case llvm::to_underlying(dxbc::RootParameterType::CBV):
-    case llvm::to_underlying(dxbc::RootParameterType::SRV):
-    case llvm::to_underlying(dxbc::RootParameterType::UAV):
+    case dxbc::RootParameterType::CBV:
+    case dxbc::RootParameterType::SRV:
+    case dxbc::RootParameterType::UAV:
       if (Version == 1)
         Size += sizeof(dxbc::RTS0::v1::RootDescriptor);
       else
         Size += sizeof(dxbc::RTS0::v2::RootDescriptor);
 
       break;
-    case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable):
+    case dxbc::RootParameterType::DescriptorTable:
       const DescriptorTable &Table =
           ParametersContainer.getDescriptorTable(I.Location);
 
@@ -97,8 +104,12 @@ void RootSignatureDesc::write(raw_ostream &OS) const {
   for (size_t I = 0; I < NumParameters; ++I) {
     rewriteOffsetToCurrentByte(BOS, ParamsOffsets[I]);
     const auto &[Type, Loc] = ParametersContainer.getTypeAndLocForParameter(I);
-    switch (Type) {
-    case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): {
+    if (!dxbc::isValidParameterType(Type))
+      continue;
+    dxbc::RootParameterType PT = static_cast<dxbc::RootParameterType>(Type);
+
+    switch (PT) {
+    case dxbc::RootParameterType::Constants32Bit: {
       const dxbc::RTS0::v1::RootConstants &Constants =
           ParametersContainer.getConstant(Loc);
       support::endian::write(BOS, Constants.ShaderRegister,
@@ -109,9 +120,9 @@ void RootSignatureDesc::write(raw_ostream &OS) const {
                              llvm::endianness::little);
       break;
     }
-    case llvm::to_underlying(dxbc::RootParameterType::CBV):
-    case llvm::to_underlying(dxbc::RootParameterType::SRV):
-    case llvm::to_underlying(dxbc::RootParameterType::UAV): {
+    case dxbc::RootParameterType::CBV:
+    case dxbc::RootParameterType::SRV:
+    case dxbc::RootParameterType::UAV: {
       const dxbc::RTS0::v2::RootDescriptor &Descriptor =
           ParametersContainer.getRootDescriptor(Loc);
 
@@ -123,7 +134,7 @@ void RootSignatureDesc::write(raw_ostream &OS) const {
         support::endian::write(BOS, Descriptor.Flags, llvm::endianness::little);
       break;
     }
-    case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
+    case dxbc::RootParameterType::DescriptorTable: {
       const DescriptorTable &Table =
           ParametersContainer.getDescriptorTable(Loc);
       support::endian::write(BOS, (uint32_t)Table.Ranges.size(),
diff --git a/llvm/lib/ObjectYAML/DXContainerEmitter.cpp b/llvm/lib/ObjectYAML/DXContainerEmitter.cpp
index 043b575a43b11..35e1c3e3b5953 100644
--- a/llvm/lib/ObjectYAML/DXContainerEmitter.cpp
+++ b/llvm/lib/ObjectYAML/DXContainerEmitter.cpp
@@ -278,8 +278,19 @@ void DXContainerWriter::writeParts(raw_ostream &OS) {
         dxbc::RTS0::v1::RootParameterHeader Header{L.Header.Type, L.Header.Visibility,
                                          L.Header.Offset};
 
-        switch (L.Header.Type) {
-        case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): {
+        if (!dxbc::isValidParameterType(L.Header.Type)) {
+          // Handling invalid parameter type edge case. We intentionally let
+          // obj2yaml/yaml2obj parse and emit invalid dxcontainer data, in order
+          // for that to be used as a testing tool more effectively.
+          RS.ParametersContainer.addInvalidParameter(Header);
+          continue;
+        }
+
+        dxbc::RootParameterType ParameterType =
+            static_cast<dxbc::RootParameterType>(L.Header.Type);
+
+        switch (ParameterType) {
+        case dxbc::RootParameterType::Constants32Bit: {
           const DXContainerYAML::RootConstantsYaml &ConstantYaml =
               P.RootSignature->Parameters.getOrInsertConstants(L);
           dxbc::RTS0::v1::RootConstants Constants;
@@ -289,9 +300,9 @@ void DXContainerWriter::writeParts(raw_ostream &OS) {
           RS.ParametersContainer.addParameter(Header, Constants);
           break;
         }
-        case llvm::to_underlying(dxbc::RootParameterType::CBV):
-        case llvm::to_underlying(dxbc::RootParameterType::SRV):
-        case llvm::to_underlying(dxbc::RootParameterType::UAV): {
+        case dxbc::RootParameterType::CBV:
+        case dxbc::RootParameterType::SRV:
+        case dxbc::RootParameterType::UAV: {
           const DXContainerYAML::RootDescriptorYaml &DescriptorYaml =
               P.RootSignature->Parameters.getOrInsertDescriptor(L);
 
@@ -303,7 +314,7 @@ void DXContainerWriter::writeParts(raw_ostream &OS) {
           RS.ParametersContainer.addParameter(Header, Descriptor);
           break;
         }
-        case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
+        case dxbc::RootParameterType::DescriptorTable: {
           const DXContainerYAML::DescriptorTableYaml &TableYaml =
               P.RootSignature->Parameters.getOrInsertTable(L);
           mcdxbc::DescriptorTable Table;
@@ -323,11 +334,6 @@ void DXContainerWriter::writeParts(raw_ostream &OS) {
           RS.ParametersContainer.addParameter(Header, Table);
           break;
         }
-        default:
-          // Handling invalid parameter type edge case. We intentionally let
-          // obj2yaml/yaml2obj parse and emit invalid dxcontainer data, in order
-          // for that to be used as a testing tool more effectively.
-          RS.ParametersContainer.addInvalidParameter(Header);
         }
       }
 
diff --git a/llvm/lib/ObjectYAML/DXContainerYAML.cpp b/llvm/lib/ObjectYAML/DXContainerYAML.cpp
index 263f7bdf37bca..03d5725c21491 100644
--- a/llvm/lib/ObjectYAML/DXContainerYAML.cpp
+++ b/llvm/lib/ObjectYAML/DXContainerYAML.cpp
@@ -424,22 +424,28 @@ void MappingContextTraits<DXContainerYAML::RootParameterLocationYaml,
   IO.mapRequired("ParameterType", L.Header.Type);
   IO.mapRequired("ShaderVisibility", L.Header.Visibility);
 
-  switch (L.Header.Type) {
-  case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): {
+  if (!dxbc::isValidParameterType(L.Header.Type))
+    return;
+  dxbc::RootParameterType PT =
+      static_cast<dxbc::RootParameterType>(L.Header.Type);
+
+  // We allow ParameterType to be invalid here.
+  switch (PT) {
+  case dxbc::RootParameterType::Constants32Bit: {
     DXContainerYAML::RootConstantsYaml &Constants =
         S.Parameters.getOrInsertConstants(L);
     IO.mapRequired("Constants", Constants);
     break;
   }
-  case llvm::to_underlying(dxbc::RootParameterType::CBV):
-  case llvm::to_underlying(dxbc::RootParameterType::SRV):
-  case llvm::to_underlying(dxbc::RootParameterType::UAV): {
+  case dxbc::RootParameterType::CBV:
+  case dxbc::RootParameterType::SRV:
+  case dxbc::RootParameterType::UAV: {
     DXContainerYAML::RootDescriptorYaml &Descriptor =
         S.Parameters.getOrInsertDescriptor(L);
     IO.mapRequired("Descriptor", Descriptor);
     break;
   }
-  case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
+  case dxbc::RootParameterType::DescriptorTable: {
     DXContainerYAML::DescriptorTableYaml &Table =
         S.Parameters.getOrInsertTable(L);
     IO.mapRequired("Table", Table);
diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.cpp b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
index ebdfcaa566b51..04c7c77953a5b 100644
--- a/llvm/lib/Target/DirectX/DXILRootSignature.cpp
+++ b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
@@ -175,8 +175,10 @@ PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M,
       OS << "- Parameter Type: " << Type << "\n"
          << "  Shader Visibility: " << Header.ShaderVisibility << "\n";
 
-      switch (Type) {
-      case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): {
+      assert(dxbc::isValidParameterType(Type) && "Invalid Parameter Type");
+      dxbc::RootParameterType PT = static_cast<dxbc::RootParameterType>(Type);
+      switch (PT) {
+      case dxbc::RootParameterType::Constants32Bit: {
         const dxbc::RTS0::v1::RootConstants &Constants =
             RS.ParametersContainer.getConstant(Loc);
         OS << "  Register Space: " << Constants.RegisterSpace << "\n"
@@ -184,9 +186,9 @@ PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M,
            << "  Num 32 Bit Values: " << Constants.Num32BitValues << "\n";
         break;
       }
-      case llvm::to_underlying(dxbc::RootParameterType::CBV):
-      case llvm::to_underlying(dxbc::RootParameterType::UAV):
-      case llvm::to_underlying(dxbc::RootParameterType::SRV): {
+      case dxbc::RootParameterType::CBV:
+      case dxbc::RootParameterType::UAV:
+      case dxbc::RootParameterType::SRV: {
         const dxbc::RTS0::v2::RootDescriptor &Descriptor =
             RS.ParametersContainer.getRootDescriptor(Loc);
         OS << "  Register Space: " << Descriptor.RegisterSpace << "\n"
@@ -195,7 +197,7 @@ PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M,
           OS << "  Flags: " << Descriptor.Flags << "\n";
         break;
       }
-      case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
+      case dxbc::RootParameterType::DescriptorTable: {
         const mcdxbc::DescriptorTable &Table =
             RS.ParametersContainer.getDescriptorTable(Loc);
         OS << "  NumRanges: " << Table.Ranges.size() << "\n";

Copy link

⚠️ We detected that you are using a GitHub private e-mail address to contribute to the repo.
Please turn off Keep my email addresses private setting in your account.
See LLVM Developer Policy and LLVM Discourse for more information.

Comment on lines -56 to +58
case llvm::to_underlying(dxbc::DescriptorRangeType::CBV):
case llvm::to_underlying(dxbc::DescriptorRangeType::SRV):
case llvm::to_underlying(dxbc::DescriptorRangeType::UAV):
case llvm::to_underlying(dxbc::DescriptorRangeType::Sampler):
#define DESCRIPTOR_RANGE(Num, Val) \
case llvm::to_underlying(dxbc::DescriptorRangeType::Val):
#include "llvm/BinaryFormat/DXContainerConstants.def"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not updating to not use to_underlying

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This specific case is necessary, to keep using llvm::to_underlying, since that is checking if an uint_32t is valid value for RootParametersType. However, I updated to use the tablegen definition, that way we will always have it covering all possible values.

Or we can change how the check is being done, if folks prefer.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to me like we get the same end result by doing the int->enum cast once before the switch and having the switch statement operate on the actual enumerations. I have a strong preference for not using to_underlying since it impairs the frontend's ability to generate warnings.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this point, we are not sure Type is a valid DescriptorRangeType so the casting here int->enum would cause undefined behaviour.

@joaosaffran joaosaffran requested a review from inbelic July 28, 2025 21:23
@@ -35,20 +36,26 @@ size_t RootSignatureDesc::getSize() const {
StaticSamplers.size() * sizeof(dxbc::RTS0::v1::StaticSampler);

for (const RootParameterInfo &I : ParametersContainer) {
switch (I.Header.ParameterType) {
case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit):
if (!dxbc::isValidParameterType(I.Header.ParameterType))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is ParameterType a uint32_t instead of being the enum type?

It seems like a lot of complexity here would go away if we just stored these as the appropriate type, like we do for enums stored in the ProgramSignatureElement structure.

I suspect it is actually impossible to get to this point in the code with a parameter type that isn't legal (if it isn't impossible it probably should be) or we should be surfacing errors.

dxbc::RootParameterType PT =
static_cast<dxbc::RootParameterType>(L.Header.Type);

// We allow ParameterType to be invalid here.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment doesn't match the implementation. It returns above if the parameter type is invalid.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed the comments to make it clearer what the current behaviour is.

// Handling invalid parameter type edge case. We intentionally let
// obj2yaml/yaml2obj parse and emit invalid dxcontainer data, in order
// for that to be used as a testing tool more effectively.
RS.ParametersContainer.addInvalidParameter(Header);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I question the utility of this. It doesn't actually enable us to emit a binary that contains potentially well-formed parameters that our implementation doesn't understand.

That said, if we're using it for test I guess it is okay, it just has the drawback of forcing the MC code to be resilient to invalid parameter types which wouldn't be necessary if not for this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:DirectX HLSL HLSL Language Support mc Machine (object) code objectyaml
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[DirectX] Clean up to_underlying in case labels
4 participants