@@ -325,6 +325,19 @@ class CGMSHLSLRuntime : public CGHLSLRuntime {
325325};
326326} //  namespace
327327
328+ static  uint32_t  GetIntConstAttrArg (ASTContext &astContext, const  Expr *expr,
329+                                    uint32_t  defaultVal = 0 ) {
330+   if  (expr) {
331+     llvm::APSInt apsInt;
332+     APValue apValue;
333+     if  (expr->isIntegerConstantExpr (apsInt, astContext))
334+       return  (uint32_t )apsInt.getSExtValue ();
335+     if  (expr->isVulkanSpecConstantExpr (astContext, &apValue) && apValue.isInt ())
336+       return  (uint32_t )apValue.getInt ().getSExtValue ();
337+   }
338+   return  defaultVal;
339+ }
340+ 
328341// ------------------------------------------------------------------------------
329342// 
330343//  CGMSHLSLRuntime methods.
@@ -1419,6 +1432,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
14191432  }
14201433
14211434  DiagnosticsEngine &Diags = CGM.getDiags ();
1435+   ASTContext &astContext = CGM.getTypes ().getContext ();
14221436
14231437  std::unique_ptr<DxilFunctionProps> funcProps =
14241438      llvm::make_unique<DxilFunctionProps>();
@@ -1629,10 +1643,9 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
16291643
16301644  //  Populate numThreads
16311645  if  (const  HLSLNumThreadsAttr *Attr = FD->getAttr <HLSLNumThreadsAttr>()) {
1632- 
1633-     funcProps->numThreads [0 ] = Attr->getX ();
1634-     funcProps->numThreads [1 ] = Attr->getY ();
1635-     funcProps->numThreads [2 ] = Attr->getZ ();
1646+     funcProps->numThreads [0 ] = GetIntConstAttrArg (astContext, Attr->getX (), 1 );
1647+     funcProps->numThreads [1 ] = GetIntConstAttrArg (astContext, Attr->getY (), 1 );
1648+     funcProps->numThreads [2 ] = GetIntConstAttrArg (astContext, Attr->getZ (), 1 );
16361649
16371650    if  (isEntry && !SM->IsCS () && !SM->IsMS () && !SM->IsAS ()) {
16381651      unsigned  DiagID = Diags.getCustomDiagID (
@@ -1805,7 +1818,8 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
18051818
18061819    if  (const  auto  *pAttr = FD->getAttr <HLSLNodeIdAttr>()) {
18071820      funcProps->NodeShaderID .Name  = pAttr->getName ().str ();
1808-       funcProps->NodeShaderID .Index  = pAttr->getArrayIndex ();
1821+       funcProps->NodeShaderID .Index  =
1822+           GetIntConstAttrArg (astContext, pAttr->getArrayIndex (), 0 );
18091823    } else  {
18101824      funcProps->NodeShaderID .Name  = FD->getName ().str ();
18111825      funcProps->NodeShaderID .Index  = 0 ;
@@ -1816,20 +1830,28 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
18161830    }
18171831    if  (const  auto  *pAttr = FD->getAttr <HLSLNodeShareInputOfAttr>()) {
18181832      funcProps->NodeShaderSharedInput .Name  = pAttr->getName ().str ();
1819-       funcProps->NodeShaderSharedInput .Index  = pAttr->getArrayIndex ();
1833+       funcProps->NodeShaderSharedInput .Index  =
1834+           GetIntConstAttrArg (astContext, pAttr->getArrayIndex (), 0 );
18201835    }
18211836    if  (const  auto  *pAttr = FD->getAttr <HLSLNodeDispatchGridAttr>()) {
1822-       funcProps->Node .DispatchGrid [0 ] = pAttr->getX ();
1823-       funcProps->Node .DispatchGrid [1 ] = pAttr->getY ();
1824-       funcProps->Node .DispatchGrid [2 ] = pAttr->getZ ();
1837+       funcProps->Node .DispatchGrid [0 ] =
1838+           GetIntConstAttrArg (astContext, pAttr->getX (), 1 );
1839+       funcProps->Node .DispatchGrid [1 ] =
1840+           GetIntConstAttrArg (astContext, pAttr->getY (), 1 );
1841+       funcProps->Node .DispatchGrid [2 ] =
1842+           GetIntConstAttrArg (astContext, pAttr->getZ (), 1 );
18251843    }
18261844    if  (const  auto  *pAttr = FD->getAttr <HLSLNodeMaxDispatchGridAttr>()) {
1827-       funcProps->Node .MaxDispatchGrid [0 ] = pAttr->getX ();
1828-       funcProps->Node .MaxDispatchGrid [1 ] = pAttr->getY ();
1829-       funcProps->Node .MaxDispatchGrid [2 ] = pAttr->getZ ();
1845+       funcProps->Node .MaxDispatchGrid [0 ] =
1846+           GetIntConstAttrArg (astContext, pAttr->getX (), 1 );
1847+       funcProps->Node .MaxDispatchGrid [1 ] =
1848+           GetIntConstAttrArg (astContext, pAttr->getY (), 1 );
1849+       funcProps->Node .MaxDispatchGrid [2 ] =
1850+           GetIntConstAttrArg (astContext, pAttr->getZ (), 1 );
18301851    }
18311852    if  (const  auto  *pAttr = FD->getAttr <HLSLNodeMaxRecursionDepthAttr>()) {
1832-       funcProps->Node .MaxRecursionDepth  = pAttr->getCount ();
1853+       funcProps->Node .MaxRecursionDepth  =
1854+           GetIntConstAttrArg (astContext, pAttr->getCount (), 0 );
18331855    }
18341856    if  (!FD->getAttr <HLSLNumThreadsAttr>()) {
18351857      //  NumThreads wasn't specified.
@@ -2343,8 +2365,9 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
23432365          NodeInputRecordParams[ArgIt].MetadataIdx  = NodeInputParamIdx++;
23442366
23452367          if  (parmDecl->hasAttr <HLSLMaxRecordsAttr>()) {
2346-             node.MaxRecords  =
2347-                 parmDecl->getAttr <HLSLMaxRecordsAttr>()->getMaxCount ();
2368+             node.MaxRecords  = GetIntConstAttrArg (
2369+                 astContext,
2370+                 parmDecl->getAttr <HLSLMaxRecordsAttr>()->getMaxCount (), 1 );
23482371          }
23492372          if  (parmDecl->hasAttr <HLSLGloballyCoherentAttr>())
23502373            node.Flags .SetGloballyCoherent ();
@@ -2375,7 +2398,8 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
23752398          //  OutputID from attribute
23762399          if  (const  auto  *Attr = parmDecl->getAttr <HLSLNodeIdAttr>()) {
23772400            node.OutputID .Name  = Attr->getName ().str ();
2378-             node.OutputID .Index  = Attr->getArrayIndex ();
2401+             node.OutputID .Index  =
2402+                 GetIntConstAttrArg (astContext, Attr->getArrayIndex (), 0 );
23792403          } else  {
23802404            node.OutputID .Name  = parmDecl->getName ().str ();
23812405            node.OutputID .Index  = 0 ;
@@ -2434,7 +2458,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
24342458      node.MaxRecordsSharedWith  = ix;
24352459    }
24362460    if  (const  auto  *Attr = parmDecl->getAttr <HLSLMaxRecordsAttr>())
2437-       node.MaxRecords  = Attr->getMaxCount ();
2461+       node.MaxRecords  = GetIntConstAttrArg (astContext,  Attr->getMaxCount (),  0 );
24382462  }
24392463
24402464  if  (inputPatchCount > 1 ) {
0 commit comments