Skip to content

[openacc][flang] Support two type bindName representation in acc routine #149147

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 17, 2025

Conversation

delaram-talaashrafi
Copy link
Contributor

Based on the OpenACC specification — which states that if the bind name is given as an identifier it should be resolved according to the compiled language, and if given as a string it should be used unmodified — we introduce two distinct bindName representations for acc routine to handle each case appropriately: one as an array of SymbolRefAttr for identifiers and another as an array of StringAttr for strings.

To ensure correct correspondence between bind names and devices, this patch also introduces two separate sets of device attributes. The routine operation is extended accordingly, along with the necessary updates to the OpenACC dialect and its lowering.

…s and device attribute in acc.routine

Based on the OpenACC specification — which states that if the bind name is given as an identifier it should
be resolved according to the compiled language, and if given as a string it should be used unmodified — we
introduce two distinct `bindName` representations for `acc routine` to handle each case appropriately:
one as an array of `SymbolRefAttr` for identifiers and another as an array of `StringAttr` for strings.

To ensure correct correspondence between bind names and devices, this patch also introduces two separate
sets of device attributes. The routine operation is extended accordingly, along with the necessary updates
to the OpenACC dialect and its lowering.
@llvmbot llvmbot added mlir flang Flang issues not falling into any other category mlir:openacc flang:fir-hlfir openacc labels Jul 16, 2025
@llvmbot
Copy link
Member

llvmbot commented Jul 16, 2025

@llvm/pr-subscribers-mlir-openacc
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-openacc

@llvm/pr-subscribers-flang-fir-hlfir

Author: None (delaram-talaashrafi)

Changes

Based on the OpenACC specification — which states that if the bind name is given as an identifier it should be resolved according to the compiled language, and if given as a string it should be used unmodified — we introduce two distinct bindName representations for acc routine to handle each case appropriately: one as an array of SymbolRefAttr for identifiers and another as an array of StringAttr for strings.

To ensure correct correspondence between bind names and devices, this patch also introduces two separate sets of device attributes. The routine operation is extended accordingly, along with the necessary updates to the OpenACC dialect and its lowering.


Patch is 23.26 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/149147.diff

7 Files Affected:

  • (modified) flang/lib/Lower/OpenACC.cpp (+72-31)
  • (modified) flang/test/Lower/OpenACC/acc-routine.f90 (+4-3)
  • (modified) flang/test/Lower/OpenACC/acc-routine03.f90 (+1-1)
  • (modified) mlir/include/mlir/Dialect/OpenACC/OpenACC.h (+1)
  • (modified) mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td (+7-5)
  • (modified) mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp (+88-25)
  • (modified) mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp (+37-7)
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 39e4444cde4e3..93664dc5d8c2d 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -4396,10 +4396,35 @@ getAttributeValueByDeviceType(llvm::SmallVector<mlir::Attribute> &attributes,
   return std::nullopt;
 }
 
