Skip to content

GQA Forward FA3-style WGMMA Overlap: Analysis & Implementation Path #3

Description

@superAngGao

Summary

Investigated implementing FA3 (Flash Attention 3) style intra-warpgroup WGMMA overlap for the GQA forward kernel in TileOPs. The goal is to overlap PV GEMM execution with softmax computation on SM90 (Hopper), matching FA3's architecture.

Key Findings

1. Current Pipeline Analysis

The existing T.Pipelined kernel with order/stage/group reordering already achieves:

  • TMA prefetch overlap: K[i+1]/V[i+1] loads overlap with compute[i] via double-buffering
  • Reordered execution: QK(0) → rescale(1) → PV(2) → softmax(3) in the fused pipeline body

But PV GEMM uses synchronous warpgroup_wait<0>() (~100+ cycles), blocking softmax until PV fully completes.

2. Root Cause of wg_wait=-1 Slowdown (~20%)

CUDA dump diff between baseline and wg_wait=-1 shows only 2 lines changed: removal of warpgroup_wait<0>() after PV GEMM.

// Baseline (PV GEMM section):
tl::warpgroup_commit_batch();
tl::warpgroup_wait<0>();           // ← full wait: smem reads + register writes
tl::warpgroup_fence_operand(...);
mbarrier[v_buf].arrive();          // ← signals TMA: V buffer can be reused

// wg_wait=-1 (what TileLang generates):
tl::warpgroup_commit_batch();
                                   // ← nothing! smem reads may still be in progress
tl::warpgroup_fence_operand(...);
mbarrier[v_buf].arrive();          // ← unsafe: WGMMA may still read v_shared

FA3's solution: Insert warpgroup_arrive() (= PTX wgmma.fence.sync.aligned, ~3 cycles) which ensures shared memory reads complete WITHOUT waiting for register writes:

// FA3 ideal:
tl::warpgroup_commit_batch();
tl::warpgroup_arrive();            // ← smem fence only (~3 cycles)
tl::warpgroup_fence_operand(...);
mbarrier[v_buf].arrive();          // ← safe: smem reads done
// softmax runs here while PV register writes continue in background

3. Why User-Level Injection Fails

Attempt Result Reason
T.warpgroup_arrive() in Python Optimized away Pipeline pass deduplicates with next GEMM's arrive
tir.call_extern("tl::warpgroup_arrive") Optimized away DCE pass removes it
Patch gemm_sm90.h template No effect codegen (gemm.cc) doesn't call template functions; it inline-generates WGMMA code directly
Reorder pipeline (softmax before PV) Incorrect acc_s_cast fragment has no double-buffer; softmax overwrites before PV reads

4. Alternative Approaches Tested

Approach Performance Issue
T.Pipelined baseline (wg_wait=0) 1.0x No FA3 overlap
T.Pipelined + wg_wait=-1 0.82x Missing fence, instruction scheduling degradation
T.serial + disable_tma (FA3 logic validated) 0.67x SIMT copy too slow
T.serial + cp.async + barriers 0.61x cp.async << TMA on SM90
WS (if tx<128) + cp.async 0.30x cp.async bottleneck + WS overhead

5. Discovered APIs for Proper WS Implementation

Found in tilelang/examples/warp_specialize/:

# T.tma_copy — TMA with explicit barrier binding (bypasses InjectTmaBarrier pass)
T.tma_copy(global_tensor[slice], shared_buffer, barrier=my_barrier)

# T.ws(N) — Declarative warp specialization
with T.ws(0):  # Producer warpgroup
    T.tma_copy(K[...], k_shared, barrier=k_ready)
with T.ws(1):  # Consumer warpgroup
    T.barrier_wait(k_ready, parity)
    T.wgmma_gemm(q_shared, k_shared, acc_s, ...)

# T.wgmma_gemm — WGMMA in WS context
T.wgmma_gemm(A_shared, B_shared, acc, transpose_B=True, clear_accum=True)

Reference: example_warp_specialize_flashmla.py — complete FlashMLA decode with 2-WG pingpong, TMA, barriers, and online softmax.

Proposed Implementation

Rewrite GQA forward using T.tma_copy + T.ws() + T.wgmma_gemm:

Architecture: 2 warpgroups (256 threads), pingpong processing alternating KV tiles:

  • WG0: QK[even] + softmax[even] + PV_left[even], prefetches odd tiles via T.tma_copy
  • WG1: QK[odd] + softmax[odd] + PV_right[odd], prefetches even tiles via T.tma_copy

Shared scores_max coordinates online softmax across WGs (same pattern as FlashMLA example).

Expected benefit: PV WGMMA overlaps with the other WG's softmax computation. TMA loads overlap with both WGs' compute. Full utilization of SM90's TMA engine + Tensor Cores + CUDA cores simultaneously.

Files Modified During Investigation

  • tileops/kernels/flash_attn/fwd.py — various experimental changes (all reverted to baseline)
  • Temporary test files: _test_manual_pipeline.py, _test_ws.py, _bench_ab.py, _bench_manual.py, _dump_cuda.py, _dump_patched.py, _patch_and_test.py, _run_fence_test.sh, _test_fence_patch.sh

Next Steps

  1. Implement GQA forward with T.tma_copy + T.ws() + T.wgmma_gemm following the FlashMLA example pattern
  2. Benchmark against baseline and FA3
  3. Consider contributing warpgroup_arrive() fence support to TileLang's gemm.cc codegen (one-line fix for wg_wait=-1 in T.Pipelined context)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions