Skip to content

Commit a25862d

Browse files
committed
[NVPTX] legalize v2i32 to improve codegen of v2f32 ops
Since v2f32 is legal but v2i32 is not, this causes some sequences of operations like bitcast (build_vector) to be lowered inefficiently.
1 parent 4ab8dab commit a25862d

12 files changed

+902
-403
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,6 +1018,7 @@ pickOpcodeForVT(MVT::SimpleValueType VT, std::optional<unsigned> Opcode_i16,
10181018
case MVT::f32:
10191019
return Opcode_i32;
10201020
case MVT::v2f32:
1021+
case MVT::v2i32:
10211022
case MVT::i64:
10221023
case MVT::f64:
10231024
return Opcode_i64;

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 115 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -226,21 +226,20 @@ getVectorLoweringShape(EVT VectorEVT, const NVPTXSubtarget &STI,
226226
switch (VectorVT.SimpleTy) {
227227
default:
228228
return std::nullopt;
229+
229230
case MVT::v4i64:
230231
case MVT::v4f64:
231-
case MVT::v8i32:
232-
// This is a "native" vector type iff the address space is global
233-
// and the target supports 256-bit loads/stores
232+
// This is a "native" vector type iff the address space is global and the
233+
// target supports 256-bit loads/stores
234234
if (!CanLowerTo256Bit)
235235
return std::nullopt;
236236
LLVM_FALLTHROUGH;
237237
case MVT::v2i8:
238-
case MVT::v2i32:
239238
case MVT::v2i64:
240239
case MVT::v2f64:
241-
case MVT::v4i32:
242240
// This is a "native" vector type
243241
return std::pair(NumElts, EltVT);
242+
244243
case MVT::v16f16: // <8 x f16x2>
245244
case MVT::v16bf16: // <8 x bf16x2>
246245
case MVT::v16i16: // <8 x i16x2>
@@ -264,12 +263,18 @@ getVectorLoweringShape(EVT VectorEVT, const NVPTXSubtarget &STI,
264263
case MVT::v16i8: // <4 x i8x4>
265264
PackRegSize = 32;
266265
break;
267-
case MVT::v8f32: // <4 x f32x2>
266+
267+
case MVT::v8f32: // <4 x f32x2>
268+
case MVT::v8i32: // <4 x i32x2>
269+
// This is a "native" vector type iff the address space is global and the
270+
// target supports 256-bit loads/stores
268271
if (!CanLowerTo256Bit)
269272
return std::nullopt;
270273
LLVM_FALLTHROUGH;
271-
case MVT::v2f32: // <1 x f32x2>
272-
case MVT::v4f32: // <2 x f32x2>
274+
case MVT::v2f32: // <1 x f32x2>
275+
case MVT::v4f32: // <2 x f32x2>
276+
case MVT::v2i32: // <1 x i32x2>
277+
case MVT::v4i32: // <2 x i32x2>
273278
if (!STI.hasF32x2Instructions())
274279
return std::pair(NumElts, EltVT);
275280
PackRegSize = 64;
@@ -590,8 +595,10 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
590595
addRegisterClass(MVT::bf16, &NVPTX::B16RegClass);
591596
addRegisterClass(MVT::v2bf16, &NVPTX::B32RegClass);
592597

593-
if (STI.hasF32x2Instructions())
598+
if (STI.hasF32x2Instructions()) {
594599
addRegisterClass(MVT::v2f32, &NVPTX::B64RegClass);
600+
addRegisterClass(MVT::v2i32, &NVPTX::B64RegClass);
601+
}
595602

596603
// Conversion to/from FP16/FP16x2 is always legal.
597604
setOperationAction(ISD::BUILD_VECTOR, MVT::v2f16, Custom);
@@ -628,12 +635,13 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
628635
setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v4i8, Custom);
629636
setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v4i8, Custom);
630637

631-
// No support for these operations with v2f32.
632-
setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v2f32, Expand);
633-
setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v2f32, Expand);
638+
// No support for these operations with v2f32/v2i32
639+
setOperationAction(ISD::INSERT_VECTOR_ELT, {MVT::v2f32, MVT::v2i32}, Expand);
640+
setOperationAction(ISD::VECTOR_SHUFFLE, {MVT::v2f32, MVT::v2i32}, Expand);
634641
// Need custom lowering in case the index is dynamic.
635642
if (STI.hasF32x2Instructions())
636-
setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v2f32, Custom);
643+
setOperationAction(ISD::EXTRACT_VECTOR_ELT, {MVT::v2f32, MVT::v2i32},
644+
Custom);
637645

638646
// Custom conversions to/from v2i8.
639647
setOperationAction(ISD::BITCAST, MVT::v2i8, Custom);
@@ -661,14 +669,13 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
661669
// Operations not directly supported by NVPTX.
662670
for (MVT VT : {MVT::bf16, MVT::f16, MVT::v2bf16, MVT::v2f16, MVT::f32,
663671
MVT::v2f32, MVT::f64, MVT::i1, MVT::i8, MVT::i16, MVT::v2i16,
664-
MVT::v4i8, MVT::i32, MVT::i64}) {
672+
MVT::v4i8, MVT::i32, MVT::v2i32, MVT::i64}) {
665673
setOperationAction(ISD::SELECT_CC, VT, Expand);
666674
setOperationAction(ISD::BR_CC, VT, Expand);
667675
}
668676

669-
// Not directly supported. TLI would attempt to expand operations like
670-
// FMINIMUM(v2f32) using invalid SETCC and VSELECT nodes.
671-
setOperationAction(ISD::VSELECT, MVT::v2f32, Expand);
677+
// We don't want ops like FMINIMUM or UMAX to be lowered to SETCC+VSELECT.
678+
setOperationAction(ISD::VSELECT, {MVT::v2f32, MVT::v2i32}, Expand);
672679

673680
// Some SIGN_EXTEND_INREG can be done using cvt instruction.
674681
// For others we will expand to a SHL/SRA pair.
@@ -815,7 +822,14 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
815822
setOperationAction({ISD::SDIV, ISD::UDIV, ISD::SRA, ISD::SRL, ISD::MULHS,
816823
ISD::MULHU, ISD::FP_TO_SINT, ISD::FP_TO_UINT,
817824
ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::SETCC},
818-
MVT::v2i16, Expand);
825+
{MVT::v2i16, MVT::v2i32}, Expand);
826+
827+
// v2i32 is not supported for any arithmetic operations
828+
setOperationAction({ISD::ABS, ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX,
829+
ISD::CTPOP, ISD::CTLZ, ISD::ADD, ISD::SUB, ISD::MUL,
830+
ISD::SHL, ISD::SRA, ISD::SRL, ISD::OR, ISD::AND, ISD::XOR,
831+
ISD::SREM, ISD::UREM},
832+
MVT::v2i32, Expand);
819833

820834
setOperationAction(ISD::ADDC, MVT::i32, Legal);
821835
setOperationAction(ISD::ADDE, MVT::i32, Legal);
@@ -829,7 +843,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
829843
}
830844

831845
setOperationAction(ISD::CTTZ, MVT::i16, Expand);
832-
setOperationAction(ISD::CTTZ, MVT::v2i16, Expand);
846+
setOperationAction(ISD::CTTZ, {MVT::v2i16, MVT::v2i32}, Expand);
833847
setOperationAction(ISD::CTTZ, MVT::i32, Expand);
834848
setOperationAction(ISD::CTTZ, MVT::i64, Expand);
835849

@@ -1067,7 +1081,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
10671081
// Custom lowering for tcgen05.st vector operands
10681082
setOperationAction(ISD::INTRINSIC_VOID,
10691083
{MVT::v2i32, MVT::v4i32, MVT::v8i32, MVT::v16i32,
1070-
MVT::v32i32, MVT::v64i32, MVT::v128i32},
1084+
MVT::v32i32, MVT::v64i32, MVT::v128i32, MVT::Other},
10711085
Custom);
10721086

10731087
// Enable custom lowering for the following:
@@ -2572,7 +2586,7 @@ static SDValue LowerVectorArith(SDValue Op, SelectionDAG &DAG) {
25722586
return V;
25732587
}
25742588

