|
| 1 | +# Triton Fused Softmax |
| 2 | +4x speedup vs naive_softmax, support multi-stages. |
| 3 | + |
| 4 | +## Install |
| 5 | + |
| 6 | +```bash |
| 7 | +python3 -m pip install -r requirements.txt |
| 8 | +``` |
| 9 | + |
| 10 | +## Benchmark |
| 11 | + |
| 12 | +```bash |
| 13 | +python3 triton_fused_softmax.py |
| 14 | +``` |
| 15 | + |
| 16 | +performance plot (Higher GB/s is better) |
| 17 | + |
| 18 | + |
| 19 | + |
| 20 | +sofmax performance data: |
| 21 | + |
| 22 | +```bash |
| 23 | +python3 triton_fused_softmax.py |
| 24 | +NUM_SM: 92, NUM_REGS: 65536, SIZE_SMEM: 49152, WARP_SIZE: 32 |
| 25 | +softmax-performance: (GB/s) |
| 26 | + M Triton Fused Softmax Torch Fused Softmax Torch Naive Softmax |
| 27 | +0 256.0 362.393505 305.358765 124.678148 |
| 28 | +1 512.0 445.778519 405.032092 194.896749 |
| 29 | +2 768.0 514.098122 438.300277 250.346181 |
| 30 | +3 1024.0 535.408538 484.151615 282.482734 |
| 31 | +4 1280.0 569.918518 495.302961 298.126599 |
| 32 | +5 1536.0 548.277633 520.750543 314.245872 |
| 33 | +6 1792.0 533.588304 522.417207 316.189641 |
| 34 | +7 2048.0 567.566790 507.963708 323.995917 |
| 35 | +8 2304.0 538.947392 534.294155 330.039737 |
| 36 | +9 2560.0 563.100752 541.798988 331.690770 |
| 37 | +10 2816.0 566.233438 550.942414 331.748259 |
| 38 | +11 3072.0 567.218890 554.429254 331.566886 |
| 39 | +12 3328.0 555.099312 551.214055 329.411258 |
| 40 | +13 3584.0 572.163959 557.957312 317.685035 |
| 41 | +14 3840.0 575.175853 548.535608 322.132599 |
| 42 | +15 4096.0 559.898472 575.041421 313.671864 |
| 43 | +16 4352.0 580.166989 573.423500 309.974206 |
| 44 | +17 4608.0 593.755089 570.839527 302.986224 |
| 45 | +18 4864.0 575.728831 559.051306 302.688434 |
| 46 | +19 5120.0 575.766222 572.413610 301.564769 |
| 47 | +20 5376.0 572.423563 573.764541 296.201227 |
| 48 | +21 5632.0 575.843562 580.459896 291.670395 |
| 49 | +22 5888.0 579.751142 576.305867 288.366253 |
| 50 | +23 6144.0 582.716775 579.061272 281.005347 |
| 51 | +24 6400.0 579.270649 589.253228 274.376955 |
| 52 | +25 6656.0 585.002711 582.368493 270.766555 |
| 53 | +26 6912.0 587.171336 584.960726 263.143868 |
| 54 | +27 7168.0 587.587015 587.994311 253.358858 |
| 55 | +28 7424.0 589.017087 588.470371 251.164961 |
| 56 | +29 7680.0 581.928985 586.436541 245.850602 |
| 57 | +30 7936.0 585.914317 602.394388 236.971384 |
| 58 | +31 8192.0 591.928478 590.293530 232.782519 |
| 59 | +32 8448.0 593.058718 593.496120 221.579701 |
| 60 | +33 8704.0 594.695425 595.261556 216.494918 |
| 61 | +34 8960.0 591.452098 591.325040 212.167842 |
| 62 | +35 9216.0 591.788325 593.896863 205.272973 |
| 63 | +36 9472.0 594.511034 598.075544 196.612392 |
| 64 | +37 9728.0 593.693925 599.458087 195.754667 |
| 65 | +38 9984.0 598.928478 600.050556 189.836630 |
| 66 | +39 10240.0 591.822549 600.587319 186.929136 |
| 67 | +40 10496.0 593.427400 604.631795 185.559242 |
| 68 | +41 10752.0 598.165555 599.295332 184.250498 |
| 69 | +42 11008.0 600.996379 603.951158 183.771913 |
| 70 | +43 11264.0 601.199174 605.015824 180.982734 |
| 71 | +44 11520.0 601.669498 605.464543 180.641048 |
| 72 | +45 11776.0 600.274667 609.133846 180.645646 |
| 73 | +46 12032.0 603.406759 607.627874 178.938962 |
| 74 | +47 12288.0 603.722180 606.700927 177.688192 |
| 75 | +48 12544.0 601.190215 611.649160 176.379854 |
| 76 | +49 12800.0 603.581671 610.237343 176.527952 |
| 77 | +50 13056.0 604.757733 610.384488 175.390770 |
| 78 | +51 13312.0 603.985819 615.509130 174.097735 |
| 79 | +52 13568.0 605.297037 612.408629 173.256892 |
| 80 | +53 13824.0 605.908459 614.069065 174.128498 |
| 81 | +54 14080.0 607.920747 613.264977 173.922682 |
| 82 | +55 14336.0 607.635887 618.994853 173.581517 |
| 83 | +56 14592.0 607.396226 617.410641 174.207748 |
| 84 | +57 14848.0 609.851532 619.969111 172.837354 |
| 85 | +58 15104.0 609.737443 619.281199 172.619360 |
| 86 | +59 15360.0 614.387454 619.042827 173.398807 |
| 87 | +60 15616.0 612.544618 622.329807 171.825113 |
| 88 | +61 15872.0 609.665642 621.071773 172.623526 |
| 89 | +62 16128.0 610.741546 625.835582 172.361260 |
| 90 | +``` |
| 91 | +## PTX gen code |
| 92 | +check [softmax_kernel.ptx](./softmax_kernel.ptx) for more details. |
| 93 | + |
| 94 | +```NASM |
| 95 | +// begin inline asm |
| 96 | + @%p2 cp.async.ca.shared.global [ %r82 + 0 ], [ %rd4 + 0 ], 0x4, %r34; |
| 97 | + // end inline asm |
| 98 | + selp.b32 %r36, %r45, 0, %p7; |
| 99 | + // begin inline asm |
| 100 | + @%p2 cp.async.ca.shared.global [ %r84 + 0 ], [ %rd5 + 0 ], 0x4, %r36; |
| 101 | + // end inline asm |
| 102 | + selp.b32 %r38, %r45, 0, %p8; |
| 103 | + // begin inline asm |
| 104 | + @%p2 cp.async.ca.shared.global [ %r86 + 0 ], [ %rd6 + 0 ], 0x4, %r38; |
| 105 | + // end inline asm |
| 106 | + selp.b32 %r40, %r45, 0, %p9; |
| 107 | + // begin inline asm |
| 108 | + @%p2 cp.async.ca.shared.global [ %r88 + 0 ], [ %rd7 + 0 ], 0x4, %r40; |
| 109 | + // end inline asm |
| 110 | + // begin inline asm |
| 111 | + cp.async.commit_group ; |
| 112 | + // end inline asm |
| 113 | + .loc 1 44 57 // triton_fused_softmax.py:44:57 |
| 114 | + @%p10 bra $L__BB0_3; |
| 115 | +// %bb.1: // %.lr.ph |
| 116 | + .loc 1 0 57 // triton_fused_softmax.py:0:57 |
| 117 | + ld.param.u32 %r28, [softmax_kernel_param_3]; |
| 118 | + ld.param.u64 %rd2, [softmax_kernel_param_0]; |
| 119 | + // begin inline asm |
| 120 | + mov.u32 %r32, %nctaid.x; |
| 121 | + // end inline asm |
| 122 | + cvt.u64.u32 %rd1, %r41; |
| 123 | + .loc 1 49 35 // triton_fused_softmax.py:49:35 |
| 124 | + and.b32 %r11, %r3, 31; |
| 125 | + sub.s32 %r12, %r29, %r32; |
| 126 | + add.s32 %r48, %r44, 4096; |
| 127 | + shr.u32 %r49, %r3, 3; |
| 128 | + and.b32 %r50, %r49, 28; |
| 129 | + add.s32 %r54, %r48, %r50; |
| 130 | + setp.lt.s32 %p15, %r3, 8; |
| 131 | + shl.b32 %r51, %r3, 2; |
| 132 | + add.s32 %r57, %r48, %r51; |
| 133 | + and.b32 %r52, %r3, 7; |
| 134 | + setp.eq.s32 %p13, %r52, 0; |
| 135 | + and.pred %p16, %p15, %p13; |
| 136 | + .loc 1 44 57 // triton_fused_softmax.py:44:57 |
| 137 | + add.s32 %r53, %r131, %r32; |
| 138 | + mul.lo.s32 %r129, %r27, %r53; |
| 139 | + mul.lo.s32 %r16, %r32, %r27; |
| 140 | + mul.lo.s32 %r128, %r131, %r28; |
| 141 | + mul.lo.s32 %r18, %r32, %r28; |
| 142 | + mov.b32 %r130, -1; |
| 143 | +$L__BB0_2: // =>This Inner Loop Header: Depth=1 |
| 144 | + .loc 1 0 57 // triton_fused_softmax.py:0:57 |
| 145 | + cvt.u32.u64 %r90, %rd1; |
| 146 | + setp.eq.s32 %p14, %r11, 0; |
| 147 | + .loc 1 52 29 // triton_fused_softmax.py:52:29 |
| 148 | + setp.lt.s32 %p20, %r90, %r30; |
| 149 | + .loc 1 44 57 // triton_fused_softmax.py:44:57 |
| 150 | + setp.lt.s32 %p28, %r131, %r12; |
| 151 | + add.s32 %r91, %r130, 1; |
| 152 | + setp.gt.u32 %p29, %r130, 2147483646; |
| 153 | + selp.b32 %r130, %r91, 0, %p29; |
| 154 | + .loc 1 53 22 // triton_fused_softmax.py:53:22 |
| 155 | + // begin inline asm |
| 156 | + cp.async.wait_group 0x0; |
| 157 | +``` |
0 commit comments