Skip to content

Commit e2eade4

Browse files
authored
[MLIR] [OpenMP] Initial support for OMP ALLOCATE directive op. (#147900)
This patch includes adding support for OMP ALLOCATE directive along with ALIGN clause and ALLOCATOR clause which are used within OMP ALLOCATE directive
1 parent 1db9eb2 commit e2eade4

File tree

6 files changed

+176
-0
lines changed

6 files changed

+176
-0
lines changed

mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,31 @@
2222
include "mlir/Dialect/OpenMP/OpenMPOpBase.td"
2323
include "mlir/IR/SymbolInterfaces.td"
2424

25+
//===----------------------------------------------------------------------===//
26+
// V5.2: [6.3] `align` clause
27+
//===----------------------------------------------------------------------===//
28+
29+
class OpenMP_AlignClauseSkip<
30+
bit traits = false, bit arguments = false, bit assemblyFormat = false,
31+
bit description = false, bit extraClassDeclaration = false
32+
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
33+
extraClassDeclaration> {
34+
let arguments = (ins
35+
ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$align
36+
);
37+
38+
let optAssemblyFormat = [{
39+
`align` `(` $align `)`
40+
}];
41+
42+
let description = [{
43+
The `align` clause is used to specify the byte alignment to use for
44+
allocations associated with the construct on which the clause appears.
45+
}];
46+
}
47+
48+
def OpenMP_AlignClause : OpenMP_AlignClauseSkip<>;
49+
2550
//===----------------------------------------------------------------------===//
2651
// V5.2: [5.11] `aligned` clause
2752
//===----------------------------------------------------------------------===//
@@ -84,6 +109,32 @@ class OpenMP_AllocateClauseSkip<
84109

85110
def OpenMP_AllocateClause : OpenMP_AllocateClauseSkip<>;
86111

112+
//===----------------------------------------------------------------------===//
113+
// V5.2: [6.4] `allocator` clause
114+
//===----------------------------------------------------------------------===//
115+
116+
class OpenMP_AllocatorClauseSkip<
117+
bit traits = false, bit arguments = false, bit assemblyFormat = false,
118+
bit description = false, bit extraClassDeclaration = false
119+
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
120+
extraClassDeclaration> {
121+
122+
let arguments = (ins
123+
OptionalAttr<AllocatorHandleAttr>:$allocator
124+
);
125+
126+
let optAssemblyFormat = [{
127+
`allocator` `(` custom<ClauseAttr>($allocator) `)`
128+
}];
129+
130+
let description = [{
131+
`allocator` specifies the memory allocator to be used for allocations
132+
associated with the construct on which the clause appears.
133+
}];
134+
}
135+
136+
def OpenMP_AllocatorClause : OpenMP_AllocatorClauseSkip<>;
137+
87138
//===----------------------------------------------------------------------===//
88139
// LLVM OpenMP extension `ompx_bare` clause
89140
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,4 +263,34 @@ def VariableCaptureKindAttr : OpenMP_EnumAttr<VariableCaptureKind,
263263
let assemblyFormat = "`(` $value `)`";
264264
}
265265

266+
267+
//===----------------------------------------------------------------------===//
268+
// allocator_handle enum.
269+
//===----------------------------------------------------------------------===//
270+
271+
def OpenMP_AllocatorHandleNullAllocator : I32EnumAttrCase<"omp_null_allocator", 0>;
272+
def OpenMP_AllocatorHandleDefaultMemAlloc : I32EnumAttrCase<"omp_default_mem_alloc", 1>;
273+
def OpenMP_AllocatorHandleLargeCapMemAlloc : I32EnumAttrCase<"omp_large_cap_mem_alloc", 2>;
274+
def OpenMP_AllocatorHandleConstMemAlloc : I32EnumAttrCase<"omp_const_mem_alloc", 3>;
275+
def OpenMP_AllocatorHandleHighBwMemAlloc : I32EnumAttrCase<"omp_high_bw_mem_alloc", 4>;
276+
def OpenMP_AllocatorHandleLowLatMemAlloc : I32EnumAttrCase<"omp_low_lat_mem_alloc", 5>;
277+
def OpenMP_AllocatorHandleCgroupMemAlloc : I32EnumAttrCase<"omp_cgroup_mem_alloc", 6>;
278+
def OpenMP_AllocatorHandlePteamMemAlloc : I32EnumAttrCase<"omp_pteam_mem_alloc", 7>;
279+
def OpenMP_AllocatorHandlethreadMemAlloc : I32EnumAttrCase<"omp_thread_mem_alloc", 8>;
280+
281+
def AllocatorHandle : OpenMP_I32EnumAttr<
282+
"AllocatorHandle",
283+
"OpenMP allocator_handle", [
284+
OpenMP_AllocatorHandleNullAllocator,
285+
OpenMP_AllocatorHandleDefaultMemAlloc,
286+
OpenMP_AllocatorHandleLargeCapMemAlloc,
287+
OpenMP_AllocatorHandleConstMemAlloc,
288+
OpenMP_AllocatorHandleHighBwMemAlloc,
289+
OpenMP_AllocatorHandleLowLatMemAlloc,
290+
OpenMP_AllocatorHandleCgroupMemAlloc,
291+
OpenMP_AllocatorHandlePteamMemAlloc,
292+
OpenMP_AllocatorHandlethreadMemAlloc
293+
]>;
294+
295+
def AllocatorHandleAttr : OpenMP_EnumAttr<AllocatorHandle, "allocator_handle">;
266296
#endif // OPENMP_ENUMS

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2090,4 +2090,27 @@ def MaskedOp : OpenMP_Op<"masked", clauses = [
20902090
];
20912091
}
20922092

