Skip to content

[REFACTOR][GQA] unify square and prefill forward into one fa3 kernel #1622

Description

@zhen8838

Parent: #1610

Description

Symptom / Motivation

GQA forward is split across two parallel warp-specialized kernel families — gqa_fwd_ws.py (square, s_q == s_kv) and gqa_prefill_fwd_ws.py (rectangular, s_q <= s_kv, bottom-right causal). FlashInfer and FlashAttention treat both as one prefill/append operation (square is just offset = s_kv - s_q == 0): FA3's flash_attn_func documents bottom-right causal alignment for any seqlen_q/seqlen_k, and FlashInfer's single_prefill_with_kv_cache ("Prefill/Append attention") takes independent qo_len/kv_len. Keeping two kernels forces every capability change to be made twice, and the two paths have diverged in capability:

  • Neither forward Op exposes sm_scale / softcap. The square Op has no such params; the prefill FA3 gate rejects non-default sm_scale and any softcap, silently falling back to the slow kernel.
  • The FA3 path is fp16-only; bf16 (training) forward falls back and loses the FA3 speedup.
  • The prefill WS kernel emits no LSE, so a backward that needs LSE cannot use the fast forward path. The square WS kernel does emit LSE — proving the WS schedule can — but the prefill kernel dropped it because emitting LSE unconditionally cost ~35% on long context (PR [Feat][Attention] FA3 warp-specialized GQA prefill kernel (Hopper, fp16, causal) #1614).

Root Cause Analysis

  • tileops/ops/attention/gqa.py:138-152_select_gqa_prefill_fwd_kernel_cls: FA3 gate is hopper and is_causal and dim==128 and dtype==torch.float16 and softcap==0.0 and abs(sm_scale - dim**-0.5) < 1e-9; anything else → GQAPrefillFwdKernel (slow fallback).
  • tileops/ops/attention/gqa.py:227-251GroupedQueryAttentionFwdOp (square) has no sm_scale/softcap params and carries seq_len only (# TODO: support s_q != s_kv).
  • tileops/kernels/attention/gqa_prefill_fwd_ws.py:105-206 and 208-309 — two ~101-line consumer warp-group bodies, near-identical except r0/my_bar/nxt_bar/Qs index (dual-maintenance hazard). gqa_fwd_ws.py has the same _1/_2 duplicated structure.
  • gqa_prefill_fwd_ws.py:367-370 — prefill forward() returns (out, None); gqa_fwd_ws.py declares an lse output.

Related Files

  • tileops/ops/attention/gqa.py (forward Ops + _select_* gates)
  • tileops/kernels/attention/gqa_fwd_ws.py, tileops/kernels/attention/gqa_prefill_fwd_ws.py
  • tileops/kernels/attention/__init__.py (exports)

Goal

Converge GQA forward onto a single FA3/warp-specialized kernel that handles s_q <= s_kv (square = offset 0, bottom-right causal) and reaches FlashInfer capability parity: bf16, configurable sm_scale, softcap, and opt-in LSE output (off by default to preserve long-context performance, on when backward needs it — mirroring FlashInfer's single_prefill_with_kv_cache_return_lse). This is the forward slice of the #1610 kernel-layer unification.

Plan

  1. Define the unified forward TileLang kernel covering s_q <= s_kv + bottom-right causal, with the duplicated warp-group body factored into one parameterized macro (resolves the dual-WG duplication in both files).
  2. Add dtype support (incl. bf16, with fp32-promoted softmax/accumulators), sm_scale, softcap, and an opt-in return_lse path.
  3. Route both the square (GroupedQueryAttentionFwdOp) and rectangular (GroupedQueryAttentionPrefillFwdOp) forward Ops through the unified kernel; widen the selection gate (coordinate with [Refactor] Extract GQA prefill kernel selection logic #1611) so non-default scale/softcap and bf16 no longer silently fall back.
  4. Keep the legacy kernels as compatibility shims during migration, then retire.

Constraints

Acceptance Criteria

  • Modified files pass unit tests
  • One forward kernel serves both square and rectangular (s_q <= s_kv) GQA forward; the duplicated warp-group body exists in a single parameterized form.
  • fp16 and bf16 causal forward both run on the FA3/WS path (no silent fallback) and match the reference within tolerance.
  • Non-default sm_scale and softcap are honored on the fast path (no silent fallback).
  • return_lse=True produces correct LSE; default (False) preserves the long-context latency from PR [Feat][Attention] FA3 warp-specialized GQA prefill kernel (Hopper, fp16, causal) #1614.
  • Related manifest update (kernel_map for the unified kernel) is tracked as a separate manifest PR.

Metadata

Metadata

Assignees

Labels

refactorCode restructuring without behavior change

Type

No type

Fields

No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions