Skip to content

[openacc][flang] Support two type bindName representation in acc routine #149147

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 71 additions & 31 deletions flang/lib/Lower/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4396,10 +4396,34 @@ getAttributeValueByDeviceType(llvm::SmallVector<mlir::Attribute> &attributes,
return std::nullopt;
}

// Helper function to extract string value from bind name variant
static std::optional<llvm::StringRef> getBindNameStringValue(
const std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
&bindNameValue) {
if (!bindNameValue.has_value())
return std::nullopt;

return std::visit(
[](const auto &attr) -> std::optional<llvm::StringRef> {
if constexpr (std::is_same_v<std::decay_t<decltype(attr)>,
mlir::StringAttr>) {
return attr.getValue();
} else if constexpr (std::is_same_v<std::decay_t<decltype(attr)>,
mlir::SymbolRefAttr>) {
return attr.getLeafReference();
} else {
return std::nullopt;
}
},
bindNameValue.value());
}

static bool compareDeviceTypeInfo(
mlir::acc::RoutineOp op,
llvm::SmallVector<mlir::Attribute> &bindNameArrayAttr,
llvm::SmallVector<mlir::Attribute> &bindNameDeviceTypeArrayAttr,
llvm::SmallVector<mlir::Attribute> &bindIdNameArrayAttr,
llvm::SmallVector<mlir::Attribute> &bindStrNameArrayAttr,
llvm::SmallVector<mlir::Attribute> &bindIdNameDeviceTypeArrayAttr,
llvm::SmallVector<mlir::Attribute> &bindStrNameDeviceTypeArrayAttr,
llvm::SmallVector<mlir::Attribute> &gangArrayAttr,
llvm::SmallVector<mlir::Attribute> &gangDimArrayAttr,
llvm::SmallVector<mlir::Attribute> &gangDimDeviceTypeArrayAttr,
Expand All @@ -4409,9 +4433,13 @@ static bool compareDeviceTypeInfo(
for (uint32_t dtypeInt = 0;
dtypeInt != mlir::acc::getMaxEnumValForDeviceType(); ++dtypeInt) {
auto dtype = static_cast<mlir::acc::DeviceType>(dtypeInt);
if (op.getBindNameValue(dtype) !=
getAttributeValueByDeviceType<llvm::StringRef, mlir::StringAttr>(
bindNameArrayAttr, bindNameDeviceTypeArrayAttr, dtype))
auto bindNameValue = getBindNameStringValue(op.getBindNameValue(dtype));
if (bindNameValue !=
getAttributeValueByDeviceType<llvm::StringRef, mlir::StringAttr>(
bindIdNameArrayAttr, bindIdNameDeviceTypeArrayAttr, dtype) &&
bindNameValue !=
getAttributeValueByDeviceType<llvm::StringRef, mlir::StringAttr>(
bindStrNameArrayAttr, bindStrNameDeviceTypeArrayAttr, dtype))
return false;
if (op.hasGang(dtype) != hasDeviceType(gangArrayAttr, dtype))
return false;
Expand Down Expand Up @@ -4458,8 +4486,10 @@ getArrayAttrOrNull(fir::FirOpBuilder &builder,
void createOpenACCRoutineConstruct(
Fortran::lower::AbstractConverter &converter, mlir::Location loc,
mlir::ModuleOp mod, mlir::func::FuncOp funcOp, std::string funcName,
bool hasNohost, llvm::SmallVector<mlir::Attribute> &bindNames,
llvm::SmallVector<mlir::Attribute> &bindNameDeviceTypes,
bool hasNohost, llvm::SmallVector<mlir::Attribute> &bindIdNames,
llvm::SmallVector<mlir::Attribute> &bindStrNames,
llvm::SmallVector<mlir::Attribute> &bindIdNameDeviceTypes,
llvm::SmallVector<mlir::Attribute> &bindStrNameDeviceTypes,
llvm::SmallVector<mlir::Attribute> &gangDeviceTypes,
llvm::SmallVector<mlir::Attribute> &gangDimValues,
llvm::SmallVector<mlir::Attribute> &gangDimDeviceTypes,
Expand All @@ -4472,7 +4502,8 @@ void createOpenACCRoutineConstruct(
0) {
// If the routine is already specified with the same clauses, just skip
// the operation creation.
if (compareDeviceTypeInfo(routineOp, bindNames, bindNameDeviceTypes,
if (compareDeviceTypeInfo(routineOp, bindIdNames, bindStrNames,
bindIdNameDeviceTypes, bindStrNameDeviceTypes,
gangDeviceTypes, gangDimValues,
gangDimDeviceTypes, seqDeviceTypes,
workerDeviceTypes, vectorDeviceTypes) &&
Expand All @@ -4489,8 +4520,10 @@ void createOpenACCRoutineConstruct(
modBuilder.create<mlir::acc::RoutineOp>(
loc, routineOpStr,
mlir::SymbolRefAttr::get(builder.getContext(), funcName),
getArrayAttrOrNull(builder, bindNames),
getArrayAttrOrNull(builder, bindNameDeviceTypes),
getArrayAttrOrNull(builder, bindIdNames),
getArrayAttrOrNull(builder, bindStrNames),
getArrayAttrOrNull(builder, bindIdNameDeviceTypes),
getArrayAttrOrNull(builder, bindStrNameDeviceTypes),
getArrayAttrOrNull(builder, workerDeviceTypes),
getArrayAttrOrNull(builder, vectorDeviceTypes),
getArrayAttrOrNull(builder, seqDeviceTypes), hasNohost,
Expand All @@ -4507,8 +4540,10 @@ static void interpretRoutineDeviceInfo(
llvm::SmallVector<mlir::Attribute> &seqDeviceTypes,
llvm::SmallVector<mlir::Attribute> &vectorDeviceTypes,
llvm::SmallVector<mlir::Attribute> &workerDeviceTypes,
llvm::SmallVector<mlir::Attribute> &bindNameDeviceTypes,
llvm::SmallVector<mlir::Attribute> &bindNames,
llvm::SmallVector<mlir::Attribute> &bindIdNameDeviceTypes,
llvm::SmallVector<mlir::Attribute> &bindStrNameDeviceTypes,
llvm::SmallVector<mlir::Attribute> &bindIdNames,
llvm::SmallVector<mlir::Attribute> &bindStrNames,
llvm::SmallVector<mlir::Attribute> &gangDeviceTypes,
llvm::SmallVector<mlir::Attribute> &gangDimValues,
llvm::SmallVector<mlir::Attribute> &gangDimDeviceTypes) {
Expand Down Expand Up @@ -4541,16 +4576,18 @@ static void interpretRoutineDeviceInfo(
if (dinfo.bindNameOpt().has_value()) {
const auto &bindName = dinfo.bindNameOpt().value();
mlir::Attribute bindNameAttr;
if (const auto &bindStr{std::get_if<std::string>(&bindName)}) {
if (const auto &bindSym{
std::get_if<Fortran::semantics::SymbolRef>(&bindName)}) {
bindNameAttr = builder.getSymbolRefAttr(converter.mangleName(*bindSym));
bindIdNames.push_back(bindNameAttr);
bindIdNameDeviceTypes.push_back(getDeviceTypeAttr());
} else if (const auto &bindStr{std::get_if<std::string>(&bindName)}) {
bindNameAttr = builder.getStringAttr(*bindStr);
} else if (const auto &bindSym{
std::get_if<Fortran::semantics::SymbolRef>(&bindName)}) {
bindNameAttr = builder.getStringAttr(converter.mangleName(*bindSym));
bindStrNames.push_back(bindNameAttr);
bindStrNameDeviceTypes.push_back(getDeviceTypeAttr());
} else {
llvm_unreachable("Unsupported bind name type");
}
bindNames.push_back(bindNameAttr);
bindNameDeviceTypes.push_back(getDeviceTypeAttr());
}
}

Expand All @@ -4566,8 +4603,9 @@ void Fortran::lower::genOpenACCRoutineConstruct(
bool hasNohost{false};

llvm::SmallVector<mlir::Attribute> seqDeviceTypes, vectorDeviceTypes,
workerDeviceTypes, bindNameDeviceTypes, bindNames, gangDeviceTypes,
gangDimDeviceTypes, gangDimValues;
workerDeviceTypes, bindIdNameDeviceTypes, bindStrNameDeviceTypes,
bindIdNames, bindStrNames, gangDeviceTypes, gangDimDeviceTypes,
gangDimValues;

for (const Fortran::semantics::OpenACCRoutineInfo &info : routineInfos) {
// Device Independent Attributes
Expand All @@ -4576,24 +4614,26 @@ void Fortran::lower::genOpenACCRoutineConstruct(
}
// Note: Device Independent Attributes are set to the
// none device type in `info`.
interpretRoutineDeviceInfo(converter, info, seqDeviceTypes,
vectorDeviceTypes, workerDeviceTypes,
bindNameDeviceTypes, bindNames, gangDeviceTypes,
gangDimValues, gangDimDeviceTypes);
interpretRoutineDeviceInfo(
converter, info, seqDeviceTypes, vectorDeviceTypes, workerDeviceTypes,
bindIdNameDeviceTypes, bindStrNameDeviceTypes, bindIdNames,
bindStrNames, gangDeviceTypes, gangDimValues, gangDimDeviceTypes);

// Device Dependent Attributes
for (const Fortran::semantics::OpenACCRoutineDeviceTypeInfo &dinfo :
info.deviceTypeInfos()) {
interpretRoutineDeviceInfo(
converter, dinfo, seqDeviceTypes, vectorDeviceTypes,
workerDeviceTypes, bindNameDeviceTypes, bindNames, gangDeviceTypes,
gangDimValues, gangDimDeviceTypes);
interpretRoutineDeviceInfo(converter, dinfo, seqDeviceTypes,
vectorDeviceTypes, workerDeviceTypes,
bindIdNameDeviceTypes, bindStrNameDeviceTypes,
bindIdNames, bindStrNames, gangDeviceTypes,
gangDimValues, gangDimDeviceTypes);
}
}
createOpenACCRoutineConstruct(
converter, loc, mod, funcOp, funcName, hasNohost, bindNames,
bindNameDeviceTypes, gangDeviceTypes, gangDimValues, gangDimDeviceTypes,
seqDeviceTypes, workerDeviceTypes, vectorDeviceTypes);
converter, loc, mod, funcOp, funcName, hasNohost, bindIdNames,
bindStrNames, bindIdNameDeviceTypes, bindStrNameDeviceTypes,
gangDeviceTypes, gangDimValues, gangDimDeviceTypes, seqDeviceTypes,
workerDeviceTypes, vectorDeviceTypes);
}

static void
Expand Down
7 changes: 4 additions & 3 deletions flang/test/Lower/OpenACC/acc-routine.f90
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s

! CHECK: acc.routine @[[r14:.*]] func(@_QPacc_routine19) bind("_QPacc_routine17" [#acc.device_type<host>], "_QPacc_routine17" [#acc.device_type<default>], "_QPacc_routine16" [#acc.device_type<multicore>])
! CHECK: acc.routine @[[r13:.*]] func(@_QPacc_routine18) bind("_QPacc_routine17" [#acc.device_type<host>], "_QPacc_routine16" [#acc.device_type<multicore>])
! CHECK: acc.routine @[[r14:.*]] func(@_QPacc_routine19) bind(@_QPacc_routine17 [#acc.device_type<host>], @_QPacc_routine17
! [#acc.device_type<default>], @_QPacc_routine16 [#acc.device_type<multicore>])
! CHECK: acc.routine @[[r13:.*]] func(@_QPacc_routine18) bind(@_QPacc_routine17 [#acc.device_type<host>], @_QPacc_routine16 [#acc.device_type<multicore>])
! CHECK: acc.routine @[[r12:.*]] func(@_QPacc_routine17) worker ([#acc.device_type<host>]) vector ([#acc.device_type<multicore>])
! CHECK: acc.routine @[[r11:.*]] func(@_QPacc_routine16) gang([#acc.device_type<nvidia>]) seq ([#acc.device_type<host>])
! CHECK: acc.routine @[[r10:.*]] func(@_QPacc_routine11) seq
! CHECK: acc.routine @[[r09:.*]] func(@_QPacc_routine10) seq
! CHECK: acc.routine @[[r08:.*]] func(@_QPacc_routine9) bind("_QPacc_routine9a")
! CHECK: acc.routine @[[r08:.*]] func(@_QPacc_routine9) bind(@_QPacc_routine9a)
! CHECK: acc.routine @[[r07:.*]] func(@_QPacc_routine8) bind("routine8_")
! CHECK: acc.routine @[[r06:.*]] func(@_QPacc_routine7) gang(dim: 1 : i64)
! CHECK: acc.routine @[[r05:.*]] func(@_QPacc_routine6) nohost
Expand Down
2 changes: 1 addition & 1 deletion flang/test/Lower/OpenACC/acc-routine03.f90
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,6 @@ subroutine sub2(a)
end subroutine

! CHECK: acc.routine @acc_routine_1 func(@_QPsub2) worker nohost
! CHECK: acc.routine @acc_routine_0 func(@_QPsub1) bind("_QPsub2") worker
! CHECK: acc.routine @acc_routine_0 func(@_QPsub1) bind(@_QPsub2) worker
! CHECK: func.func @_QPsub1(%arg0: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "a"}) attributes {acc.routine_info = #acc.routine_info<[@acc_routine_0]>}
! CHECK: func.func @_QPsub2(%arg0: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "a"}) attributes {acc.routine_info = #acc.routine_info<[@acc_routine_1]>}
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/OpenACC/OpenACC.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include <variant>

#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.h.inc"
Expand Down
12 changes: 7 additions & 5 deletions mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2772,8 +2772,10 @@ def OpenACC_RoutineOp : OpenACC_Op<"routine", [IsolatedFromAbove]> {
}];

let arguments = (ins SymbolNameAttr:$sym_name, SymbolRefAttr:$func_name,
OptionalAttr<StrArrayAttr>:$bindName,
OptionalAttr<DeviceTypeArrayAttr>:$bindNameDeviceType,
OptionalAttr<SymbolRefArrayAttr>:$bindIdName,
OptionalAttr<StrArrayAttr>:$bindStrName,
OptionalAttr<DeviceTypeArrayAttr>:$bindIdNameDeviceType,
OptionalAttr<DeviceTypeArrayAttr>:$bindStrNameDeviceType,
OptionalAttr<DeviceTypeArrayAttr>:$worker,
OptionalAttr<DeviceTypeArrayAttr>:$vector,
OptionalAttr<DeviceTypeArrayAttr>:$seq, UnitAttr:$nohost,
Expand Down Expand Up @@ -2815,14 +2817,14 @@ def OpenACC_RoutineOp : OpenACC_Op<"routine", [IsolatedFromAbove]> {
std::optional<int64_t> getGangDimValue();
std::optional<int64_t> getGangDimValue(mlir::acc::DeviceType deviceType);

std::optional<llvm::StringRef> getBindNameValue();
std::optional<llvm::StringRef> getBindNameValue(mlir::acc::DeviceType deviceType);
std::optional<::std::variant<mlir::SymbolRefAttr, mlir::StringAttr>> getBindNameValue();
std::optional<::std::variant<mlir::SymbolRefAttr, mlir::StringAttr>> getBindNameValue(mlir::acc::DeviceType deviceType);
}];

let assemblyFormat = [{
$sym_name `func` `(` $func_name `)`
oilist (
`bind` `(` custom<BindName>($bindName, $bindNameDeviceType) `)`
`bind` `(` custom<BindName>($bindIdName, $bindStrName ,$bindIdNameDeviceType, $bindStrNameDeviceType) `)`
| `gang` `` custom<RoutineGangClause>($gang, $gangDim, $gangDimDeviceType)
| `worker` custom<DeviceTypeArrayAttr>($worker)
| `vector` custom<DeviceTypeArrayAttr>($vector)
Expand Down
112 changes: 87 additions & 25 deletions mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/LogicalResult.h"
#include <variant>

using namespace mlir;
using namespace acc;
Expand Down Expand Up @@ -3461,40 +3462,88 @@ LogicalResult acc::RoutineOp::verify() {
return success();
}

static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindName,
mlir::ArrayAttr &deviceTypes) {
llvm::SmallVector<mlir::Attribute> bindNameAttrs;
llvm::SmallVector<mlir::Attribute> deviceTypeAttrs;
static ParseResult parseBindName(OpAsmParser &parser,
mlir::ArrayAttr &bindIdName,
mlir::ArrayAttr &bindStrName,
mlir::ArrayAttr &deviceIdTypes,
mlir::ArrayAttr &deviceStrTypes) {
llvm::SmallVector<mlir::Attribute> bindIdNameAttrs;
llvm::SmallVector<mlir::Attribute> bindStrNameAttrs;
llvm::SmallVector<mlir::Attribute> deviceIdTypeAttrs;
llvm::SmallVector<mlir::Attribute> deviceStrTypeAttrs;

if (failed(parser.parseCommaSeparatedList([&]() {
if (parser.parseAttribute(bindNameAttrs.emplace_back()))
mlir::Attribute newAttr;
bool isSymbolRefAttr;
auto parseResult = parser.parseAttribute(newAttr);
if (auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(newAttr)) {
bindIdNameAttrs.push_back(symbolRefAttr);
isSymbolRefAttr = true;
} else if (auto stringAttr = dyn_cast<mlir::StringAttr>(newAttr)) {
bindStrNameAttrs.push_back(stringAttr);
isSymbolRefAttr = false;
}
if (parseResult)
return failure();
if (failed(parser.parseOptionalLSquare())) {
deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
parser.getContext(), mlir::acc::DeviceType::None));
if (isSymbolRefAttr) {
deviceIdTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
parser.getContext(), mlir::acc::DeviceType::None));
} else {
deviceStrTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
parser.getContext(), mlir::acc::DeviceType::None));
}
} else {
if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
parser.parseRSquare())
return failure();
if (isSymbolRefAttr) {
if (parser.parseAttribute(deviceIdTypeAttrs.emplace_back()) ||
parser.parseRSquare())
return failure();
} else {
if (parser.parseAttribute(deviceStrTypeAttrs.emplace_back()) ||
parser.parseRSquare())
return failure();
}
}
return success();
})))
return failure();

bindName = ArrayAttr::get(parser.getContext(), bindNameAttrs);
deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
bindIdName = ArrayAttr::get(parser.getContext(), bindIdNameAttrs);
bindStrName = ArrayAttr::get(parser.getContext(), bindStrNameAttrs);
deviceIdTypes = ArrayAttr::get(parser.getContext(), deviceIdTypeAttrs);
deviceStrTypes = ArrayAttr::get(parser.getContext(), deviceStrTypeAttrs);

return success();
}

static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op,
std::optional<mlir::ArrayAttr> bindName,
std::optional<mlir::ArrayAttr> deviceTypes) {
llvm::interleaveComma(llvm::zip(*bindName, *deviceTypes), p,
[&](const auto &pair) {
p << std::get<0>(pair);
printSingleDeviceType(p, std::get<1>(pair));
});
std::optional<mlir::ArrayAttr> bindIdName,
std::optional<mlir::ArrayAttr> bindStrName,
std::optional<mlir::ArrayAttr> deviceIdTypes,
std::optional<mlir::ArrayAttr> deviceStrTypes) {
// Create combined vectors for all bind names and device types
llvm::SmallVector<mlir::Attribute> allBindNames;
llvm::SmallVector<mlir::Attribute> allDeviceTypes;

// Append bindIdName and deviceIdTypes
if (hasDeviceTypeValues(deviceIdTypes)) {
allBindNames.append(bindIdName->begin(), bindIdName->end());
allDeviceTypes.append(deviceIdTypes->begin(), deviceIdTypes->end());
}

// Append bindStrName and deviceStrTypes
if (hasDeviceTypeValues(deviceStrTypes)) {
allBindNames.append(bindStrName->begin(), bindStrName->end());
allDeviceTypes.append(deviceStrTypes->begin(), deviceStrTypes->end());
}

// Print the combined sequence
if (!allBindNames.empty())
llvm::interleaveComma(llvm::zip(allBindNames, allDeviceTypes), p,
[&](const auto &pair) {
p << std::get<0>(pair);
printSingleDeviceType(p, std::get<1>(pair));
});
}

static ParseResult parseRoutineGangClause(OpAsmParser &parser,
Expand Down Expand Up @@ -3654,19 +3703,32 @@ bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
return hasDeviceType(getSeq(), deviceType);
}

std::optional<llvm::StringRef> RoutineOp::getBindNameValue() {
std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
RoutineOp::getBindNameValue() {
return getBindNameValue(mlir::acc::DeviceType::None);
}

std::optional<llvm::StringRef>
std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
if (!hasDeviceTypeValues(getBindNameDeviceType()))
if (!hasDeviceTypeValues(getBindIdNameDeviceType()) &&
!hasDeviceTypeValues(getBindStrNameDeviceType())) {
return std::nullopt;
if (auto pos = findSegment(*getBindNameDeviceType(), deviceType)) {
auto attr = (*getBindName())[*pos];
}

if (auto pos = findSegment(*getBindIdNameDeviceType(), deviceType)) {
auto attr = (*getBindIdName())[*pos];
auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(attr);
assert(symbolRefAttr && "expected SymbolRef");
return symbolRefAttr;
}

if (auto pos = findSegment(*getBindStrNameDeviceType(), deviceType)) {
auto attr = (*getBindStrName())[*pos];
auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
return stringAttr.getValue();
assert(stringAttr && "expected String");
return stringAttr;
}

return std::nullopt;
}

Expand Down
Loading
Loading