2093+
//===----------------------------------------------------------------------===//
2094+
// [Spec 5.2] 6.5 allocate Directive
2095+
//===----------------------------------------------------------------------===//
2096+
def AllocateDirOp : OpenMP_Op<"allocate_dir", clauses = [
2097+
OpenMP_AlignClause, OpenMP_AllocatorClause
2098+
]> {
2099+
let summary = "allocate directive";
2100+
let description = [{
2101+
The storage for each list item that appears in the allocate directive is
2102+
provided an allocation through the memory allocator.
2103+
}] # clausesDescription;
2104+
2105+
let arguments = !con((ins Variadic<AnyType>:$varList),
2106+
clausesArgs);
2107+
2108+
// Override inherited assembly format to include `varList`.
2109+
let assemblyFormat = " `(` $varList `:` type($varList) `)` oilist(" #
2110+
clausesOptAssemblyFormat #
2111+
") attr-dict ";
2112+
2113+
let hasVerifier = 1;
2114+
}
2115+
20932116
#endif // OPENMP_OPS

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "llvm/ADT/StringExtras.h"
3434
#include "llvm/ADT/StringRef.h"
3535
#include "llvm/ADT/TypeSwitch.h"
36+
#include "llvm/ADT/bit.h"
3637
#include "llvm/Frontend/OpenMP/OMPConstants.h"
3738
#include "llvm/Frontend/OpenMP/OMPDeviceConstants.h"
3839
#include <cstddef>
@@ -3863,6 +3864,20 @@ LogicalResult ScanOp::verify() {
38633864
"reduction modifier");
38643865
}
38653866

3867+
/// Verifies align clause in allocate directive
3868+
3869+
LogicalResult AllocateDirOp::verify() {
3870+
std::optional<u_int64_t> align = this->getAlign();
3871+
3872+
if (align.has_value()) {
3873+
if ((align.value() > 0) && !llvm::has_single_bit(align.value()))
3874+
return emitError() << "ALIGN value : " << align.value()
3875+
<< " must be power of 2";
3876+
}
3877+
3878+
return success();
3879+
}
3880+
38663881
#define GET_ATTRDEF_CLASSES
38673882
#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
38683883

mlir/test/Dialect/OpenMP/invalid.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2993,3 +2993,27 @@ llvm.func @invalid_mapper(%0 : !llvm.ptr) {
29932993
}
29942994
llvm.return
29952995
}
2996+
2997+
// -----
2998+
func.func @invalid_allocate_align_1(%arg0 : memref<i32>) -> () {
2999+
// expected-error @below {{failed to satisfy constraint: 64-bit signless integer attribute whose value is positive}}
3000+
omp.allocate_dir (%arg0 : memref<i32>) align(-1)
3001+
3002+
return
3003+
}
3004+
3005+
// -----
3006+
func.func @invalid_allocate_align_2(%arg0 : memref<i32>) -> () {
3007+
// expected-error @below {{must be power of 2}}
3008+
omp.allocate_dir (%arg0 : memref<i32>) align(3)
3009+
3010+
return
3011+
}
3012+
3013+
// -----
3014+
func.func @invalid_allocate_allocator(%arg0 : memref<i32>) -> () {
3015+
// expected-error @below {{invalid clause value}}
3016+
omp.allocate_dir (%arg0 : memref<i32>) allocator(omp_small_cap_mem_alloc)
3017+
3018+
return
3019+
}

mlir/test/Dialect/OpenMP/ops.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3197,3 +3197,36 @@ func.func @omp_workshare_loop_wrapper_attrs(%idx : index) {
31973197
}
31983198
return
31993199
}
3200+
3201+
// CHECK-LABEL: func.func @omp_allocate_dir(
3202+
// CHECK-SAME: %[[ARG0:.*]]: memref<i32>,
3203+
// CHECK-SAME: %[[ARG1:.*]]: memref<i32>) {
3204+
func.func @omp_allocate_dir(%arg0 : memref<i32>, %arg1 : memref<i32>) -> () {
3205+
3206+
// Test with one data var
3207+
// CHECK: omp.allocate_dir(%[[ARG0]] : memref<i32>)
3208+
omp.allocate_dir (%arg0 : memref<i32>)
3209+
3210+
// Test with two data vars
3211+
// CHECK: omp.allocate_dir(%[[ARG0]], %[[ARG1]] : memref<i32>, memref<i32>)
3212+
omp.allocate_dir (%arg0, %arg1: memref<i32>, memref<i32>)
3213+
3214+
// Test with one data var and align clause
3215+
// CHECK: omp.allocate_dir(%[[ARG0]] : memref<i32>) align(2)
3216+
omp.allocate_dir (%arg0 : memref<i32>) align(2)
3217+
3218+
// Test with one data var and allocator clause
3219+
// CHECK: omp.allocate_dir(%[[ARG0]] : memref<i32>) allocator(omp_pteam_mem_alloc)
3220+
omp.allocate_dir (%arg0 : memref<i32>) allocator(omp_pteam_mem_alloc)
3221+
3222+
// Test with one data var, align clause and allocator clause
3223+
// CHECK: omp.allocate_dir(%[[ARG0]] : memref<i32>) align(2) allocator(omp_thread_mem_alloc)
3224+
omp.allocate_dir (%arg0 : memref<i32>) align(2) allocator(omp_thread_mem_alloc)
3225+
3226+
// Test with two data vars, align clause and allocator clause
3227+
// CHECK: omp.allocate_dir(%[[ARG0]], %[[ARG1]] : memref<i32>, memref<i32>) align(2) allocator(omp_cgroup_mem_alloc)
3228+
omp.allocate_dir (%arg0, %arg1 : memref<i32>, memref<i32>) align(2) allocator(omp_cgroup_mem_alloc)
3229+
3230+
return
3231+
}
3232+

0 commit comments

Comments
 (0)