Skip to content

[OpenMP][Flang] Emit default declare mappers implicitly for derived types #140562

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

TIFitis
Copy link
Member

@TIFitis TIFitis commented May 19, 2025

This patch adds support to emit default declare mappers for implicit mapping of derived types when not supplied by user. This especially helps tackle mapping of allocatables of derived types.

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir flang:openmp flang:semantics labels May 19, 2025
@llvmbot
Copy link
Member

llvmbot commented May 19, 2025

@llvm/pr-subscribers-flang-semantics

@llvm/pr-subscribers-flang-openmp

Author: Akash Banerjee (TIFitis)

Changes

This patch adds support to emit default declare mappers for implicit mapping of derived types when not supplied by user. This especially helps tackle mapping of allocatables of derived types.


Patch is 33.53 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140562.diff

5 Files Affected:

  • (modified) flang/lib/Lower/OpenMP/ClauseProcessor.cpp (+23-16)
  • (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+131-1)
  • (modified) flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp (+1-1)
  • (modified) flang/lib/Semantics/resolve-names.cpp (+47-48)
  • (modified) flang/test/Lower/OpenMP/derived-type-map.f90 (+21-1)
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 8dcc8be9be5bf..cf25e91a437b8 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -1102,23 +1102,30 @@ void ClauseProcessor::processMapObjects(
 
   auto getDefaultMapperID = [&](const omp::Object &object,
                                 std::string &mapperIdName) {
-    if (!mlir::isa<mlir::omp::DeclareMapperOp>(
-            firOpBuilder.getRegion().getParentOp())) {
-      const semantics::DerivedTypeSpec *typeSpec = nullptr;
-
-      if (object.sym()->owner().IsDerivedType())
-        typeSpec = object.sym()->owner().derivedTypeSpec();
-      else if (object.sym()->GetType() &&
-               object.sym()->GetType()->category() ==
-                   semantics::DeclTypeSpec::TypeDerived)
-        typeSpec = &object.sym()->GetType()->derivedTypeSpec();
-
-      if (typeSpec) {
-        mapperIdName = typeSpec->name().ToString() + ".omp.default.mapper";
-        if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName))
-          mapperIdName = converter.mangleName(mapperIdName, sym->owner());
-      }
+    const semantics::DerivedTypeSpec *typeSpec = nullptr;
+
+    if (object.sym()->GetType() && object.sym()->GetType()->category() ==
+                                       semantics::DeclTypeSpec::TypeDerived)
+      typeSpec = &object.sym()->GetType()->derivedTypeSpec();
+    else if (object.sym()->owner().IsDerivedType())
+      typeSpec = object.sym()->owner().derivedTypeSpec();
+
+    if (typeSpec) {
+      mapperIdName = typeSpec->name().ToString() + ".omp.default.mapper";
+      if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName))
+        mapperIdName = converter.mangleName(mapperIdName, sym->owner());
+      else
+        mapperIdName =
+            converter.mangleName(mapperIdName, *typeSpec->GetScope());
     }
+
+    // Make sure we don't return a mapper to self
+    llvm::StringRef parentOpName;
+    if (auto declMapOp = mlir::dyn_cast<mlir::omp::DeclareMapperOp>(
+            firOpBuilder.getRegion().getParentOp()))
+      parentOpName = declMapOp.getSymName();
+    if (mapperIdName == parentOpName)
+      mapperIdName = "";
   };
 
   // Create the mapper symbol from its name, if specified.
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index cfcba0159db8d..8d5c26a4f2d58 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -2348,6 +2348,124 @@ genSingleOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
       queue, item, clauseOps);
 }
 
+static mlir::FlatSymbolRefAttr
+genImplicitDefaultDeclareMapper(lower::AbstractConverter &converter,
+                                mlir::Location loc, fir::RecordType recordType,
+                                llvm::StringRef mapperNameStr) {
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+  lower::StatementContext stmtCtx;
+
+  // Save current insertion point before moving to the module scope to create
+  // the DeclareMapperOp
+  mlir::OpBuilder::InsertionGuard guard(firOpBuilder);
+
+  firOpBuilder.setInsertionPointToStart(converter.getModuleOp().getBody());
+  auto declMapperOp = firOpBuilder.create<mlir::omp::DeclareMapperOp>(
+      loc, mapperNameStr, recordType);
+  auto &region = declMapperOp.getRegion();
+  firOpBuilder.createBlock(&region);
+  auto mapperArg = region.addArgument(firOpBuilder.getRefType(recordType), loc);
+
+  auto declareOp =
+      firOpBuilder.create<hlfir::DeclareOp>(loc, mapperArg, /*uniq_name=*/"");
+
+  const auto genBoundsOps = [&](mlir::Value mapVal,
+                                llvm::SmallVectorImpl<mlir::Value> &bounds) {
+    fir::ExtendedValue extVal =
+        hlfir::translateToExtendedValue(mapVal.getLoc(), firOpBuilder,
+                                        hlfir::Entity{mapVal},
+                                        /*contiguousHint=*/true)
+            .first;
+    fir::factory::AddrAndBoundsInfo info = fir::factory::getDataOperandBaseAddr(
+        firOpBuilder, mapVal, /*isOptional=*/false, mapVal.getLoc());
+    bounds = fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
+                                                mlir::omp::MapBoundsType>(
+        firOpBuilder, info, extVal,
+        /*dataExvIsAssumedSize=*/false, mapVal.getLoc());
+  };
+
+  // Return a reference to the contents of a derived type with one field.
+  // Also return the field type.
+  const auto getFieldRef =
+      [&](mlir::Value rec,
+          unsigned index) -> std::tuple<mlir::Value, mlir::Type> {
+    auto recType = mlir::dyn_cast<fir::RecordType>(
+        fir::unwrapPassByRefType(rec.getType()));
+    auto [fieldName, fieldTy] = recType.getTypeList()[index];
+    mlir::Value field = firOpBuilder.create<fir::FieldIndexOp>(
+        loc, fir::FieldType::get(recType.getContext()), fieldName, recType,
+        fir::getTypeParams(rec));
+    return {firOpBuilder.create<fir::CoordinateOp>(
+                loc, firOpBuilder.getRefType(fieldTy), rec, field),
+            fieldTy};
+  };
+
+  mlir::omp::DeclareMapperInfoOperands clauseOps;
+  llvm::SmallVector<llvm::SmallVector<int64_t>> memberPlacementIndices;
+  llvm::SmallVector<mlir::Value> memberMapOps;
+
+  llvm::omp::OpenMPOffloadMappingFlags mapFlag =
+      llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
+      llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM |
+      llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
+  mlir::omp::VariableCaptureKind captureKind =
+      mlir::omp::VariableCaptureKind::ByRef;
+  int64_t index = 0;
+
+  // Populate the declareMapper region with the map information.
+  for (const auto &[memberName, memberType] :
+       mlir::dyn_cast<fir::RecordType>(recordType).getTypeList()) {
+    auto [ref, type] = getFieldRef(declareOp.getBase(), index);
+    mlir::FlatSymbolRefAttr mapperId;
+    if (auto recType = mlir::dyn_cast<fir::RecordType>(memberType)) {
+      std::string mapperIdName =
+          recType.getName().str() + ".omp.default.mapper";
+      if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName))
+        mapperIdName = converter.mangleName(mapperIdName, sym->owner());
+      else if (auto *sym = converter.getCurrentScope().FindSymbol(memberName))
+        mapperIdName = converter.mangleName(mapperIdName, sym->owner());
+
+      if (converter.getModuleOp().lookupSymbol(mapperIdName))
+        mapperId = mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(),
+                                                mapperIdName);
+      else
+        mapperId = genImplicitDefaultDeclareMapper(converter, loc, recType,
+                                                   mapperIdName);
+    }
+
+    llvm::SmallVector<mlir::Value> bounds;
+    genBoundsOps(ref, bounds);
+    mlir::Value mapOp = createMapInfoOp(
+        firOpBuilder, loc, ref, /*varPtrPtr=*/mlir::Value{}, "", bounds,
+        /*members=*/{},
+        /*membersIndex=*/mlir::ArrayAttr{},
+        static_cast<
+            std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
+            mapFlag),
+        captureKind, ref.getType(), /*partialMap=*/false, mapperId);
+    memberMapOps.emplace_back(mapOp);
+    memberPlacementIndices.emplace_back(llvm::SmallVector<int64_t>{index++});
+  }
+
+  llvm::SmallVector<mlir::Value> bounds;
+  genBoundsOps(declareOp.getOriginalBase(), bounds);
+  mlir::omp::MapInfoOp mapOp = createMapInfoOp(
+      firOpBuilder, loc, declareOp.getOriginalBase(),
+      /*varPtrPtr=*/mlir::Value(), /*name=*/"", bounds, memberMapOps,
+      firOpBuilder.create2DI64ArrayAttr(memberPlacementIndices),
+      static_cast<std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
+          mapFlag),
+      captureKind, declareOp.getType(0),
+      /*partialMap=*/true);
+
+  clauseOps.mapVars.emplace_back(mapOp);
+
+  firOpBuilder.create<mlir::omp::DeclareMapperInfoOp>(loc, clauseOps.mapVars);
+  // declMapperOp->dumpPretty();
+  return mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(),
+                                      mapperNameStr);
+}
+
 static mlir::omp::TargetOp
 genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
             lower::StatementContext &stmtCtx,
@@ -2420,15 +2538,26 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
       name << sym.name().ToString();
 
       mlir::FlatSymbolRefAttr mapperId;
-      if (sym.GetType()->category() == semantics::DeclTypeSpec::TypeDerived) {
+      if (sym.GetType()->category() == semantics::DeclTypeSpec::TypeDerived &&
+          defaultMaps.empty()) {
         auto &typeSpec = sym.GetType()->derivedTypeSpec();
         std::string mapperIdName =
             typeSpec.name().ToString() + ".omp.default.mapper";
         if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName))
           mapperIdName = converter.mangleName(mapperIdName, sym->owner());
+        else
+          mapperIdName =
+              converter.mangleName(mapperIdName, *typeSpec.GetScope());
+
         if (converter.getModuleOp().lookupSymbol(mapperIdName))
           mapperId = mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(),
                                                   mapperIdName);
+        else
+          mapperId = genImplicitDefaultDeclareMapper(
+              converter, loc,
+              mlir::cast<fir::RecordType>(
+                  converter.genType(sym.GetType()->derivedTypeSpec())),
+              mapperIdName);
       }
 
       fir::factory::AddrAndBoundsInfo info =
@@ -4039,6 +4168,7 @@ genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
   ClauseProcessor cp(converter, semaCtx, clauses);
   cp.processMap(loc, stmtCtx, clauseOps);
   firOpBuilder.create<mlir::omp::DeclareMapperInfoOp>(loc, clauseOps.mapVars);
+  // declMapperOp->dumpPretty();
 }
 
 static void
diff --git a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
index 86095d708f7e5..1ef6239e51fa8 100644
--- a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
+++ b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
@@ -377,7 +377,7 @@ class MapInfoFinalizationPass
                                    getDescriptorMapType(mapType, target)),
             op.getMapCaptureTypeAttr(), /*varPtrPtr=*/mlir::Value{}, newMembers,
             newMembersAttr, /*bounds=*/mlir::SmallVector<mlir::Value>{},
-            /*mapperId*/ mlir::FlatSymbolRefAttr(), op.getNameAttr(),
+            op.getMapperIdAttr(), op.getNameAttr(),
             /*partial_map=*/builder.getBoolAttr(false));
     op.replaceAllUsesWith(newDescParentMapOp.getResult());
     op->erase();
diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index 322562b06b87f..318af99a57848 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -2163,7 +2163,7 @@ void AttrsVisitor::SetBindNameOn(Symbol &symbol) {
   }
   symbol.SetBindName(std::move(*label));
   if (!oldBindName.empty()) {
-    if (const std::string * newBindName{symbol.GetBindName()}) {
+    if (const std::string *newBindName{symbol.GetBindName()}) {
       if (oldBindName != *newBindName) {
         Say(symbol.name(),
             "The entity '%s' has multiple BIND names ('%s' and '%s')"_err_en_US,
@@ -2285,7 +2285,7 @@ void DeclTypeSpecVisitor::Post(const parser::TypeSpec &typeSpec) {
   // expression semantics if the DeclTypeSpec is a valid TypeSpec.
   // The grammar ensures that it's an intrinsic or derived type spec,
   // not TYPE(*) or CLASS(*) or CLASS(T).
-  if (const DeclTypeSpec * spec{state_.declTypeSpec}) {
+  if (const DeclTypeSpec *spec{state_.declTypeSpec}) {
     switch (spec->category()) {
     case DeclTypeSpec::Numeric:
     case DeclTypeSpec::Logical:
@@ -2293,7 +2293,7 @@ void DeclTypeSpecVisitor::Post(const parser::TypeSpec &typeSpec) {
       typeSpec.declTypeSpec = spec;
       break;
     case DeclTypeSpec::TypeDerived:
-      if (const DerivedTypeSpec * derived{spec->AsDerived()}) {
+      if (const DerivedTypeSpec *derived{spec->AsDerived()}) {
         CheckForAbstractType(derived->typeSymbol()); // C703
         typeSpec.declTypeSpec = spec;
       }
@@ -2802,8 +2802,8 @@ Symbol &ScopeHandler::MakeSymbol(const parser::Name &name, Attrs attrs) {
 Symbol &ScopeHandler::MakeHostAssocSymbol(
     const parser::Name &name, const Symbol &hostSymbol) {
   Symbol &symbol{*NonDerivedTypeScope()
-                      .try_emplace(name.source, HostAssocDetails{hostSymbol})
-                      .first->second};
+          .try_emplace(name.source, HostAssocDetails{hostSymbol})
+          .first->second};
   name.symbol = &symbol;
   symbol.attrs() = hostSymbol.attrs(); // TODO: except PRIVATE, PUBLIC?
   // These attributes can be redundantly reapplied without error
@@ -2891,7 +2891,7 @@ void ScopeHandler::ApplyImplicitRules(
   if (context().HasError(symbol) || !NeedsType(symbol)) {
     return;
   }
-  if (const DeclTypeSpec * type{GetImplicitType(symbol)}) {
+  if (const DeclTypeSpec *type{GetImplicitType(symbol)}) {
     if (!skipImplicitTyping_) {
       symbol.set(Symbol::Flag::Implicit);
       symbol.SetType(*type);
@@ -2991,7 +2991,7 @@ const DeclTypeSpec *ScopeHandler::GetImplicitType(
   const auto *type{implicitRulesMap_->at(scope).GetType(
       symbol.name(), respectImplicitNoneType)};
   if (type) {
-    if (const DerivedTypeSpec * derived{type->AsDerived()}) {
+    if (const DerivedTypeSpec *derived{type->AsDerived()}) {
       // Resolve any forward-referenced derived type; a quick no-op else.
       auto &instantiatable{*const_cast<DerivedTypeSpec *>(derived)};
       instantiatable.Instantiate(currScope());
@@ -3706,10 +3706,10 @@ void ModuleVisitor::DoAddUse(SourceName location, SourceName localName,
     } else if (IsPointer(p1) || IsPointer(p2)) {
       return false;
     } else if (const auto *subp{p1.detailsIf<SubprogramDetails>()};
-               subp && !subp->isInterface()) {
+        subp && !subp->isInterface()) {
       return false; // defined in module, not an external
     } else if (const auto *subp{p2.detailsIf<SubprogramDetails>()};
-               subp && !subp->isInterface()) {
+        subp && !subp->isInterface()) {
       return false; // defined in module, not an external
     } else {
       // Both are external interfaces, perhaps to the same procedure
@@ -3969,7 +3969,7 @@ Scope *ModuleVisitor::FindModule(const parser::Name &name,
   if (scope) {
     if (DoesScopeContain(scope, currScope())) { // 14.2.2(1)
       std::optional<SourceName> submoduleName;
-      if (const Scope * container{FindModuleOrSubmoduleContaining(currScope())};
+      if (const Scope *container{FindModuleOrSubmoduleContaining(currScope())};
           container && container->IsSubmodule()) {
         submoduleName = container->GetName();
       }
@@ -4074,7 +4074,7 @@ bool InterfaceVisitor::isAbstract() const {
 
 void InterfaceVisitor::AddSpecificProcs(
     const std::list<parser::Name> &names, ProcedureKind kind) {
-  if (Symbol * symbol{GetGenericInfo().symbol};
+  if (Symbol *symbol{GetGenericInfo().symbol};
       symbol && symbol->has<GenericDetails>()) {
     for (const auto &name : names) {
       specificsForGenericProcs_.emplace(symbol, std::make_pair(&name, kind));
@@ -4174,7 +4174,7 @@ void GenericHandler::DeclaredPossibleSpecificProc(Symbol &proc) {
 }
 
 void InterfaceVisitor::ResolveNewSpecifics() {
-  if (Symbol * generic{genericInfo_.top().symbol};
+  if (Symbol *generic{genericInfo_.top().symbol};
       generic && generic->has<GenericDetails>()) {
     ResolveSpecificsInGeneric(*generic, false);
   }
@@ -4259,7 +4259,7 @@ bool SubprogramVisitor::HandleStmtFunction(const parser::StmtFunctionStmt &x) {
           name.source);
       MakeSymbol(name, Attrs{}, UnknownDetails{});
     } else if (auto *entity{ultimate.detailsIf<EntityDetails>()};
-               entity && !ultimate.has<ProcEntityDetails>()) {
+        entity && !ultimate.has<ProcEntityDetails>()) {
       resultType = entity->type();
       ultimate.details() = UnknownDetails{}; // will be replaced below
     } else {
@@ -4315,7 +4315,7 @@ bool SubprogramVisitor::Pre(const parser::Suffix &suffix) {
     } else {
       Message &msg{Say(*suffix.resultName,
           "RESULT(%s) may appear only in a function"_err_en_US)};
-      if (const Symbol * subprogram{InclusiveScope().symbol()}) {
+      if (const Symbol *subprogram{InclusiveScope().symbol()}) {
         msg.Attach(subprogram->name(), "Containing subprogram"_en_US);
       }
     }
@@ -4833,7 +4833,7 @@ Symbol *ScopeHandler::FindSeparateModuleProcedureInterface(
       symbol = generic->specific();
     }
   }
-  if (const Symbol * defnIface{FindSeparateModuleSubprogramInterface(symbol)}) {
+  if (const Symbol *defnIface{FindSeparateModuleSubprogramInterface(symbol)}) {
     // Error recovery in case of multiple definitions
     symbol = const_cast<Symbol *>(defnIface);
   }
@@ -4970,8 +4970,8 @@ bool SubprogramVisitor::HandlePreviousCalls(
     return generic->specific() &&
         HandlePreviousCalls(name, *generic->specific(), subpFlag);
   } else if (const auto *proc{symbol.detailsIf<ProcEntityDetails>()}; proc &&
-             !proc->isDummy() &&
-             !symbol.attrs().HasAny(Attrs{Attr::INTRINSIC, Attr::POINTER})) {
+      !proc->isDummy() &&
+      !symbol.attrs().HasAny(Attrs{Attr::INTRINSIC, Attr::POINTER})) {
     // There's a symbol created for previous calls to this subprogram or
     // ENTRY's name.  We have to replace that symbol in situ to avoid the
     // obligation to rewrite symbol pointers in the parse tree.
@@ -5012,7 +5012,7 @@ void SubprogramVisitor::CheckExtantProc(
   if (auto *prev{FindSymbol(name)}) {
     if (IsDummy(*prev)) {
     } else if (auto *entity{prev->detailsIf<EntityDetails>()};
-               IsPointer(*prev) && entity && !entity->type()) {
+        IsPointer(*prev) && entity && !entity->type()) {
       // POINTER attribute set before interface
     } else if (inInterfaceBlock() && currScope() != prev->owner()) {
       // Procedures in an INTERFACE block do not resolve to symbols
@@ -5068,7 +5068,7 @@ Symbol &SubprogramVisitor::PushSubprogramScope(const parser::Name &name,
     }
     set_inheritFromParent(false); // interfaces don't inherit, even if MODULE
   }
-  if (Symbol * found{FindSymbol(name)};
+  if (Symbol *found{FindSymbol(name)};
       found && found->has<HostAssocDetails>()) {
     found->set(subpFlag); // PushScope() created symbol
   }
@@ -5910,9 +5910,9 @@ void DeclarationVisitor::Post(const parser::VectorTypeSpec &x) {
     vectorDerivedType.CookParameters(GetFoldingContext());
   }
 
-  if (const DeclTypeSpec *
-      extant{ppcBuiltinTypesScope->FindInstantiatedDerivedType(
-          vectorDerivedType, DeclTypeSpec::Category::TypeDerived)}) {
+  if (const DeclTypeSpec *extant{
+          ppcBuiltinTypesScope->FindInstantiatedDerivedType(
+              vectorDerivedType, DeclTypeSpec::Category::TypeDerived)}) {
     // This derived type and parameter expressions (if any) are already present
     // in the __ppc_intrinsics scope.
     SetDeclTypeSpec(*extant);
@@ -5934,7 +5934,7 @@ bool DeclarationVisitor::Pre(const parser::DeclarationTypeSpec::Type &) {
 
 void DeclarationVisitor::Post(const parser::DeclarationTypeSpec::Type &type) {
   const parser::Name &derivedName{std::get<parser::Name>(type.derived.t)};
-  if (const Symbol * derivedSymbol{derivedName.symbol}) {
+  if (const Symbol *derivedSymbol{derivedName.symbol}) {
     CheckForAbstractType(*derivedSymbol); // C706
   }
 }
@@ -6003,8 +6003,8 @@ void DeclarationVisitor::Post(const parser::DerivedTypeSpec &x) {
   if (!spec->MightBeParameterized()) {
     spec->EvaluateParameters(context());
   }
-  if (const DeclTypeSpec *
-      extant{currScope().FindInstantiatedDerivedType(*spec, category)}) {
+  if (const DeclTypeSpec *extant{
+          currScope().FindInstantiatedDerivedType(*spec, category)}) {
     // This derived type and parameter expressions (if any) are already present
     // in this scope.
     SetDeclTypeSpec(*extant);
@@ -6035,8 +6035,7 @@ void DeclarationVisitor::Post(const parser::DeclarationTypeSpec::Record &rec) {
   if (auto spec{ResolveDerivedType(typeName)}) {
     spec->CookParameters(GetFoldingContext());
   ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented May 19, 2025

@llvm/pr-subscribers-flang-fir-hlfir

Author: Akash Banerjee (TIFitis)

Changes

This patch adds support to emit default declare mappers for implicit mapping of derived types when not supplied by user. This especially helps tackle mapping of allocatables of derived types.


Patch is 33.53 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140562.diff

5 Files Affected:

  • (modified) flang/lib/Lower/OpenMP/ClauseProcessor.cpp (+23-16)
  • (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+131-1)
  • (modified) flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp (+1-1)
  • (modified) flang/lib/Semantics/resolve-names.cpp (+47-48)
  • (modified) flang/test/Lower/OpenMP/derived-type-map.f90 (+21-1)
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 8dcc8be9be5bf..cf25e91a437b8 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -1102,23 +1102,30 @@ void ClauseProcessor::processMapObjects(
 
   auto getDefaultMapperID = [&](const omp::Object &object,
                                 std::string &mapperIdName) {
-    if (!mlir::isa<mlir::omp::DeclareMapperOp>(
-            firOpBuilder.getRegion().getParentOp())) {
-      const semantics::DerivedTypeSpec *typeSpec = nullptr;
-
-      if (object.sym()->owner().IsDerivedType())
-        typeSpec = object.sym()->owner().derivedTypeSpec();
-      else if (object.sym()->GetType() &&
-               object.sym()->GetType()->category() ==
-                   semantics::DeclTypeSpec::TypeDerived)
-        typeSpec = &object.sym()->GetType()->derivedTypeSpec();
-
-      if (typeSpec) {
-        mapperIdName = typeSpec->name().ToString() + ".omp.default.mapper";
-        if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName))
-          mapperIdName = converter.mangleName(mapperIdName, sym->owner());
-      }
+    const semantics::DerivedTypeSpec *typeSpec = nullptr;
+
+    if (object.sym()->GetType() && object.sym()->GetType()->category() ==
+                                       semantics::DeclTypeSpec::TypeDerived)
+      typeSpec = &object.sym()->GetType()->derivedTypeSpec();
+    else if (object.sym()->owner().IsDerivedType())
+      typeSpec = object.sym()->owner().derivedTypeSpec();
+
+    if (typeSpec) {
+      mapperIdName = typeSpec->name().ToString() + ".omp.default.mapper";
+      if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName))
+        mapperIdName = converter.mangleName(mapperIdName, sym->owner());
+      else
+        mapperIdName =
+            converter.mangleName(mapperIdName, *typeSpec->GetScope());
     }
+
+    // Make sure we don't return a mapper to self
+    llvm::StringRef parentOpName;
+    if (auto declMapOp = mlir::dyn_cast<mlir::omp::DeclareMapperOp>(
+            firOpBuilder.getRegion().getParentOp()))
+      parentOpName = declMapOp.getSymName();
+    if (mapperIdName == parentOpName)
+      mapperIdName = "";
   };
 
   // Create the mapper symbol from its name, if specified.
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index cfcba0159db8d..8d5c26a4f2d58 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -2348,6 +2348,124 @@ genSingleOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
       queue, item, clauseOps);
 }
 
+static mlir::FlatSymbolRefAttr
+genImplicitDefaultDeclareMapper(lower::AbstractConverter &converter,
+                                mlir::Location loc, fir::RecordType recordType,
+                                llvm::StringRef mapperNameStr) {
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+  lower::StatementContext stmtCtx;
+
+  // Save current insertion point before moving to the module scope to create
+  // the DeclareMapperOp
+  mlir::OpBuilder::InsertionGuard guard(firOpBuilder);
+
+  firOpBuilder.setInsertionPointToStart(converter.getModuleOp().getBody());
+  auto declMapperOp = firOpBuilder.create<mlir::omp::DeclareMapperOp>(
+      loc, mapperNameStr, recordType);
+  auto &region = declMapperOp.getRegion();
+  firOpBuilder.createBlock(&region);
+  auto mapperArg = region.addArgument(firOpBuilder.getRefType(recordType), loc);
+
+  auto declareOp =
+      firOpBuilder.create<hlfir::DeclareOp>(loc, mapperArg, /*uniq_name=*/"");
+
+  const auto genBoundsOps = [&](mlir::Value mapVal,
+                                llvm::SmallVectorImpl<mlir::Value> &bounds) {
+    fir::ExtendedValue extVal =
+        hlfir::translateToExtendedValue(mapVal.getLoc(), firOpBuilder,
+                                        hlfir::Entity{mapVal},
+                                        /*contiguousHint=*/true)
+            .first;
+    fir::factory::AddrAndBoundsInfo info = fir::factory::getDataOperandBaseAddr(
+        firOpBuilder, mapVal, /*isOptional=*/false, mapVal.getLoc());
+    bounds = fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
+                                                mlir::omp::MapBoundsType>(
+        firOpBuilder, info, extVal,
+        /*dataExvIsAssumedSize=*/false, mapVal.getLoc());
+  };
+
+  // Return a reference to the contents of a derived type with one field.
+  // Also return the field type.
+  const auto getFieldRef =
+      [&](mlir::Value rec,
+          unsigned index) -> std::tuple<mlir::Value, mlir::Type> {
+    auto recType = mlir::dyn_cast<fir::RecordType>(
+        fir::unwrapPassByRefType(rec.getType()));
+    auto [fieldName, fieldTy] = recType.getTypeList()[index];
+    mlir::Value field = firOpBuilder.create<fir::FieldIndexOp>(
+        loc, fir::FieldType::get(recType.getContext()), fieldName, recType,
+        fir::getTypeParams(rec));
+    return {firOpBuilder.create<fir::CoordinateOp>(
+                loc, firOpBuilder.getRefType(fieldTy), rec, field),
+            fieldTy};
+  };
+
+  mlir::omp::DeclareMapperInfoOperands clauseOps;
+  llvm::SmallVector<llvm::SmallVector<int64_t>> memberPlacementIndices;
+  llvm::SmallVector<mlir::Value> memberMapOps;
+
+  llvm::omp::OpenMPOffloadMappingFlags mapFlag =
+      llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
+      llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM |
+      llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
+  mlir::omp::VariableCaptureKind captureKind =
+      mlir::omp::VariableCaptureKind::ByRef;
+  int64_t index = 0;
+
+  // Populate the declareMapper region with the map information.
+  for (const auto &[memberName, memberType] :
+       mlir::dyn_cast<fir::RecordType>(recordType).getTypeList()) {
+    auto [ref, type] = getFieldRef(declareOp.getBase(), index);
+    mlir::FlatSymbolRefAttr mapperId;
+    if (auto recType = mlir::dyn_cast<fir::RecordType>(memberType)) {
+      std::string mapperIdName =
+          recType.getName().str() + ".omp.default.mapper";
+      if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName))
+        mapperIdName = converter.mangleName(mapperIdName, sym->owner());
+      else if (auto *sym = converter.getCurrentScope().FindSymbol(memberName))
+        mapperIdName = converter.mangleName(mapperIdName, sym->owner());
+
+      if (converter.getModuleOp().lookupSymbol(mapperIdName))
+        mapperId = mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(),
+                                                mapperIdName);
+      else
+        mapperId = genImplicitDefaultDeclareMapper(converter, loc, recType,
+                                                   mapperIdName);
+    }
+
+    llvm::SmallVector<mlir::Value> bounds;
+    genBoundsOps(ref, bounds);
+    mlir::Value mapOp = createMapInfoOp(
+        firOpBuilder, loc, ref, /*varPtrPtr=*/mlir::Value{}, "", bounds,
+        /*members=*/{},
+        /*membersIndex=*/mlir::ArrayAttr{},
+        static_cast<
+            std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
+            mapFlag),
+        captureKind, ref.getType(), /*partialMap=*/false, mapperId);
+    memberMapOps.emplace_back(mapOp);
+    memberPlacementIndices.emplace_back(llvm::SmallVector<int64_t>{index++});
+  }
+
+  llvm::SmallVector<mlir::Value> bounds;
+  genBoundsOps(declareOp.getOriginalBase(), bounds);
+  mlir::omp::MapInfoOp mapOp = createMapInfoOp(
+      firOpBuilder, loc, declareOp.getOriginalBase(),
+      /*varPtrPtr=*/mlir::Value(), /*name=*/"", bounds, memberMapOps,
+      firOpBuilder.create2DI64ArrayAttr(memberPlacementIndices),
+      static_cast<std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
+          mapFlag),
+      captureKind, declareOp.getType(0),
+      /*partialMap=*/true);
+
+  clauseOps.mapVars.emplace_back(mapOp);
+
+  firOpBuilder.create<mlir::omp::DeclareMapperInfoOp>(loc, clauseOps.mapVars);
+  // declMapperOp->dumpPretty();
+  return mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(),
+                                      mapperNameStr);
+}
+
 static mlir::omp::TargetOp
 genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
             lower::StatementContext &stmtCtx,
@@ -2420,15 +2538,26 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
       name << sym.name().ToString();
 
       mlir::FlatSymbolRefAttr mapperId;
-      if (sym.GetType()->category() == semantics::DeclTypeSpec::TypeDerived) {
+      if (sym.GetType()->category() == semantics::DeclTypeSpec::TypeDerived &&
+          defaultMaps.empty()) {
         auto &typeSpec = sym.GetType()->derivedTypeSpec();
         std::string mapperIdName =
             typeSpec.name().ToString() + ".omp.default.mapper";
         if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName))
           mapperIdName = converter.mangleName(mapperIdName, sym->owner());
+        else
+          mapperIdName =
+              converter.mangleName(mapperIdName, *typeSpec.GetScope());
+
         if (converter.getModuleOp().lookupSymbol(mapperIdName))
           mapperId = mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(),
                                                   mapperIdName);
+        else
+          mapperId = genImplicitDefaultDeclareMapper(
+              converter, loc,
+              mlir::cast<fir::RecordType>(
+                  converter.genType(sym.GetType()->derivedTypeSpec())),
+              mapperIdName);
       }
 
       fir::factory::AddrAndBoundsInfo info =
@@ -4039,6 +4168,7 @@ genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
   ClauseProcessor cp(converter, semaCtx, clauses);
   cp.processMap(loc, stmtCtx, clauseOps);
   firOpBuilder.create<mlir::omp::DeclareMapperInfoOp>(loc, clauseOps.mapVars);
+  // declMapperOp->dumpPretty();
 }
 
 static void
diff --git a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
index 86095d708f7e5..1ef6239e51fa8 100644
--- a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
+++ b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
@@ -377,7 +377,7 @@ class MapInfoFinalizationPass
                                    getDescriptorMapType(mapType, target)),
             op.getMapCaptureTypeAttr(), /*varPtrPtr=*/mlir::Value{}, newMembers,
             newMembersAttr, /*bounds=*/mlir::SmallVector<mlir::Value>{},
-            /*mapperId*/ mlir::FlatSymbolRefAttr(), op.getNameAttr(),
+            op.getMapperIdAttr(), op.getNameAttr(),
             /*partial_map=*/builder.getBoolAttr(false));
     op.replaceAllUsesWith(newDescParentMapOp.getResult());
     op->erase();
diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index 322562b06b87f..318af99a57848 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -2163,7 +2163,7 @@ void AttrsVisitor::SetBindNameOn(Symbol &symbol) {
   }
   symbol.SetBindName(std::move(*label));
   if (!oldBindName.empty()) {
-    if (const std::string * newBindName{symbol.GetBindName()}) {
+    if (const std::string *newBindName{symbol.GetBindName()}) {
       if (oldBindName != *newBindName) {
         Say(symbol.name(),
             "The entity '%s' has multiple BIND names ('%s' and '%s')"_err_en_US,
@@ -2285,7 +2285,7 @@ void DeclTypeSpecVisitor::Post(const parser::TypeSpec &typeSpec) {
   // expression semantics if the DeclTypeSpec is a valid TypeSpec.
   // The grammar ensures that it's an intrinsic or derived type spec,
   // not TYPE(*) or CLASS(*) or CLASS(T).
-  if (const DeclTypeSpec * spec{state_.declTypeSpec}) {
+  if (const DeclTypeSpec *spec{state_.declTypeSpec}) {
     switch (spec->category()) {
     case DeclTypeSpec::Numeric:
     case DeclTypeSpec::Logical:
@@ -2293,7 +2293,7 @@ void DeclTypeSpecVisitor::Post(const parser::TypeSpec &typeSpec) {
       typeSpec.declTypeSpec = spec;
       break;
     case DeclTypeSpec::TypeDerived:
-      if (const DerivedTypeSpec * derived{spec->AsDerived()}) {
+      if (const DerivedTypeSpec *derived{spec->AsDerived()}) {
         CheckForAbstractType(derived->typeSymbol()); // C703
         typeSpec.declTypeSpec = spec;
       }
@@ -2802,8 +2802,8 @@ Symbol &ScopeHandler::MakeSymbol(const parser::Name &name, Attrs attrs) {
 Symbol &ScopeHandler::MakeHostAssocSymbol(
     const parser::Name &name, const Symbol &hostSymbol) {
   Symbol &symbol{*NonDerivedTypeScope()
-                      .try_emplace(name.source, HostAssocDetails{hostSymbol})
-                      .first->second};
+          .try_emplace(name.source, HostAssocDetails{hostSymbol})
+          .first->second};
   name.symbol = &symbol;
   symbol.attrs() = hostSymbol.attrs(); // TODO: except PRIVATE, PUBLIC?
   // These attributes can be redundantly reapplied without error
@@ -2891,7 +2891,7 @@ void ScopeHandler::ApplyImplicitRules(
   if (context().HasError(symbol) || !NeedsType(symbol)) {
     return;
   }
-  if (const DeclTypeSpec * type{GetImplicitType(symbol)}) {
+  if (const DeclTypeSpec *type{GetImplicitType(symbol)}) {
     if (!skipImplicitTyping_) {
       symbol.set(Symbol::Flag::Implicit);
       symbol.SetType(*type);
@@ -2991,7 +2991,7 @@ const DeclTypeSpec *ScopeHandler::GetImplicitType(
   const auto *type{implicitRulesMap_->at(scope).GetType(
       symbol.name(), respectImplicitNoneType)};
   if (type) {
-    if (const DerivedTypeSpec * derived{type->AsDerived()}) {
+    if (const DerivedTypeSpec *derived{type->AsDerived()}) {
       // Resolve any forward-referenced derived type; a quick no-op else.
       auto &instantiatable{*const_cast<DerivedTypeSpec *>(derived)};
       instantiatable.Instantiate(currScope());
@@ -3706,10 +3706,10 @@ void ModuleVisitor::DoAddUse(SourceName location, SourceName localName,
     } else if (IsPointer(p1) || IsPointer(p2)) {
       return false;
     } else if (const auto *subp{p1.detailsIf<SubprogramDetails>()};
-               subp && !subp->isInterface()) {
+        subp && !subp->isInterface()) {
       return false; // defined in module, not an external
     } else if (const auto *subp{p2.detailsIf<SubprogramDetails>()};
-               subp && !subp->isInterface()) {
+        subp && !subp->isInterface()) {
       return false; // defined in module, not an external
     } else {
       // Both are external interfaces, perhaps to the same procedure
@@ -3969,7 +3969,7 @@ Scope *ModuleVisitor::FindModule(const parser::Name &name,
   if (scope) {
     if (DoesScopeContain(scope, currScope())) { // 14.2.2(1)
       std::optional<SourceName> submoduleName;
-      if (const Scope * container{FindModuleOrSubmoduleContaining(currScope())};
+      if (const Scope *container{FindModuleOrSubmoduleContaining(currScope())};
           container && container->IsSubmodule()) {
         submoduleName = container->GetName();
       }
@@ -4074,7 +4074,7 @@ bool InterfaceVisitor::isAbstract() const {
 
 void InterfaceVisitor::AddSpecificProcs(
     const std::list<parser::Name> &names, ProcedureKind kind) {
-  if (Symbol * symbol{GetGenericInfo().symbol};
+  if (Symbol *symbol{GetGenericInfo().symbol};
       symbol && symbol->has<GenericDetails>()) {
     for (const auto &name : names) {
       specificsForGenericProcs_.emplace(symbol, std::make_pair(&name, kind));
@@ -4174,7 +4174,7 @@ void GenericHandler::DeclaredPossibleSpecificProc(Symbol &proc) {
 }
 
 void InterfaceVisitor::ResolveNewSpecifics() {
-  if (Symbol * generic{genericInfo_.top().symbol};
+  if (Symbol *generic{genericInfo_.top().symbol};
       generic && generic->has<GenericDetails>()) {
     ResolveSpecificsInGeneric(*generic, false);
   }
@@ -4259,7 +4259,7 @@ bool SubprogramVisitor::HandleStmtFunction(const parser::StmtFunctionStmt &x) {
           name.source);
       MakeSymbol(name, Attrs{}, UnknownDetails{});
     } else if (auto *entity{ultimate.detailsIf<EntityDetails>()};
-               entity && !ultimate.has<ProcEntityDetails>()) {
+        entity && !ultimate.has<ProcEntityDetails>()) {
       resultType = entity->type();
       ultimate.details() = UnknownDetails{}; // will be replaced below
     } else {
@@ -4315,7 +4315,7 @@ bool SubprogramVisitor::Pre(const parser::Suffix &suffix) {
     } else {
       Message &msg{Say(*suffix.resultName,
           "RESULT(%s) may appear only in a function"_err_en_US)};
-      if (const Symbol * subprogram{InclusiveScope().symbol()}) {
+      if (const Symbol *subprogram{InclusiveScope().symbol()}) {
         msg.Attach(subprogram->name(), "Containing subprogram"_en_US);
       }
     }
@@ -4833,7 +4833,7 @@ Symbol *ScopeHandler::FindSeparateModuleProcedureInterface(
       symbol = generic->specific();
     }
   }
-  if (const Symbol * defnIface{FindSeparateModuleSubprogramInterface(symbol)}) {
+  if (const Symbol *defnIface{FindSeparateModuleSubprogramInterface(symbol)}) {
     // Error recovery in case of multiple definitions
     symbol = const_cast<Symbol *>(defnIface);
   }
@@ -4970,8 +4970,8 @@ bool SubprogramVisitor::HandlePreviousCalls(
     return generic->specific() &&
         HandlePreviousCalls(name, *generic->specific(), subpFlag);
   } else if (const auto *proc{symbol.detailsIf<ProcEntityDetails>()}; proc &&
-             !proc->isDummy() &&
-             !symbol.attrs().HasAny(Attrs{Attr::INTRINSIC, Attr::POINTER})) {
+      !proc->isDummy() &&
+      !symbol.attrs().HasAny(Attrs{Attr::INTRINSIC, Attr::POINTER})) {
     // There's a symbol created for previous calls to this subprogram or
     // ENTRY's name.  We have to replace that symbol in situ to avoid the
     // obligation to rewrite symbol pointers in the parse tree.
@@ -5012,7 +5012,7 @@ void SubprogramVisitor::CheckExtantProc(
   if (auto *prev{FindSymbol(name)}) {
     if (IsDummy(*prev)) {
     } else if (auto *entity{prev->detailsIf<EntityDetails>()};
-               IsPointer(*prev) && entity && !entity->type()) {
+        IsPointer(*prev) && entity && !entity->type()) {
       // POINTER attribute set before interface
     } else if (inInterfaceBlock() && currScope() != prev->owner()) {
       // Procedures in an INTERFACE block do not resolve to symbols
@@ -5068,7 +5068,7 @@ Symbol &SubprogramVisitor::PushSubprogramScope(const parser::Name &name,
     }
     set_inheritFromParent(false); // interfaces don't inherit, even if MODULE
   }
-  if (Symbol * found{FindSymbol(name)};
+  if (Symbol *found{FindSymbol(name)};
       found && found->has<HostAssocDetails>()) {
     found->set(subpFlag); // PushScope() created symbol
   }
@@ -5910,9 +5910,9 @@ void DeclarationVisitor::Post(const parser::VectorTypeSpec &x) {
     vectorDerivedType.CookParameters(GetFoldingContext());
   }
 
-  if (const DeclTypeSpec *
-      extant{ppcBuiltinTypesScope->FindInstantiatedDerivedType(
-          vectorDerivedType, DeclTypeSpec::Category::TypeDerived)}) {
+  if (const DeclTypeSpec *extant{
+          ppcBuiltinTypesScope->FindInstantiatedDerivedType(
+              vectorDerivedType, DeclTypeSpec::Category::TypeDerived)}) {
     // This derived type and parameter expressions (if any) are already present
     // in the __ppc_intrinsics scope.
     SetDeclTypeSpec(*extant);
@@ -5934,7 +5934,7 @@ bool DeclarationVisitor::Pre(const parser::DeclarationTypeSpec::Type &) {
 
 void DeclarationVisitor::Post(const parser::DeclarationTypeSpec::Type &type) {
   const parser::Name &derivedName{std::get<parser::Name>(type.derived.t)};
-  if (const Symbol * derivedSymbol{derivedName.symbol}) {
+  if (const Symbol *derivedSymbol{derivedName.symbol}) {
     CheckForAbstractType(*derivedSymbol); // C706
   }
 }
@@ -6003,8 +6003,8 @@ void DeclarationVisitor::Post(const parser::DerivedTypeSpec &x) {
   if (!spec->MightBeParameterized()) {
     spec->EvaluateParameters(context());
   }
-  if (const DeclTypeSpec *
-      extant{currScope().FindInstantiatedDerivedType(*spec, category)}) {
+  if (const DeclTypeSpec *extant{
+          currScope().FindInstantiatedDerivedType(*spec, category)}) {
     // This derived type and parameter expressions (if any) are already present
     // in this scope.
     SetDeclTypeSpec(*extant);
@@ -6035,8 +6035,7 @@ void DeclarationVisitor::Post(const parser::DeclarationTypeSpec::Record &rec) {
   if (auto spec{ResolveDerivedType(typeName)}) {
     spec->CookParameters(GetFoldingContext());
   ...
[truncated]

Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR implements implicit generation of default OpenMP declare mappers for derived types (including allocatable ones) when no user-supplied mapper is provided, updating tests and lowering passes to exercise and emit these mappers.

  • Added FileCheck tests for implicit declare mappers in derived-type-map.f90
  • Introduced genImplicitDefaultDeclareMapper in the OpenMP lowering to emit DeclareMapperOp regions
  • Wired through the new mapper IDs in the map info finalization and clause processing

Reviewed Changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
flang/test/Lower/OpenMP/derived-type-map.f90 Added CHECK lines for implicit declare mappers
flang/lib/Semantics/resolve-names.cpp Minor brace-formatting cleanup
flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp Use the real mapperIdAttr instead of a blank attr
flang/lib/Lower/OpenMP/OpenMP.cpp Added genImplicitDefaultDeclareMapper and hook up mapper IDs
flang/lib/Lower/OpenMP/ClauseProcessor.cpp Adjusted default-mapper lookup and avoid self-mapping
Comments suppressed due to low confidence (2)

flang/test/Lower/OpenMP/derived-type-map.f90:23

  • The uniq_name uses Escalar instead of Scalar. Correcting this typo will keep naming consistent with the derived type scalar_and_array.
!CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.box<!fir.heap<!fir.type<_QFmaptype_derived_implicit_allocatableTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>>> {bindc_name = "scalar_arr", uniq_name = "_QFmaptype_derived_implicit_allocatableEscalar_arr"}

flang/lib/Lower/OpenMP/OpenMP.cpp:2351

  • [nitpick] Consider adding a brief comment above genImplicitDefaultDeclareMapper to explain its purpose, inputs, and outputs to improve readability and maintainability.
static mlir::FlatSymbolRefAttr

mlir::Location loc, fir::RecordType recordType,
llvm::StringRef mapperNameStr) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
lower::StatementContext stmtCtx;
Copy link
Preview

Copilot AI May 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The local variable stmtCtx is declared but never used in this function. Removing it will clean up the code and avoid compiler warnings.

Suggested change
lower::StatementContext stmtCtx;

Copilot uses AI. Check for mistakes.

Copy link

github-actions bot commented May 19, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

…ypes

This patch adds support to emit default declare mappers for implicit mapping of derived types when not supplied by user. This especially helps tackle mapping of allocatables of derived types.
This supports nested derived types as well.
@TIFitis TIFitis force-pushed the users/Akash/implicit_default_mapper branch from 5d735f1 to 580e862 Compare May 19, 2025 15:41
Base automatically changed from users/Akash/mapper_semantic to main May 28, 2025 13:32
Copy link
Member

@skatrak skatrak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you Akash for this work, and excuse the delay!

@@ -3540,7 +3540,7 @@ WRAPPER_CLASS(OmpLocatorList, std::list<OmpLocator>);
struct OmpMapperSpecifier {
// Absent mapper-identifier is equivalent to DEFAULT.
TUPLE_CLASS_BOILERPLATE(OmpMapperSpecifier);
std::tuple<std::optional<Name>, TypeSpec, Name> t;
std::tuple<std::string, TypeSpec, Name> t;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we haven't figured out a better alternative for this situation, I think we can hopefully get away with making this PFT node owner of the identifier name for now.

However, this is a departure from the convention of only storing references to externally allocated strings, so I'd suggest adding a comment here to explain why we're making an exception (i.e. this identifier might be compiler-generated and not part of the source).

}

// Make sure we don't return a mapper to self
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Make sure we don't return a mapper to self
// Make sure we don't return a mapper to self.

@@ -1389,8 +1389,28 @@ TYPE_PARSER(
TYPE_PARSER(sourced(construct<OpenMPDeclareTargetConstruct>(
verbatim("DECLARE TARGET"_tok), Parser<OmpDeclareTargetSpecifier>{})))

static OmpMapperSpecifier ConstructOmpMapperSpecifier(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: This name seems to follow other existing similar functions better, IMO. In any case, the first letter should be lowercase.

Suggested change
static OmpMapperSpecifier ConstructOmpMapperSpecifier(
static OmpMapperSpecifier makeMapperSpecifier(

@@ -149,7 +150,7 @@ subroutine declare_mapper_4
integer :: num
end type

!CHECK: omp.declare_mapper @[[MY_TYPE_MAPPER:_QQFdeclare_mapper_4my_type.default]] : [[MY_TYPE:!fir\.type<_QFdeclare_mapper_4Tmy_type\{num:i32\}>]]
!CHECK: omp.declare_mapper @[[MY_TYPE_MAPPER:_QQFdeclare_mapper_4my_type.omp.default.mapper]] : [[MY_TYPE:!fir\.type<_QFdeclare_mapper_4Tmy_type\{num:i32\}>]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
!CHECK: omp.declare_mapper @[[MY_TYPE_MAPPER:_QQFdeclare_mapper_4my_type.omp.default.mapper]] : [[MY_TYPE:!fir\.type<_QFdeclare_mapper_4Tmy_type\{num:i32\}>]]
!CHECK: omp.declare_mapper @[[MY_TYPE_MAPPER:_QQFdeclare_mapper_4my_type\.omp\.default\.mapper]] : [[MY_TYPE:!fir\.type<_QFdeclare_mapper_4Tmy_type\{num:i32\}>]]

end type

!CHECK: omp.declare_mapper @[[INNER_MAPPER_NAMED:_QQFFuse_innermy_mapper]] : [[MY_TYPE:!fir\.type<_QFTmytype\{x:i32,y:i32\}>]]
!CHECK: omp.declare_mapper @[[INNER_MAPPER_DEFAULT:_QQFFuse_innermytype.omp.default.mapper]] : [[MY_TYPE]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
!CHECK: omp.declare_mapper @[[INNER_MAPPER_DEFAULT:_QQFFuse_innermytype.omp.default.mapper]] : [[MY_TYPE]]
!CHECK: omp.declare_mapper @[[INNER_MAPPER_DEFAULT:_QQFFuse_innermytype\.omp\.default\.mapper]] : [[MY_TYPE]]

Comment on lines +2427 to +2432
if (converter.getModuleOp().lookupSymbol(mapperIdName))
mapperId = mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(),
mapperIdName);
else
mapperId = genImplicitDefaultDeclareMapper(converter, loc, recType,
mapperIdName);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Since both calls to the new function are wrapped inside of the same logic, perhaps genImplicitDefaultDeclareMapper could be getOrGenImplicitDefaultDeclareMapper and do the lookup inside of the function. That way we ensure we won't create duplicated DeclareMapperOps, if we decide to call this function from other spots in the future.

@@ -2348,6 +2348,122 @@ genSingleOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
queue, item, clauseOps);
}

static mlir::FlatSymbolRefAttr
genImplicitDefaultDeclareMapper(lower::AbstractConverter &converter,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this exclusively needed for implicitly mapped variables into target regions or should it be called for explicit map clauses for derived types when they don't specify any mapper? In particular, in target, declare mapper, target enter data and target exit data directives.

Would we need this for any other map-like clause, like use_device_ptr, use_device_addr or has_device_addr?

llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
mlir::omp::VariableCaptureKind captureKind =
mlir::omp::VariableCaptureKind::ByRef;
int64_t index = 0;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Use llvm::enumerate instead of an outside index variable. That way we avoid letting it go out-of-sync in the future if we introduce any continue statements inside of the loop body.

const auto getFieldRef =
[&](mlir::Value rec,
unsigned index) -> std::tuple<mlir::Value, mlir::Type> {
auto recType = mlir::dyn_cast<fir::RecordType>(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's assumed that recType != nullptr below, so either assert this is the case by using cast or we need to handle that situation.

Suggested change
auto recType = mlir::dyn_cast<fir::RecordType>(
auto recType = mlir::cast<fir::RecordType>(

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is recType not going to always match the recordType argument? This seems quite a complicated implementation if that's the case. [fieldName, fieldTy] would match [memberName, memberType] of the caller below.

Let me know if I'm missing something here.

llvm::SmallVector<mlir::Value> bounds;
genBoundsOps(ref, bounds);
mlir::Value mapOp = createMapInfoOp(
firOpBuilder, loc, ref, /*varPtrPtr=*/mlir::Value{}, "", bounds,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
firOpBuilder, loc, ref, /*varPtrPtr=*/mlir::Value{}, "", bounds,
firOpBuilder, loc, ref, /*varPtrPtr=*/mlir::Value{}, /*name=*/"", bounds,

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Could the symbol name be passed to this function, so that the name here and below could be populated? I would imagine it might be populated as "symName + "%" + memberName" here (which could also be the symName passed to the recursive self call above) and just symName below.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang:openmp flang:semantics flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants