Skip to content

Commit 2faf49b

Browse files
committed
[NVPTX] legalize v2i32 to improve codegen of v2f32 ops
Since v2f32 is legal but v2i32 is not, this causes some sequences of operations like bitcast (build_vector) to be lowered inefficiently.
1 parent 4ab8dab commit 2faf49b

13 files changed

+894
-504
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,6 +1018,7 @@ pickOpcodeForVT(MVT::SimpleValueType VT, std::optional<unsigned> Opcode_i16,
10181018
case MVT::f32:
10191019
return Opcode_i32;
10201020
case MVT::v2f32:
1021+
case MVT::v2i32:
10211022
case MVT::i64:
10221023
case MVT::f64:
10231024
return Opcode_i64;

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -228,17 +228,9 @@ getVectorLoweringShape(EVT VectorEVT, const NVPTXSubtarget &STI,
228228
return std::nullopt;
229229
case MVT::v4i64:
230230
case MVT::v4f64:
231-
case MVT::v8i32:
232-
// This is a "native" vector type iff the address space is global
233-
// and the target supports 256-bit loads/stores
234-
if (!CanLowerTo256Bit)
235-
return std::nullopt;
236-
LLVM_FALLTHROUGH;
237231
case MVT::v2i8:
238-
case MVT::v2i32:
239232
case MVT::v2i64:
240233
case MVT::v2f64:
241-
case MVT::v4i32:
242234
// This is a "native" vector type
243235
return std::pair(NumElts, EltVT);
244236
case MVT::v16f16: // <8 x f16x2>
@@ -264,6 +256,7 @@ getVectorLoweringShape(EVT VectorEVT, const NVPTXSubtarget &STI,
264256
case MVT::v16i8: // <4 x i8x4>
265257
PackRegSize = 32;
266258
break;
259+
267260
case MVT::v8f32: // <4 x f32x2>
268261
if (!CanLowerTo256Bit)
269262
return std::nullopt;
@@ -274,6 +267,17 @@ getVectorLoweringShape(EVT VectorEVT, const NVPTXSubtarget &STI,
274267
return std::pair(NumElts, EltVT);
275268
PackRegSize = 64;
276269
break;
270+
271+
case MVT::v8i32: // <4 x i32x2>
272+
if (!CanLowerTo256Bit)
273+
return std::nullopt;
274+
LLVM_FALLTHROUGH;
275+
case MVT::v2i32: // <1 x i32x2>
276+
case MVT::v4i32: // <2 x i32x2>
277+
if (!STI.hasF32x2Instructions())
278+
return std::pair(NumElts, EltVT);
279+
PackRegSize = 64;
280+
break;
277281
}
278282

279283
// If we reach here, then we can pack 2 or more elements into a single 32-bit
@@ -590,8 +594,10 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
590594
addRegisterClass(MVT::bf16, &NVPTX::B16RegClass);
591595
addRegisterClass(MVT::v2bf16, &NVPTX::B32RegClass);
592596

593-
if (STI.hasF32x2Instructions())
597+
if (STI.hasF32x2Instructions()) {
594598
addRegisterClass(MVT::v2f32, &NVPTX::B64RegClass);
599+
addRegisterClass(MVT::v2i32, &NVPTX::B64RegClass);
600+
}
595601

596602
// Conversion to/from FP16/FP16x2 is always legal.
597603
setOperationAction(ISD::BUILD_VECTOR, MVT::v2f16, Custom);
@@ -628,12 +634,13 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
628634
setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v4i8, Custom);
629635
setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v4i8, Custom);
630636

631-
// No support for these operations with v2f32.
632-
setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v2f32, Expand);
633-
setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v2f32, Expand);
637+
// No support for these operations with v2f32/v2i32
638+
setOperationAction(ISD::INSERT_VECTOR_ELT, {MVT::v2f32, MVT::v2i32}, Expand);
639+
setOperationAction(ISD::VECTOR_SHUFFLE, {MVT::v2f32, MVT::v2i32}, Expand);
634640
// Need custom lowering in case the index is dynamic.
635641
if (STI.hasF32x2Instructions())
636-
setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v2f32, Custom);
642+
setOperationAction(ISD::EXTRACT_VECTOR_ELT, {MVT::v2f32, MVT::v2i32},
643+
Custom);
637644