2575-
static SDValue LowerTcgen05St(SDValue Op, SelectionDAG &DAG) {
2589+
static SDValue lowerTcgen05St(SDValue Op, SelectionDAG &DAG) {
25762590
SDNode *N = Op.getNode();
25772591
SDLoc DL(N);
25782592
SmallVector<SDValue, 32> Ops;
@@ -2598,7 +2612,52 @@ static SDValue LowerTcgen05St(SDValue Op, SelectionDAG &DAG) {
25982612
return Tcgen05StNode;
25992613
}
26002614

2601-
static SDValue LowerIntrinsicVoid(SDValue Op, SelectionDAG &DAG) {
2615+
// Lower vector return type of tcgen05.ld intrinsics
2616+
static std::optional<std::pair<SDValue, SDValue>>
2617+
lowerTcgen05Ld(SDNode *N, SelectionDAG &DAG, bool HasOffset = false) {
2618+
SDLoc DL(N);
2619+
EVT ResVT = N->getValueType(0);
2620+
if (!ResVT.isVector())
2621+
return {}; // already legalized.
2622+
2623+
const unsigned NumElts = ResVT.getVectorNumElements();
2624+
2625+
// Create the return type of the instructions
2626+
SmallVector<EVT, 5> ListVTs;
2627+
for (unsigned i = 0; i < NumElts; ++i)
2628+
ListVTs.push_back(MVT::i32);
2629+
2630+
ListVTs.push_back(N->getValueType(1)); // Chain
2631+
2632+
SDVTList ResVTs = DAG.getVTList(ListVTs);
2633+
2634+
SmallVector<SDValue, 8> Ops{N->getOperand(0), N->getOperand(1),
2635+
N->getOperand(2)};
2636+
2637+
if (HasOffset) {
2638+
Ops.push_back(N->getOperand(3)); // offset
2639+
Ops.push_back(N->getOperand(4)); // Pack flag
2640+
} else
2641+
Ops.push_back(N->getOperand(3)); // Pack flag
2642+
2643+
MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
2644+
SDValue NewNode =
2645+
DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, ResVTs, Ops,
2646+
MemSD->getMemoryVT(), MemSD->getMemOperand());
2647+
2648+
// split the vector result
2649+
SmallVector<SDValue, 4> ScalarRes;
2650+
for (unsigned i = 0; i < NumElts; ++i) {
2651+
SDValue Res = NewNode.getValue(i);
2652+
ScalarRes.push_back(Res);
2653+
}
2654+
2655+
SDValue Chain = NewNode.getValue(NumElts);
2656+
SDValue BuildVector = DAG.getNode(ISD::BUILD_VECTOR, DL, ResVT, ScalarRes);
2657+
return {{BuildVector, Chain}};
2658+
}
2659+
2660+
static SDValue lowerIntrinsicVoid(SDValue Op, SelectionDAG &DAG) {
26022661
SDNode *N = Op.getNode();
26032662
SDValue Intrin = N->getOperand(1);
26042663

@@ -2644,7 +2703,7 @@ static SDValue LowerIntrinsicVoid(SDValue Op, SelectionDAG &DAG) {
26442703
case Intrinsic::nvvm_tcgen05_st_16x64b_x64:
26452704
case Intrinsic::nvvm_tcgen05_st_32x32b_x64:
26462705
case Intrinsic::nvvm_tcgen05_st_32x32b_x128:
2647-
return LowerTcgen05St(Op, DAG);
2706+
return lowerTcgen05St(Op, DAG);
26482707
}
26492708
return Op;
26502709
}
@@ -2717,6 +2776,26 @@ static SDValue lowerPrmtIntrinsic(SDValue Op, SelectionDAG &DAG) {
27172776
SDValue Selector = (Op->op_end() - 1)->get();
27182777
return getPRMT(A, B, Selector, DL, DAG, Mode);
27192778
}
2779+
2780+
static SDValue lowerIntrinsicWChain(SDValue Op, SelectionDAG &DAG) {
2781+
switch (Op->getConstantOperandVal(1)) {
2782+
default:
2783+
return Op;
2784+
2785+
case Intrinsic::nvvm_tcgen05_ld_16x64b_x2:
2786+
case Intrinsic::nvvm_tcgen05_ld_16x128b_x1:
2787+
case Intrinsic::nvvm_tcgen05_ld_32x32b_x2:
2788+
if (auto Pair = lowerTcgen05Ld(Op.getNode(), DAG))
2789+
return DAG.getMergeValues({Pair->first, Pair->second}, SDLoc(Op));
2790+
return SDValue();
2791+
2792+
case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x2:
2793+
if (auto Pair = lowerTcgen05Ld(Op.getNode(), DAG, /*HasOffset=*/true))
2794+
return DAG.getMergeValues({Pair->first, Pair->second}, SDLoc(Op));
2795+
return SDValue();
2796+
}
2797+
}
2798+
27202799
static SDValue lowerIntrinsicWOChain(SDValue Op, SelectionDAG &DAG) {
27212800
switch (Op->getConstantOperandVal(0)) {
27222801
default:
@@ -2879,11 +2958,11 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
28792958
case ISD::ADDRSPACECAST:
28802959
return LowerADDRSPACECAST(Op, DAG);
28812960
case ISD::INTRINSIC_W_CHAIN:
2882-
return Op;
2961+
return lowerIntrinsicWChain(Op, DAG);
28832962
case ISD::INTRINSIC_WO_CHAIN:
28842963
return lowerIntrinsicWOChain(Op, DAG);
28852964
case ISD::INTRINSIC_VOID:
2886-
return LowerIntrinsicVoid(Op, DAG);
2965+
return lowerIntrinsicVoid(Op, DAG);
28872966
case ISD::BUILD_VECTOR:
28882967
return LowerBUILD_VECTOR(Op, DAG);
28892968
case ISD::BITCAST:
@@ -5673,7 +5752,7 @@ static SDValue PerformEXTRACTCombine(SDNode *N,
56735752
IsPTXVectorType(VectorVT.getSimpleVT()))
56745753
return SDValue(); // Native vector loads already combine nicely w/
56755754
// extract_vector_elt.
5676-
// Don't mess with singletons or packed types (v2f32, v2*16, v4i8 and v8i8),
5755+
// Don't mess with singletons or packed types (v2*32, v2*16, v4i8 and v8i8),
56775756
// we already handle them OK.
56785757
if (VectorVT.getVectorNumElements() == 1 ||
56795758
NVPTX::isPackedVectorTy(VectorVT) || VectorVT == MVT::v8i8)
@@ -6045,53 +6124,6 @@ static void ReplaceBITCAST(SDNode *Node, SelectionDAG &DAG,
60456124
DAG.getNode(ISD::BUILD_VECTOR, DL, MVT::v2i8, {Vec0, Vec1}));
60466125
}
60476126

6048-
// Lower vector return type of tcgen05.ld intrinsics
6049-
static void ReplaceTcgen05Ld(SDNode *N, SelectionDAG &DAG,
6050-
SmallVectorImpl<SDValue> &Results,
6051-
bool hasOffset = false) {
6052-
SDLoc DL(N);
6053-
EVT ResVT = N->getValueType(0);
6054-
if (!ResVT.isVector())
6055-
return; // already legalized.
6056-
6057-
const unsigned NumElts = ResVT.getVectorNumElements();
6058-
6059-
// Create the return type of the instructions
6060-
SmallVector<EVT, 5> ListVTs;
6061-
for (unsigned i = 0; i < NumElts; ++i)
6062-
ListVTs.push_back(MVT::i32);
6063-
6064-
ListVTs.push_back(N->getValueType(1)); // Chain
6065-
6066-
SDVTList ResVTs = DAG.getVTList(ListVTs);
6067-
6068-
SmallVector<SDValue, 8> Ops{N->getOperand(0), N->getOperand(1),
6069-
N->getOperand(2)};
6070-
6071-
if (hasOffset) {
6072-
Ops.push_back(N->getOperand(3)); // offset
6073-
Ops.push_back(N->getOperand(4)); // Pack flag
6074-
} else
6075-
Ops.push_back(N->getOperand(3)); // Pack flag
6076-
6077-
MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
6078-
SDValue NewNode =
6079-
DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, ResVTs, Ops,
6080-
MemSD->getMemoryVT(), MemSD->getMemOperand());
6081-
6082-
// split the vector result
6083-
SmallVector<SDValue, 4> ScalarRes;
6084-
for (unsigned i = 0; i < NumElts; ++i) {
6085-
SDValue Res = NewNode.getValue(i);
6086-
ScalarRes.push_back(Res);
6087-
}
6088-
6089-
SDValue Chain = NewNode.getValue(NumElts);
6090-
SDValue BuildVector = DAG.getNode(ISD::BUILD_VECTOR, DL, ResVT, ScalarRes);
6091-
Results.push_back(BuildVector); // Build Vector
6092-
Results.push_back(Chain); // Chain
6093-
}
6094-
60956127
static void ReplaceINTRINSIC_W_CHAIN(SDNode *N, SelectionDAG &DAG,
60966128
SmallVectorImpl<SDValue> &Results) {
60976129
SDValue Chain = N->getOperand(0);
@@ -6227,7 +6259,11 @@ static void ReplaceINTRINSIC_W_CHAIN(SDNode *N, SelectionDAG &DAG,
62276259
case Intrinsic::nvvm_tcgen05_ld_16x256b_x8:
62286260
case Intrinsic::nvvm_tcgen05_ld_16x256b_x16:
62296261
case Intrinsic::nvvm_tcgen05_ld_16x256b_x32:
6230-
return ReplaceTcgen05Ld(N, DAG, Results);
6262+
if (auto Pair = lowerTcgen05Ld(N, DAG)) {
6263+
Results.push_back(Pair->first);
6264+
Results.push_back(Pair->second);
6265+
}
6266+
return;
62316267

62326268
case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x2:
62336269
case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x4:
@@ -6236,7 +6272,11 @@ static void ReplaceINTRINSIC_W_CHAIN(SDNode *N, SelectionDAG &DAG,
62366272
case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x32:
62376273
case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x64:
62386274
case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x128:
6239-
return ReplaceTcgen05Ld(N, DAG, Results, /* Offset */ true);
6275+
if (auto Pair = lowerTcgen05Ld(N, DAG, /*HasOffset=*/true)) {
6276+
Results.push_back(Pair->first);
6277+
Results.push_back(Pair->second);
6278+
}
6279+
return;
62406280
}
62416281
}
62426282

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -754,8 +754,10 @@ def : Pat<(vt (select i1:$p, vt:$a, vt:$b)),
754754
(SELP_b32rr $a, $b, $p)>;
755755
}
756756

757-
def : Pat<(v2f32 (select i1:$p, v2f32:$a, v2f32:$b)),
757+
foreach vt = [v2f32, v2i32] in {
758+
def : Pat<(vt (select i1:$p, vt:$a, vt:$b)),
758759
(SELP_b64rr $a, $b, $p)>;
760+
}
759761

760762
//-----------------------------------
761763
// Test Instructions
@@ -2092,8 +2094,8 @@ foreach vt = [v2f16, v2bf16, v2i16] in {
20922094
(V2I16toI32 $a, $b)>;
20932095
}
20942096

2095-
// Same thing for the 64-bit type v2f32.
2096-
foreach vt = [v2f32] in {
2097+
// Handle extracting one element from the pair (64-bit types)
2098+
foreach vt = [v2f32, v2i32] in {
20972099
def : Pat<(extractelt vt:$src, 0), (I64toI32L_Sink $src)>, Requires<[hasPTX<71>]>;
20982100
def : Pat<(extractelt vt:$src, 1), (I64toI32H_Sink $src)>, Requires<[hasPTX<71>]>;
20992101

llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ def B16 : NVPTXRegClass<[i16, f16, bf16], 16, (add (sequence "RS%u", 0, 4))>;
5454
def B32 : NVPTXRegClass<[i32, v2f16, v2bf16, v2i16, v4i8, f32], 32,
5555
(add (sequence "R%u", 0, 4),
5656
VRFrame32, VRFrameLocal32)>;
57-
def B64 : NVPTXRegClass<[i64, v2f32, f64], 64, (add (sequence "RL%u", 0, 4),
57+
def B64 : NVPTXRegClass<[i64, v2i32, v2f32, f64], 64,
58+
(add (sequence "RL%u", 0, 4),
5859
VRFrame64, VRFrameLocal64)>;
5960
// 128-bit regs are not defined as general regs in NVPTX. They are used for inlineASM only.
6061
def B128 : NVPTXRegClass<[i128], 128, (add (sequence "RQ%u", 0, 4))>;

llvm/lib/Target/NVPTX/NVPTXUtilities.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@ namespace NVPTX {
9999
// register. NOTE: This must be kept in sync with the register classes
100100
// defined in NVPTXRegisterInfo.td.
101101
inline auto packed_types() {
102-
static const auto PackedTypes = {MVT::v4i8, MVT::v2f16, MVT::v2bf16,
103-
MVT::v2i16, MVT::v2f32};
102+
static const auto PackedTypes = {MVT::v4i8, MVT::v2f16, MVT::v2bf16,
103+
MVT::v2i16, MVT::v2f32, MVT::v2i32};
104104
return PackedTypes;
105105
}
106106

0 commit comments

Comments
 (0)