@@ -226,21 +226,20 @@ getVectorLoweringShape(EVT VectorEVT, const NVPTXSubtarget &STI,
226
226
switch (VectorVT.SimpleTy ) {
227
227
default :
228
228
return std::nullopt ;
229
+
229
230
case MVT::v4i64:
230
231
case MVT::v4f64:
231
- case MVT::v8i32:
232
- // This is a "native" vector type iff the address space is global
233
- // and the target supports 256-bit loads/stores
232
+ // This is a "native" vector type iff the address space is global and the
233
+ // target supports 256-bit loads/stores
234
234
if (!CanLowerTo256Bit)
235
235
return std::nullopt ;
236
236
LLVM_FALLTHROUGH;
237
237
case MVT::v2i8:
238
- case MVT::v2i32:
239
238
case MVT::v2i64:
240
239
case MVT::v2f64:
241
- case MVT::v4i32:
242
240
// This is a "native" vector type
243
241
return std::pair (NumElts, EltVT);
242
+
244
243
case MVT::v16f16: // <8 x f16x2>
245
244
case MVT::v16bf16: // <8 x bf16x2>
246
245
case MVT::v16i16: // <8 x i16x2>
@@ -264,12 +263,18 @@ getVectorLoweringShape(EVT VectorEVT, const NVPTXSubtarget &STI,
264
263
case MVT::v16i8: // <4 x i8x4>
265
264
PackRegSize = 32 ;
266
265
break ;
267
- case MVT::v8f32: // <4 x f32x2>
266
+
267
+ case MVT::v8f32: // <4 x f32x2>
268
+ case MVT::v8i32: // <4 x i32x2>
269
+ // This is a "native" vector type iff the address space is global and the
270
+ // target supports 256-bit loads/stores
268
271
if (!CanLowerTo256Bit)
269
272
return std::nullopt ;
270
273
LLVM_FALLTHROUGH;
271
- case MVT::v2f32: // <1 x f32x2>
272
- case MVT::v4f32: // <2 x f32x2>
274
+ case MVT::v2f32: // <1 x f32x2>
275
+ case MVT::v4f32: // <2 x f32x2>
276
+ case MVT::v2i32: // <1 x i32x2>
277
+ case MVT::v4i32: // <2 x i32x2>
273
278
if (!STI.hasF32x2Instructions ())
274
279
return std::pair (NumElts, EltVT);
275
280
PackRegSize = 64 ;
@@ -590,8 +595,10 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
590
595
addRegisterClass (MVT::bf16 , &NVPTX::B16RegClass);
591
596
addRegisterClass (MVT::v2bf16, &NVPTX::B32RegClass);
592
597
593
- if (STI.hasF32x2Instructions ())
598
+ if (STI.hasF32x2Instructions ()) {
594
599
addRegisterClass (MVT::v2f32, &NVPTX::B64RegClass);
600
+ addRegisterClass (MVT::v2i32, &NVPTX::B64RegClass);
601
+ }
595
602
596
603
// Conversion to/from FP16/FP16x2 is always legal.
597
604
setOperationAction (ISD::BUILD_VECTOR, MVT::v2f16, Custom);
@@ -628,12 +635,13 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
628
635
setOperationAction (ISD::INSERT_VECTOR_ELT, MVT::v4i8, Custom);
629
636
setOperationAction (ISD::VECTOR_SHUFFLE, MVT::v4i8, Custom);
630
637
631
- // No support for these operations with v2f32.
632
- setOperationAction (ISD::INSERT_VECTOR_ELT, MVT::v2f32, Expand);
633
- setOperationAction (ISD::VECTOR_SHUFFLE, MVT::v2f32, Expand);
638
+ // No support for these operations with v2f32/v2i32
639
+ setOperationAction (ISD::INSERT_VECTOR_ELT, { MVT::v2f32, MVT::v2i32} , Expand);
640
+ setOperationAction (ISD::VECTOR_SHUFFLE, { MVT::v2f32, MVT::v2i32} , Expand);
634
641
// Need custom lowering in case the index is dynamic.
635
642
if (STI.hasF32x2Instructions ())
636
- setOperationAction (ISD::EXTRACT_VECTOR_ELT, MVT::v2f32, Custom);
643
+ setOperationAction (ISD::EXTRACT_VECTOR_ELT, {MVT::v2f32, MVT::v2i32},
644
+ Custom);
637
645
638
646
// Custom conversions to/from v2i8.
639
647
setOperationAction (ISD::BITCAST, MVT::v2i8, Custom);
@@ -661,14 +669,13 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
661
669
// Operations not directly supported by NVPTX.
662
670
for (MVT VT : {MVT::bf16 , MVT::f16 , MVT::v2bf16, MVT::v2f16, MVT::f32 ,
663
671
MVT::v2f32, MVT::f64 , MVT::i1, MVT::i8 , MVT::i16 , MVT::v2i16,
664
- MVT::v4i8, MVT::i32 , MVT::i64 }) {
672
+ MVT::v4i8, MVT::i32 , MVT::v2i32, MVT:: i64 }) {
665
673
setOperationAction (ISD::SELECT_CC, VT, Expand);
666
674
setOperationAction (ISD::BR_CC, VT, Expand);
667
675
}
668
676
669
- // Not directly supported. TLI would attempt to expand operations like
670
- // FMINIMUM(v2f32) using invalid SETCC and VSELECT nodes.
671
- setOperationAction (ISD::VSELECT, MVT::v2f32, Expand);
677
+ // We don't want ops like FMINIMUM or UMAX to be lowered to SETCC+VSELECT.
678
+ setOperationAction (ISD::VSELECT, {MVT::v2f32, MVT::v2i32}, Expand);
672
679
673
680
// Some SIGN_EXTEND_INREG can be done using cvt instruction.
674
681
// For others we will expand to a SHL/SRA pair.
@@ -815,7 +822,14 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
815
822
setOperationAction ({ISD::SDIV, ISD::UDIV, ISD::SRA, ISD::SRL, ISD::MULHS,
816
823
ISD::MULHU, ISD::FP_TO_SINT, ISD::FP_TO_UINT,
817
824
ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::SETCC},
818
- MVT::v2i16, Expand);
825
+ {MVT::v2i16, MVT::v2i32}, Expand);
826
+
827
+ // v2i32 is not supported for any arithmetic operations
828
+ setOperationAction ({ISD::ABS, ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX,
829
+ ISD::CTPOP, ISD::CTLZ, ISD::ADD, ISD::SUB, ISD::MUL,
830
+ ISD::SHL, ISD::SRA, ISD::SRL, ISD::OR, ISD::AND, ISD::XOR,
831
+ ISD::SREM, ISD::UREM},
832
+ MVT::v2i32, Expand);
819
833
820
834
setOperationAction (ISD::ADDC, MVT::i32 , Legal);
821
835
setOperationAction (ISD::ADDE, MVT::i32 , Legal);
@@ -829,7 +843,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
829
843
}
830
844
831
845
setOperationAction (ISD::CTTZ, MVT::i16 , Expand);
832
- setOperationAction (ISD::CTTZ, MVT::v2i16, Expand);
846
+ setOperationAction (ISD::CTTZ, { MVT::v2i16, MVT::v2i32} , Expand);
833
847
setOperationAction (ISD::CTTZ, MVT::i32 , Expand);
834
848
setOperationAction (ISD::CTTZ, MVT::i64 , Expand);
835
849
@@ -5673,7 +5687,7 @@ static SDValue PerformEXTRACTCombine(SDNode *N,
5673
5687
IsPTXVectorType (VectorVT.getSimpleVT ()))
5674
5688
return SDValue (); // Native vector loads already combine nicely w/
5675
5689
// extract_vector_elt.
5676
- // Don't mess with singletons or packed types (v2f32 , v2*16, v4i8 and v8i8),
5690
+ // Don't mess with singletons or packed types (v2*32 , v2*16, v4i8 and v8i8),
5677
5691
// we already handle them OK.
5678
5692
if (VectorVT.getVectorNumElements () == 1 ||
5679
5693
NVPTX::isPackedVectorTy (VectorVT) || VectorVT == MVT::v8i8)
0 commit comments