@@ -226,21 +226,20 @@ getVectorLoweringShape(EVT VectorEVT, const NVPTXSubtarget &STI,
226
226
switch (VectorVT.SimpleTy ) {
227
227
default :
228
228
return std::nullopt ;
229
+
229
230
case MVT::v4i64:
230
231
case MVT::v4f64:
231
- case MVT::v8i32:
232
- // This is a "native" vector type iff the address space is global
233
- // and the target supports 256-bit loads/stores
232
+ // This is a "native" vector type iff the address space is global and the
233
+ // target supports 256-bit loads/stores
234
234
if (!CanLowerTo256Bit)
235
235
return std::nullopt ;
236
236
LLVM_FALLTHROUGH;
237
237
case MVT::v2i8:
238
- case MVT::v2i32:
239
238
case MVT::v2i64:
240
239
case MVT::v2f64:
241
- case MVT::v4i32:
242
240
// This is a "native" vector type
243
241
return std::pair (NumElts, EltVT);
242
+
244
243
case MVT::v16f16: // <8 x f16x2>
245
244
case MVT::v16bf16: // <8 x bf16x2>
246
245
case MVT::v16i16: // <8 x i16x2>
@@ -264,12 +263,18 @@ getVectorLoweringShape(EVT VectorEVT, const NVPTXSubtarget &STI,
264
263
case MVT::v16i8: // <4 x i8x4>
265
264
PackRegSize = 32 ;
266
265
break ;
267
- case MVT::v8f32: // <4 x f32x2>
266
+
267
+ case MVT::v8f32: // <4 x f32x2>
268
+ case MVT::v8i32: // <4 x i32x2>
269
+ // This is a "native" vector type iff the address space is global and the
270
+ // target supports 256-bit loads/stores
268
271
if (!CanLowerTo256Bit)
269
272
return std::nullopt ;
270
273
LLVM_FALLTHROUGH;
271
- case MVT::v2f32: // <1 x f32x2>
272
- case MVT::v4f32: // <2 x f32x2>
274
+ case MVT::v2f32: // <1 x f32x2>
275
+ case MVT::v4f32: // <2 x f32x2>
276
+ case MVT::v2i32: // <1 x i32x2>
277
+ case MVT::v4i32: // <2 x i32x2>
273
278
if (!STI.hasF32x2Instructions ())
274
279
return std::pair (NumElts, EltVT);
275
280
PackRegSize = 64 ;
@@ -590,8 +595,10 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
590
595
addRegisterClass (MVT::bf16 , &NVPTX::B16RegClass);
591
596
addRegisterClass (MVT::v2bf16, &NVPTX::B32RegClass);
592
597
593
- if (STI.hasF32x2Instructions ())
598
+ if (STI.hasF32x2Instructions ()) {
594
599
addRegisterClass (MVT::v2f32, &NVPTX::B64RegClass);
600
+ addRegisterClass (MVT::v2i32, &NVPTX::B64RegClass);
601
+ }
595
602
596
603
// Conversion to/from FP16/FP16x2 is always legal.
597
604
setOperationAction (ISD::BUILD_VECTOR, MVT::v2f16, Custom);
@@ -628,12 +635,13 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
628
635
setOperationAction (ISD::INSERT_VECTOR_ELT, MVT::v4i8, Custom);
629
636
setOperationAction (ISD::VECTOR_SHUFFLE, MVT::v4i8, Custom);
630
637
631
- // No support for these operations with v2f32.
632
- setOperationAction (ISD::INSERT_VECTOR_ELT, MVT::v2f32, Expand);
633
- setOperationAction (ISD::VECTOR_SHUFFLE, MVT::v2f32, Expand);
638
+ // No support for these operations with v2f32/v2i32
639
+ setOperationAction (ISD::INSERT_VECTOR_ELT, { MVT::v2f32, MVT::v2i32} , Expand);
640
+ setOperationAction (ISD::VECTOR_SHUFFLE, { MVT::v2f32, MVT::v2i32} , Expand);
634
641
// Need custom lowering in case the index is dynamic.
635
642
if (STI.hasF32x2Instructions ())
636
- setOperationAction (ISD::EXTRACT_VECTOR_ELT, MVT::v2f32, Custom);
643
+ setOperationAction (ISD::EXTRACT_VECTOR_ELT, {MVT::v2f32, MVT::v2i32},
644
+ Custom);
637
645
638
646
// Custom conversions to/from v2i8.
639
647
setOperationAction (ISD::BITCAST, MVT::v2i8, Custom);
@@ -661,14 +669,13 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
661
669
// Operations not directly supported by NVPTX.
662
670
for (MVT VT : {MVT::bf16 , MVT::f16 , MVT::v2bf16, MVT::v2f16, MVT::f32 ,
663
671
MVT::v2f32, MVT::f64 , MVT::i1, MVT::i8 , MVT::i16 , MVT::v2i16,
664
- MVT::v4i8, MVT::i32 , MVT::i64 }) {
672
+ MVT::v4i8, MVT::i32 , MVT::v2i32, MVT:: i64 }) {
665
673
setOperationAction (ISD::SELECT_CC, VT, Expand);
666
674
setOperationAction (ISD::BR_CC, VT, Expand);
667
675
}
668
676
669
- // Not directly supported. TLI would attempt to expand operations like
670
- // FMINIMUM(v2f32) using invalid SETCC and VSELECT nodes.
671
- setOperationAction (ISD::VSELECT, MVT::v2f32, Expand);
677
+ // We don't want ops like FMINIMUM or UMAX to be lowered to SETCC+VSELECT.
678
+ setOperationAction (ISD::VSELECT, {MVT::v2f32, MVT::v2i32}, Expand);
672
679
673
680
// Some SIGN_EXTEND_INREG can be done using cvt instruction.
674
681
// For others we will expand to a SHL/SRA pair.
@@ -815,7 +822,14 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
815
822
setOperationAction ({ISD::SDIV, ISD::UDIV, ISD::SRA, ISD::SRL, ISD::MULHS,
816
823
ISD::MULHU, ISD::FP_TO_SINT, ISD::FP_TO_UINT,
817
824
ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::SETCC},
818
- MVT::v2i16, Expand);
825
+ {MVT::v2i16, MVT::v2i32}, Expand);
826
+
827
+ // v2i32 is not supported for any arithmetic operations
828
+ setOperationAction ({ISD::ABS, ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX,
829
+ ISD::CTPOP, ISD::CTLZ, ISD::ADD, ISD::SUB, ISD::MUL,
830
+ ISD::SHL, ISD::SRA, ISD::SRL, ISD::OR, ISD::AND, ISD::XOR,
831
+ ISD::SREM, ISD::UREM},
832
+ MVT::v2i32, Expand);
819
833
820
834
setOperationAction (ISD::ADDC, MVT::i32 , Legal);
821
835
setOperationAction (ISD::ADDE, MVT::i32 , Legal);
@@ -829,7 +843,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
829
843
}
830
844
831
845
setOperationAction (ISD::CTTZ, MVT::i16 , Expand);
832
- setOperationAction (ISD::CTTZ, MVT::v2i16, Expand);
846
+ setOperationAction (ISD::CTTZ, { MVT::v2i16, MVT::v2i32} , Expand);
833
847
setOperationAction (ISD::CTTZ, MVT::i32 , Expand);
834
848
setOperationAction (ISD::CTTZ, MVT::i64 , Expand);
835
849
@@ -1067,7 +1081,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
1067
1081
// Custom lowering for tcgen05.st vector operands
1068
1082
setOperationAction (ISD::INTRINSIC_VOID,
1069
1083
{MVT::v2i32, MVT::v4i32, MVT::v8i32, MVT::v16i32,
1070
- MVT::v32i32, MVT::v64i32, MVT::v128i32},
1084
+ MVT::v32i32, MVT::v64i32, MVT::v128i32, MVT::Other },
1071
1085
Custom);
1072
1086
1073
1087
// Enable custom lowering for the following:
@@ -2572,7 +2586,7 @@ static SDValue LowerVectorArith(SDValue Op, SelectionDAG &DAG) {
2572
2586
return V;
2573
2587
}
2574
2588
2575
- static SDValue LowerTcgen05St (SDValue Op, SelectionDAG &DAG) {
2589
+ static SDValue lowerTcgen05St (SDValue Op, SelectionDAG &DAG) {
2576
2590
SDNode *N = Op.getNode ();
2577
2591
SDLoc DL (N);
2578
2592
SmallVector<SDValue, 32 > Ops;
@@ -2598,7 +2612,52 @@ static SDValue LowerTcgen05St(SDValue Op, SelectionDAG &DAG) {
2598
2612
return Tcgen05StNode;
2599
2613
}
2600
2614
2601
- static SDValue LowerIntrinsicVoid (SDValue Op, SelectionDAG &DAG) {
2615
+ // Lower vector return type of tcgen05.ld intrinsics
2616
+ static std::optional<std::pair<SDValue, SDValue>>
2617
+ lowerTcgen05Ld (SDNode *N, SelectionDAG &DAG, bool HasOffset = false ) {
2618
+ SDLoc DL (N);
2619
+ EVT ResVT = N->getValueType (0 );
2620
+ if (!ResVT.isVector ())
2621
+ return {}; // already legalized.
2622
+
2623
+ const unsigned NumElts = ResVT.getVectorNumElements ();
2624
+
2625
+ // Create the return type of the instructions
2626
+ SmallVector<EVT, 5 > ListVTs;
2627
+ for (unsigned i = 0 ; i < NumElts; ++i)
2628
+ ListVTs.push_back (MVT::i32 );
2629
+
2630
+ ListVTs.push_back (N->getValueType (1 )); // Chain
2631
+
2632
+ SDVTList ResVTs = DAG.getVTList (ListVTs);
2633
+
2634
+ SmallVector<SDValue, 8 > Ops{N->getOperand (0 ), N->getOperand (1 ),
2635
+ N->getOperand (2 )};
2636
+
2637
+ if (HasOffset) {
2638
+ Ops.push_back (N->getOperand (3 )); // offset
2639
+ Ops.push_back (N->getOperand (4 )); // Pack flag
2640
+ } else
2641
+ Ops.push_back (N->getOperand (3 )); // Pack flag
2642
+
2643
+ MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
2644
+ SDValue NewNode =
2645
+ DAG.getMemIntrinsicNode (ISD::INTRINSIC_W_CHAIN, DL, ResVTs, Ops,
2646
+ MemSD->getMemoryVT (), MemSD->getMemOperand ());
2647
+
2648
+ // split the vector result
2649
+ SmallVector<SDValue, 4 > ScalarRes;
2650
+ for (unsigned i = 0 ; i < NumElts; ++i) {
2651
+ SDValue Res = NewNode.getValue (i);
2652
+ ScalarRes.push_back (Res);
2653
+ }
2654
+
2655
+ SDValue Chain = NewNode.getValue (NumElts);
2656
+ SDValue BuildVector = DAG.getNode (ISD::BUILD_VECTOR, DL, ResVT, ScalarRes);
2657
+ return {{BuildVector, Chain}};
2658
+ }
2659
+
2660
+ static SDValue lowerIntrinsicVoid (SDValue Op, SelectionDAG &DAG) {
2602
2661
SDNode *N = Op.getNode ();
2603
2662
SDValue Intrin = N->getOperand (1 );
2604
2663
@@ -2644,7 +2703,7 @@ static SDValue LowerIntrinsicVoid(SDValue Op, SelectionDAG &DAG) {
2644
2703
case Intrinsic::nvvm_tcgen05_st_16x64b_x64:
2645
2704
case Intrinsic::nvvm_tcgen05_st_32x32b_x64:
2646
2705
case Intrinsic::nvvm_tcgen05_st_32x32b_x128:
2647
- return LowerTcgen05St (Op, DAG);
2706
+ return lowerTcgen05St (Op, DAG);
2648
2707
}
2649
2708
return Op;
2650
2709
}
@@ -2717,6 +2776,26 @@ static SDValue lowerPrmtIntrinsic(SDValue Op, SelectionDAG &DAG) {
2717
2776
SDValue Selector = (Op->op_end () - 1 )->get ();
2718
2777
return getPRMT (A, B, Selector, DL, DAG, Mode);
2719
2778
}
2779
+
2780
+ static SDValue lowerIntrinsicWChain (SDValue Op, SelectionDAG &DAG) {
2781
+ switch (Op->getConstantOperandVal (1 )) {
2782
+ default :
2783
+ return Op;
2784
+
2785
+ case Intrinsic::nvvm_tcgen05_ld_16x64b_x2:
2786
+ case Intrinsic::nvvm_tcgen05_ld_16x128b_x1:
2787
+ case Intrinsic::nvvm_tcgen05_ld_32x32b_x2:
2788
+ if (auto Pair = lowerTcgen05Ld (Op.getNode (), DAG))
2789
+ return DAG.getMergeValues ({Pair->first , Pair->second }, SDLoc (Op));
2790
+ return SDValue ();
2791
+
2792
+ case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x2:
2793
+ if (auto Pair = lowerTcgen05Ld (Op.getNode (), DAG, /* HasOffset=*/ true ))
2794
+ return DAG.getMergeValues ({Pair->first , Pair->second }, SDLoc (Op));
2795
+ return SDValue ();
2796
+ }
2797
+ }
2798
+
2720
2799
static SDValue lowerIntrinsicWOChain (SDValue Op, SelectionDAG &DAG) {
2721
2800
switch (Op->getConstantOperandVal (0 )) {
2722
2801
default :
@@ -2879,11 +2958,11 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
2879
2958
case ISD::ADDRSPACECAST:
2880
2959
return LowerADDRSPACECAST (Op, DAG);
2881
2960
case ISD::INTRINSIC_W_CHAIN:
2882
- return Op ;
2961
+ return lowerIntrinsicWChain (Op, DAG) ;
2883
2962
case ISD::INTRINSIC_WO_CHAIN:
2884
2963
return lowerIntrinsicWOChain (Op, DAG);
2885
2964
case ISD::INTRINSIC_VOID:
2886
- return LowerIntrinsicVoid (Op, DAG);
2965
+ return lowerIntrinsicVoid (Op, DAG);
2887
2966
case ISD::BUILD_VECTOR:
2888
2967
return LowerBUILD_VECTOR (Op, DAG);
2889
2968
case ISD::BITCAST:
@@ -5673,7 +5752,7 @@ static SDValue PerformEXTRACTCombine(SDNode *N,
5673
5752
IsPTXVectorType (VectorVT.getSimpleVT ()))
5674
5753
return SDValue (); // Native vector loads already combine nicely w/
5675
5754
// extract_vector_elt.
5676
- // Don't mess with singletons or packed types (v2f32 , v2*16, v4i8 and v8i8),
5755
+ // Don't mess with singletons or packed types (v2*32 , v2*16, v4i8 and v8i8),
5677
5756
// we already handle them OK.
5678
5757
if (VectorVT.getVectorNumElements () == 1 ||
5679
5758
NVPTX::isPackedVectorTy (VectorVT) || VectorVT == MVT::v8i8)
@@ -6045,53 +6124,6 @@ static void ReplaceBITCAST(SDNode *Node, SelectionDAG &DAG,
6045
6124
DAG.getNode (ISD::BUILD_VECTOR, DL, MVT::v2i8, {Vec0, Vec1}));
6046
6125
}
6047
6126
6048
- // Lower vector return type of tcgen05.ld intrinsics
6049
- static void ReplaceTcgen05Ld (SDNode *N, SelectionDAG &DAG,
6050
- SmallVectorImpl<SDValue> &Results,
6051
- bool hasOffset = false ) {
6052
- SDLoc DL (N);
6053
- EVT ResVT = N->getValueType (0 );
6054
- if (!ResVT.isVector ())
6055
- return ; // already legalized.
6056
-
6057
- const unsigned NumElts = ResVT.getVectorNumElements ();
6058
-
6059
- // Create the return type of the instructions
6060
- SmallVector<EVT, 5 > ListVTs;
6061
- for (unsigned i = 0 ; i < NumElts; ++i)
6062
- ListVTs.push_back (MVT::i32 );
6063
-
6064
- ListVTs.push_back (N->getValueType (1 )); // Chain
6065
-
6066
- SDVTList ResVTs = DAG.getVTList (ListVTs);
6067
-
6068
- SmallVector<SDValue, 8 > Ops{N->getOperand (0 ), N->getOperand (1 ),
6069
- N->getOperand (2 )};
6070
-
6071
- if (hasOffset) {
6072
- Ops.push_back (N->getOperand (3 )); // offset
6073
- Ops.push_back (N->getOperand (4 )); // Pack flag
6074
- } else
6075
- Ops.push_back (N->getOperand (3 )); // Pack flag
6076
-
6077
- MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
6078
- SDValue NewNode =
6079
- DAG.getMemIntrinsicNode (ISD::INTRINSIC_W_CHAIN, DL, ResVTs, Ops,
6080
- MemSD->getMemoryVT (), MemSD->getMemOperand ());
6081
-
6082
- // split the vector result
6083
- SmallVector<SDValue, 4 > ScalarRes;
6084
- for (unsigned i = 0 ; i < NumElts; ++i) {
6085
- SDValue Res = NewNode.getValue (i);
6086
- ScalarRes.push_back (Res);
6087
- }
6088
-
6089
- SDValue Chain = NewNode.getValue (NumElts);
6090
- SDValue BuildVector = DAG.getNode (ISD::BUILD_VECTOR, DL, ResVT, ScalarRes);
6091
- Results.push_back (BuildVector); // Build Vector
6092
- Results.push_back (Chain); // Chain
6093
- }
6094
-
6095
6127
static void ReplaceINTRINSIC_W_CHAIN (SDNode *N, SelectionDAG &DAG,
6096
6128
SmallVectorImpl<SDValue> &Results) {
6097
6129
SDValue Chain = N->getOperand (0 );
@@ -6227,7 +6259,11 @@ static void ReplaceINTRINSIC_W_CHAIN(SDNode *N, SelectionDAG &DAG,
6227
6259
case Intrinsic::nvvm_tcgen05_ld_16x256b_x8:
6228
6260
case Intrinsic::nvvm_tcgen05_ld_16x256b_x16:
6229
6261
case Intrinsic::nvvm_tcgen05_ld_16x256b_x32:
6230
- return ReplaceTcgen05Ld (N, DAG, Results);
6262
+ if (auto Pair = lowerTcgen05Ld (N, DAG)) {
6263
+ Results.push_back (Pair->first );
6264
+ Results.push_back (Pair->second );
6265
+ }
6266
+ return ;
6231
6267
6232
6268
case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x2:
6233
6269
case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x4:
@@ -6236,7 +6272,11 @@ static void ReplaceINTRINSIC_W_CHAIN(SDNode *N, SelectionDAG &DAG,
6236
6272
case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x32:
6237
6273
case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x64:
6238
6274
case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x128:
6239
- return ReplaceTcgen05Ld (N, DAG, Results, /* Offset */ true );
6275
+ if (auto Pair = lowerTcgen05Ld (N, DAG, /* HasOffset=*/ true )) {
6276
+ Results.push_back (Pair->first );
6277
+ Results.push_back (Pair->second );
6278
+ }
6279
+ return ;
6240
6280
}
6241
6281
}
6242
6282
0 commit comments