Skip to content

Reland "[MLGO][IR2Vec] Integrating IR2Vec with MLInliner (#143479)" #145664

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

svkeerthy
Copy link
Contributor

Relanding #143479 after fixes.

Removed NumberOfFeatures from the FeatureIndex enum as the number of features used depends on whether IR2Vec embeddings are used.

@llvmbot llvmbot added mlgo llvm:analysis Includes value tracking, cost tables and constant folding labels Jun 25, 2025
@llvmbot
Copy link
Member

llvmbot commented Jun 25, 2025

@llvm/pr-subscribers-mlgo

@llvm/pr-subscribers-llvm-analysis

Author: S. VenkataKeerthy (svkeerthy)

Changes

Relanding #143479 after fixes.

Removed NumberOfFeatures from the FeatureIndex enum as the number of features used depends on whether IR2Vec embeddings are used.


Patch is 36.15 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/145664.diff

8 Files Affected:

  • (modified) llvm/include/llvm/Analysis/FunctionPropertiesAnalysis.h (+21-5)
  • (modified) llvm/include/llvm/Analysis/InlineAdvisor.h (+4)
  • (modified) llvm/include/llvm/Analysis/InlineModelFeatureMaps.h (+11-5)
  • (modified) llvm/include/llvm/Analysis/MLInlineAdvisor.h (+1)
  • (modified) llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp (+112-3)
  • (modified) llvm/lib/Analysis/InlineAdvisor.cpp (+29)
  • (modified) llvm/lib/Analysis/MLInlineAdvisor.cpp (+35-3)
  • (modified) llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp (+131-14)
diff --git a/llvm/include/llvm/Analysis/FunctionPropertiesAnalysis.h b/llvm/include/llvm/Analysis/FunctionPropertiesAnalysis.h
index babb6d9d6cf0c..06dbfc35a5294 100644
--- a/llvm/include/llvm/Analysis/FunctionPropertiesAnalysis.h
+++ b/llvm/include/llvm/Analysis/FunctionPropertiesAnalysis.h
@@ -15,6 +15,7 @@
 #define LLVM_ANALYSIS_FUNCTIONPROPERTIESANALYSIS_H
 
 #include "llvm/ADT/DenseSet.h"
+#include "llvm/Analysis/IR2Vec.h"
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/Support/Compiler.h"
@@ -32,17 +33,19 @@ class FunctionPropertiesInfo {
   void updateAggregateStats(const Function &F, const LoopInfo &LI);
   void reIncludeBB(const BasicBlock &BB);
 
+  ir2vec::Embedding FunctionEmbedding = ir2vec::Embedding(0.0);
+  std::optional<ir2vec::Vocab> IR2VecVocab;
+
 public:
   LLVM_ABI static FunctionPropertiesInfo
   getFunctionPropertiesInfo(const Function &F, const DominatorTree &DT,
-                            const LoopInfo &LI);
+                            const LoopInfo &LI,
+                            const IR2VecVocabResult *VocabResult);
 
   LLVM_ABI static FunctionPropertiesInfo
   getFunctionPropertiesInfo(Function &F, FunctionAnalysisManager &FAM);
 
