Skip to content

[LSR] Clean up code using SCEVPatternMatch (NFC) #145556

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 28, 2025
Merged

Conversation

artagnon
Copy link
Contributor

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Jun 24, 2025

@llvm/pr-subscribers-llvm-analysis

@llvm/pr-subscribers-llvm-transforms

Author: Ramkumar Ramachandra (artagnon)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/145556.diff

1 Files Affected:

  • (modified) llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp (+48-63)
diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
index 4ba69034d6448..9ffcfb47b0a89 100644
--- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
@@ -923,10 +923,12 @@ static const SCEV *getExactSDiv(const SCEV *LHS, const SCEV *RHS,
 /// If S involves the addition of a constant integer value, return that integer
 /// value, and mutate S to point to a new SCEV with that value excluded.
 static Immediate ExtractImmediate(const SCEV *&S, ScalarEvolution &SE) {
-  if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) {
-    if (C->getAPInt().getSignificantBits() <= 64) {
-      S = SE.getConstant(C->getType(), 0);
-      return Immediate::getFixed(C->getValue()->getSExtValue());
+  const APInt *C;
+  const SCEV *Op1;
+  if (match(S, m_scev_APInt(C))) {
+    if (C->getSignificantBits() <= 64) {
+      S = SE.getConstant(S->getType(), 0);
+      return Immediate::getFixed(C->getSExtValue());
     }
   } else if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) {
     SmallVector<const SCEV *, 8> NewOps(Add->operands());
@@ -942,14 +944,11 @@ static Immediate ExtractImmediate(const SCEV *&S, ScalarEvolution &SE) {
                            // FIXME: AR->getNoWrapFlags(SCEV::FlagNW)
                            SCEV::FlagAnyWrap);
     return Result;
-  } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S)) {
-    if (EnableVScaleImmediates && M->getNumOperands() == 2) {
-      if (const SCEVConstant *C = dyn_cast<SCEVConstant>(M->getOperand(0)))
-        if (isa<SCEVVScale>(M->getOperand(1))) {
-          S = SE.getConstant(M->getType(), 0);
-          return Immediate::getScalable(C->getValue()->getSExtValue());
-        }
-    }
+  } else if (EnableVScaleImmediates &&
+             match(S, m_scev_Mul(m_scev_APInt(C), m_SCEV(Op1))) &&
+             isa<SCEVVScale>(Op1)) {
+    S = SE.getConstant(S->getType(), 0);
+    return Immediate::getScalable(C->getSExtValue());
   }
   return Immediate::getZero();
 }
@@ -1133,23 +1132,22 @@ static bool isHighCostExpansion(const SCEV *S,
     return false;
   }
 
