Summary
Investigated implementing FA3 (Flash Attention 3) style intra-warpgroup WGMMA overlap for the GQA forward kernel in TileOPs. The goal is to overlap PV GEMM execution with softmax computation on SM90 (Hopper), matching FA3's architecture.
Key Findings
1. Current Pipeline Analysis
The existing T.Pipelined kernel with order/stage/group reordering already achieves:
- TMA prefetch overlap: K[i+1]/V[i+1] loads overlap with compute[i] via double-buffering
- Reordered execution: QK(0) → rescale(1) → PV(2) → softmax(3) in the fused pipeline body
But PV GEMM uses synchronous warpgroup_wait<0>() (~100+ cycles), blocking softmax until PV fully completes.
2. Root Cause of wg_wait=-1 Slowdown (~20%)
CUDA dump diff between baseline and wg_wait=-1 shows only 2 lines changed: removal of warpgroup_wait<0>() after PV GEMM.
// Baseline (PV GEMM section):
tl::warpgroup_commit_batch();
tl::warpgroup_wait<0>(); // ← full wait: smem reads + register writes
tl::warpgroup_fence_operand(...);
mbarrier[v_buf].arrive(); // ← signals TMA: V buffer can be reused
// wg_wait=-1 (what TileLang generates):
tl::warpgroup_commit_batch();
// ← nothing! smem reads may still be in progress
tl::warpgroup_fence_operand(...);
mbarrier[v_buf].arrive(); // ← unsafe: WGMMA may still read v_shared
FA3's solution: Insert warpgroup_arrive() (= PTX wgmma.fence.sync.aligned, ~3 cycles) which ensures shared memory reads complete WITHOUT waiting for register writes:
// FA3 ideal:
tl::warpgroup_commit_batch();
tl::warpgroup_arrive(); // ← smem fence only (~3 cycles)
tl::warpgroup_fence_operand(...);
mbarrier[v_buf].arrive(); // ← safe: smem reads done
// softmax runs here while PV register writes continue in background
3. Why User-Level Injection Fails
| Attempt |
Result |
Reason |
T.warpgroup_arrive() in Python |
Optimized away |
Pipeline pass deduplicates with next GEMM's arrive |
tir.call_extern("tl::warpgroup_arrive") |
Optimized away |
DCE pass removes it |
Patch gemm_sm90.h template |
No effect |
codegen (gemm.cc) doesn't call template functions; it inline-generates WGMMA code directly |
| Reorder pipeline (softmax before PV) |
Incorrect |
acc_s_cast fragment has no double-buffer; softmax overwrites before PV reads |
4. Alternative Approaches Tested
| Approach |
Performance |
Issue |
| T.Pipelined baseline (wg_wait=0) |
1.0x |
No FA3 overlap |
| T.Pipelined + wg_wait=-1 |
0.82x |
Missing fence, instruction scheduling degradation |
| T.serial + disable_tma (FA3 logic validated) |
0.67x |
SIMT copy too slow |
| T.serial + cp.async + barriers |
0.61x |
cp.async << TMA on SM90 |
| WS (if tx<128) + cp.async |
0.30x |
cp.async bottleneck + WS overhead |
5. Discovered APIs for Proper WS Implementation
Found in tilelang/examples/warp_specialize/:
# T.tma_copy — TMA with explicit barrier binding (bypasses InjectTmaBarrier pass)
T.tma_copy(global_tensor[slice], shared_buffer, barrier=my_barrier)
# T.ws(N) — Declarative warp specialization
with T.ws(0): # Producer warpgroup
T.tma_copy(K[...], k_shared, barrier=k_ready)
with T.ws(1): # Consumer warpgroup
T.barrier_wait(k_ready, parity)
T.wgmma_gemm(q_shared, k_shared, acc_s, ...)
# T.wgmma_gemm — WGMMA in WS context
T.wgmma_gemm(A_shared, B_shared, acc, transpose_B=True, clear_accum=True)
Reference: example_warp_specialize_flashmla.py — complete FlashMLA decode with 2-WG pingpong, TMA, barriers, and online softmax.
Proposed Implementation
Rewrite GQA forward using T.tma_copy + T.ws() + T.wgmma_gemm:
Architecture: 2 warpgroups (256 threads), pingpong processing alternating KV tiles:
- WG0: QK[even] + softmax[even] + PV_left[even], prefetches odd tiles via
T.tma_copy
- WG1: QK[odd] + softmax[odd] + PV_right[odd], prefetches even tiles via
T.tma_copy
Shared scores_max coordinates online softmax across WGs (same pattern as FlashMLA example).
Expected benefit: PV WGMMA overlaps with the other WG's softmax computation. TMA loads overlap with both WGs' compute. Full utilization of SM90's TMA engine + Tensor Cores + CUDA cores simultaneously.
Files Modified During Investigation
tileops/kernels/flash_attn/fwd.py — various experimental changes (all reverted to baseline)
- Temporary test files:
_test_manual_pipeline.py, _test_ws.py, _bench_ab.py, _bench_manual.py, _dump_cuda.py, _dump_patched.py, _patch_and_test.py, _run_fence_test.sh, _test_fence_patch.sh
Next Steps
- Implement GQA forward with
T.tma_copy + T.ws() + T.wgmma_gemm following the FlashMLA example pattern
- Benchmark against baseline and FA3
- Consider contributing
warpgroup_arrive() fence support to TileLang's gemm.cc codegen (one-line fix for wg_wait=-1 in T.Pipelined context)
Summary
Investigated implementing FA3 (Flash Attention 3) style intra-warpgroup WGMMA overlap for the GQA forward kernel in TileOPs. The goal is to overlap PV GEMM execution with softmax computation on SM90 (Hopper), matching FA3's architecture.
Key Findings
1. Current Pipeline Analysis
The existing
T.Pipelinedkernel withorder/stage/groupreordering already achieves:But PV GEMM uses synchronous
warpgroup_wait<0>()(~100+ cycles), blocking softmax until PV fully completes.2. Root Cause of
wg_wait=-1Slowdown (~20%)CUDA dump diff between baseline and
wg_wait=-1shows only 2 lines changed: removal ofwarpgroup_wait<0>()after PV GEMM.FA3's solution: Insert
warpgroup_arrive()(= PTXwgmma.fence.sync.aligned, ~3 cycles) which ensures shared memory reads complete WITHOUT waiting for register writes:3. Why User-Level Injection Fails
T.warpgroup_arrive()in Pythontir.call_extern("tl::warpgroup_arrive")gemm_sm90.htemplategemm.cc) doesn't call template functions; it inline-generates WGMMA code directlyacc_s_castfragment has no double-buffer; softmax overwrites before PV reads4. Alternative Approaches Tested
5. Discovered APIs for Proper WS Implementation
Found in
tilelang/examples/warp_specialize/:Reference:
example_warp_specialize_flashmla.py— complete FlashMLA decode with 2-WG pingpong, TMA, barriers, and online softmax.Proposed Implementation
Rewrite GQA forward using
T.tma_copy+T.ws()+T.wgmma_gemm:Architecture: 2 warpgroups (256 threads), pingpong processing alternating KV tiles:
T.tma_copyT.tma_copyShared
scores_maxcoordinates online softmax across WGs (same pattern as FlashMLA example).Expected benefit: PV WGMMA overlaps with the other WG's softmax computation. TMA loads overlap with both WGs' compute. Full utilization of SM90's TMA engine + Tensor Cores + CUDA cores simultaneously.
Files Modified During Investigation
tileops/kernels/flash_attn/fwd.py— various experimental changes (all reverted to baseline)_test_manual_pipeline.py,_test_ws.py,_bench_ab.py,_bench_manual.py,_dump_cuda.py,_dump_patched.py,_patch_and_test.py,_run_fence_test.sh,_test_fence_patch.shNext Steps
T.tma_copy+T.ws()+T.wgmma_gemmfollowing the FlashMLA example patternwarpgroup_arrive()fence support to TileLang'sgemm.cccodegen (one-line fix forwg_wait=-1inT.Pipelinedcontext)