Skip to content

Commit 364bd52

Browse files
committed
[lumen] add support for llvm call arg attributes, namely sret/byval/etc.
1 parent 4daf124 commit 364bd52

File tree

3 files changed

+105
-2
lines changed

3 files changed

+105
-2
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -536,13 +536,27 @@ def LLVM_CallOp : LLVM_Op<"call",
536536
let results = (outs Variadic<LLVM_Type>);
537537
let builders = [
538538
OpBuilder<(ins "LLVMFuncOp":$func, "ValueRange":$operands,
539-
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes), [{
540-
Type resultType = func.getFunctionType().getReturnType();
539+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes,
540+
CArg<"ArrayRef<DictionaryAttr>", "{}">:$argAttrs), [{
541+
auto funcType = func.getFunctionType().cast<LLVMFunctionType>();
542+
Type resultType = funcType.getReturnType();
541543
if (!resultType.isa<LLVM::LLVMVoidType>())
542544
$_state.addTypes(resultType);
543545
$_state.addAttribute("callee", SymbolRefAttr::get(func));
544546
$_state.addAttributes(attributes);
545547
$_state.addOperands(operands);
548+
if (argAttrs.size() > 0) {
549+
assert(funcType.getNumParams() == argAttrs.size() &&
550+
"expected as many argument attribute lists as arguments");
551+
auto nonEmptyAttrsFn = [](const ArrayRef<DictionaryAttr> &attrs) { return !attrs.empty(); };
552+
auto buildFn = [](ArrayRef<DictionaryAttr> attrs) {
553+
return ArrayRef<Attribute>(attrs.data(), attrs.size());
554+
};
555+
if (!argAttrs.empty() && llvm::any_of(argAttrs, nonEmptyAttrsFn)) {
556+
ArrayAttr attrDicts = $_builder.getArrayAttr(buildFn(argAttrs));
557+
$_state.addAttribute("arg_attrs", attrDicts);
558+
}
559+
}
546560
}]>,
547561
OpBuilder<(ins "TypeRange":$results, "StringAttr":$callee,
548562
CArg<"ValueRange", "{}">:$operands), [{
@@ -553,6 +567,7 @@ def LLVM_CallOp : LLVM_Op<"call",
553567
build($_builder, $_state, results,
554568
StringAttr::get($_builder.getContext(), callee), operands);
555569
}]>];
570+
556571
let hasCustomAssemblyFormat = 1;
557572
let hasVerifier = 1;
558573
}

mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,35 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
302302
callInst->addAttributeAtIndex(llvm::AttributeList::FunctionIndex, fnAttrKind);
303303
}
304304
}
305+
if (auto argAttrs = op.getAttrOfType<ArrayAttr>("arg_attrs")) {
306+
for (unsigned i = 0, e = operandsRef.size(); i < e; ++i) {
307+
ArrayRef<NamedAttribute> attrs;
308+
if (argAttrs)
309+
attrs = argAttrs[i].cast<DictionaryAttr>().getValue();
310+
for (auto argAttr : attrs) {
311+
StringRef argAttrKindName = argAttr.getName();
312+
if (argAttrKindName == "llvm.sret") {
313+
auto argPtrTy = op.getOperand(i).getType().cast<LLVMPointerType>();
314+
auto argTy = moduleTranslation.convertType(argPtrTy.getElementType());
315+
callInst->addAttributeAtIndex(llvm::AttributeList::FirstArgIndex + i, llvm::Attribute::get(moduleTranslation.getLLVMContext(),
316+
llvm::Attribute::StructRet, argTy));
317+
} else if (argAttrKindName == "llvm.byval") {
318+
auto tyAttr = argAttr.getValue().cast<TypeAttr>();
319+
auto argTy = moduleTranslation.convertType(tyAttr.getValue());
320+
callInst->addAttributeAtIndex(llvm::AttributeList::FirstArgIndex + i, llvm::Attribute::get(moduleTranslation.getLLVMContext(),
321+
llvm::Attribute::ByVal, argTy));
322+
} else if (auto unitAttr = argAttr.getValue().dyn_cast_or_null<UnitAttr>()) {
323+
llvm::Attribute::AttrKind argAttrKind;
324+
if (argAttrKindName.startswith("llvm."))
325+
argAttrKind = llvm::Attribute::getAttrKindFromName(argAttrKindName.drop_front(5));
326+
else
327+
argAttrKind = llvm::Attribute::getAttrKindFromName(argAttrKindName);
328+
if (argAttrKind != llvm::Attribute::None)
329+
callInst->addAttributeAtIndex(llvm::AttributeList::FirstArgIndex + i, argAttrKind);
330+
}
331+
}
332+
}
333+
}
305334
return callInst;
306335
};
307336