-  if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) {
-    if (Mul->getNumOperands() == 2) {
-      // Multiplication by a constant is ok
-      if (isa<SCEVConstant>(Mul->getOperand(0)))
-        return isHighCostExpansion(Mul->getOperand(1), Processed, SE);
-
-      // If we have the value of one operand, check if an existing
-      // multiplication already generates this expression.
-      if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(Mul->getOperand(1))) {
-        Value *UVal = U->getValue();
-        for (User *UR : UVal->users()) {
-          // If U is a constant, it may be used by a ConstantExpr.
-          Instruction *UI = dyn_cast<Instruction>(UR);
-          if (UI && UI->getOpcode() == Instruction::Mul &&
-              SE.isSCEVable(UI->getType())) {
-            return SE.getSCEV(UI) == Mul;
-          }
+  const SCEV *Op0, *Op1;
+  if (match(S, m_scev_Mul(m_SCEV(Op0), m_SCEV(Op1)))) {
+    // Multiplication by a constant is ok
+    if (isa<SCEVConstant>(Op0))
+      return isHighCostExpansion(Op1, Processed, SE);
+
+    // If we have the value of one operand, check if an existing
+    // multiplication already generates this expression.
+    if (const auto *U = dyn_cast<SCEVUnknown>(Op1)) {
+      Value *UVal = U->getValue();
+      for (User *UR : UVal->users()) {
+        // If U is a constant, it may be used by a ConstantExpr.
+        Instruction *UI = dyn_cast<Instruction>(UR);
+        if (UI && UI->getOpcode() == Instruction::Mul &&
+            SE.isSCEVable(UI->getType())) {
+          return SE.getSCEV(UI) == S;
         }
       }
     }
@@ -3333,14 +3331,12 @@ static bool canFoldIVIncExpr(const SCEV *IncExpr, Instruction *UserInst,
     IncOffset = Immediate::getFixed(IncConst->getValue()->getSExtValue());
   } else {
     // Look for mul(vscale, constant), to detect a scalable offset.
-    auto *IncVScale = dyn_cast<SCEVMulExpr>(IncExpr);
-    if (!IncVScale || IncVScale->getNumOperands() != 2 ||
-        !isa<SCEVVScale>(IncVScale->getOperand(1)))
-      return false;
-    auto *Scale = dyn_cast<SCEVConstant>(IncVScale->getOperand(0));
-    if (!Scale || Scale->getType()->getScalarSizeInBits() > 64)
+    const APInt *C;
+    const SCEV *Op1;
+    if (!match(IncExpr, m_scev_Mul(m_scev_APInt(C), m_SCEV(Op1))) ||
+        !isa<SCEVVScale>(Op1) || C->getSignificantBits() > 64)
       return false;
-    IncOffset = Immediate::getScalable(Scale->getValue()->getSExtValue());
+    IncOffset = Immediate::getScalable(C->getSExtValue());
   }
 
   if (!isAddressUse(TTI, UserInst, Operand))
@@ -3818,6 +3814,8 @@ static const SCEV *CollectSubexprs(const SCEV *S, const SCEVConstant *C,
     return nullptr;
   }
   const SCEV *Start, *Step;
+  const SCEVConstant *Op0;
+  const SCEV *Op1;
   if (match(S, m_scev_AffineAddRec(m_SCEV(Start), m_SCEV(Step)))) {
     // Split a non-zero base out of an addrec.
     if (Start->isZero())
@@ -3839,19 +3837,13 @@ static const SCEV *CollectSubexprs(const SCEV *S, const SCEVConstant *C,
                               // FIXME: AR->getNoWrapFlags(SCEV::FlagNW)
                               SCEV::FlagAnyWrap);
     }
-  } else if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) {
+  } else if (match(S, m_scev_Mul(m_SCEVConstant(Op0), m_SCEV(Op1)))) {
     // Break (C * (a + b + c)) into C*a + C*b + C*c.
-    if (Mul->getNumOperands() != 2)
-      return S;
-    if (const SCEVConstant *Op0 =
-        dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
-      C = C ? cast<SCEVConstant>(SE.getMulExpr(C, Op0)) : Op0;
-      const SCEV *Remainder =
-        CollectSubexprs(Mul->getOperand(1), C, Ops, L, SE, Depth+1);
-      if (Remainder)
-        Ops.push_back(SE.getMulExpr(C, Remainder));
-      return nullptr;
-    }
+    C = C ? cast<SCEVConstant>(SE.getMulExpr(C, Op0)) : Op0;
+    const SCEV *Remainder = CollectSubexprs(Op1, C, Ops, L, SE, Depth + 1);
+    if (Remainder)
+      Ops.push_back(SE.getMulExpr(C, Remainder));
+    return nullptr;
   }
   return S;
 }
@@ -6478,13 +6470,10 @@ struct SCEVDbgValueBuilder {
   /// Components of the expression are omitted if they are an identity function.
   /// Chain (non-affine) SCEVs are not supported.
   bool SCEVToValueExpr(const llvm::SCEVAddRecExpr &SAR, ScalarEvolution &SE) {
-    assert(SAR.isAffine() && "Expected affine SCEV");
-    // TODO: Is this check needed?
-    if (isa<SCEVAddRecExpr>(SAR.getStart()))
-      return false;
-
-    const SCEV *Start = SAR.getStart();
-    const SCEV *Stride = SAR.getStepRecurrence(SE);
+    const SCEV *Start, *Stride;
+    [[maybe_unused]] bool Match =
+        match(&SAR, m_scev_AffineAddRec(m_SCEV(Start), m_SCEV(Stride)));
+    assert(Match && "Expected affine SCEV");
 
     // Skip pushing arithmetic noops.
     if (!isIdentityFunction(llvm::dwarf::DW_OP_mul, Stride)) {
@@ -6549,14 +6538,10 @@ struct SCEVDbgValueBuilder {
   /// Components of the expression are omitted if they are an identity function.
   bool SCEVToIterCountExpr(const llvm::SCEVAddRecExpr &SAR,
                            ScalarEvolution &SE) {
-    assert(SAR.isAffine() && "Expected affine SCEV");
-    if (isa<SCEVAddRecExpr>(SAR.getStart())) {
-      LLVM_DEBUG(dbgs() << "scev-salvage: IV SCEV. Unsupported nested AddRec: "
-                        << SAR << '\n');
-      return false;
-    }
-    const SCEV *Start = SAR.getStart();
-    const SCEV *Stride = SAR.getStepRecurrence(SE);
+    const SCEV *Start, *Stride;
+    [[maybe_unused]] bool Match =
+        match(&SAR, m_scev_AffineAddRec(m_SCEV(Start), m_SCEV(Stride)));
+    assert(Match && "Expected affine SCEV");
 
     // Skip pushing arithmetic noops.
     if (!isIdentityFunction(llvm::dwarf::DW_OP_minus, Start)) {

@artagnon artagnon requested a review from fhahn June 26, 2025 08:10
@llvmbot llvmbot added the llvm:analysis Includes value tracking, cost tables and constant folding label Jun 27, 2025
Copy link
Contributor

@fhahn fhahn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks

@artagnon artagnon merged commit 04cd0f2 into llvm:main Jun 28, 2025
7 checks passed
@artagnon artagnon deleted the lsr-scevpm branch June 28, 2025 10:41
rlavaee pushed a commit to rlavaee/llvm-project that referenced this pull request Jul 1, 2025
rlavaee pushed a commit to rlavaee/llvm-project that referenced this pull request Jul 1, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:analysis Includes value tracking, cost tables and constant folding llvm:transforms
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants