Skip to content

[VectorCombine] Generalize foldBitOpOfBitcasts to support more cast operations #148350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

rhyadav
Copy link

@rhyadav rhyadav commented Jul 12, 2025

This patch generalizes the existing foldBitOpOfBitcasts optimization in the VectorCombine pass to handle
additional cast operations beyond just bitcast.

Fixes: #146037

Summary

The optimization now supports folding bitwise operations (AND/OR/XOR) with the following cast operations:

  • bitcast (original functionality)
  • trunc (truncate)
  • sext (sign extend)
  • zext (zero extend)

The transformation pattern is:
bitop(castop(x), castop(y)) -> castop(bitop(x, y))

This reduces the number of cast instructions from 2 to 1, improving performance on targets where cast operations
are expensive or where performing bitwise operations on narrower types is beneficial.

Implementation Details

  • Renamed foldBitOpOfBitcasts to foldBitOpOfCastops to reflect broader functionality
  • Extended pattern matching to handle any CastInst operation
  • Added validation for each cast type's constraints (e.g., trunc requires source > dest)
  • Updated cost model to use the actual cast opcode
  • Preserves IR flags from original instructions
  • Handles multi-use scenarios appropriately

Testing

  • Added comprehensive tests in test/Transforms/VectorCombine/bitop-of-castops.ll
  • Tests cover all supported cast types with all bitwise operations
  • Includes negative tests for unsupported patterns
  • All existing VectorCombine tests pass

…perations

This patch generalizes the foldBitOpOfBitcasts function (renamed to
foldBitOpOfCastops) to handle additional cast operations beyond just
bitcast. The optimization now supports:
- trunc (truncate)
- sext (sign extend)
- zext (zero extend)
- bitcast (original functionality)

The optimization transforms:
  bitop(cast(x), cast(y)) -> cast(bitop(x, y))

This reduces the number of cast instructions from 2 to 1, which can
improve performance on targets where cast operations are expensive or
where performing bitwise operations on narrower types is beneficial.

Changes:
- Renamed foldBitOpOfBitcasts to foldBitOpOfCastops
- Extended pattern matching to handle any CastInst
- Added validation for each cast type's constraints
- Updated cost model to use actual cast opcode
- Added comprehensive tests for all supported cast types

Fixes: llvm#146037
@rhyadav rhyadav marked this pull request as draft July 12, 2025 09:17
Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Member

llvmbot commented Jul 12, 2025

@llvm/pr-subscribers-vectorizers

@llvm/pr-subscribers-llvm-transforms

Author: Rahul Yadav (rhyadav)

Changes

This patch generalizes the existing foldBitOpOfBitcasts optimization in the VectorCombine pass to handle
additional cast operations beyond just bitcast.

Fixes: #146037

Summary

The optimization now supports folding bitwise operations (AND/OR/XOR) with the following cast operations:

  • bitcast (original functionality)
  • trunc (truncate)
  • sext (sign extend)
  • zext (zero extend)

The transformation pattern is:
bitop(castop(x), castop(y)) -> castop(bitop(x, y))

This reduces the number of cast instructions from 2 to 1, improving performance on targets where cast operations
are expensive or where performing bitwise operations on narrower types is beneficial.

Implementation Details

  • Renamed foldBitOpOfBitcasts to foldBitOpOfCastops to reflect broader functionality
  • Extended pattern matching to handle any CastInst operation
  • Added validation for each cast type's constraints (e.g., trunc requires source > dest)
  • Updated cost model to use the actual cast opcode
  • Preserves IR flags from original instructions
  • Handles multi-use scenarios appropriately

Testing

  • Added comprehensive tests in test/Transforms/VectorCombine/bitop-of-castops.ll
  • Tests cover all supported cast types with all bitwise operations
  • Includes negative tests for unsupported patterns
  • All existing VectorCombine tests pass

Full diff: https://github.com/llvm/llvm-project/pull/148350.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/VectorCombine.cpp (+100-34)
  • (added) llvm/test/Transforms/VectorCombine/bitop-of-castops.ll (+263)
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index fe8d74c43dfdc..58aa53694b22e 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -52,9 +52,9 @@ STATISTIC(NumScalarOps, "Number of scalar unary + binary ops formed");
 STATISTIC(NumScalarCmp, "Number of scalar compares formed");
 STATISTIC(NumScalarIntrinsic, "Number of scalar intrinsic calls formed");
 
-static cl::opt<bool> DisableVectorCombine(
-    "disable-vector-combine", cl::init(false), cl::Hidden,
-    cl::desc("Disable all vector combine transforms"));
+static cl::opt<bool>
+    DisableVectorCombine("disable-vector-combine", cl::init(false), cl::Hidden,
+                         cl::desc("Disable all vector combine transforms"));
 
 static cl::opt<bool> DisableBinopExtractShuffle(
     "disable-binop-extract-shuffle", cl::init(false), cl::Hidden,
@@ -115,7 +115,7 @@ class VectorCombine {
   bool foldInsExtFNeg(Instruction &I);
   bool foldInsExtBinop(Instruction &I);
   bool foldInsExtVectorToShuffle(Instruction &I);
-  bool foldBitOpOfBitcasts(Instruction &I);
+  bool foldBitOpOfCastops(Instruction &I);
   bool foldBitcastShuffle(Instruction &I);
   bool scalarizeOpOrCmp(Instruction &I);
   bool scalarizeVPIntrinsic(Instruction &I);
@@ -808,46 +808,105 @@ bool VectorCombine::foldInsExtBinop(Instruction &I) {
   return true;
 }
 
-bool VectorCombine::foldBitOpOfBitcasts(Instruction &I) {
-  // Match: bitop(bitcast(x), bitcast(y)) -> bitcast(bitop(x, y))
-  Value *LHSSrc, *RHSSrc;
-  if (!match(&I, m_BitwiseLogic(m_BitCast(m_Value(LHSSrc)),
-                                m_BitCast(m_Value(RHSSrc)))))
+bool VectorCombine::foldBitOpOfCastops(Instruction &I) {
+  // Match: bitop(castop(x), castop(y)) -> castop(bitop(x, y))
+  // Supports: bitcast, trunc, sext, zext
+
+  // Check if this is a bitwise logic operation
+  auto *BinOp = dyn_cast<BinaryOperator>(&I);
+  if (!BinOp || !BinOp->isBitwiseLogicOp())
+    return false;
+
+  LLVM_DEBUG(dbgs() << "Found bitwise logic op: " << I << "\n");
+
+  // Get the cast instructions
+  auto *LHSCast = dyn_cast<CastInst>(BinOp->getOperand(0));
+  auto *RHSCast = dyn_cast<CastInst>(BinOp->getOperand(1));
+  if (!LHSCast || !RHSCast) {
+    LLVM_DEBUG(dbgs() << "  One or both operands are not cast instructions\n");
+    return false;
+  }
+
+  LLVM_DEBUG(dbgs() << "  LHS cast: " << *LHSCast << "\n");
+  LLVM_DEBUG(dbgs() << "  RHS cast: " << *RHSCast << "\n");
+
+  // Both casts must be the same type
+  Instruction::CastOps CastOpcode = LHSCast->getOpcode();
+  if (CastOpcode != RHSCast->getOpcode())
     return false;
 
+  // Only handle supported cast operations
+  switch (CastOpcode) {
+  case Instruction::BitCast:
+  case Instruction::Trunc:
+  case Instruction::SExt:
+  case Instruction::ZExt:
+    break;
+  default:
+    return false;
+  }
+
+  Value *LHSSrc = LHSCast->getOperand(0);
+  Value *RHSSrc = RHSCast->getOperand(0);
+
   // Source types must match
   if (LHSSrc->getType() != RHSSrc->getType())
     return false;
-  if (!LHSSrc->getType()->getScalarType()->isIntegerTy())
-    return false;
 
-  // Only handle vector types
+  // Only handle vector types with integer elements
   auto *SrcVecTy = dyn_cast<FixedVectorType>(LHSSrc->getType());
   auto *DstVecTy = dyn_cast<FixedVectorType>(I.getType());
   if (!SrcVecTy || !DstVecTy)
     return false;
 
-  // Same total bit width
-  assert(SrcVecTy->getPrimitiveSizeInBits() ==
-             DstVecTy->getPrimitiveSizeInBits() &&
-         "Bitcast should preserve total bit width");
+  if (!SrcVecTy->getScalarType()->isIntegerTy() ||
+      !DstVecTy->getScalarType()->isIntegerTy())
+    return false;
+
+  // Validate cast operation constraints
+  switch (CastOpcode) {
+  case Instruction::BitCast:
+    // Total bit width must be preserved
+    if (SrcVecTy->getPrimitiveSizeInBits() !=
+        DstVecTy->getPrimitiveSizeInBits())
+      return false;
+    break;
+  case Instruction::Trunc:
+    // Source elements must be wider
+    if (SrcVecTy->getScalarSizeInBits() <= DstVecTy->getScalarSizeInBits())
+      return false;
+    break;
+  case Instruction::SExt:
+  case Instruction::ZExt:
+    // Source elements must be narrower
+    if (SrcVecTy->getScalarSizeInBits() >= DstVecTy->getScalarSizeInBits())
+      return false;
+    break;
+  }
 
   // Cost Check :
-  // OldCost = bitlogic + 2*bitcasts
-  // NewCost = bitlogic + bitcast
-  auto *BinOp = cast<BinaryOperator>(&I);
+  // OldCost = bitlogic + 2*casts
+  // NewCost = bitlogic + cast
   InstructionCost OldCost =
       TTI.getArithmeticInstrCost(BinOp->getOpcode(), DstVecTy) +
-      TTI.getCastInstrCost(Instruction::BitCast, DstVecTy, LHSSrc->getType(),
-                           TTI::CastContextHint::None) +
-      TTI.getCastInstrCost(Instruction::BitCast, DstVecTy, RHSSrc->getType(),
-                           TTI::CastContextHint::None);
+      TTI.getCastInstrCost(CastOpcode, DstVecTy, SrcVecTy,
+                           TTI::CastContextHint::None) *
+          2;
+
   InstructionCost NewCost =
       TTI.getArithmeticInstrCost(BinOp->getOpcode(), SrcVecTy) +
-      TTI.getCastInstrCost(Instruction::BitCast, DstVecTy, SrcVecTy,
+      TTI.getCastInstrCost(CastOpcode, DstVecTy, SrcVecTy,
                            TTI::CastContextHint::None);
 
-  LLVM_DEBUG(dbgs() << "Found a bitwise logic op of bitcasted values: " << I
+  // Account for multi-use casts
+  if (!LHSCast->hasOneUse())
+    NewCost += TTI.getCastInstrCost(CastOpcode, DstVecTy, SrcVecTy,
+                                    TTI::CastContextHint::None);
+  if (!RHSCast->hasOneUse())
+    NewCost += TTI.getCastInstrCost(CastOpcode, DstVecTy, SrcVecTy,
+                                    TTI::CastContextHint::None);
+
+  LLVM_DEBUG(dbgs() << "Found bitwise logic op of cast ops: " << I
                     << "\n  OldCost: " << OldCost << " vs NewCost: " << NewCost
                     << "\n");
 
@@ -862,8 +921,15 @@ bool VectorCombine::foldBitOpOfBitcasts(Instruction &I) {
 
   Worklist.pushValue(NewOp);
 
-  // Bitcast the result back
-  Value *Result = Builder.CreateBitCast(NewOp, I.getType());
+  // Create the cast operation
+  Value *Result = Builder.CreateCast(CastOpcode, NewOp, I.getType());
+
+  // Preserve cast instruction flags
+  if (auto *NewCast = dyn_cast<CastInst>(Result)) {
+    NewCast->copyIRFlags(LHSCast);
+    NewCast->andIRFlags(RHSCast);
+  }
+
   replaceValue(I, *Result);
   return true;
 }
@@ -1020,8 +1086,7 @@ bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) {
   InstructionCost OldCost = 2 * SplatCost + VectorOpCost;
 
   // Determine scalar opcode
-  std::optional<unsigned> FunctionalOpcode =
-      VPI.getFunctionalOpcode();
+  std::optional<unsigned> FunctionalOpcode = VPI.getFunctionalOpcode();
   std::optional<Intrinsic::ID> ScalarIntrID = std::nullopt;
   if (!FunctionalOpcode) {
     ScalarIntrID = VPI.getFunctionalIntrinsicID();
@@ -1044,8 +1109,7 @@ bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) {
       (SplatCost * !Op0->hasOneUse()) + (SplatCost * !Op1->hasOneUse());
   InstructionCost NewCost = ScalarOpCost + SplatCost + CostToKeepSplats;
 
-  LLVM_DEBUG(dbgs() << "Found a VP Intrinsic to scalarize: " << VPI
-                    << "\n");
+  LLVM_DEBUG(dbgs() << "Found a VP Intrinsic to scalarize: " << VPI << "\n");
   LLVM_DEBUG(dbgs() << "Cost of Intrinsic: " << OldCost
                     << ", Cost of scalarizing:" << NewCost << "\n");
 
@@ -2015,10 +2079,12 @@ bool VectorCombine::foldPermuteOfBinops(Instruction &I) {
   }
 
   unsigned NumOpElts = Op0Ty->getNumElements();
-  bool IsIdentity0 = ShuffleDstTy == Op0Ty &&
+  bool IsIdentity0 =
+      ShuffleDstTy == Op0Ty &&
       all_of(NewMask0, [NumOpElts](int M) { return M < (int)NumOpElts; }) &&
       ShuffleVectorInst::isIdentityMask(NewMask0, NumOpElts);
-  bool IsIdentity1 = ShuffleDstTy == Op1Ty &&
+  bool IsIdentity1 =
+      ShuffleDstTy == Op1Ty &&
       all_of(NewMask1, [NumOpElts](int M) { return M < (int)NumOpElts; }) &&
       ShuffleVectorInst::isIdentityMask(NewMask1, NumOpElts);
 
@@ -3773,7 +3839,7 @@ bool VectorCombine::run() {
       case Instruction::And:
       case Instruction::Or:
       case Instruction::Xor:
-        MadeChange |= foldBitOpOfBitcasts(I);
+        MadeChange |= foldBitOpOfCastops(I);
         break;
       default:
         MadeChange |= shrinkType(I);
diff --git a/llvm/test/Transforms/VectorCombine/bitop-of-castops.ll b/llvm/test/Transforms/VectorCombine/bitop-of-castops.ll
new file mode 100644
index 0000000000000..003e14bebd169
--- /dev/null
+++ b/llvm/test/Transforms/VectorCombine/bitop-of-castops.ll
@@ -0,0 +1,263 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -passes=vector-combine -S -mtriple=x86_64-- | FileCheck %s
+
+; Test bitwise operations with bitcast
+define <4 x i32> @and_bitcast_v4f32_to_v4i32(<4 x float> %a, <4 x float> %b) {
+; CHECK-LABEL: @and_bitcast_v4f32_to_v4i32(
+; CHECK-NEXT:    [[BC1:%.*]] = bitcast <4 x float> [[A:%.*]] to <4 x i32>
+; CHECK-NEXT:    [[BC2:%.*]] = bitcast <4 x float> [[B:%.*]] to <4 x i32>
+; CHECK-NEXT:    [[AND:%.*]] = and <4 x i32> [[BC1]], [[BC2]]
+; CHECK-NEXT:    ret <4 x i32> [[AND]]
+;
+  %bc1 = bitcast <4 x float> %a to <4 x i32>
+  %bc2 = bitcast <4 x float> %b to <4 x i32>
+  %and = and <4 x i32> %bc1, %bc2
+  ret <4 x i32> %and
+}
+
+define <4 x i32> @or_bitcast_v4f32_to_v4i32(<4 x float> %a, <4 x float> %b) {
+; CHECK-LABEL: @or_bitcast_v4f32_to_v4i32(
+; CHECK-NEXT:    [[BC1:%.*]] = bitcast <4 x float> [[A:%.*]] to <4 x i32>
+; CHECK-NEXT:    [[BC2:%.*]] = bitcast <4 x float> [[B:%.*]] to <4 x i32>
+; CHECK-NEXT:    [[OR:%.*]] = or <4 x i32> [[BC1]], [[BC2]]
+; CHECK-NEXT:    ret <4 x i32> [[OR]]
+;
+  %bc1 = bitcast <4 x float> %a to <4 x i32>
+  %bc2 = bitcast <4 x float> %b to <4 x i32>
+  %or = or <4 x i32> %bc1, %bc2
+  ret <4 x i32> %or
+}
+
+define <4 x i32> @xor_bitcast_v4f32_to_v4i32(<4 x float> %a, <4 x float> %b) {
+; CHECK-LABEL: @xor_bitcast_v4f32_to_v4i32(
+; CHECK-NEXT:    [[BC1:%.*]] = bitcast <4 x float> [[A:%.*]] to <4 x i32>
+; CHECK-NEXT:    [[BC2:%.*]] = bitcast <4 x float> [[B:%.*]] to <4 x i32>
+; CHECK-NEXT:    [[XOR:%.*]] = xor <4 x i32> [[BC1]], [[BC2]]
+; CHECK-NEXT:    ret <4 x i32> [[XOR]]
+;
+  %bc1 = bitcast <4 x float> %a to <4 x i32>
+  %bc2 = bitcast <4 x float> %b to <4 x i32>
+  %xor = xor <4 x i32> %bc1, %bc2
+  ret <4 x i32> %xor
+}
+
+; Test bitwise operations with truncate
+define <4 x i16> @and_trunc_v4i32_to_v4i16(<4 x i32> %a, <4 x i32> %b) {
+; CHECK-LABEL: @and_trunc_v4i32_to_v4i16(
+; CHECK-NEXT:    [[AND_INNER:%.*]] = and <4 x i32> [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT:    [[AND:%.*]] = trunc <4 x i32> [[AND_INNER]] to <4 x i16>
+; CHECK-NEXT:    ret <4 x i16> [[AND]]
+;
+  %t1 = trunc <4 x i32> %a to <4 x i16>
+  %t2 = trunc <4 x i32> %b to <4 x i16>
+  %and = and <4 x i16> %t1, %t2
+  ret <4 x i16> %and
+}
+
+define <8 x i8> @or_trunc_v8i16_to_v8i8(<8 x i16> %a, <8 x i16> %b) {
+; CHECK-LABEL: @or_trunc_v8i16_to_v8i8(
+; CHECK-NEXT:    [[OR_INNER:%.*]] = or <8 x i16> [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT:    [[OR:%.*]] = trunc <8 x i16> [[OR_INNER]] to <8 x i8>
+; CHECK-NEXT:    ret <8 x i8> [[OR]]
+;
+  %t1 = trunc <8 x i16> %a to <8 x i8>
+  %t2 = trunc <8 x i16> %b to <8 x i8>
+  %or = or <8 x i8> %t1, %t2
+  ret <8 x i8> %or
+}
+
+define <2 x i32> @xor_trunc_v2i64_to_v2i32(<2 x i64> %a, <2 x i64> %b) {
+; CHECK-LABEL: @xor_trunc_v2i64_to_v2i32(
+; CHECK-NEXT:    [[XOR_INNER:%.*]] = xor <2 x i64> [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT:    [[XOR:%.*]] = trunc <2 x i64> [[XOR_INNER]] to <2 x i32>
+; CHECK-NEXT:    ret <2 x i32> [[XOR]]
+;
+  %t1 = trunc <2 x i64> %a to <2 x i32>
+  %t2 = trunc <2 x i64> %b to <2 x i32>
+  %xor = xor <2 x i32> %t1, %t2
+  ret <2 x i32> %xor
+}
+
+; Test bitwise operations with zero extend
+define <4 x i32> @and_zext_v4i16_to_v4i32(<4 x i16> %a, <4 x i16> %b) {
+; CHECK-LABEL: @and_zext_v4i16_to_v4i32(
+; CHECK-NEXT:    [[AND_INNER:%.*]] = and <4 x i16> [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT:    [[AND:%.*]] = zext <4 x i16> [[AND_INNER]] to <4 x i32>
+; CHECK-NEXT:    ret <4 x i32> [[AND]]
+;
+  %z1 = zext <4 x i16> %a to <4 x i32>
+  %z2 = zext <4 x i16> %b to <4 x i32>
+  %and = and <4 x i32> %z1, %z2
+  ret <4 x i32> %and
+}
+
+define <8 x i16> @or_zext_v8i8_to_v8i16(<8 x i8> %a, <8 x i8> %b) {
+; CHECK-LABEL: @or_zext_v8i8_to_v8i16(
+; CHECK-NEXT:    [[OR_INNER:%.*]] = or <8 x i8> [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT:    [[OR:%.*]] = zext <8 x i8> [[OR_INNER]] to <8 x i16>
+; CHECK-NEXT:    ret <8 x i16> [[OR]]
+;
+  %z1 = zext <8 x i8> %a to <8 x i16>
+  %z2 = zext <8 x i8> %b to <8 x i16>
+  %or = or <8 x i16> %z1, %z2
+  ret <8 x i16> %or
+}
+
+define <2 x i64> @xor_zext_v2i32_to_v2i64(<2 x i32> %a, <2 x i32> %b) {
+; CHECK-LABEL: @xor_zext_v2i32_to_v2i64(
+; CHECK-NEXT:    [[XOR_INNER:%.*]] = xor <2 x i32> [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT:    [[XOR:%.*]] = zext <2 x i32> [[XOR_INNER]] to <2 x i64>
+; CHECK-NEXT:    ret <2 x i64> [[XOR]]
+;
+  %z1 = zext <2 x i32> %a to <2 x i64>
+  %z2 = zext <2 x i32> %b to <2 x i64>
+  %xor = xor <2 x i64> %z1, %z2
+  ret <2 x i64> %xor
+}
+
+; Test bitwise operations with sign extend
+define <4 x i32> @and_sext_v4i16_to_v4i32(<4 x i16> %a, <4 x i16> %b) {
+; CHECK-LABEL: @and_sext_v4i16_to_v4i32(
+; CHECK-NEXT:    [[AND_INNER:%.*]] = and <4 x i16> [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT:    [[AND:%.*]] = sext <4 x i16> [[AND_INNER]] to <4 x i32>
+; CHECK-NEXT:    ret <4 x i32> [[AND]]
+;
+  %s1 = sext <4 x i16> %a to <4 x i32>
+  %s2 = sext <4 x i16> %b to <4 x i32>
+  %and = and <4 x i32> %s1, %s2
+  ret <4 x i32> %and
+}
+
+define <8 x i16> @or_sext_v8i8_to_v8i16(<8 x i8> %a, <8 x i8> %b) {
+; CHECK-LABEL: @or_sext_v8i8_to_v8i16(
+; CHECK-NEXT:    [[OR_INNER:%.*]] = or <8 x i8> [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT:    [[OR:%.*]] = sext <8 x i8> [[OR_INNER]] to <8 x i16>
+; CHECK-NEXT:    ret <8 x i16> [[OR]]
+;
+  %s1 = sext <8 x i8> %a to <8 x i16>
+  %s2 = sext <8 x i8> %b to <8 x i16>
+  %or = or <8 x i16> %s1, %s2
+  ret <8 x i16> %or
+}
+
+define <2 x i64> @xor_sext_v2i32_to_v2i64(<2 x i32> %a, <2 x i32> %b) {
+; CHECK-LABEL: @xor_sext_v2i32_to_v2i64(
+; CHECK-NEXT:    [[XOR_INNER:%.*]] = xor <2 x i32> [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT:    [[XOR:%.*]] = sext <2 x i32> [[XOR_INNER]] to <2 x i64>
+; CHECK-NEXT:    ret <2 x i64> [[XOR]]
+;
+  %s1 = sext <2 x i32> %a to <2 x i64>
+  %s2 = sext <2 x i32> %b to <2 x i64>
+  %xor = xor <2 x i64> %s1, %s2
+  ret <2 x i64> %xor
+}
+
+; Negative test: mismatched cast types (zext and sext)
+define <4 x i32> @and_zext_sext_mismatch(<4 x i16> %a, <4 x i16> %b) {
+; CHECK-LABEL: @and_zext_sext_mismatch(
+; CHECK-NEXT:    [[Z1:%.*]] = zext <4 x i16> [[A:%.*]] to <4 x i32>
+; CHECK-NEXT:    [[S2:%.*]] = sext <4 x i16> [[B:%.*]] to <4 x i32>
+; CHECK-NEXT:    [[AND:%.*]] = and <4 x i32> [[Z1]], [[S2]]
+; CHECK-NEXT:    ret <4 x i32> [[AND]]
+;
+  %z1 = zext <4 x i16> %a to <4 x i32>
+  %s2 = sext <4 x i16> %b to <4 x i32>
+  %and = and <4 x i32> %z1, %s2
+  ret <4 x i32> %and
+}
+
+; Negative test: mismatched source types
+define <4 x i32> @or_zext_different_src_types(<4 x i16> %a, <4 x i8> %b) {
+; CHECK-LABEL: @or_zext_different_src_types(
+; CHECK-NEXT:    [[Z1:%.*]] = zext <4 x i16> [[A:%.*]] to <4 x i32>
+; CHECK-NEXT:    [[Z2:%.*]] = zext <4 x i8> [[B:%.*]] to <4 x i32>
+; CHECK-NEXT:    [[OR:%.*]] = or <4 x i32> [[Z1]], [[Z2]]
+; CHECK-NEXT:    ret <4 x i32> [[OR]]
+;
+  %z1 = zext <4 x i16> %a to <4 x i32>
+  %z2 = zext <4 x i8> %b to <4 x i32>
+  %or = or <4 x i32> %z1, %z2
+  ret <4 x i32> %or
+}
+
+; Negative test: scalar types (not vectors)
+define i32 @xor_zext_scalar(i16 %a, i16 %b) {
+; CHECK-LABEL: @xor_zext_scalar(
+; CHECK-NEXT:    [[Z1:%.*]] = zext i16 [[A:%.*]] to i32
+; CHECK-NEXT:    [[Z2:%.*]] = zext i16 [[B:%.*]] to i32
+; CHECK-NEXT:    [[XOR:%.*]] = xor i32 [[Z1]], [[Z2]]
+; CHECK-NEXT:    ret i32 [[XOR]]
+;
+  %z1 = zext i16 %a to i32
+  %z2 = zext i16 %b to i32
+  %xor = xor i32 %z1, %z2
+  ret i32 %xor
+}
+
+; Test multi-use: one cast has multiple uses
+define <4 x i32> @and_zext_multiuse(<4 x i16> %a, <4 x i16> %b) {
+; CHECK-LABEL: @and_zext_multiuse(
+; CHECK-NEXT:    [[Z1:%.*]] = zext <4 x i16> [[A:%.*]] to <4 x i32>
+; CHECK-NEXT:    [[AND_INNER:%.*]] = and <4 x i16> [[A]], [[B:%.*]]
+; CHECK-NEXT:    [[AND:%.*]] = zext <4 x i16> [[AND_INNER]] to <4 x i32>
+; CHECK-NEXT:    [[ADD:%.*]] = add <4 x i32> [[Z1]], [[AND]]
+; CHECK-NEXT:    ret <4 x i32> [[ADD]]
+;
+  %z1 = zext <4 x i16> %a to <4 x i32>
+  %z2 = zext <4 x i16> %b to <4 x i32>
+  %and = and <4 x i32> %z1, %z2
+  %add = add <4 x i32> %z1, %and  ; z1 has multiple uses
+  ret <4 x i32> %add
+}
+
+; Test with different vector sizes
+define <16 x i16> @or_zext_v16i8_to_v16i16(<16 x i8> %a, <16 x i8> %b) {
+; CHECK-LABEL: @or_zext_v16i8_to_v16i16(
+; CHECK-NEXT:    [[OR_INNER:%.*]] = or <16 x i8> [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT:    [[OR:%.*]] = zext <16 x i8> [[OR_INNER]] to <16 x i16>
+; CHECK-NEXT:    ret <16 x i16> [[OR]]
+;
+  %z1 = zext <16 x i8> %a to <16 x i16>
+  %z2 = zext <16 x i8> %b to <16 x i16>
+  %or = or <16 x i16> %z1, %z2
+  ret <16 x i16> %or
+}
+
+; Test bitcast with different element counts
+define <8 x i16> @xor_bitcast_v4i32_to_v8i16(<4 x i32> %a, <4 x i32> %b) {
+; CHECK-LABEL: @xor_bitcast_v4i32_to_v8i16(
+; CHECK-NEXT:    [[XOR_INNER:%.*]] = xor <4 x i32> [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT:    [[XOR:%.*]] = bitcast <4 x i32> [[XOR_INNER]] to <8 x i16>
+; CHECK-NEXT:    ret <8 x i16> [[XOR]]
+;
+  %bc1 = bitcast <4 x i32> %a to <8 x i16>
+  %bc2 = bitcast <4 x i32> %b to <8 x i16>
+  %xor = xor <8 x i16> %bc1, %bc2
+  ret <8 x i16> %xor
+}
+
+; Test truncate with flag preservation
+define <4 x i16> @and_trunc_nuw_nsw(<4 x i32> %a, <4 x i32> %b) {
+; CHECK-LABEL: @and_trunc_nuw_nsw(
+; CHECK-NEXT:    [[AND_INNER:%.*]] = and <4 x i32> [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT:    [[AND:%.*]] = trunc nuw nsw <4 x i32> [[AND_INNER]] to <4 x i16>
+; CHECK-NEXT:    ret <4 x i16> [[AND]]
+;
+  %t1 = trunc nuw nsw <4 x i32> %a to <4 x i16>
+  %t2 = trunc nuw nsw <4 x i32> %b to <4 x i16>
+  %and = and <4 x i16> %t1, %t2
+  ret <4 x i16> %and
+}
+
+; Test sign extend with nneg flag
+define <4 x i32> @or_zext_nneg(<4 x i16> %a, <4 x i16> %b) {
+; CHECK-LABEL: @or_zext_nneg(
+; CHECK-NEXT:    [[OR_INNER:%.*]] = or <4 x i16> [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT:    [[OR:%.*]] = zext nneg <4 x i16> [[OR_INNER]] to <4 x i32>
+; CHECK-NEXT:    ret <4 x i32> [[OR]]
+;
+  %z1 = zext nneg <4 x i16> %a to <4 x i32>
+  %z2 = zext nneg <4 x i16> %b to <4 x i32>
+  %or = or <4 x i32> %z1, %z2
+  ret <4 x i32> %or
+}

@rhyadav
Copy link
Author

rhyadav commented Jul 12, 2025

@RKSimon request your review

@rhyadav rhyadav marked this pull request as ready for review July 12, 2025 09:19
@dtcxzyw dtcxzyw requested a review from RKSimon July 12, 2025 13:31
cl::desc("Disable all vector combine transforms"));
static cl::opt<bool>
DisableVectorCombine("disable-vector-combine", cl::init(false), cl::Hidden,
cl::desc("Disable all vector combine transforms"));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(style) don't clang-format lines unrelated to a patch

m_BitCast(m_Value(RHSSrc)))))
bool VectorCombine::foldBitOpOfCastops(Instruction &I) {
// Match: bitop(castop(x), castop(y)) -> castop(bitop(x, y))
// Supports: bitcast, trunc, sext, zext
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(style) Move the description comment up a few lines to outside the function def

if (!BinOp || !BinOp->isBitwiseLogicOp())
return false;

LLVM_DEBUG(dbgs() << "Found bitwise logic op: " << I << "\n");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably drop this?

}

LLVM_DEBUG(dbgs() << " LHS cast: " << *LHSCast << "\n");
LLVM_DEBUG(dbgs() << " RHS cast: " << *RHSCast << "\n");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably drop this?

if (SrcVecTy->getScalarSizeInBits() >= DstVecTy->getScalarSizeInBits())
return false;
break;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if any of these checks are necessary - even the old assertion didn't contribute much as Builder.CreateCast should assert the cast is valid. for us if we get to that stage.

TTI::CastContextHint::None);
TTI.getCastInstrCost(CastOpcode, DstVecTy, SrcVecTy,
TTI::CastContextHint::None) *
2;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we hoist the the separate getCastInstrCost calls here to avoid calling it again for the !hasOneUse cases below, Add the Instruction* args as well to help improve costs - we can't do it for new cost calc but its still useful for old costs. We're missing the CostKind as well

InstructionCost NewCost =
TTI.getArithmeticInstrCost(BinOp->getOpcode(), SrcVecTy) +
TTI.getCastInstrCost(Instruction::BitCast, DstVecTy, SrcVecTy,
TTI.getCastInstrCost(CastOpcode, DstVecTy, SrcVecTy,
TTI::CastContextHint::None);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing CostKind.

TTI.getCastInstrCost(CastOpcode, DstVecTy, SrcVecTy,
TTI::CastContextHint::None) *
2;

InstructionCost NewCost =
TTI.getArithmeticInstrCost(BinOp->getOpcode(), SrcVecTy) +
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing CostKind.

@@ -1020,8 +1086,7 @@ bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) {
InstructionCost OldCost = 2 * SplatCost + VectorOpCost;

// Determine scalar opcode
std::optional<unsigned> FunctionalOpcode =
VPI.getFunctionalOpcode();
std::optional<unsigned> FunctionalOpcode = VPI.getFunctionalOpcode();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(style) don't clang-format lines unrelated to a patch

@RKSimon RKSimon requested a review from davemgreen July 14, 2025 08:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[VectorCombine] Generalise foldBitOpOfBitcasts into foldBitOpOfCastops
3 participants