-
Notifications
You must be signed in to change notification settings - Fork 12.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[VectorCombine] Fold vector.interleave2 with two constant splats #125144
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-llvm-transforms Author: Min-Yih Hsu (mshockwave) ChangesIf we're interleaving 2 constant splats, for instance This is split out from #120490 Full diff: https://github.com/llvm/llvm-project/pull/125144.diff 2 Files Affected:
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 59920b5a4dd20a..fd49620b5e3ac3 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -125,6 +125,7 @@ class VectorCombine {
bool foldShuffleFromReductions(Instruction &I);
bool foldCastFromReductions(Instruction &I);
bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
+ bool foldInterleaveIntrinsics(Instruction &I);
bool shrinkType(Instruction &I);
void replaceValue(Value &Old, Value &New) {
@@ -3145,6 +3146,45 @@ bool VectorCombine::foldInsExtVectorToShuffle(Instruction &I) {
return true;
}
+bool VectorCombine::foldInterleaveIntrinsics(Instruction &I) {
+ // If we're interleaving 2 constant splats, for instance `<vscale x 8 x i32>
+ // <splat of 666>` and `<vscale x 8 x i32> <splat of 777>`, we can create a
+ // larger splat
+ // `<vscale x 8 x i64> <splat of ((777 << 32) | 666)>` first before casting it
+ // back into `<vscale x 16 x i32>`.
+ using namespace PatternMatch;
+ const APInt *SplatVal0, *SplatVal1;
+ if (!match(&I, m_Intrinsic<Intrinsic::vector_interleave2>(
+ m_APInt(SplatVal0), m_APInt(SplatVal1))))
+ return false;
+
+ LLVM_DEBUG(dbgs() << "VC: Folding interleave2 with two splats: " << I
+ << "\n");
+
+ auto *VTy =
+ cast<VectorType>(cast<IntrinsicInst>(I).getArgOperand(0)->getType());
+ auto *ExtVTy = VectorType::getExtendedElementVectorType(VTy);
+ unsigned Width = VTy->getElementType()->getIntegerBitWidth();
+
+ if (TTI.getInstructionCost(&I, CostKind) <
+ TTI.getCastInstrCost(Instruction::BitCast, I.getType(), ExtVTy,
+ TTI::CastContextHint::None, CostKind)) {
+ LLVM_DEBUG(dbgs() << "VC: The cost to cast from " << *ExtVTy << " to "
+ << *I.getType() << " is too high.\n");
+ return false;
+ }
+
+ APInt NewSplatVal = SplatVal1->zext(Width * 2);
+ NewSplatVal <<= Width;
+ NewSplatVal |= SplatVal0->zext(Width * 2);
+ auto *NewSplat = ConstantVector::getSplat(
+ ExtVTy->getElementCount(), ConstantInt::get(F.getContext(), NewSplatVal));
+
+ IRBuilder<> Builder(&I);
+ replaceValue(I, *Builder.CreateBitCast(NewSplat, I.getType()));
+ return true;
+}
+
/// This is the entry point for all transforms. Pass manager differences are
/// handled in the callers of this function.
bool VectorCombine::run() {
@@ -3189,6 +3229,7 @@ bool VectorCombine::run() {
MadeChange |= scalarizeBinopOrCmp(I);
MadeChange |= scalarizeLoadExtract(I);
MadeChange |= scalarizeVPIntrinsic(I);
+ MadeChange |= foldInterleaveIntrinsics(I);
}
if (Opcode == Instruction::Store)
diff --git a/llvm/test/Transforms/VectorCombine/RISCV/vector-interleave2-splat.ll b/llvm/test/Transforms/VectorCombine/RISCV/vector-interleave2-splat.ll
new file mode 100644
index 00000000000000..f2eb4e4e2dbc85
--- /dev/null
+++ b/llvm/test/Transforms/VectorCombine/RISCV/vector-interleave2-splat.ll
@@ -0,0 +1,14 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S -mtriple=riscv64 -mattr=+v,+m,+zvfh %s -passes=vector-combine | FileCheck %s
+; RUN: opt -S -mtriple=riscv32 -mattr=+v,+m,+zvfh %s -passes=vector-combine | FileCheck %s
+
+define void @store_factor2_const_splat(ptr %dst) {
+; CHECK-LABEL: define void @store_factor2_const_splat(
+; CHECK-SAME: ptr [[DST:%.*]]) #[[ATTR0:[0-9]+]] {
+; CHECK-NEXT: call void @llvm.vp.store.nxv16i32.p0(<vscale x 16 x i32> bitcast (<vscale x 8 x i64> splat (i64 3337189589658) to <vscale x 16 x i32>), ptr [[DST]], <vscale x 16 x i1> splat (i1 true), i32 88)
+; CHECK-NEXT: ret void
+;
+ %interleave2 = call <vscale x 16 x i32> @llvm.vector.interleave2.nxv16i32(<vscale x 8 x i32> splat (i32 666), <vscale x 8 x i32> splat (i32 777))
+ call void @llvm.vp.store.nxv16i32.p0(<vscale x 16 x i32> %interleave2, ptr %dst, <vscale x 16 x i1> splat (i1 true), i32 88)
+ ret void
+}
|
@llvm/pr-subscribers-vectorizers Author: Min-Yih Hsu (mshockwave) ChangesIf we're interleaving 2 constant splats, for instance This is split out from #120490 Full diff: https://github.com/llvm/llvm-project/pull/125144.diff 2 Files Affected:
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 59920b5a4dd20a..fd49620b5e3ac3 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -125,6 +125,7 @@ class VectorCombine {
bool foldShuffleFromReductions(Instruction &I);
bool foldCastFromReductions(Instruction &I);
bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
+ bool foldInterleaveIntrinsics(Instruction &I);
bool shrinkType(Instruction &I);
void replaceValue(Value &Old, Value &New) {
@@ -3145,6 +3146,45 @@ bool VectorCombine::foldInsExtVectorToShuffle(Instruction &I) {
return true;
}
+bool VectorCombine::foldInterleaveIntrinsics(Instruction &I) {
+ // If we're interleaving 2 constant splats, for instance `<vscale x 8 x i32>
+ // <splat of 666>` and `<vscale x 8 x i32> <splat of 777>`, we can create a
+ // larger splat
+ // `<vscale x 8 x i64> <splat of ((777 << 32) | 666)>` first before casting it
+ // back into `<vscale x 16 x i32>`.
+ using namespace PatternMatch;
+ const APInt *SplatVal0, *SplatVal1;
+ if (!match(&I, m_Intrinsic<Intrinsic::vector_interleave2>(
+ m_APInt(SplatVal0), m_APInt(SplatVal1))))
+ return false;
+
+ LLVM_DEBUG(dbgs() << "VC: Folding interleave2 with two splats: " << I
+ << "\n");
+
+ auto *VTy =
+ cast<VectorType>(cast<IntrinsicInst>(I).getArgOperand(0)->getType());
+ auto *ExtVTy = VectorType::getExtendedElementVectorType(VTy);
+ unsigned Width = VTy->getElementType()->getIntegerBitWidth();
+
+ if (TTI.getInstructionCost(&I, CostKind) <
+ TTI.getCastInstrCost(Instruction::BitCast, I.getType(), ExtVTy,
+ TTI::CastContextHint::None, CostKind)) {
+ LLVM_DEBUG(dbgs() << "VC: The cost to cast from " << *ExtVTy << " to "
+ << *I.getType() << " is too high.\n");
+ return false;
+ }
+
+ APInt NewSplatVal = SplatVal1->zext(Width * 2);
+ NewSplatVal <<= Width;
+ NewSplatVal |= SplatVal0->zext(Width * 2);
+ auto *NewSplat = ConstantVector::getSplat(
+ ExtVTy->getElementCount(), ConstantInt::get(F.getContext(), NewSplatVal));
+
+ IRBuilder<> Builder(&I);
+ replaceValue(I, *Builder.CreateBitCast(NewSplat, I.getType()));
+ return true;
+}
+
/// This is the entry point for all transforms. Pass manager differences are
/// handled in the callers of this function.
bool VectorCombine::run() {
@@ -3189,6 +3229,7 @@ bool VectorCombine::run() {
MadeChange |= scalarizeBinopOrCmp(I);
MadeChange |= scalarizeLoadExtract(I);
MadeChange |= scalarizeVPIntrinsic(I);
+ MadeChange |= foldInterleaveIntrinsics(I);
}
if (Opcode == Instruction::Store)
diff --git a/llvm/test/Transforms/VectorCombine/RISCV/vector-interleave2-splat.ll b/llvm/test/Transforms/VectorCombine/RISCV/vector-interleave2-splat.ll
new file mode 100644
index 00000000000000..f2eb4e4e2dbc85
--- /dev/null
+++ b/llvm/test/Transforms/VectorCombine/RISCV/vector-interleave2-splat.ll
@@ -0,0 +1,14 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S -mtriple=riscv64 -mattr=+v,+m,+zvfh %s -passes=vector-combine | FileCheck %s
+; RUN: opt -S -mtriple=riscv32 -mattr=+v,+m,+zvfh %s -passes=vector-combine | FileCheck %s
+
+define void @store_factor2_const_splat(ptr %dst) {
+; CHECK-LABEL: define void @store_factor2_const_splat(
+; CHECK-SAME: ptr [[DST:%.*]]) #[[ATTR0:[0-9]+]] {
+; CHECK-NEXT: call void @llvm.vp.store.nxv16i32.p0(<vscale x 16 x i32> bitcast (<vscale x 8 x i64> splat (i64 3337189589658) to <vscale x 16 x i32>), ptr [[DST]], <vscale x 16 x i1> splat (i1 true), i32 88)
+; CHECK-NEXT: ret void
+;
+ %interleave2 = call <vscale x 16 x i32> @llvm.vector.interleave2.nxv16i32(<vscale x 8 x i32> splat (i32 666), <vscale x 8 x i32> splat (i32 777))
+ call void @llvm.vp.store.nxv16i32.p0(<vscale x 16 x i32> %interleave2, ptr %dst, <vscale x 16 x i1> splat (i1 true), i32 88)
+ ret void
+}
|
unsigned Width = VTy->getElementType()->getIntegerBitWidth(); | ||
|
||
if (TTI.getInstructionCost(&I, CostKind) < | ||
TTI.getCastInstrCost(Instruction::BitCast, I.getType(), ExtVTy, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're really just worrying about the legalization cost here, should ExtVTy
be an illegal type.
@@ -0,0 +1,14 @@ | |||
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 | |||
; RUN: opt -S -mtriple=riscv64 -mattr=+v,+m,+zvfh %s -passes=vector-combine | FileCheck %s | |||
; RUN: opt -S -mtriple=riscv32 -mattr=+v,+m,+zvfh %s -passes=vector-combine | FileCheck %s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a test that uses Zve32x instead of V. We shouldn't form an i64 vector type in that case. I think it would crash the backend.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe test to make sure we don't do i64->i128 too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM w/ the zve32x test
// larger splat | ||
// `<vscale x 8 x i64> <splat of ((777 << 32) | 666)>` first before casting it | ||
// back into `<vscale x 16 x i32>`. | ||
using namespace PatternMatch; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can remove this since PatternMatch is already included at the top of VectorCombine.cpp
// If we're interleaving 2 constant splats, for instance `<vscale x 8 x i32> | ||
// <splat of 666>` and `<vscale x 8 x i32> <splat of 777>`, we can create a | ||
// larger splat | ||
// `<vscale x 8 x i64> <splat of ((777 << 32) | 666)>` first before casting it | ||
// back into `<vscale x 16 x i32>`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit, make this a doc comment by moving above the signature + use ///
If we're interleaving 2 constant splats, for instance
<vscale x 8 x i32> <splat of 666>
and<vscale x 8 x i32> <splat of 777>
, we can create a larger splat<vscale x 8 x i64> <splat of ((777 << 32) | 666)>
first before casting it back into<vscale x 16 x i32>
.This is split out from #120490