+// Helper function to extract string value from bind name variant
+static std::optional<llvm::StringRef> getBindNameStringValue(
+    const std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
+        &bindNameValue) {
+  if (!bindNameValue.has_value()) {
+    return std::nullopt;
+  }
+
+  return std::visit(
+      [](const auto &attr) -> std::optional<llvm::StringRef> {
+        if constexpr (std::is_same_v<std::decay_t<decltype(attr)>,
+                                     mlir::StringAttr>) {
+          return attr.getValue();
+        } else if constexpr (std::is_same_v<std::decay_t<decltype(attr)>,
+                                            mlir::SymbolRefAttr>) {
+          return attr.getLeafReference();
+        } else {
+          return std::nullopt;
+        }
+      },
+      bindNameValue.value());
+}
+
 static bool compareDeviceTypeInfo(
     mlir::acc::RoutineOp op,
-    llvm::SmallVector<mlir::Attribute> &bindNameArrayAttr,
-    llvm::SmallVector<mlir::Attribute> &bindNameDeviceTypeArrayAttr,
+    llvm::SmallVector<mlir::Attribute> &bindIdNameArrayAttr,
+    llvm::SmallVector<mlir::Attribute> &bindStrNameArrayAttr,
+    llvm::SmallVector<mlir::Attribute> &bindIdNameDeviceTypeArrayAttr,
+    llvm::SmallVector<mlir::Attribute> &bindStrNameDeviceTypeArrayAttr,
     llvm::SmallVector<mlir::Attribute> &gangArrayAttr,
     llvm::SmallVector<mlir::Attribute> &gangDimArrayAttr,
     llvm::SmallVector<mlir::Attribute> &gangDimDeviceTypeArrayAttr,
@@ -4409,9 +4434,13 @@ static bool compareDeviceTypeInfo(
   for (uint32_t dtypeInt = 0;
        dtypeInt != mlir::acc::getMaxEnumValForDeviceType(); ++dtypeInt) {
     auto dtype = static_cast<mlir::acc::DeviceType>(dtypeInt);
-    if (op.getBindNameValue(dtype) !=
-        getAttributeValueByDeviceType<llvm::StringRef, mlir::StringAttr>(
-            bindNameArrayAttr, bindNameDeviceTypeArrayAttr, dtype))
+    auto bindNameValue = getBindNameStringValue(op.getBindNameValue(dtype));
+    if (bindNameValue !=
+            getAttributeValueByDeviceType<llvm::StringRef, mlir::StringAttr>(
+                bindIdNameArrayAttr, bindIdNameDeviceTypeArrayAttr, dtype) &&
+        bindNameValue !=
+            getAttributeValueByDeviceType<llvm::StringRef, mlir::StringAttr>(
+                bindStrNameArrayAttr, bindStrNameDeviceTypeArrayAttr, dtype))
       return false;
     if (op.hasGang(dtype) != hasDeviceType(gangArrayAttr, dtype))
       return false;
@@ -4458,8 +4487,10 @@ getArrayAttrOrNull(fir::FirOpBuilder &builder,
 void createOpenACCRoutineConstruct(
     Fortran::lower::AbstractConverter &converter, mlir::Location loc,
     mlir::ModuleOp mod, mlir::func::FuncOp funcOp, std::string funcName,
-    bool hasNohost, llvm::SmallVector<mlir::Attribute> &bindNames,
-    llvm::SmallVector<mlir::Attribute> &bindNameDeviceTypes,
+    bool hasNohost, llvm::SmallVector<mlir::Attribute> &bindIdNames,
+    llvm::SmallVector<mlir::Attribute> &bindStrNames,
+    llvm::SmallVector<mlir::Attribute> &bindIdNameDeviceTypes,
+    llvm::SmallVector<mlir::Attribute> &bindStrNameDeviceTypes,
     llvm::SmallVector<mlir::Attribute> &gangDeviceTypes,
     llvm::SmallVector<mlir::Attribute> &gangDimValues,
     llvm::SmallVector<mlir::Attribute> &gangDimDeviceTypes,
@@ -4472,7 +4503,8 @@ void createOpenACCRoutineConstruct(
         0) {
       // If the routine is already specified with the same clauses, just skip
       // the operation creation.
-      if (compareDeviceTypeInfo(routineOp, bindNames, bindNameDeviceTypes,
+      if (compareDeviceTypeInfo(routineOp, bindIdNames, bindStrNames,
+                                bindIdNameDeviceTypes, bindStrNameDeviceTypes,
                                 gangDeviceTypes, gangDimValues,
                                 gangDimDeviceTypes, seqDeviceTypes,
                                 workerDeviceTypes, vectorDeviceTypes) &&
@@ -4489,8 +4521,10 @@ void createOpenACCRoutineConstruct(
   modBuilder.create<mlir::acc::RoutineOp>(
       loc, routineOpStr,
       mlir::SymbolRefAttr::get(builder.getContext(), funcName),
-      getArrayAttrOrNull(builder, bindNames),
-      getArrayAttrOrNull(builder, bindNameDeviceTypes),
+      getArrayAttrOrNull(builder, bindIdNames),
+      getArrayAttrOrNull(builder, bindStrNames),
+      getArrayAttrOrNull(builder, bindIdNameDeviceTypes),
+      getArrayAttrOrNull(builder, bindStrNameDeviceTypes),
       getArrayAttrOrNull(builder, workerDeviceTypes),
       getArrayAttrOrNull(builder, vectorDeviceTypes),
       getArrayAttrOrNull(builder, seqDeviceTypes), hasNohost,
@@ -4507,8 +4541,10 @@ static void interpretRoutineDeviceInfo(
     llvm::SmallVector<mlir::Attribute> &seqDeviceTypes,
     llvm::SmallVector<mlir::Attribute> &vectorDeviceTypes,
     llvm::SmallVector<mlir::Attribute> &workerDeviceTypes,
-    llvm::SmallVector<mlir::Attribute> &bindNameDeviceTypes,
-    llvm::SmallVector<mlir::Attribute> &bindNames,
+    llvm::SmallVector<mlir::Attribute> &bindIdNameDeviceTypes,
+    llvm::SmallVector<mlir::Attribute> &bindStrNameDeviceTypes,
+    llvm::SmallVector<mlir::Attribute> &bindIdNames,
+    llvm::SmallVector<mlir::Attribute> &bindStrNames,
     llvm::SmallVector<mlir::Attribute> &gangDeviceTypes,
     llvm::SmallVector<mlir::Attribute> &gangDimValues,
     llvm::SmallVector<mlir::Attribute> &gangDimDeviceTypes) {
@@ -4541,16 +4577,18 @@ static void interpretRoutineDeviceInfo(
   if (dinfo.bindNameOpt().has_value()) {
     const auto &bindName = dinfo.bindNameOpt().value();
     mlir::Attribute bindNameAttr;
-    if (const auto &bindStr{std::get_if<std::string>(&bindName)}) {
+    if (const auto &bindSym{
+            std::get_if<Fortran::semantics::SymbolRef>(&bindName)}) {
+      bindNameAttr = builder.getSymbolRefAttr(converter.mangleName(*bindSym));
+      bindIdNames.push_back(bindNameAttr);
+      bindIdNameDeviceTypes.push_back(getDeviceTypeAttr());
+    } else if (const auto &bindStr{std::get_if<std::string>(&bindName)}) {
       bindNameAttr = builder.getStringAttr(*bindStr);
-    } else if (const auto &bindSym{
-                   std::get_if<Fortran::semantics::SymbolRef>(&bindName)}) {
-      bindNameAttr = builder.getStringAttr(converter.mangleName(*bindSym));
+      bindStrNames.push_back(bindNameAttr);
+      bindStrNameDeviceTypes.push_back(getDeviceTypeAttr());
     } else {
       llvm_unreachable("Unsupported bind name type");
     }
-    bindNames.push_back(bindNameAttr);
-    bindNameDeviceTypes.push_back(getDeviceTypeAttr());
   }
 }
 
@@ -4566,8 +4604,9 @@ void Fortran::lower::genOpenACCRoutineConstruct(
   bool hasNohost{false};
 
   llvm::SmallVector<mlir::Attribute> seqDeviceTypes, vectorDeviceTypes,
-      workerDeviceTypes, bindNameDeviceTypes, bindNames, gangDeviceTypes,
-      gangDimDeviceTypes, gangDimValues;
+      workerDeviceTypes, bindIdNameDeviceTypes, bindStrNameDeviceTypes,
+      bindIdNames, bindStrNames, gangDeviceTypes, gangDimDeviceTypes,
+      gangDimValues;
 
   for (const Fortran::semantics::OpenACCRoutineInfo &info : routineInfos) {
     // Device Independent Attributes
@@ -4576,24 +4615,26 @@ void Fortran::lower::genOpenACCRoutineConstruct(
     }
     // Note: Device Independent Attributes are set to the
     // none device type in `info`.
-    interpretRoutineDeviceInfo(converter, info, seqDeviceTypes,
-                               vectorDeviceTypes, workerDeviceTypes,
-                               bindNameDeviceTypes, bindNames, gangDeviceTypes,
-                               gangDimValues, gangDimDeviceTypes);
+    interpretRoutineDeviceInfo(
+        converter, info, seqDeviceTypes, vectorDeviceTypes, workerDeviceTypes,
+        bindIdNameDeviceTypes, bindStrNameDeviceTypes, bindIdNames,
+        bindStrNames, gangDeviceTypes, gangDimValues, gangDimDeviceTypes);
 
     // Device Dependent Attributes
     for (const Fortran::semantics::OpenACCRoutineDeviceTypeInfo &dinfo :
          info.deviceTypeInfos()) {
-      interpretRoutineDeviceInfo(
-          converter, dinfo, seqDeviceTypes, vectorDeviceTypes,
-          workerDeviceTypes, bindNameDeviceTypes, bindNames, gangDeviceTypes,
-          gangDimValues, gangDimDeviceTypes);
+      interpretRoutineDeviceInfo(converter, dinfo, seqDeviceTypes,
+                                 vectorDeviceTypes, workerDeviceTypes,
+                                 bindIdNameDeviceTypes, bindStrNameDeviceTypes,
+                                 bindIdNames, bindStrNames, gangDeviceTypes,
+                                 gangDimValues, gangDimDeviceTypes);
     }
   }
   createOpenACCRoutineConstruct(
-      converter, loc, mod, funcOp, funcName, hasNohost, bindNames,
-      bindNameDeviceTypes, gangDeviceTypes, gangDimValues, gangDimDeviceTypes,
-      seqDeviceTypes, workerDeviceTypes, vectorDeviceTypes);
+      converter, loc, mod, funcOp, funcName, hasNohost, bindIdNames,
+      bindStrNames, bindIdNameDeviceTypes, bindStrNameDeviceTypes,
+      gangDeviceTypes, gangDimValues, gangDimDeviceTypes, seqDeviceTypes,
+      workerDeviceTypes, vectorDeviceTypes);
 }
 
 static void
diff --git a/flang/test/Lower/OpenACC/acc-routine.f90 b/flang/test/Lower/OpenACC/acc-routine.f90
index 789f3a57e1f79..1a63b4120235c 100644
--- a/flang/test/Lower/OpenACC/acc-routine.f90
+++ b/flang/test/Lower/OpenACC/acc-routine.f90
@@ -2,13 +2,14 @@
 
 ! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
 
-! CHECK: acc.routine @[[r14:.*]] func(@_QPacc_routine19) bind("_QPacc_routine17" [#acc.device_type<host>], "_QPacc_routine17" [#acc.device_type<default>], "_QPacc_routine16" [#acc.device_type<multicore>])
-! CHECK: acc.routine @[[r13:.*]] func(@_QPacc_routine18) bind("_QPacc_routine17" [#acc.device_type<host>], "_QPacc_routine16" [#acc.device_type<multicore>])
+! CHECK: acc.routine @[[r14:.*]] func(@_QPacc_routine19) bind(@_QPacc_routine17 [#acc.device_type<host>], @_QPacc_routine17
+! [#acc.device_type<default>], @_QPacc_routine16 [#acc.device_type<multicore>])
+! CHECK: acc.routine @[[r13:.*]] func(@_QPacc_routine18) bind(@_QPacc_routine17 [#acc.device_type<host>], @_QPacc_routine16 [#acc.device_type<multicore>])
 ! CHECK: acc.routine @[[r12:.*]] func(@_QPacc_routine17) worker ([#acc.device_type<host>]) vector ([#acc.device_type<multicore>])
 ! CHECK: acc.routine @[[r11:.*]] func(@_QPacc_routine16) gang([#acc.device_type<nvidia>]) seq ([#acc.device_type<host>])
 ! CHECK: acc.routine @[[r10:.*]] func(@_QPacc_routine11) seq
 ! CHECK: acc.routine @[[r09:.*]] func(@_QPacc_routine10) seq
-! CHECK: acc.routine @[[r08:.*]] func(@_QPacc_routine9) bind("_QPacc_routine9a")
+! CHECK: acc.routine @[[r08:.*]] func(@_QPacc_routine9) bind(@_QPacc_routine9a)
 ! CHECK: acc.routine @[[r07:.*]] func(@_QPacc_routine8) bind("routine8_")
 ! CHECK: acc.routine @[[r06:.*]] func(@_QPacc_routine7) gang(dim: 1 : i64)
 ! CHECK: acc.routine @[[r05:.*]] func(@_QPacc_routine6) nohost
diff --git a/flang/test/Lower/OpenACC/acc-routine03.f90 b/flang/test/Lower/OpenACC/acc-routine03.f90
index 85e4ef580f983..ddd6bda0367e4 100644
--- a/flang/test/Lower/OpenACC/acc-routine03.f90
+++ b/flang/test/Lower/OpenACC/acc-routine03.f90
@@ -30,6 +30,6 @@ subroutine sub2(a)
 end subroutine
 
 ! CHECK: acc.routine @acc_routine_1 func(@_QPsub2) worker nohost
-! CHECK: acc.routine @acc_routine_0 func(@_QPsub1) bind("_QPsub2") worker
+! CHECK: acc.routine @acc_routine_0 func(@_QPsub1) bind(@_QPsub2) worker
 ! CHECK: func.func @_QPsub1(%arg0: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "a"}) attributes {acc.routine_info = #acc.routine_info<[@acc_routine_0]>}
 ! CHECK: func.func @_QPsub2(%arg0: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "a"}) attributes {acc.routine_info = #acc.routine_info<[@acc_routine_1]>}
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
index 4eb666239d4e4..8f87235fcd237 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
@@ -29,6 +29,7 @@
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
+#include <variant>
 
 #define GET_TYPEDEF_CLASSES
 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.h.inc"
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 66378f116784e..96b9adcc53b3c 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -2772,8 +2772,10 @@ def OpenACC_RoutineOp : OpenACC_Op<"routine", [IsolatedFromAbove]> {
   }];
 
   let arguments = (ins SymbolNameAttr:$sym_name, SymbolRefAttr:$func_name,
-      OptionalAttr<StrArrayAttr>:$bindName,
-      OptionalAttr<DeviceTypeArrayAttr>:$bindNameDeviceType,
+      OptionalAttr<SymbolRefArrayAttr>:$bindIdName,
+      OptionalAttr<StrArrayAttr>:$bindStrName,
+      OptionalAttr<DeviceTypeArrayAttr>:$bindIdNameDeviceType,
+      OptionalAttr<DeviceTypeArrayAttr>:$bindStrNameDeviceType,
       OptionalAttr<DeviceTypeArrayAttr>:$worker,
       OptionalAttr<DeviceTypeArrayAttr>:$vector,
       OptionalAttr<DeviceTypeArrayAttr>:$seq, UnitAttr:$nohost,
@@ -2815,14 +2817,14 @@ def OpenACC_RoutineOp : OpenACC_Op<"routine", [IsolatedFromAbove]> {
     std::optional<int64_t> getGangDimValue();
     std::optional<int64_t> getGangDimValue(mlir::acc::DeviceType deviceType);
 
-    std::optional<llvm::StringRef> getBindNameValue();
-    std::optional<llvm::StringRef> getBindNameValue(mlir::acc::DeviceType deviceType);
+    std::optional<::std::variant<mlir::SymbolRefAttr, mlir::StringAttr>> getBindNameValue();
+    std::optional<::std::variant<mlir::SymbolRefAttr, mlir::StringAttr>> getBindNameValue(mlir::acc::DeviceType deviceType);
   }];
 
   let assemblyFormat = [{
     $sym_name `func` `(` $func_name `)`
     oilist (
-        `bind` `(` custom<BindName>($bindName, $bindNameDeviceType) `)`
+        `bind` `(` custom<BindName>($bindIdName, $bindStrName ,$bindIdNameDeviceType, $bindStrNameDeviceType) `)`
       | `gang` `` custom<RoutineGangClause>($gang, $gangDim, $gangDimDeviceType)
       | `worker` custom<DeviceTypeArrayAttr>($worker)
       | `vector` custom<DeviceTypeArrayAttr>($vector)
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index f2eab62b286af..33519ed08ee0b 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -21,6 +21,7 @@
 #include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/LogicalResult.h"
+#include <variant>
 
 using namespace mlir;
 using namespace acc;
@@ -3461,40 +3462,89 @@ LogicalResult acc::RoutineOp::verify() {
   return success();
 }
 
-static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindName,
-                                 mlir::ArrayAttr &deviceTypes) {
-  llvm::SmallVector<mlir::Attribute> bindNameAttrs;
-  llvm::SmallVector<mlir::Attribute> deviceTypeAttrs;
+static ParseResult parseBindName(OpAsmParser &parser,
+                                 mlir::ArrayAttr &bindIdName,
+                                 mlir::ArrayAttr &bindStrName,
+                                 mlir::ArrayAttr &deviceIdTypes,
+                                 mlir::ArrayAttr &deviceStrTypes) {
+  llvm::SmallVector<mlir::Attribute> bindIdNameAttrs;
+  llvm::SmallVector<mlir::Attribute> bindStrNameAttrs;
+  llvm::SmallVector<mlir::Attribute> deviceIdTypeAttrs;
+  llvm::SmallVector<mlir::Attribute> deviceStrTypeAttrs;
 
   if (failed(parser.parseCommaSeparatedList([&]() {
-        if (parser.parseAttribute(bindNameAttrs.emplace_back()))
+        mlir::Attribute newAttr;
+        bool isSymbolRefAttr;
+        auto parseResult = parser.parseAttribute(newAttr);
+        if (auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(newAttr)) {
+          bindIdNameAttrs.push_back(symbolRefAttr);
+          isSymbolRefAttr = true;
+        } else if (auto stringAttr = dyn_cast<mlir::StringAttr>(newAttr)) {
+          bindStrNameAttrs.push_back(stringAttr);
+          isSymbolRefAttr = false;
+        }
+        if (parseResult)
           return failure();
         if (failed(parser.parseOptionalLSquare())) {
-          deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
-              parser.getContext(), mlir::acc::DeviceType::None));
+          if (isSymbolRefAttr) {
+            deviceIdTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
+                parser.getContext(), mlir::acc::DeviceType::None));
+          } else {
+            deviceStrTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
+                parser.getContext(), mlir::acc::DeviceType::None));
+          }
         } else {
-          if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
-              parser.parseRSquare())
-            return failure();
+          if (isSymbolRefAttr) {
+            if (parser.parseAttribute(deviceIdTypeAttrs.emplace_back()) ||
+                parser.parseRSquare())
+              return failure();
+          } else {
+            if (parser.parseAttribute(deviceStrTypeAttrs.emplace_back()) ||
+                parser.parseRSquare())
+              return failure();
+          }
         }
         return success();
       })))
     return failure();
 
-  bindName = ArrayAttr::get(parser.getContext(), bindNameAttrs);
-  deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
+  bindIdName = ArrayAttr::get(parser.getContext(), bindIdNameAttrs);
+  bindStrName = ArrayAttr::get(parser.getContext(), bindStrNameAttrs);
+  deviceIdTypes = ArrayAttr::get(parser.getContext(), deviceIdTypeAttrs);
+  deviceStrTypes = ArrayAttr::get(parser.getContext(), deviceStrTypeAttrs);
 
   return success();
 }
 
 static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op,
-                          std::optional<mlir::ArrayAttr> bindName,
-                          std::optional<mlir::ArrayAttr> deviceTypes) {
-  llvm::interleaveComma(llvm::zip(*bindName, *deviceTypes), p,
-                        [&](const auto &pair) {
-                          p << std::get<0>(pair);
-                          printSingleDeviceType(p, std::get<1>(pair));
-                        });
+                          std::optional<mlir::ArrayAttr> bindIdName,
+                          std::optional<mlir::ArrayAttr> bindStrName,
+                          std::optional<mlir::ArrayAttr> deviceIdTypes,
+                          std::optional<mlir::ArrayAttr> deviceStrTypes) {
+  // Create combined vectors for all bind names and device types
+  llvm::SmallVector<mlir::Attribute> allBindNames;
+  llvm::SmallVector<mlir::Attribute> allDeviceTypes;
+
+  // Append bindIdName and deviceIdTypes
+  if (hasDeviceTypeValues(deviceIdTypes)) {
+    allBindNames.append(bindIdName->begin(), bindIdName->end());
+    allDeviceTypes.append(deviceIdTypes->begin(), deviceIdTypes->end());
+  }
+
+  // Append bindStrName and deviceStrTypes
+  if (hasDeviceTypeValues(deviceStrTypes)) {
+    allBindNames.append(bindStrName->begin(), bindStrName->end());
+    allDeviceTypes.append(deviceStrTypes->begin(), deviceStrTypes->end());
+  }
+
+  // Print the combined sequence
+  if (!allBindNames.empty()) {
+    llvm::interleaveComma(llvm::zip(allBindNames, allDeviceTypes), p,
+                          [&](const auto &pair) {
+                            p << std::get<0>(pair);
+                            printSingleDeviceType(p, std::get<1>(pair));
+                          });
+  }
 }
 
 static ParseResult parseRoutineGangClause(OpAsmParser &parser,
@@ -3654,19 +3704,32 @@ bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
   return hasDeviceType(getSeq(), deviceType);
 }
 
-std::optional<llvm::StringRef> RoutineOp::getBindNameValue() {
+std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
+RoutineOp::getBindNameValue() {
   return getBindNameValue(mlir::acc::DeviceType::None);
 }
 
-std::optional<llvm::StringRef>
+std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
 RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
-  if (!hasDeviceTypeValues(getBindName...
[truncated]

Copy link

⚠️ We detected that you are using a GitHub private e-mail address to contribute to the repo.
Please turn off Keep my email addresses private setting in your account.
See LLVM Developer Policy and LLVM Discourse for more information.

Copy link
Contributor

@clementval clementval left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@razvanlupusoru razvanlupusoru self-requested a review July 16, 2025 17:41
Copy link
Contributor

@razvanlupusoru razvanlupusoru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work!

@razvanlupusoru razvanlupusoru merged commit 0dae924 into llvm:main Jul 17, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category mlir:openacc mlir openacc
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants