Skip to content

Commit 266575c

Browse files
committed
[Prototype][Flang][OpenMP] Swap to attach semantics for descriptor mapping
1 parent e73d555 commit 266575c

20 files changed

+607
-507
lines changed

flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp

Lines changed: 45 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,15 @@ class MapInfoFinalizationPass
8484
/// | |
8585
std::map<mlir::Operation *, mlir::Value> localBoxAllocas;
8686

87-
// List of deferrable descriptors to process at the end of
88-
// the pass.
87+
/// List of deferrable descriptors to process at the end of
88+
/// the pass.
8989
llvm::SmallVector<mlir::Operation *> deferrableDesc;
9090

91+
/// List of base addresses already expanded from their
92+
/// descriptors within a parent, currently used to
93+
/// prevent incorrect member index generation.
94+
std::map<mlir::Operation *, llvm::SmallVector<uint64_t>> expandedBaseAddr;
95+
9196
/// Return true if the given path exists in a list of paths.
9297
static bool
9398
containsPath(const llvm::SmallVectorImpl<llvm::SmallVector<int64_t>> &paths,
@@ -403,26 +408,38 @@ class MapInfoFinalizationPass
403408
/// of the base address index.
404409
void adjustMemberIndices(
405410
llvm::SmallVectorImpl<llvm::SmallVector<int64_t>> &memberIndices,
406-
size_t memberIndex) {
407-
llvm::SmallVector<int64_t> baseAddrIndex = memberIndices[memberIndex];
411+
ParentAndPlacement parentAndPlacement) {
412+
llvm::SmallVector<int64_t> baseAddrIndex =
413+
memberIndices[parentAndPlacement.index];
414+
auto &expansionIndexes = expandedBaseAddr[parentAndPlacement.parent];
408415

409416
// If we find another member that is "derived/a member of" the descriptor
410417
// that is not the descriptor itself, we must insert a 0 for the new base
411418
// address we have just added for the descriptor into the list at the
412419
// appropriate position to maintain correctness of the positional/index data
413420
// for that member.
414-
for (llvm::SmallVector<int64_t> &member : memberIndices)
421+
for (auto [i, member] : llvm::enumerate(memberIndices)) {
422+
if (std::find(expansionIndexes.begin(), expansionIndexes.end(), i) !=
423+
expansionIndexes.end())
424+
if (member.size() == baseAddrIndex.size() + 1 &&
425+
member[baseAddrIndex.size()] == 0)
426+
continue;
427+
415428
if (member.size() > baseAddrIndex.size() &&
416429
std::equal(baseAddrIndex.begin(), baseAddrIndex.end(),
417430
member.begin()))
418431
member.insert(std::next(member.begin(), baseAddrIndex.size()), 0);
432+
}
419433

420434
// Add the base address index to the main base address member data
421435
baseAddrIndex.push_back(0);
422436

423-
// Insert our newly created baseAddrIndex into the larger list of indices at
424-
// the correct location.
425-
memberIndices.insert(std::next(memberIndices.begin(), memberIndex + 1),
437+
uint64_t newIdxInsert = parentAndPlacement.index + 1;
438+
expansionIndexes.push_back(newIdxInsert);
439+
440+
// Insert our newly created baseAddrIndex into the larger list of
441+
// indices at the correct location.
442+
memberIndices.insert(std::next(memberIndices.begin(), newIdxInsert),
426443
baseAddrIndex);
427444
}
428445

@@ -449,30 +466,23 @@ class MapInfoFinalizationPass
449466
/// descriptor tag to it as it's used differently to a regular mapping
450467
/// and some of the runtime descriptor behaviour at the moment can cause
451468
/// issues.
452-
mlir::omp::ClauseMapFlags getDescriptorMapType(mlir::omp::ClauseMapFlags mapTypeFlag,
453-
mlir::Operation *target) {
469+
mlir::omp::ClauseMapFlags
470+
getDescriptorMapType(mlir::omp::ClauseMapFlags mapTypeFlag,
471+
mlir::Operation *target, bool isHasDeviceAddr) {
454472
using mapFlags = mlir::omp::ClauseMapFlags;
473+
mapFlags flags = mapFlags::none;
474+
if (!isHasDeviceAddr)
475+
flags |= mapFlags::attach;
476+
455477
if (llvm::isa_and_nonnull<mlir::omp::TargetExitDataOp,
456-
mlir::omp::TargetUpdateOp>(target))
457-
return mapTypeFlag;
458-
459-
mapFlags flags = mapFlags::to | mapFlags::descriptor |
460-
(mapTypeFlag & mapFlags::implicit);
461-
// Descriptors for objects will always be copied. This is because the
462-
// descriptor can be rematerialized by the compiler, and so the addres
463-
// of the descriptor for a given object at one place in the code may
464-
// differ from that address in another place. The contents of the
465-
// descriptor (the base address in particular) will remain unchanged
466-
// though.
467-
// TODO/FIXME: We currently cannot have MAP_CLOSE and MAP_ALWAYS on
468-
// the descriptor at once, these are mutually exclusive and when
469-
// both are applied the runtime will fail to map.
470-
flags |= ((mapTypeFlag & mapFlags::close) == mapFlags::close)
471-
? mapFlags::close
472-
: mapFlags::always;
473-
// For unified_shared_memory, we additionally add `CLOSE` on the descriptor
474-
// to ensure device-local placement where required by tests relying on USM +
475-
// close semantics.
478+
mlir::omp::TargetUpdateOp>(target)) {
479+
flags |= mapTypeFlag | mapFlags::descriptor;
480+
return flags;
481+
}
482+
483+
flags |= mapFlags::to | mapFlags::descriptor | mapFlags::always |
484+
(mapTypeFlag & mapFlags::implicit);
485+
476486
if (moduleRequiresUSM(target->getParentOfType<mlir::ModuleOp>()))
477487
flags |= mapFlags::close;
478488
return flags;
@@ -676,7 +686,7 @@ class MapInfoFinalizationPass
676686
auto baseAddr =
677687
genBaseAddrMap(descriptor, op.getBounds(), op.getMapType(), builder);
678688
ParentAndPlacement mapUser = mapMemberUsers[0];
679-
adjustMemberIndices(memberIndices, mapUser.index);
689+
adjustMemberIndices(memberIndices, mapUser);
680690
llvm::SmallVector<mlir::Value> newMemberOps;
681691
for (auto v : mapUser.parent.getMembers()) {
682692
newMemberOps.push_back(v);
@@ -706,7 +716,7 @@ class MapInfoFinalizationPass
706716
builder, op->getLoc(), op.getResult().getType(), descriptor,
707717
mlir::TypeAttr::get(fir::unwrapRefType(descriptor.getType())),
708718
builder.getAttr<mlir::omp::ClauseMapFlagsAttr>(
709-
getDescriptorMapType(op.getMapType(), target)),
719+
getDescriptorMapType(op.getMapType(), target, isHasDeviceAddrFlag)),
710720
op.getMapCaptureTypeAttr(), /*varPtrPtr=*/mlir::Value{}, newMembers,
711721
newMembersAttr, /*bounds=*/mlir::SmallVector<mlir::Value>{},
712722
/*mapperId*/ mlir::FlatSymbolRefAttr(), op.getNameAttr(),
@@ -908,7 +918,8 @@ class MapInfoFinalizationPass
908918
op->getLoc(), op.getResult().getType(), op.getVarPtr(),
909919
op.getVarTypeAttr(),
910920
builder.getAttr<mlir::omp::ClauseMapFlagsAttr>(
911-
mlir::omp::ClauseMapFlags::to | mlir::omp::ClauseMapFlags::always |
921+
mlir::omp::ClauseMapFlags::to |
922+
mlir::omp::ClauseMapFlags::always |
912923
mlir::omp::ClauseMapFlags::descriptor),
913924
op.getMapCaptureTypeAttr(), /*varPtrPtr=*/mlir::Value{},
914925
mlir::SmallVector<mlir::Value>{}, mlir::ArrayAttr{},
@@ -1007,6 +1018,7 @@ class MapInfoFinalizationPass
10071018
// iterations from previous function scopes.
10081019
localBoxAllocas.clear();
10091020
deferrableDesc.clear();
1021+
expandedBaseAddr.clear();
10101022

10111023
// First, walk `omp.map.info` ops to see if any of them have varPtrs
10121024
// with an underlying type of fir.char<k, ?>, i.e a character

0 commit comments

Comments
 (0)