@@ -129,6 +129,7 @@ class VectorCombine {
129
129
bool foldShuffleOfIntrinsics (Instruction &I);
130
130
bool foldShuffleToIdentity (Instruction &I);
131
131
bool foldShuffleFromReductions (Instruction &I);
132
+ bool foldShuffleChainsToReduce (Instruction &I);
132
133
bool foldCastFromReductions (Instruction &I);
133
134
bool foldSelectShuffle (Instruction &I, bool FromReduction = false );
134
135
bool foldInterleaveIntrinsics (Instruction &I);
@@ -2910,6 +2911,133 @@ bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
2910
2911
return foldSelectShuffle (*Shuffle, true );
2911
2912
}
2912
2913
2914
+ bool VectorCombine::foldShuffleChainsToReduce (Instruction &I) {
2915
+ auto *SVI = dyn_cast<ShuffleVectorInst>(&I);
2916
+ if (!SVI)
2917
+ return false ;
2918
+
2919
+ std::queue<Value *> Worklist;
2920
+ SmallVector<Instruction *> ToEraseFromParent;
2921
+
2922
+ SmallVector<int > ShuffleMask;
2923
+ bool IsShuffleOp = true ;
2924
+
2925
+ Worklist.push (SVI);
2926
+ SVI->getShuffleMask (ShuffleMask);
2927
+
2928
+ if (ShuffleMask.size () < 2 )
2929
+ return false ;
2930
+
2931
+ Instruction *Prev0 = nullptr , *Prev1 = nullptr ;
2932
+ Instruction *LastOp = nullptr ;
2933
+
2934
+ int MaskHalfPos = ShuffleMask.size () / 2 ;
2935
+ bool IsFirst = true ;
2936
+
2937
+ while (!Worklist.empty ()) {
2938
+ Value *V = Worklist.front ();
2939
+ Worklist.pop ();
2940
+
2941
+ auto *CI = dyn_cast<Instruction>(V);
2942
+ if (!CI)
2943
+ return false ;
2944
+
2945
+ if (auto *SV = dyn_cast<ShuffleVectorInst>(V)) {
2946
+ if (!IsShuffleOp || MaskHalfPos < 1 || (!Prev1 && !IsFirst))
2947
+ return false ;
2948
+
2949
+ auto *Op0 = SV->getOperand (0 );
2950
+ auto *Op1 = SV->getOperand (1 );
2951
+ if (!Op0 || !Op1)
2952
+ return false ;
2953
+
2954
+ auto *FVT = dyn_cast<FixedVectorType>(Op1->getType ());
2955
+ if (!FVT || !isa<PoisonValue>(Op1))
2956
+ return false ;
2957
+
2958
+ SmallVector<int > CurrentMask;
2959
+ SV->getShuffleMask (CurrentMask);
2960
+
2961
+ int64_t MaskSize = CurrentMask.size ();
2962
+ for (int MaskPos = 0 ; MaskPos != MaskSize; ++MaskPos) {
2963
+ if (MaskPos < MaskHalfPos &&
2964
+ CurrentMask[MaskPos] != MaskHalfPos + MaskPos)
2965
+ return false ;
2966
+ if (MaskPos >= MaskHalfPos && CurrentMask[MaskPos] != -1 )
2967
+ return false ;
2968
+ }
2969
+ MaskHalfPos /= 2 ;
2970
+ Prev0 = SV;
2971
+ } else if (auto *Call = dyn_cast<CallInst>(V)) {
2972
+ if (IsShuffleOp || !Prev0)
2973
+ return false ;
2974
+
2975
+ auto *II = dyn_cast<IntrinsicInst>(Call);
2976
+ if (!II)
2977
+ return false ;
2978
+
2979
+ switch (II->getIntrinsicID ()) {
2980
+ case Intrinsic::umin: {
2981
+ auto *Op0 = Call->getOperand (0 );
2982
+ auto *Op1 = Call->getOperand (1 );
2983
+ if (!(Op0 == Prev0 && Op1 == Prev1) &&
2984
+ !(Op0 == Prev1 && Op1 == Prev0) && !IsFirst)
2985
+ return false ;
2986
+
2987
+ if (!IsFirst)
2988
+ Prev0 = Prev1;
2989
+ else
2990
+ IsFirst = false ;
2991
+ Prev1 = Call;
2992
+ break ;
2993
+ }
2994
+ default :
2995
+ return false ;
2996
+ }
2997
+ } else if (auto *ExtractElement = dyn_cast<ExtractElementInst>(CI)) {
2998
+ if (!IsShuffleOp || !Prev0 || !Prev1 || MaskHalfPos != 0 )
2999
+ return false ;
3000
+
3001
+ auto *Op0 = ExtractElement->getOperand (0 );
3002
+ auto *Op1 = ExtractElement->getOperand (1 );
3003
+ if (Op0 != Prev1)
3004
+ return false ;
3005
+
3006
+ if (auto *Op1Idx = dyn_cast<ConstantInt>(Op1)) {
3007
+ if (Op1Idx->getValue () != 0 )
3008
+ return false ;
3009
+ } else {
3010
+ return false ;
3011
+ }
3012
+ LastOp = ExtractElement;
3013
+ break ;
3014
+ }
3015
+ IsShuffleOp ^= 1 ;
3016
+ ToEraseFromParent.push_back (CI);
3017
+
3018
+ auto *NextI = CI->getNextNode ();
3019
+ if (!NextI)
3020
+ return false ;
3021
+ Worklist.push (NextI);
3022
+ }
3023
+
3024
+ if (!LastOp)
3025
+ return false ;
3026
+
3027
+ auto *ReducedResult = Builder.CreateIntrinsic (
3028
+ Intrinsic::vector_reduce_umin, {SVI->getType ()}, {SVI->getOperand (0 )});
3029
+ replaceValue (*LastOp, *ReducedResult);
3030
+
3031
+ ToEraseFromParent.push_back (LastOp);
3032
+
3033
+ std::reverse (ToEraseFromParent.begin (), ToEraseFromParent.end ());
3034
+ // for (auto &Instr : ToEraseFromParent)
3035
+ // eraseInstruction(*Instr);
3036
+ // Instr->eraseFromParent();
3037
+
3038
+ return true ;
3039
+ }
3040
+
2913
3041
// / Determine if its more efficient to fold:
2914
3042
// / reduce(trunc(x)) -> trunc(reduce(x)).
2915
3043
// / reduce(sext(x)) -> sext(reduce(x)).
@@ -3607,6 +3735,7 @@ bool VectorCombine::run() {
3607
3735
MadeChange |= foldShuffleOfIntrinsics (I);
3608
3736
MadeChange |= foldSelectShuffle (I);
3609
3737
MadeChange |= foldShuffleToIdentity (I);
3738
+ MadeChange |= foldShuffleChainsToReduce (I);
3610
3739
break ;
3611
3740
case Instruction::BitCast:
3612
3741
MadeChange |= foldBitcastShuffle (I);
0 commit comments