mlir/lib/Target/LLVMIR/ModuleTranslation.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -827,6 +827,7 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
827827
llvm::Argument &llvmArg = std::get<1>(kvp);
828828
BlockArgument mlirArg = std::get<0>(kvp);
829829

830+
/*
830831
if (auto attr = func.getArgAttrOfType<UnitAttr>(
831832
argIdx, LLVMDialect::getNoAliasAttrName())) {
832833
// NB: Attribute already verified to be boolean, so check if we can indeed
@@ -877,6 +878,7 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
877878
llvmArg.addAttrs(llvm::AttrBuilder(llvmArg.getContext())
878879
.addAttribute(llvm::Attribute::Nest));
879880
}
881+
*/
880882

881883
mapValue(mlirArg, &llvmArg);
882884
argIdx++;
@@ -937,6 +939,63 @@ LogicalResult ModuleTranslation::convertFunctionSignatures() {
937939
mapFunction(function.getName(), llvmFunc);
938940
addRuntimePreemptionSpecifier(function.getDsoLocal(), llvmFunc);
939941

942+
// Set any parameter attributes
943+
unsigned int argIdx = 0;
944+
for (auto kvp : llvm::zip(function.getFunctionType().getParams(), llvmFunc->args())) {
945+
Type argTy = std::get<0>(kvp);
946+
llvm::Argument &llvmArg = std::get<1>(kvp);
947+
948+
if (auto attr = function.getArgAttrOfType<UnitAttr>(
949+
argIdx, LLVMDialect::getNoAliasAttrName())) {
950+
// NB: Attribute already verified to be boolean, so check if we can indeed
951+
// attach the attribute to this argument, based on its type.
952+
if (!argTy.isa<LLVM::LLVMPointerType>())
953+
return function.emitError(
954+
"llvm.noalias attribute attached to LLVM non-pointer argument");
955+
llvmArg.addAttr(llvm::Attribute::AttrKind::NoAlias);
956+
}
957+
958+
if (auto attr = function.getArgAttrOfType<IntegerAttr>(
959+
argIdx, LLVMDialect::getAlignAttrName())) {
960+
// NB: Attribute already verified to be int, so check if we can indeed
961+
// attach the attribute to this argument, based on its type.
962+
if (!argTy.isa<LLVM::LLVMPointerType>())
963+
return function.emitError(
964+
"llvm.align attribute attached to LLVM non-pointer argument");
965+
llvmArg.addAttrs(llvm::AttrBuilder(llvmArg.getContext())
966+
.addAlignmentAttr(llvm::Align(attr.getInt())));
967+
}
968+
969+
if (auto attr = function.getArgAttrOfType<UnitAttr>(argIdx, "llvm.sret")) {
970+
auto argPtrTy = argTy.dyn_cast<LLVM::LLVMPointerType>();
971+
if (!argPtrTy)
972+
return function.emitError(
973+
"llvm.sret attribute attached to LLVM non-pointer argument");
974+
llvmArg.addAttrs(
975+
llvm::AttrBuilder(llvmArg.getContext())
976+
.addStructRetAttr(convertType(argPtrTy.getElementType())));
977+
}
978+
979+
if (auto attr = function.getArgAttrOfType<UnitAttr>(argIdx, "llvm.byval")) {
980+
auto argPtrTy = argTy.dyn_cast<LLVM::LLVMPointerType>();
981+
if (!argPtrTy)
982+
return function.emitError(
983+
"llvm.byval attribute attached to LLVM non-pointer argument");
984+
llvmArg.addAttrs(llvm::AttrBuilder(llvmArg.getContext())
985+
.addByValAttr(convertType(argPtrTy.getElementType())));
986+
}
987+
988+
if (auto attr = function.getArgAttrOfType<UnitAttr>(argIdx, "llvm.nest")) {
989+
if (!argTy.isa<LLVM::LLVMPointerType>())
990+
return function.emitError(
991+
"llvm.nest attribute attached to LLVM non-pointer argument");
992+
llvmArg.addAttrs(llvm::AttrBuilder(llvmArg.getContext())
993+
.addAttribute(llvm::Attribute::Nest));
994+
}
995+
996+
argIdx++;
997+
}
998+
940999
// Forward the pass-through attributes to LLVM.
9411000
if (failed(forwardPassthroughAttributes(
9421001
function.getLoc(), function.getPassthrough(), llvmFunc)))

0 commit comments

Comments
 (0)