@@ -84,10 +84,15 @@ class MapInfoFinalizationPass
8484 // / | |
8585 std::map<mlir::Operation *, mlir::Value> localBoxAllocas;
8686
87- // List of deferrable descriptors to process at the end of
88- // the pass.
87+ // / List of deferrable descriptors to process at the end of
88+ // / the pass.
8989 llvm::SmallVector<mlir::Operation *> deferrableDesc;
9090
91+ // / List of base addresses already expanded from their
92+ // / descriptors within a parent, currently used to
93+ // / prevent incorrect member index generation.
94+ std::map<mlir::Operation *, llvm::SmallVector<uint64_t >> expandedBaseAddr;
95+
9196 // / Return true if the given path exists in a list of paths.
9297 static bool
9398 containsPath (const llvm::SmallVectorImpl<llvm::SmallVector<int64_t >> &paths,
@@ -403,26 +408,38 @@ class MapInfoFinalizationPass
403408 // / of the base address index.
404409 void adjustMemberIndices (
405410 llvm::SmallVectorImpl<llvm::SmallVector<int64_t >> &memberIndices,
406- size_t memberIndex) {
407- llvm::SmallVector<int64_t > baseAddrIndex = memberIndices[memberIndex];
411+ ParentAndPlacement parentAndPlacement) {
412+ llvm::SmallVector<int64_t > baseAddrIndex =
413+ memberIndices[parentAndPlacement.index ];
414+ auto &expansionIndexes = expandedBaseAddr[parentAndPlacement.parent ];
408415
409416 // If we find another member that is "derived/a member of" the descriptor
410417 // that is not the descriptor itself, we must insert a 0 for the new base
411418 // address we have just added for the descriptor into the list at the
412419 // appropriate position to maintain correctness of the positional/index data
413420 // for that member.
414- for (llvm::SmallVector<int64_t > &member : memberIndices)
421+ for (auto [i, member] : llvm::enumerate (memberIndices)) {
422+ if (std::find (expansionIndexes.begin (), expansionIndexes.end (), i) !=
423+ expansionIndexes.end ())
424+ if (member.size () == baseAddrIndex.size () + 1 &&
425+ member[baseAddrIndex.size ()] == 0 )
426+ continue ;
427+
415428 if (member.size () > baseAddrIndex.size () &&
416429 std::equal (baseAddrIndex.begin (), baseAddrIndex.end (),
417430 member.begin ()))
418431 member.insert (std::next (member.begin (), baseAddrIndex.size ()), 0 );
432+ }
419433
420434 // Add the base address index to the main base address member data
421435 baseAddrIndex.push_back (0 );
422436
423- // Insert our newly created baseAddrIndex into the larger list of indices at
424- // the correct location.
425- memberIndices.insert (std::next (memberIndices.begin (), memberIndex + 1 ),
437+ uint64_t newIdxInsert = parentAndPlacement.index + 1 ;
438+ expansionIndexes.push_back (newIdxInsert);
439+
440+ // Insert our newly created baseAddrIndex into the larger list of
441+ // indices at the correct location.
442+ memberIndices.insert (std::next (memberIndices.begin (), newIdxInsert),
426443 baseAddrIndex);
427444 }
428445
@@ -449,30 +466,23 @@ class MapInfoFinalizationPass
449466 // / descriptor tag to it as it's used differently to a regular mapping
450467 // / and some of the runtime descriptor behaviour at the moment can cause
451468 // / issues.
452- mlir::omp::ClauseMapFlags getDescriptorMapType (mlir::omp::ClauseMapFlags mapTypeFlag,
453- mlir::Operation *target) {
469+ mlir::omp::ClauseMapFlags
470+ getDescriptorMapType (mlir::omp::ClauseMapFlags mapTypeFlag,
471+ mlir::Operation *target, bool isHasDeviceAddr) {
454472 using mapFlags = mlir::omp::ClauseMapFlags;
473+ mapFlags flags = mapFlags::none;
474+ if (!isHasDeviceAddr)
475+ flags |= mapFlags::attach;
476+
455477 if (llvm::isa_and_nonnull<mlir::omp::TargetExitDataOp,
456- mlir::omp::TargetUpdateOp>(target))
457- return mapTypeFlag;
458-
459- mapFlags flags = mapFlags::to | mapFlags::descriptor |
460- (mapTypeFlag & mapFlags::implicit);
461- // Descriptors for objects will always be copied. This is because the
462- // descriptor can be rematerialized by the compiler, and so the addres
463- // of the descriptor for a given object at one place in the code may
464- // differ from that address in another place. The contents of the
465- // descriptor (the base address in particular) will remain unchanged
466- // though.
467- // TODO/FIXME: We currently cannot have MAP_CLOSE and MAP_ALWAYS on
468- // the descriptor at once, these are mutually exclusive and when
469- // both are applied the runtime will fail to map.
470- flags |= ((mapTypeFlag & mapFlags::close) == mapFlags::close)
471- ? mapFlags::close
472- : mapFlags::always;
473- // For unified_shared_memory, we additionally add `CLOSE` on the descriptor
474- // to ensure device-local placement where required by tests relying on USM +
475- // close semantics.
478+ mlir::omp::TargetUpdateOp>(target)) {
479+ flags |= mapTypeFlag | mapFlags::descriptor;
480+ return flags;
481+ }
482+
483+ flags |= mapFlags::to | mapFlags::descriptor | mapFlags::always |
484+ (mapTypeFlag & mapFlags::implicit);
485+
476486 if (moduleRequiresUSM (target->getParentOfType <mlir::ModuleOp>()))
477487 flags |= mapFlags::close;
478488 return flags;
@@ -676,7 +686,7 @@ class MapInfoFinalizationPass
676686 auto baseAddr =
677687 genBaseAddrMap (descriptor, op.getBounds (), op.getMapType (), builder);
678688 ParentAndPlacement mapUser = mapMemberUsers[0 ];
679- adjustMemberIndices (memberIndices, mapUser. index );
689+ adjustMemberIndices (memberIndices, mapUser);
680690 llvm::SmallVector<mlir::Value> newMemberOps;
681691 for (auto v : mapUser.parent .getMembers ()) {
682692 newMemberOps.push_back (v);
@@ -706,7 +716,7 @@ class MapInfoFinalizationPass
706716 builder, op->getLoc (), op.getResult ().getType (), descriptor,
707717 mlir::TypeAttr::get (fir::unwrapRefType (descriptor.getType ())),
708718 builder.getAttr <mlir::omp::ClauseMapFlagsAttr>(
709- getDescriptorMapType (op.getMapType (), target)),
719+ getDescriptorMapType (op.getMapType (), target, isHasDeviceAddrFlag )),
710720 op.getMapCaptureTypeAttr (), /* varPtrPtr=*/ mlir::Value{}, newMembers,
711721 newMembersAttr, /* bounds=*/ mlir::SmallVector<mlir::Value>{},
712722 /* mapperId*/ mlir::FlatSymbolRefAttr (), op.getNameAttr (),
@@ -908,7 +918,8 @@ class MapInfoFinalizationPass
908918 op->getLoc (), op.getResult ().getType (), op.getVarPtr (),
909919 op.getVarTypeAttr (),
910920 builder.getAttr <mlir::omp::ClauseMapFlagsAttr>(
911- mlir::omp::ClauseMapFlags::to | mlir::omp::ClauseMapFlags::always |
921+ mlir::omp::ClauseMapFlags::to |
922+ mlir::omp::ClauseMapFlags::always |
912923 mlir::omp::ClauseMapFlags::descriptor),
913924 op.getMapCaptureTypeAttr (), /* varPtrPtr=*/ mlir::Value{},
914925 mlir::SmallVector<mlir::Value>{}, mlir::ArrayAttr{},
@@ -1007,6 +1018,7 @@ class MapInfoFinalizationPass
10071018 // iterations from previous function scopes.
10081019 localBoxAllocas.clear ();
10091020 deferrableDesc.clear ();
1021+ expandedBaseAddr.clear ();
10101022
10111023 // First, walk `omp.map.info` ops to see if any of them have varPtrs
10121024 // with an underlying type of fir.char<k, ?>, i.e a character
0 commit comments