638645
// Custom conversions to/from v2i8.
639646
setOperationAction(ISD::BITCAST, MVT::v2i8, Custom);
@@ -661,14 +668,13 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
661668
// Operations not directly supported by NVPTX.
662669
for (MVT VT : {MVT::bf16, MVT::f16, MVT::v2bf16, MVT::v2f16, MVT::f32,
663670
MVT::v2f32, MVT::f64, MVT::i1, MVT::i8, MVT::i16, MVT::v2i16,
664-
MVT::v4i8, MVT::i32, MVT::i64}) {
671+
MVT::v4i8, MVT::i32, MVT::v2i32, MVT::i64}) {
665672
setOperationAction(ISD::SELECT_CC, VT, Expand);
666673
setOperationAction(ISD::BR_CC, VT, Expand);
667674
}
668675

669-
// Not directly supported. TLI would attempt to expand operations like
670-
// FMINIMUM(v2f32) using invalid SETCC and VSELECT nodes.
671-
setOperationAction(ISD::VSELECT, MVT::v2f32, Expand);
676+
// We don't want ops like FMINIMUM or UMAX to be lowered to SETCC+VSELECT.
677+
setOperationAction(ISD::VSELECT, {MVT::v2f32, MVT::v2i32}, Expand);
672678

673679
// Some SIGN_EXTEND_INREG can be done using cvt instruction.
674680
// For others we will expand to a SHL/SRA pair.
@@ -815,7 +821,14 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
815821
setOperationAction({ISD::SDIV, ISD::UDIV, ISD::SRA, ISD::SRL, ISD::MULHS,
816822
ISD::MULHU, ISD::FP_TO_SINT, ISD::FP_TO_UINT,
817823
ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::SETCC},
818-
MVT::v2i16, Expand);
824+
{MVT::v2i16, MVT::v2i32}, Expand);
825+
826+
// v2i32 is not supported for any arithmetic operations
827+
setOperationAction({ISD::ABS, ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX,
828+
ISD::CTPOP, ISD::CTLZ, ISD::ADD, ISD::SUB, ISD::MUL,
829+
ISD::SHL, ISD::SRA, ISD::SRL, ISD::OR, ISD::AND, ISD::XOR,
830+
ISD::SREM, ISD::UREM},
831+
MVT::v2i32, Expand);
819832

820833
setOperationAction(ISD::ADDC, MVT::i32, Legal);
821834
setOperationAction(ISD::ADDE, MVT::i32, Legal);
@@ -829,7 +842,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
829842
}
830843

831844
setOperationAction(ISD::CTTZ, MVT::i16, Expand);
832-
setOperationAction(ISD::CTTZ, MVT::v2i16, Expand);
845+
setOperationAction(ISD::CTTZ, {MVT::v2i16, MVT::v2i32}, Expand);
833846
setOperationAction(ISD::CTTZ, MVT::i32, Expand);
834847
setOperationAction(ISD::CTTZ, MVT::i64, Expand);
835848

@@ -5673,7 +5686,7 @@ static SDValue PerformEXTRACTCombine(SDNode *N,
56735686
IsPTXVectorType(VectorVT.getSimpleVT()))
56745687
return SDValue(); // Native vector loads already combine nicely w/
56755688
// extract_vector_elt.
5676-
// Don't mess with singletons or packed types (v2f32, v2*16, v4i8 and v8i8),
5689+
// Don't mess with singletons or packed types (v2*32, v2*16, v4i8 and v8i8),
56775690
// we already handle them OK.
56785691
if (VectorVT.getVectorNumElements() == 1 ||
56795692
NVPTX::isPackedVectorTy(VectorVT) || VectorVT == MVT::v8i8)

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -754,8 +754,10 @@ def : Pat<(vt (select i1:$p, vt:$a, vt:$b)),
754754
(SELP_b32rr $a, $b, $p)>;
755755
}
756756

757-
def : Pat<(v2f32 (select i1:$p, v2f32:$a, v2f32:$b)),
757+
foreach vt = [v2f32, v2i32] in {
758+
def : Pat<(vt (select i1:$p, vt:$a, vt:$b)),
758759
(SELP_b64rr $a, $b, $p)>;
760+
}
759761

760762
//-----------------------------------
761763
// Test Instructions
@@ -2092,8 +2094,8 @@ foreach vt = [v2f16, v2bf16, v2i16] in {
20922094
(V2I16toI32 $a, $b)>;
20932095
}
20942096

2095-
// Same thing for the 64-bit type v2f32.
2096-
foreach vt = [v2f32] in {
2097+
// Handle extracting one element from the pair (64-bit types)
2098+
foreach vt = [v2f32, v2i32] in {
20972099
def : Pat<(extractelt vt:$src, 0), (I64toI32L_Sink $src)>, Requires<[hasPTX<71>]>;
20982100
def : Pat<(extractelt vt:$src, 1), (I64toI32H_Sink $src)>, Requires<[hasPTX<71>]>;
20992101

llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ def B16 : NVPTXRegClass<[i16, f16, bf16], 16, (add (sequence "RS%u", 0, 4))>;
5454
def B32 : NVPTXRegClass<[i32, v2f16, v2bf16, v2i16, v4i8, f32], 32,
5555
(add (sequence "R%u", 0, 4),
5656
VRFrame32, VRFrameLocal32)>;
57-
def B64 : NVPTXRegClass<[i64, v2f32, f64], 64, (add (sequence "RL%u", 0, 4),
57+
def B64 : NVPTXRegClass<[i64, v2i32, v2f32, f64], 64,
58+
(add (sequence "RL%u", 0, 4),
5859
VRFrame64, VRFrameLocal64)>;
5960
// 128-bit regs are not defined as general regs in NVPTX. They are used for inlineASM only.
6061
def B128 : NVPTXRegClass<[i128], 128, (add (sequence "RQ%u", 0, 4))>;

llvm/lib/Target/NVPTX/NVPTXUtilities.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@ namespace NVPTX {
9999
// register. NOTE: This must be kept in sync with the register classes
100100
// defined in NVPTXRegisterInfo.td.
101101
inline auto packed_types() {
102-
static const auto PackedTypes = {MVT::v4i8, MVT::v2f16, MVT::v2bf16,
103-
MVT::v2i16, MVT::v2f32};
102+
static const auto PackedTypes = {MVT::v4i8, MVT::v2f16, MVT::v2bf16,
103+
MVT::v2i16, MVT::v2f32, MVT::v2i32};
104104
return PackedTypes;
105105
}
106106

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc < %s -mcpu=sm_90a -O0 -disable-post-ra -frame-pointer=all \
3+
; RUN: -verify-machineinstrs | FileCheck --check-prefixes=CHECK,CHECK-SM90A %s
4+
; RUN: %if ptxas-12.7 %{ \
5+
; RUN: llc < %s -mcpu=sm_90a -O0 -disable-post-ra -frame-pointer=all \
6+
; RUN: -verify-machineinstrs | %ptxas-verify -arch=sm_90a \
7+
; RUN: %}
8+
; RUN: llc < %s -mcpu=sm_100 -O0 -disable-post-ra -frame-pointer=all \
9+
; RUN: -verify-machineinstrs | FileCheck --check-prefixes=CHECK,CHECK-SM100 %s
10+
; RUN: %if ptxas-12.7 %{ \
11+
; RUN: llc < %s -mcpu=sm_100 -O0 -disable-post-ra -frame-pointer=all \
12+
; RUN: -verify-machineinstrs | %ptxas-verify -arch=sm_100 \
13+
; RUN: %}
14+
15+
; Test that v2i32 -> v2f32 conversions don't emit bitwise operations on i64.
16+
17+
target triple = "nvptx64-nvidia-cuda"
18+
19+
declare <2 x i32> @return_i32x2(i32 %0)
20+
21+
; Test with v2i32.
22+
define ptx_kernel void @store_i32x2(i32 %0, ptr %p) {
23+
; CHECK-SM90A-LABEL: store_i32x2(
24+
; CHECK-SM90A: {
25+
; CHECK-SM90A-NEXT: .reg .b32 %r<6>;
26+
; CHECK-SM90A-NEXT: .reg .b64 %rd<2>;
27+
; CHECK-SM90A-EMPTY:
28+
; CHECK-SM90A-NEXT: // %bb.0:
29+
; CHECK-SM90A-NEXT: ld.param.b64 %rd1, [store_i32x2_param_1];
30+
; CHECK-SM90A-NEXT: ld.param.b32 %r1, [store_i32x2_param_0];
31+
; CHECK-SM90A-NEXT: { // callseq 0, 0
32+
; CHECK-SM90A-NEXT: .param .b32 param0;
33+
; CHECK-SM90A-NEXT: .param .align 8 .b8 retval0[8];
34+
; CHECK-SM90A-NEXT: st.param.b32 [param0], %r1;
35+
; CHECK-SM90A-NEXT: call.uni (retval0), return_i32x2, (param0);
36+
; CHECK-SM90A-NEXT: ld.param.v2.b32 {%r2, %r3}, [retval0];
37+
; CHECK-SM90A-NEXT: } // callseq 0
38+
; CHECK-SM90A-NEXT: add.rn.f32 %r4, %r3, %r3;
39+
; CHECK-SM90A-NEXT: add.rn.f32 %r5, %r2, %r2;
40+
; CHECK-SM90A-NEXT: st.v2.b32 [%rd1], {%r5, %r4};
41+
; CHECK-SM90A-NEXT: ret;
42+
;
43+
; CHECK-SM100-LABEL: store_i32x2(
44+
; CHECK-SM100: {
45+
; CHECK-SM100-NEXT: .reg .b32 %r<2>;
46+
; CHECK-SM100-NEXT: .reg .b64 %rd<4>;
47+
; CHECK-SM100-EMPTY:
48+
; CHECK-SM100-NEXT: // %bb.0:
49+
; CHECK-SM100-NEXT: ld.param.b64 %rd1, [store_i32x2_param_1];
50+
; CHECK-SM100-NEXT: ld.param.b32 %r1, [store_i32x2_param_0];
51+
; CHECK-SM100-NEXT: { // callseq 0, 0
52+
; CHECK-SM100-NEXT: .param .b32 param0;
53+
; CHECK-SM100-NEXT: .param .align 8 .b8 retval0[8];
54+
; CHECK-SM100-NEXT: st.param.b32 [param0], %r1;
55+
; CHECK-SM100-NEXT: call.uni (retval0), return_i32x2, (param0);
56+
; CHECK-SM100-NEXT: ld.param.b64 %rd2, [retval0];
57+
; CHECK-SM100-NEXT: } // callseq 0
58+
; CHECK-SM100-NEXT: add.rn.f32x2 %rd3, %rd2, %rd2;
59+
; CHECK-SM100-NEXT: st.b64 [%rd1], %rd3;
60+
; CHECK-SM100-NEXT: ret;
61+
%v = call <2 x i32> @return_i32x2(i32 %0)
62+
%v.f32x2 = bitcast <2 x i32> %v to <2 x float>
63+
%res = fadd <2 x float> %v.f32x2, %v.f32x2
64+
store <2 x float> %res, ptr %p, align 8
65+
ret void
66+
}
67+
68+
; Test with inline ASM returning { <1 x float>, <1 x float> }, which decays to
69+
; v2i32.
70+
define ptx_kernel void @inlineasm(ptr %p) {
71+
; CHECK-SM90A-LABEL: inlineasm(
72+
; CHECK-SM90A: {
73+
; CHECK-SM90A-NEXT: .reg .b32 %r<7>;
74+
; CHECK-SM90A-NEXT: .reg .b64 %rd<2>;
75+
; CHECK-SM90A-EMPTY:
76+
; CHECK-SM90A-NEXT: // %bb.0:
77+
; CHECK-SM90A-NEXT: ld.param.b64 %rd1, [inlineasm_param_0];
78+
; CHECK-SM90A-NEXT: mov.b32 %r3, 0;
79+
; CHECK-SM90A-NEXT: mov.b32 %r4, %r3;
80+
; CHECK-SM90A-NEXT: mov.b32 %r2, %r4;
81+
; CHECK-SM90A-NEXT: mov.b32 %r1, %r3;
82+
; CHECK-SM90A-NEXT: // begin inline asm
83+
; CHECK-SM90A-NEXT: // nop
84+
; CHECK-SM90A-NEXT: // end inline asm
85+
; CHECK-SM90A-NEXT: mul.rn.f32 %r5, %r2, 0f00000000;
86+
; CHECK-SM90A-NEXT: mul.rn.f32 %r6, %r1, 0f00000000;
87+
; CHECK-SM90A-NEXT: st.v2.b32 [%rd1], {%r6, %r5};
88+
; CHECK-SM90A-NEXT: ret;
89+
;
90+
; CHECK-SM100-LABEL: inlineasm(
91+
; CHECK-SM100: {
92+
; CHECK-SM100-NEXT: .reg .b32 %r<6>;
93+
; CHECK-SM100-NEXT: .reg .b64 %rd<5>;
94+
; CHECK-SM100-EMPTY:
95+
; CHECK-SM100-NEXT: // %bb.0:
96+
; CHECK-SM100-NEXT: ld.param.b64 %rd1, [inlineasm_param_0];
97+
; CHECK-SM100-NEXT: mov.b32 %r3, 0;
98+
; CHECK-SM100-NEXT: mov.b32 %r4, %r3;
99+
; CHECK-SM100-NEXT: mov.b32 %r2, %r4;
100+
; CHECK-SM100-NEXT: mov.b32 %r1, %r3;
101+
; CHECK-SM100-NEXT: // begin inline asm
102+
; CHECK-SM100-NEXT: // nop
103+
; CHECK-SM100-NEXT: // end inline asm
104+
; CHECK-SM100-NEXT: mov.b64 %rd2, {%r1, %r2};
105+
; CHECK-SM100-NEXT: mov.b32 %r5, 0f00000000;
106+
; CHECK-SM100-NEXT: mov.b64 %rd3, {%r5, %r5};
107+
; CHECK-SM100-NEXT: mul.rn.f32x2 %rd4, %rd2, %rd3;
108+
; CHECK-SM100-NEXT: st.b64 [%rd1], %rd4;
109+
; CHECK-SM100-NEXT: ret;
110+
%r = call { <1 x float>, <1 x float> } asm sideeffect "// nop", "=f,=f,0,1"(<1 x float> zeroinitializer, <1 x float> zeroinitializer)
111+
%i0 = extractvalue { <1 x float>, <1 x float> } %r, 0
112+
%i1 = extractvalue { <1 x float>, <1 x float> } %r, 1
113+
%i4 = shufflevector <1 x float> %i0, <1 x float> %i1, <2 x i32> <i32 0, i32 1>
114+
%mul = fmul < 2 x float> %i4, zeroinitializer
115+
store <2 x float> %mul, ptr %p, align 8
116+
ret void
117+
}
118+
;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
119+
; CHECK: {{.*}}

0 commit comments

Comments
 (0)