Skip to content

Commit 384d4fc

Browse files
DefTruthyanjun.qiu
andauthored
feat: add triton fused-softmax (xlite-dev#301)
Co-authored-by: yanjun.qiu <[email protected]>
1 parent d5ade54 commit 384d4fc

File tree

7 files changed

+977
-1
lines changed

7 files changed

+977
-1
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,7 @@ The kernels listed here will guide you through a step-by-step progression, rangi
469469
|📖 Triton Kernel| 📖 Elem DType| 📖 Acc DType| 📖 Docs | 📖 Level |
470470
|:---|:---|:---|:---|:---|
471471
| ✔️ [triton_vector_add_kernel](./kernels/openai-triton/elementwise/)|all|all|[link](./kernels/openai-triton/elementwise/)|⭐️⭐️|
472+
| ✔️ [triton_fused_softmax(multi-stages)](./kernels/openai-triton/fused-softmax/)|f16/bf16/f32|f32|[link](./kernels/openai-triton/fused-softmax//)|⭐️⭐️⭐️|
472473
| ✔️ [triton_merge_attn_states_kernel(w/ CUDA)](./kernels/openai-triton/merge-attn-states/)|f16/bf16/f32|f32|[link](./kernels/openai-triton/merge-attn-states/)|⭐️⭐️⭐️|
473474

474475

docs/.gitignore

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
*.so
2+
*.a
3+
*.dylib
4+
*.dll
5+
*.lib
6+
.DS_Store
7+
build
8+
*.whl
9+
tmp
10+
__pycache__
11+
*.onnx
12+
*.engine
13+
*.pt
14+
*.pth
15+
*.nsys*
16+
*.ncu*
17+
*.sqlite*
18+
*.engine
19+
*.bin
20+
outupt
21+
bin
22+
*.log
23+
*.txt
24+
*.tex
25+
__pycache__
26+
pdfs
27+
build*

kernels/openai-triton/softmax/.gitignore renamed to kernels/openai-triton/fused-softmax/.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,7 @@ tmp
1717
*.nsys
1818
*.nvvp
1919
*.nsys*
20-
*.sqlite
20+
*.sqlite
21+
*.csv
22+
*.html
23+
cache
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
# Triton Fused Softmax
2+
4x speedup vs naive_softmax, support multi-stages.
3+
4+
## Install
5+
6+
```bash
7+
python3 -m pip install -r requirements.txt
8+
```
9+
10+
## Benchmark
11+
12+
```bash
13+
python3 triton_fused_softmax.py
14+
```
15+
16+
performance plot (Higher GB/s is better)
17+
18+
![](./softmax-performance.png)
19+
20+
sofmax performance data:
21+
22+
```bash
23+
python3 triton_fused_softmax.py
24+
NUM_SM: 92, NUM_REGS: 65536, SIZE_SMEM: 49152, WARP_SIZE: 32
25+
softmax-performance: (GB/s)
26+
M Triton Fused Softmax Torch Fused Softmax Torch Naive Softmax
27+
0 256.0 362.393505 305.358765 124.678148
28+
1 512.0 445.778519 405.032092 194.896749
29+
2 768.0 514.098122 438.300277 250.346181
30+
3 1024.0 535.408538 484.151615 282.482734
31+
4 1280.0 569.918518 495.302961 298.126599
32+
5 1536.0 548.277633 520.750543 314.245872
33+
6 1792.0 533.588304 522.417207 316.189641
34+
7 2048.0 567.566790 507.963708 323.995917
35+
8 2304.0 538.947392 534.294155 330.039737
36+
9 2560.0 563.100752 541.798988 331.690770
37+
10 2816.0 566.233438 550.942414 331.748259
38+
11 3072.0 567.218890 554.429254 331.566886
39+
12 3328.0 555.099312 551.214055 329.411258
40+
13 3584.0 572.163959 557.957312 317.685035
41+
14 3840.0 575.175853 548.535608 322.132599
42+
15 4096.0 559.898472 575.041421 313.671864
43+
16 4352.0 580.166989 573.423500 309.974206
44+
17 4608.0 593.755089 570.839527 302.986224
45+
18 4864.0 575.728831 559.051306 302.688434
46+
19 5120.0 575.766222 572.413610 301.564769
47+
20 5376.0 572.423563 573.764541 296.201227
48+
21 5632.0 575.843562 580.459896 291.670395
49+
22 5888.0 579.751142 576.305867 288.366253
50+
23 6144.0 582.716775 579.061272 281.005347
51+
24 6400.0 579.270649 589.253228 274.376955
52+
25 6656.0 585.002711 582.368493 270.766555
53+
26 6912.0 587.171336 584.960726 263.143868
54+
27 7168.0 587.587015 587.994311 253.358858
55+
28 7424.0 589.017087 588.470371 251.164961
56+
29 7680.0 581.928985 586.436541 245.850602
57+
30 7936.0 585.914317 602.394388 236.971384
58+
31 8192.0 591.928478 590.293530 232.782519
59+
32 8448.0 593.058718 593.496120 221.579701
60+
33 8704.0 594.695425 595.261556 216.494918
61+
34 8960.0 591.452098 591.325040 212.167842
62+
35 9216.0 591.788325 593.896863 205.272973
63+
36 9472.0 594.511034 598.075544 196.612392
64+
37 9728.0 593.693925 599.458087 195.754667
65+
38 9984.0 598.928478 600.050556 189.836630
66+
39 10240.0 591.822549 600.587319 186.929136
67+
40 10496.0 593.427400 604.631795 185.559242
68+
41 10752.0 598.165555 599.295332 184.250498
69+
42 11008.0 600.996379 603.951158 183.771913
70+
43 11264.0 601.199174 605.015824 180.982734
71+
44 11520.0 601.669498 605.464543 180.641048
72+
45 11776.0 600.274667 609.133846 180.645646
73+
46 12032.0 603.406759 607.627874 178.938962
74+
47 12288.0 603.722180 606.700927 177.688192
75+
48 12544.0 601.190215 611.649160 176.379854
76+
49 12800.0 603.581671 610.237343 176.527952
77+
50 13056.0 604.757733 610.384488 175.390770
78+
51 13312.0 603.985819 615.509130 174.097735
79+
52 13568.0 605.297037 612.408629 173.256892
80+
53 13824.0 605.908459 614.069065 174.128498
81+
54 14080.0 607.920747 613.264977 173.922682
82+
55 14336.0 607.635887 618.994853 173.581517
83+
56 14592.0 607.396226 617.410641 174.207748
84+
57 14848.0 609.851532 619.969111 172.837354
85+
58 15104.0 609.737443 619.281199 172.619360
86+
59 15360.0 614.387454 619.042827 173.398807
87+
60 15616.0 612.544618 622.329807 171.825113
88+
61 15872.0 609.665642 621.071773 172.623526
89+
62 16128.0 610.741546 625.835582 172.361260
90+
```
91+
## PTX gen code
92+
check [softmax_kernel.ptx](./softmax_kernel.ptx) for more details.
93+
94+
```NASM
95+
// begin inline asm
96+
@%p2 cp.async.ca.shared.global [ %r82 + 0 ], [ %rd4 + 0 ], 0x4, %r34;
97+
// end inline asm
98+
selp.b32 %r36, %r45, 0, %p7;
99+
// begin inline asm
100+
@%p2 cp.async.ca.shared.global [ %r84 + 0 ], [ %rd5 + 0 ], 0x4, %r36;
101+
// end inline asm
102+
selp.b32 %r38, %r45, 0, %p8;
103+
// begin inline asm
104+
@%p2 cp.async.ca.shared.global [ %r86 + 0 ], [ %rd6 + 0 ], 0x4, %r38;
105+
// end inline asm
106+
selp.b32 %r40, %r45, 0, %p9;
107+
// begin inline asm
108+
@%p2 cp.async.ca.shared.global [ %r88 + 0 ], [ %rd7 + 0 ], 0x4, %r40;
109+
// end inline asm
110+
// begin inline asm
111+
cp.async.commit_group ;
112+
// end inline asm
113+
.loc 1 44 57 // triton_fused_softmax.py:44:57
114+
@%p10 bra $L__BB0_3;
115+
// %bb.1: // %.lr.ph
116+
.loc 1 0 57 // triton_fused_softmax.py:0:57
117+
ld.param.u32 %r28, [softmax_kernel_param_3];
118+
ld.param.u64 %rd2, [softmax_kernel_param_0];
119+
// begin inline asm
120+
mov.u32 %r32, %nctaid.x;
121+
// end inline asm
122+
cvt.u64.u32 %rd1, %r41;
123+
.loc 1 49 35 // triton_fused_softmax.py:49:35
124+
and.b32 %r11, %r3, 31;
125+
sub.s32 %r12, %r29, %r32;
126+
add.s32 %r48, %r44, 4096;
127+
shr.u32 %r49, %r3, 3;
128+
and.b32 %r50, %r49, 28;
129+
add.s32 %r54, %r48, %r50;
130+
setp.lt.s32 %p15, %r3, 8;
131+
shl.b32 %r51, %r3, 2;
132+
add.s32 %r57, %r48, %r51;
133+
and.b32 %r52, %r3, 7;
134+
setp.eq.s32 %p13, %r52, 0;
135+
and.pred %p16, %p15, %p13;
136+
.loc 1 44 57 // triton_fused_softmax.py:44:57
137+
add.s32 %r53, %r131, %r32;
138+
mul.lo.s32 %r129, %r27, %r53;
139+
mul.lo.s32 %r16, %r32, %r27;
140+
mul.lo.s32 %r128, %r131, %r28;
141+
mul.lo.s32 %r18, %r32, %r28;
142+
mov.b32 %r130, -1;
143+
$L__BB0_2: // =>This Inner Loop Header: Depth=1
144+
.loc 1 0 57 // triton_fused_softmax.py:0:57
145+
cvt.u32.u64 %r90, %rd1;
146+
setp.eq.s32 %p14, %r11, 0;
147+
.loc 1 52 29 // triton_fused_softmax.py:52:29
148+
setp.lt.s32 %p20, %r90, %r30;
149+
.loc 1 44 57 // triton_fused_softmax.py:44:57
150+
setp.lt.s32 %p28, %r131, %r12;
151+
add.s32 %r91, %r130, 1;
152+
setp.gt.u32 %p29, %r130, 2147483646;
153+
selp.b32 %r130, %r91, 0, %p29;
154+
.loc 1 53 22 // triton_fused_softmax.py:53:22
155+
// begin inline asm
156+
cp.async.wait_group 0x0;
157+
```
Loading

0 commit comments

Comments
 (0)