Skip to content

Commit fea1941

Browse files
committed
[VectorCombine] New folding pattern for extract/binop/shuffle chains
Resolves #144654 Part of #143088 This adds a new `foldShuffleChainsToReduce` for horizontal reduction of patterns like: ```llvm define i16 @test_reduce_v8i16(<8 x i16> %a0) local_unnamed_addr #0 { %1 = shufflevector <8 x i16> %a0, <8 x i16> poison, <8 x i32> <i32 4, i32 5, i32 6, i32 7, i32 poison, i32 poison, i32 poison, i32 poison> %2 = tail call <8 x i16> @llvm.umin.v8i16(<8 x i16> %a0, <8 x i16> %1) %3 = shufflevector <8 x i16> %2, <8 x i16> poison, <8 x i32> <i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison> %4 = tail call <8 x i16> @llvm.umin.v8i16(<8 x i16> %2, <8 x i16> %3) %5 = shufflevector <8 x i16> %4, <8 x i16> poison, <8 x i32> <i32 1, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison> %6 = tail call <8 x i16> @llvm.umin.v8i16(<8 x i16> %4, <8 x i16> %5) %7 = extractelement <8 x i16> %6, i64 0 ret i16 %7 } ``` ...which can be reduced to a llvm.vector.reduce.umin.v8i16(%a0) intrinsic call. Similar transformation for other ops when costs permit to do so.
1 parent d4826cd commit fea1941

File tree

2 files changed

+147
-0
lines changed

2 files changed

+147
-0
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ class VectorCombine {
129129
bool foldShuffleOfIntrinsics(Instruction &I);
130130
bool foldShuffleToIdentity(Instruction &I);
131131
bool foldShuffleFromReductions(Instruction &I);
132+
bool foldShuffleChainsToReduce(Instruction &I);
132133
bool foldCastFromReductions(Instruction &I);
133134
bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
134135
bool foldInterleaveIntrinsics(Instruction &I);
@@ -2910,6 +2911,133 @@ bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
29102911
return foldSelectShuffle(*Shuffle, true);
29112912
}
29122913

2914+
bool VectorCombine::foldShuffleChainsToReduce(Instruction &I) {
2915+
auto *SVI = dyn_cast<ShuffleVectorInst>(&I);
2916+
if (!SVI)
2917+
return false;
2918+
2919+
std::queue<Value *> Worklist;
2920+
SmallVector<Instruction *> ToEraseFromParent;
2921+
2922+
SmallVector<int> ShuffleMask;
2923+
bool IsShuffleOp = true;
2924+
2925+
Worklist.push(SVI);
2926+
SVI->getShuffleMask(ShuffleMask);
2927+
2928+
if (ShuffleMask.size() < 2)
2929+
return false;
2930+
2931+
Instruction *Prev0 = nullptr, *Prev1 = nullptr;
2932+
Instruction *LastOp = nullptr;
2933+
2934+
int MaskHalfPos = ShuffleMask.size() / 2;
2935+
bool IsFirst = true;
2936+
2937+
while (!Worklist.empty()) {
2938+
Value *V = Worklist.front();
2939+
Worklist.pop();
2940+
2941+
auto *CI = dyn_cast<Instruction>(V);
2942+
if (!CI)
2943+
return false;
2944+
2945+
if (auto *SV = dyn_cast<ShuffleVectorInst>(V)) {
2946+
if (!IsShuffleOp || MaskHalfPos < 1 || (!Prev1 && !IsFirst))
2947+
return false;
2948+
2949+
auto *Op0 = SV->getOperand(0);
2950+
auto *Op1 = SV->getOperand(1);
2951+
if (!Op0 || !Op1)
2952+
return false;
2953+
2954+
auto *FVT = dyn_cast<FixedVectorType>(Op1->getType());
2955+
if (!FVT || !isa<PoisonValue>(Op1))
2956+
return false;
2957+
2958+
SmallVector<int> CurrentMask;
2959+
SV->getShuffleMask(CurrentMask);
2960+
2961+
int64_t MaskSize = CurrentMask.size();
2962+
for (int MaskPos = 0; MaskPos != MaskSize; ++MaskPos) {
2963+
if (MaskPos < MaskHalfPos &&
2964+
CurrentMask[MaskPos] != MaskHalfPos + MaskPos)
2965+
return false;
2966+
if (MaskPos >= MaskHalfPos && CurrentMask[MaskPos] != -1)
2967+
return false;
2968+
}
2969+
MaskHalfPos /= 2;
2970+
Prev0 = SV;
2971+
} else if (auto *Call = dyn_cast<CallInst>(V)) {
2972+
if (IsShuffleOp || !Prev0)
2973+
return false;
2974+
2975+
auto *II = dyn_cast<IntrinsicInst>(Call);
2976+
if (!II)
2977+
return false;
2978+
2979+
switch (II->getIntrinsicID()) {
2980+
case Intrinsic::umin: {
2981+
auto *Op0 = Call->getOperand(0);
2982+
auto *Op1 = Call->getOperand(1);
2983+
if (!(Op0 == Prev0 && Op1 == Prev1) &&
2984+
!(Op0 == Prev1 && Op1 == Prev0) && !IsFirst)
2985+
return false;
2986+
2987+
if (!IsFirst)
2988+
Prev0 = Prev1;
2989+
else
2990+
IsFirst = false;
2991+
Prev1 = Call;
2992+
break;
2993+
}
2994+
default:
2995+
return false;
2996+
}
2997+
} else if (auto *ExtractElement = dyn_cast<ExtractElementInst>(CI)) {
2998+
if (!IsShuffleOp || !Prev0 || !Prev1 || MaskHalfPos != 0)
2999+
return false;
3000+
3001+
auto *Op0 = ExtractElement->getOperand(0);
3002+
auto *Op1 = ExtractElement->getOperand(1);
3003+
if (Op0 != Prev1)
3004+
return false;
3005+
3006+
if (auto *Op1Idx = dyn_cast<ConstantInt>(Op1)) {
3007+
if (Op1Idx->getValue() != 0)
3008+
return false;
3009+
} else {
3010+
return false;
3011+
}
3012+
LastOp = ExtractElement;
3013+
break;
3014+
}
3015+
IsShuffleOp ^= 1;
3016+
ToEraseFromParent.push_back(CI);
3017+
3018+
auto *NextI = CI->getNextNode();
3019+
if (!NextI)
3020+
return false;
3021+
Worklist.push(NextI);
3022+
}
3023+
3024+
if (!LastOp)
3025+
return false;
3026+
3027+
auto *ReducedResult = Builder.CreateIntrinsic(
3028+
Intrinsic::vector_reduce_umin, {SVI->getType()}, {SVI->getOperand(0)});
3029+
replaceValue(*LastOp, *ReducedResult);
3030+
3031+
ToEraseFromParent.push_back(LastOp);
3032+
3033+
std::reverse(ToEraseFromParent.begin(), ToEraseFromParent.end());
3034+
// for (auto &Instr : ToEraseFromParent)
3035+
// eraseInstruction(*Instr);
3036+
// Instr->eraseFromParent();
3037+
3038+
return true;
3039+
}
3040+
29133041
/// Determine if its more efficient to fold:
29143042
/// reduce(trunc(x)) -> trunc(reduce(x)).
29153043
/// reduce(sext(x)) -> sext(reduce(x)).
@@ -3607,6 +3735,7 @@ bool VectorCombine::run() {
36073735
MadeChange |= foldShuffleOfIntrinsics(I);
36083736
MadeChange |= foldSelectShuffle(I);
36093737
MadeChange |= foldShuffleToIdentity(I);
3738+
MadeChange |= foldShuffleChainsToReduce(I);
36103739
break;
36113740
case Instruction::BitCast:
36123741
MadeChange |= foldBitcastShuffle(I);
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
2+
; RUN: opt < %s -passes=vector-combine -S | FileCheck %s
3+
4+
define i16 @test_reduce_v8i16(<8 x i16> %a0) local_unnamed_addr #0 {
5+
; CHECK-LABEL: define i16 @test_reduce_v8i16(
6+
; CHECK-SAME: <8 x i16> [[A0:%.*]]) local_unnamed_addr {
7+
; CHECK-NEXT: [[TMP1:%.*]] = call i16 @llvm.vector.reduce.umin.v8i16(<8 x i16> [[A0]])
8+
; CHECK-NEXT: ret i16 [[TMP1]]
9+
;
10+
%1 = shufflevector <8 x i16> %a0, <8 x i16> poison, <8 x i32> <i32 4, i32 5, i32 6, i32 7, i32 poison, i32 poison, i32 poison, i32 poison>
11+
%2 = tail call <8 x i16> @llvm.umin.v8i16(<8 x i16> %a0, <8 x i16> %1)
12+
%3 = shufflevector <8 x i16> %2, <8 x i16> poison, <8 x i32> <i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
13+
%4 = tail call <8 x i16> @llvm.umin.v8i16(<8 x i16> %2, <8 x i16> %3)
14+
%5 = shufflevector <8 x i16> %4, <8 x i16> poison, <8 x i32> <i32 1, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
15+
%6 = tail call <8 x i16> @llvm.umin.v8i16(<8 x i16> %4, <8 x i16> %5)
16+
%7 = extractelement <8 x i16> %6, i64 0
17+
ret i16 %7
18+
}

0 commit comments

Comments
 (0)