@@ -228,17 +228,9 @@ getVectorLoweringShape(EVT VectorEVT, const NVPTXSubtarget &STI,
228
228
return std::nullopt ;
229
229
case MVT::v4i64:
230
230
case MVT::v4f64:
231
- case MVT::v8i32:
232
- // This is a "native" vector type iff the address space is global
233
- // and the target supports 256-bit loads/stores
234
- if (!CanLowerTo256Bit)
235
- return std::nullopt ;
236
- LLVM_FALLTHROUGH;
237
231
case MVT::v2i8:
238
- case MVT::v2i32:
239
232
case MVT::v2i64:
240
233
case MVT::v2f64:
241
- case MVT::v4i32:
242
234
// This is a "native" vector type
243
235
return std::pair (NumElts, EltVT);
244
236
case MVT::v16f16: // <8 x f16x2>
@@ -264,6 +256,7 @@ getVectorLoweringShape(EVT VectorEVT, const NVPTXSubtarget &STI,
264
256
case MVT::v16i8: // <4 x i8x4>
265
257
PackRegSize = 32 ;
266
258
break ;
259
+
267
260
case MVT::v8f32: // <4 x f32x2>
268
261
if (!CanLowerTo256Bit)
269
262
return std::nullopt ;
@@ -274,6 +267,17 @@ getVectorLoweringShape(EVT VectorEVT, const NVPTXSubtarget &STI,
274
267
return std::pair (NumElts, EltVT);
275
268
PackRegSize = 64 ;
276
269
break ;
270
+
271
+ case MVT::v8i32: // <4 x i32x2>
272
+ if (!CanLowerTo256Bit)
273
+ return std::nullopt ;
274
+ LLVM_FALLTHROUGH;
275
+ case MVT::v2i32: // <1 x i32x2>
276
+ case MVT::v4i32: // <2 x i32x2>
277
+ if (!STI.hasF32x2Instructions ())
278
+ return std::pair (NumElts, EltVT);
279
+ PackRegSize = 64 ;
280
+ break ;
277
281
}
278
282
279
283
// If we reach here, then we can pack 2 or more elements into a single 32-bit
@@ -590,8 +594,10 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
590
594
addRegisterClass (MVT::bf16 , &NVPTX::B16RegClass);
591
595
addRegisterClass (MVT::v2bf16, &NVPTX::B32RegClass);
592
596
593
- if (STI.hasF32x2Instructions ())
597
+ if (STI.hasF32x2Instructions ()) {
594
598
addRegisterClass (MVT::v2f32, &NVPTX::B64RegClass);
599
+ addRegisterClass (MVT::v2i32, &NVPTX::B64RegClass);
600
+ }
595
601
596
602
// Conversion to/from FP16/FP16x2 is always legal.
597
603
setOperationAction (ISD::BUILD_VECTOR, MVT::v2f16, Custom);
@@ -628,12 +634,13 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
628
634
setOperationAction (ISD::INSERT_VECTOR_ELT, MVT::v4i8, Custom);
629
635
setOperationAction (ISD::VECTOR_SHUFFLE, MVT::v4i8, Custom);
630
636
631
- // No support for these operations with v2f32.
632
- setOperationAction (ISD::INSERT_VECTOR_ELT, MVT::v2f32, Expand);
633
- setOperationAction (ISD::VECTOR_SHUFFLE, MVT::v2f32, Expand);
637
+ // No support for these operations with v2f32/v2i32
638
+ setOperationAction (ISD::INSERT_VECTOR_ELT, { MVT::v2f32, MVT::v2i32} , Expand);
639
+ setOperationAction (ISD::VECTOR_SHUFFLE, { MVT::v2f32, MVT::v2i32} , Expand);
634
640
// Need custom lowering in case the index is dynamic.
635
641
if (STI.hasF32x2Instructions ())
636
- setOperationAction (ISD::EXTRACT_VECTOR_ELT, MVT::v2f32, Custom);
642
+ setOperationAction (ISD::EXTRACT_VECTOR_ELT, {MVT::v2f32, MVT::v2i32},
643
+ Custom);
637
644
638
645
// Custom conversions to/from v2i8.
639
646
setOperationAction (ISD::BITCAST, MVT::v2i8, Custom);
@@ -661,14 +668,13 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
661
668
// Operations not directly supported by NVPTX.
662
669
for (MVT VT : {MVT::bf16 , MVT::f16 , MVT::v2bf16, MVT::v2f16, MVT::f32 ,
663
670
MVT::v2f32, MVT::f64 , MVT::i1, MVT::i8 , MVT::i16 , MVT::v2i16,
664
- MVT::v4i8, MVT::i32 , MVT::i64 }) {
671
+ MVT::v4i8, MVT::i32 , MVT::v2i32, MVT:: i64 }) {
665
672
setOperationAction (ISD::SELECT_CC, VT, Expand);
666
673
setOperationAction (ISD::BR_CC, VT, Expand);
667
674
}
668
675
669
- // Not directly supported. TLI would attempt to expand operations like
670
- // FMINIMUM(v2f32) using invalid SETCC and VSELECT nodes.
671
- setOperationAction (ISD::VSELECT, MVT::v2f32, Expand);
676
+ // We don't want ops like FMINIMUM or UMAX to be lowered to SETCC+VSELECT.
677
+ setOperationAction (ISD::VSELECT, {MVT::v2f32, MVT::v2i32}, Expand);
672
678
673
679
// Some SIGN_EXTEND_INREG can be done using cvt instruction.
674
680
// For others we will expand to a SHL/SRA pair.
@@ -815,7 +821,14 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
815
821
setOperationAction ({ISD::SDIV, ISD::UDIV, ISD::SRA, ISD::SRL, ISD::MULHS,
816
822
ISD::MULHU, ISD::FP_TO_SINT, ISD::FP_TO_UINT,
817
823
ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::SETCC},
818
- MVT::v2i16, Expand);
824
+ {MVT::v2i16, MVT::v2i32}, Expand);
825
+
826
+ // v2i32 is not supported for any arithmetic operations
827
+ setOperationAction ({ISD::ABS, ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX,
828
+ ISD::CTPOP, ISD::CTLZ, ISD::ADD, ISD::SUB, ISD::MUL,
829
+ ISD::SHL, ISD::SRA, ISD::SRL, ISD::OR, ISD::AND, ISD::XOR,
830
+ ISD::SREM, ISD::UREM},
831
+ MVT::v2i32, Expand);
819
832
820
833
setOperationAction (ISD::ADDC, MVT::i32 , Legal);
821
834
setOperationAction (ISD::ADDE, MVT::i32 , Legal);
@@ -829,7 +842,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
829
842
}
830
843
831
844
setOperationAction (ISD::CTTZ, MVT::i16 , Expand);
832
- setOperationAction (ISD::CTTZ, MVT::v2i16, Expand);
845
+ setOperationAction (ISD::CTTZ, { MVT::v2i16, MVT::v2i32} , Expand);
833
846
setOperationAction (ISD::CTTZ, MVT::i32 , Expand);
834
847
setOperationAction (ISD::CTTZ, MVT::i64 , Expand);
835
848
@@ -5673,7 +5686,7 @@ static SDValue PerformEXTRACTCombine(SDNode *N,
5673
5686
IsPTXVectorType (VectorVT.getSimpleVT ()))
5674
5687
return SDValue (); // Native vector loads already combine nicely w/
5675
5688
// extract_vector_elt.
5676
- // Don't mess with singletons or packed types (v2f32 , v2*16, v4i8 and v8i8),
5689
+ // Don't mess with singletons or packed types (v2*32 , v2*16, v4i8 and v8i8),
5677
5690
// we already handle them OK.
5678
5691
if (VectorVT.getVectorNumElements () == 1 ||
5679
5692
NVPTX::isPackedVectorTy (VectorVT) || VectorVT == MVT::v8i8)
0 commit comments