-  bool operator==(const FunctionPropertiesInfo &FPI) const {
-    return std::memcmp(this, &FPI, sizeof(FunctionPropertiesInfo)) == 0;
-  }
+  bool operator==(const FunctionPropertiesInfo &FPI) const;
 
   bool operator!=(const FunctionPropertiesInfo &FPI) const {
     return !(*this == FPI);
@@ -137,6 +140,19 @@ class FunctionPropertiesInfo {
   int64_t CallReturnsVectorPointerCount = 0;
   int64_t CallWithManyArgumentsCount = 0;
   int64_t CallWithPointerArgumentCount = 0;
+
+  const ir2vec::Embedding &getFunctionEmbedding() const {
+    return FunctionEmbedding;
+  }
+
+  const std::optional<ir2vec::Vocab> &getIR2VecVocab() const {
+    return IR2VecVocab;
+  }
+
+  // Helper intended to be useful for unittests
+  void setFunctionEmbeddingForTest(const ir2vec::Embedding &Embedding) {
+    FunctionEmbedding = Embedding;
+  }
 };
 
 // Analysis pass
@@ -192,7 +208,7 @@ class FunctionPropertiesUpdater {
 
   DominatorTree &getUpdatedDominatorTree(FunctionAnalysisManager &FAM) const;
 
-  DenseSet<const BasicBlock *> Successors;
+  DenseSet<const BasicBlock *> Successors, CallUsers;
 
   // Edges we might potentially need to remove from the dominator tree.
   SmallVector<DominatorTree::UpdateType, 2> DomTreeUpdates;
diff --git a/llvm/include/llvm/Analysis/InlineAdvisor.h b/llvm/include/llvm/Analysis/InlineAdvisor.h
index 9d15136e81d10..50ba3c13da70f 100644
--- a/llvm/include/llvm/Analysis/InlineAdvisor.h
+++ b/llvm/include/llvm/Analysis/InlineAdvisor.h
@@ -331,6 +331,10 @@ class InlineAdvisorAnalysis : public AnalysisInfoMixin<InlineAdvisorAnalysis> {
   };
 
   Result run(Module &M, ModuleAnalysisManager &MAM) { return Result(M, MAM); }
+
+private:
+  static bool initializeIR2VecVocabIfRequested(Module &M,
+                                               ModuleAnalysisManager &MAM);
 };
 
 /// Printer pass for the InlineAdvisorAnalysis results.
diff --git a/llvm/include/llvm/Analysis/InlineModelFeatureMaps.h b/llvm/include/llvm/Analysis/InlineModelFeatureMaps.h
index 961d5091bf9f3..25a35df3efe2c 100644
--- a/llvm/include/llvm/Analysis/InlineModelFeatureMaps.h
+++ b/llvm/include/llvm/Analysis/InlineModelFeatureMaps.h
@@ -132,6 +132,11 @@ constexpr bool isHeuristicInlineCostFeature(InlineCostFeatureIndex Feature) {
     "not "                                                                     \
     "fully inlined by ElimAvailExtern)")
 
+// Not all features listed in FeatureIndex are used by the ML model.
+// Specifically, callee_embedding and caller_embedding are used only when the
+// usage of IR2Vec embeddings is explicitly enabled. Meaning, the size/number of
+// features is not static. So, we cannot determine number of features based on
+// the number of elements in this enum.
 // clang-format off
 enum class FeatureIndex : size_t {
 #define POPULATE_INDICES(DTYPE, SHAPE, NAME, COMMENT) NAME,
@@ -142,7 +147,11 @@ enum class FeatureIndex : size_t {
   INLINE_FEATURE_ITERATOR(POPULATE_INDICES)
 #undef POPULATE_INDICES
 
-  NumberOfFeatures
+// IR2Vec embeddings
+// Dimensions of embeddings are not known in the compile time (until vocab is 
+// read). Hence macros cannot be used here.
+  callee_embedding,
+  caller_embedding
 };
 // clang-format on
 
@@ -151,10 +160,7 @@ inlineCostFeatureToMlFeature(InlineCostFeatureIndex Feature) {
   return static_cast<FeatureIndex>(static_cast<size_t>(Feature));
 }
 
-constexpr size_t NumberOfFeatures =
-    static_cast<size_t>(FeatureIndex::NumberOfFeatures);
-
-LLVM_ABI extern const std::vector<TensorSpec> FeatureMap;
+LLVM_ABI extern std::vector<TensorSpec> FeatureMap;
 
 LLVM_ABI extern const char *const DecisionName;
 LLVM_ABI extern const TensorSpec InlineDecisionSpec;
diff --git a/llvm/include/llvm/Analysis/MLInlineAdvisor.h b/llvm/include/llvm/Analysis/MLInlineAdvisor.h
index 580dd5e95d760..8262dd0846ede 100644
--- a/llvm/include/llvm/Analysis/MLInlineAdvisor.h
+++ b/llvm/include/llvm/Analysis/MLInlineAdvisor.h
@@ -82,6 +82,7 @@ class MLInlineAdvisor : public InlineAdvisor {
   int64_t NodeCount = 0;
   int64_t EdgeCount = 0;
   int64_t EdgesOfLastSeenNodes = 0;
+  const bool UseIR2Vec;
 
   std::map<const LazyCallGraph::Node *, unsigned> FunctionLevels;
   const int32_t InitialIRSize = 0;
diff --git a/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp b/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp
index 9d044c8a35910..29d3aaf46dc06 100644
--- a/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp
+++ b/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp
@@ -199,6 +199,29 @@ void FunctionPropertiesInfo::updateForBB(const BasicBlock &BB,
 #undef CHECK_OPERAND
     }
   }
+
+  if (IR2VecVocab) {
+    // We instantiate the IR2Vec embedder each time, as having an unique
+    // pointer to the embedder as member of the class would make it
+    // non-copyable. Instantiating the embedder in itself is not costly.
+    auto EmbOrErr = ir2vec::Embedder::create(IR2VecKind::Symbolic,
+                                             *BB.getParent(), *IR2VecVocab);
+    if (Error Err = EmbOrErr.takeError()) {
+      handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
+        BB.getContext().emitError("Error creating IR2Vec embeddings: " +
+                                  EI.message());
+      });
+      return;
+    }
+    auto Embedder = std::move(*EmbOrErr);
+    const auto &BBEmbedding = Embedder->getBBVector(BB);
+    // Subtract BBEmbedding from Function embedding if the direction is -1,
+    // and add it if the direction is +1.
+    if (Direction == -1)
+      FunctionEmbedding -= BBEmbedding;
+    else
+      FunctionEmbedding += BBEmbedding;
+  }
 }
 
 void FunctionPropertiesInfo::updateAggregateStats(const Function &F,
@@ -220,14 +243,24 @@ void FunctionPropertiesInfo::updateAggregateStats(const Function &F,
 
 FunctionPropertiesInfo FunctionPropertiesInfo::getFunctionPropertiesInfo(
     Function &F, FunctionAnalysisManager &FAM) {
+  // We use the cached result of the IR2VecVocabAnalysis run by
+  // InlineAdvisorAnalysis. If the IR2VecVocabAnalysis is not run, we don't
+  // use IR2Vec embeddings.
+  auto VocabResult = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F)
+                         .getCachedResult<IR2VecVocabAnalysis>(*F.getParent());
   return getFunctionPropertiesInfo(F, FAM.getResult<DominatorTreeAnalysis>(F),
-                                   FAM.getResult<LoopAnalysis>(F));
+                                   FAM.getResult<LoopAnalysis>(F), VocabResult);
 }
 
 FunctionPropertiesInfo FunctionPropertiesInfo::getFunctionPropertiesInfo(
-    const Function &F, const DominatorTree &DT, const LoopInfo &LI) {
+    const Function &F, const DominatorTree &DT, const LoopInfo &LI,
+    const IR2VecVocabResult *VocabResult) {
 
   FunctionPropertiesInfo FPI;
+  if (VocabResult && VocabResult->isValid()) {
+    FPI.IR2VecVocab = VocabResult->getVocabulary();
+    FPI.FunctionEmbedding = ir2vec::Embedding(VocabResult->getDimension(), 0.0);
+  }
   for (const auto &BB : F)
     if (DT.isReachableFromEntry(&BB))
       FPI.reIncludeBB(BB);
@@ -235,6 +268,66 @@ FunctionPropertiesInfo FunctionPropertiesInfo::getFunctionPropertiesInfo(
   return FPI;
 }
 
+bool FunctionPropertiesInfo::operator==(
+    const FunctionPropertiesInfo &FPI) const {
+  if (BasicBlockCount != FPI.BasicBlockCount ||
+      BlocksReachedFromConditionalInstruction !=
+          FPI.BlocksReachedFromConditionalInstruction ||
+      Uses != FPI.Uses ||
+      DirectCallsToDefinedFunctions != FPI.DirectCallsToDefinedFunctions ||
+      LoadInstCount != FPI.LoadInstCount ||
+      StoreInstCount != FPI.StoreInstCount ||
+      MaxLoopDepth != FPI.MaxLoopDepth ||
+      TopLevelLoopCount != FPI.TopLevelLoopCount ||
+      TotalInstructionCount != FPI.TotalInstructionCount ||
+      BasicBlocksWithSingleSuccessor != FPI.BasicBlocksWithSingleSuccessor ||
+      BasicBlocksWithTwoSuccessors != FPI.BasicBlocksWithTwoSuccessors ||
+      BasicBlocksWithMoreThanTwoSuccessors !=
+          FPI.BasicBlocksWithMoreThanTwoSuccessors ||
+      BasicBlocksWithSinglePredecessor !=
+          FPI.BasicBlocksWithSinglePredecessor ||
+      BasicBlocksWithTwoPredecessors != FPI.BasicBlocksWithTwoPredecessors ||
+      BasicBlocksWithMoreThanTwoPredecessors !=
+          FPI.BasicBlocksWithMoreThanTwoPredecessors ||
+      BigBasicBlocks != FPI.BigBasicBlocks ||
+      MediumBasicBlocks != FPI.MediumBasicBlocks ||
+      SmallBasicBlocks != FPI.SmallBasicBlocks ||
+      CastInstructionCount != FPI.CastInstructionCount ||
+      FloatingPointInstructionCount != FPI.FloatingPointInstructionCount ||
+      IntegerInstructionCount != FPI.IntegerInstructionCount ||
+      ConstantIntOperandCount != FPI.ConstantIntOperandCount ||
+      ConstantFPOperandCount != FPI.ConstantFPOperandCount ||
+      ConstantOperandCount != FPI.ConstantOperandCount ||
+      InstructionOperandCount != FPI.InstructionOperandCount ||
+      BasicBlockOperandCount != FPI.BasicBlockOperandCount ||
+      GlobalValueOperandCount != FPI.GlobalValueOperandCount ||
+      InlineAsmOperandCount != FPI.InlineAsmOperandCount ||
+      ArgumentOperandCount != FPI.ArgumentOperandCount ||
+      UnknownOperandCount != FPI.UnknownOperandCount ||
+      CriticalEdgeCount != FPI.CriticalEdgeCount ||
+      ControlFlowEdgeCount != FPI.ControlFlowEdgeCount ||
+      UnconditionalBranchCount != FPI.UnconditionalBranchCount ||
+      IntrinsicCount != FPI.IntrinsicCount ||
+      DirectCallCount != FPI.DirectCallCount ||
+      IndirectCallCount != FPI.IndirectCallCount ||
+      CallReturnsIntegerCount != FPI.CallReturnsIntegerCount ||
+      CallReturnsFloatCount != FPI.CallReturnsFloatCount ||
+      CallReturnsPointerCount != FPI.CallReturnsPointerCount ||
+      CallReturnsVectorIntCount != FPI.CallReturnsVectorIntCount ||
+      CallReturnsVectorFloatCount != FPI.CallReturnsVectorFloatCount ||
+      CallReturnsVectorPointerCount != FPI.CallReturnsVectorPointerCount ||
+      CallWithManyArgumentsCount != FPI.CallWithManyArgumentsCount ||
+      CallWithPointerArgumentCount != FPI.CallWithPointerArgumentCount) {
+    return false;
+  }
+  // Check the equality of the function embeddings. We don't check the equality
+  // of Vocabulary as it remains the same.
+  if (!FunctionEmbedding.approximatelyEquals(FPI.FunctionEmbedding))
+    return false;
+
+  return true;
+}
+
 void FunctionPropertiesInfo::print(raw_ostream &OS) const {
 #define PRINT_PROPERTY(PROP_NAME) OS << #PROP_NAME ": " << PROP_NAME << "\n";
 
@@ -322,6 +415,16 @@ FunctionPropertiesUpdater::FunctionPropertiesUpdater(
   // The caller's entry BB may change due to new alloca instructions.
   LikelyToChangeBBs.insert(&*Caller.begin());
 
+  // The users of the value returned by call instruction can change
+  // leading to the change in embeddings being computed, when used.
+  // We conservatively add the BBs with such uses to LikelyToChangeBBs.
+  for (const auto *User : CB.users())
+    CallUsers.insert(dyn_cast<Instruction>(User)->getParent());
+  // CallSiteBB can be removed from CallUsers if present, it's taken care
+  // separately.
+  CallUsers.erase(&CallSiteBB);
+  LikelyToChangeBBs.insert_range(CallUsers);
+
   // The successors may become unreachable in the case of `invoke` inlining.
   // We track successors separately, too, because they form a boundary, together
   // with the CB BB ('Entry') between which the inlined callee will be pasted.
@@ -435,6 +538,9 @@ void FunctionPropertiesUpdater::finish(FunctionAnalysisManager &FAM) const {
   if (&CallSiteBB != &*Caller.begin())
     Reinclude.insert(&*Caller.begin());
 
+  // Reinclude the BBs which use the values returned by call instruction
+  Reinclude.insert_range(CallUsers);
+
   // Distribute the successors to the 2 buckets.
   for (const auto *Succ : Successors)
     if (DT.isReachableFromEntry(Succ))
@@ -486,6 +592,9 @@ bool FunctionPropertiesUpdater::isUpdateValid(Function &F,
     return false;
   DominatorTree DT(F);
   LoopInfo LI(DT);
-  auto Fresh = FunctionPropertiesInfo::getFunctionPropertiesInfo(F, DT, LI);
+  auto VocabResult = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F)
+                         .getCachedResult<IR2VecVocabAnalysis>(*F.getParent());
+  auto Fresh =
+      FunctionPropertiesInfo::getFunctionPropertiesInfo(F, DT, LI, VocabResult);
   return FPI == Fresh;
 }
diff --git a/llvm/lib/Analysis/InlineAdvisor.cpp b/llvm/lib/Analysis/InlineAdvisor.cpp
index 3d30f3d10a9d0..28b14c2562df1 100644
--- a/llvm/lib/Analysis/InlineAdvisor.cpp
+++ b/llvm/lib/Analysis/InlineAdvisor.cpp
@@ -16,6 +16,7 @@
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Analysis/AssumptionCache.h"
 #include "llvm/Analysis/EphemeralValuesCache.h"
+#include "llvm/Analysis/IR2Vec.h"
 #include "llvm/Analysis/InlineCost.h"
 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
 #include "llvm/Analysis/ProfileSummaryInfo.h"
@@ -64,6 +65,13 @@ static cl::opt<bool>
                         cl::desc("If true, annotate inline advisor remarks "
                                  "with LTO and pass information."));
 
+// This flag is used to enable IR2Vec embeddings in the ML inliner; Only valid
+// with ML inliner. The vocab file is used to initialize the embeddings.
+static cl::opt<std::string> IR2VecVocabFile(
+    "ml-inliner-ir2vec-vocab-file", cl::Hidden,
+    cl::desc("Vocab file for IR2Vec; Setting this enables "
+             "configuring the model to use IR2Vec embeddings."));
+
 namespace llvm {
 extern cl::opt<InlinerFunctionImportStatsOpts> InlinerFunctionImportStats;
 } // namespace llvm
@@ -206,6 +214,20 @@ void InlineAdvice::recordInliningWithCalleeDeleted() {
 AnalysisKey InlineAdvisorAnalysis::Key;
 AnalysisKey PluginInlineAdvisorAnalysis::Key;
 
+bool InlineAdvisorAnalysis::initializeIR2VecVocabIfRequested(
+    Module &M, ModuleAnalysisManager &MAM) {
+  if (!IR2VecVocabFile.empty()) {
+    auto IR2VecVocabResult = MAM.getResult<IR2VecVocabAnalysis>(M);
+    if (!IR2VecVocabResult.isValid()) {
+      M.getContext().emitError("Failed to load IR2Vec vocabulary");
+      return false;
+    }
+  }
+  // No vocab file specified is OK; We just don't use IR2Vec
+  // embeddings.
+  return true;
+}
+
 bool InlineAdvisorAnalysis::Result::tryCreate(
     InlineParams Params, InliningAdvisorMode Mode,
     const ReplayInlinerSettings &ReplaySettings, InlineContext IC) {
@@ -231,14 +253,21 @@ bool InlineAdvisorAnalysis::Result::tryCreate(
                                              /* EmitRemarks =*/true, IC);
     }
     break;
+    // Run IR2VecVocabAnalysis once per module to get the vocabulary.
+    // We run it here because it is immutable and we want to avoid running it
+    // multiple times.
   case InliningAdvisorMode::Development:
 #ifdef LLVM_HAVE_TFLITE
     LLVM_DEBUG(dbgs() << "Using development-mode inliner policy.\n");
+    if (!InlineAdvisorAnalysis::initializeIR2VecVocabIfRequested(M, MAM))
+      return false;
     Advisor = llvm::getDevelopmentModeAdvisor(M, MAM, GetDefaultAdvice);
 #endif
     break;
   case InliningAdvisorMode::Release:
     LLVM_DEBUG(dbgs() << "Using release-mode inliner policy.\n");
+    if (!InlineAdvisorAnalysis::initializeIR2VecVocabIfRequested(M, MAM))
+      return false;
     Advisor = llvm::getReleaseModeAdvisor(M, MAM, GetDefaultAdvice);
     break;
   }
diff --git a/llvm/lib/Analysis/MLInlineAdvisor.cpp b/llvm/lib/Analysis/MLInlineAdvisor.cpp
index 81a3bc94a6ad8..b1a5040b68b75 100644
--- a/llvm/lib/Analysis/MLInlineAdvisor.cpp
+++ b/llvm/lib/Analysis/MLInlineAdvisor.cpp
@@ -107,7 +107,7 @@ static cl::opt<bool> KeepFPICache(
     cl::init(false));
 
 // clang-format off
-const std::vector<TensorSpec> llvm::FeatureMap{
+std::vector<TensorSpec> llvm::FeatureMap{
 #define POPULATE_NAMES(DTYPE, SHAPE, NAME, __) TensorSpec::createSpec<DTYPE>(#NAME, SHAPE),
 // InlineCost features - these must come first
   INLINE_COST_FEATURE_ITERATOR(POPULATE_NAMES)
@@ -144,6 +144,7 @@ MLInlineAdvisor::MLInlineAdvisor(
           M, MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager()),
       ModelRunner(std::move(Runner)), GetDefaultAdvice(GetDefaultAdvice),
       CG(MAM.getResult<LazyCallGraphAnalysis>(M)),
+      UseIR2Vec(MAM.getCachedResult<IR2VecVocabAnalysis>(M) != nullptr),
       InitialIRSize(getModuleIRSize()), CurrentIRSize(InitialIRSize),
       PSI(MAM.getResult<ProfileSummaryAnalysis>(M)) {
   assert(ModelRunner);
@@ -186,6 +187,19 @@ MLInlineAdvisor::MLInlineAdvisor(
     EdgeCount += getLocalCalls(KVP.first->getFunction());
   }
   NodeCount = AllNodes.size();
+
+  if (auto IR2VecVocabResult = MAM.getCachedResult<IR2VecVocabAnalysis>(M)) {
+    if (!IR2VecVocabResult->isValid()) {
+      M.getContext().emitError("IR2VecVocabAnalysis is not valid");
+      return;
+    }
+    // Add the IR2Vec features to the feature map
+    auto IR2VecDim = IR2VecVocabResult->getDimension();
+    FeatureMap.push_back(
+        TensorSpec::createSpec<float>("callee_embedding", {IR2VecDim}));
+    FeatureMap.push_back(
+        TensorSpec::createSpec<float>("caller_embedding", {IR2VecDim}));
+  }
 }
 
 unsigned MLInlineAdvisor::getInitialFunctionLevel(const Function &F) const {
@@ -433,6 +447,24 @@ std::unique_ptr<InlineAdvice> MLInlineAdvisor::getAdviceImpl(CallBase &CB) {
   *ModelRunner->getTensor<int64_t>(FeatureIndex::is_caller_avail_external) =
       Caller.hasAvailableExternallyLinkage();
 
+  if (UseIR2Vec) {
+    // Python side expects float embeddings. The IR2Vec embeddings are doubles
+    // as of now due to the restriction of fromJSON method used by the
+    // readVocabulary method in ir2vec::Embeddings.
+    auto setEmbedding = [&](const ir2vec::Embedding &Embedding,
+                            FeatureIndex Index) {
+      auto Embedding_float =
+          std::vector<float>(Embedding.begin(), Embedding.end());
+      std::memcpy(ModelRunner->getTensor<float>(Index), Embedding_float.data(),
+                  Embedding.size() * sizeof(float));
+    };
+
+    setEmbedding(CalleeBefore.getFunctionEmbedding(),
+                 FeatureIndex::callee_embedding);
+    setEmbedding(CallerBefore.getFunctionEmbedding(),
+                 FeatureIndex::caller_embedding);
+  }
+
   // Add the cost features
   for (size_t I = 0;
        I < static_cast<size_t>(InlineCostFeatureIndex::NumberOfFeatures); ++I) {
@@ -441,7 +473,7 @@ std::unique_ptr<InlineAdvice> MLInlineAdvisor::getAdviceImpl(CallBase &CB) {
   }
   // This one would have been set up to be right at the end.
   if (!InteractiveChannelBaseName.empty() && InteractiveIncludeDefault)
-    *ModelRunner->getTensor<int64_t>(FeatureIndex::NumberOfFeatures) =
+    *ModelRunner->getTensor<int64_t>((int)FeatureMap.size()) =
         GetDefaultAdvice(CB);
   return getAdviceFromModel(CB, ORE);
 }
@@ -520,7 +552,7 @@ void MLInlineAdvice::reportContextForRemark(
     DiagnosticInfoOptimizationBase &OR) {
   using namespace ore;
   OR << NV("Callee", Callee->getName());
-  for (size_t I = 0; I < NumberOfFeatures; ++I)
+  for (size_t I = 0; I < FeatureMap.size(); ++I)
     OR << NV(FeatureMap[I].name(),
              *getAdvisor()->getModelRunner().getTensor<int64_t>(I));
   OR << NV("ShouldInline", isInliningRecommended());
diff --git a/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp b/ll...
[truncated]

@svkeerthy svkeerthy requested a review from mtrofin June 25, 2025 10:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:analysis Includes value tracking, cost tables and constant folding mlgo
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants