You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
GQA forward is split across two parallel warp-specialized kernel families — gqa_fwd_ws.py (square, s_q == s_kv) and gqa_prefill_fwd_ws.py (rectangular, s_q <= s_kv, bottom-right causal). FlashInfer and FlashAttention treat both as one prefill/append operation (square is just offset = s_kv - s_q == 0): FA3's flash_attn_func documents bottom-right causal alignment for any seqlen_q/seqlen_k, and FlashInfer's single_prefill_with_kv_cache ("Prefill/Append attention") takes independent qo_len/kv_len. Keeping two kernels forces every capability change to be made twice, and the two paths have diverged in capability:
Neither forward Op exposes sm_scale / softcap. The square Op has no such params; the prefill FA3 gate rejects non-default sm_scale and any softcap, silently falling back to the slow kernel.
The FA3 path is fp16-only; bf16 (training) forward falls back and loses the FA3 speedup.
The prefill WS kernel emits no LSE, so a backward that needs LSE cannot use the fast forward path. The square WS kernel does emit LSE — proving the WS schedule can — but the prefill kernel dropped it because emitting LSE unconditionally cost ~35% on long context (PR [Feat][Attention] FA3 warp-specialized GQA prefill kernel (Hopper, fp16, causal) #1614).
Root Cause Analysis
tileops/ops/attention/gqa.py:138-152 — _select_gqa_prefill_fwd_kernel_cls: FA3 gate is hopper and is_causal and dim==128 and dtype==torch.float16 and softcap==0.0 and abs(sm_scale - dim**-0.5) < 1e-9; anything else → GQAPrefillFwdKernel (slow fallback).
tileops/ops/attention/gqa.py:227-251 — GroupedQueryAttentionFwdOp (square) has no sm_scale/softcap params and carries seq_len only (# TODO: support s_q != s_kv).
tileops/kernels/attention/gqa_prefill_fwd_ws.py:105-206 and 208-309 — two ~101-line consumer warp-group bodies, near-identical except r0/my_bar/nxt_bar/Qs index (dual-maintenance hazard). gqa_fwd_ws.py has the same _1/_2 duplicated structure.
Converge GQA forward onto a single FA3/warp-specialized kernel that handles s_q <= s_kv (square = offset 0, bottom-right causal) and reaches FlashInfer capability parity: bf16, configurable sm_scale, softcap, and opt-in LSE output (off by default to preserve long-context performance, on when backward needs it — mirroring FlashInfer's single_prefill_with_kv_cache_return_lse). This is the forward slice of the #1610 kernel-layer unification.
Plan
Define the unified forward TileLang kernel covering s_q <= s_kv + bottom-right causal, with the duplicated warp-group body factored into one parameterized macro (resolves the dual-WG duplication in both files).
Add dtype support (incl. bf16, with fp32-promoted softmax/accumulators), sm_scale, softcap, and an opt-in return_lse path.
Route both the square (GroupedQueryAttentionFwdOp) and rectangular (GroupedQueryAttentionPrefillFwdOp) forward Ops through the unified kernel; widen the selection gate (coordinate with [Refactor] Extract GQA prefill kernel selection logic #1611) so non-default scale/softcap and bf16 no longer silently fall back.
Keep the legacy kernels as compatibility shims during migration, then retire.
Constraints
Forward only. Decode (separate sub-issue), the backward kernel, and varlen/paged unification are out of scope.
Trust model: any tileops/manifest/attention.yamlsignature/kernel_map/source.kernel change is a separate manifest-only PR with human review; this code issue must not edit the manifest to match code.
Parent: #1610
Description
Symptom / Motivation
GQA forward is split across two parallel warp-specialized kernel families —
gqa_fwd_ws.py(square,s_q == s_kv) andgqa_prefill_fwd_ws.py(rectangular,s_q <= s_kv, bottom-right causal). FlashInfer and FlashAttention treat both as one prefill/append operation (square is justoffset = s_kv - s_q == 0): FA3'sflash_attn_funcdocuments bottom-right causal alignment for anyseqlen_q/seqlen_k, and FlashInfer'ssingle_prefill_with_kv_cache("Prefill/Append attention") takes independentqo_len/kv_len. Keeping two kernels forces every capability change to be made twice, and the two paths have diverged in capability:sm_scale/softcap. The square Op has no such params; the prefill FA3 gate rejects non-defaultsm_scaleand anysoftcap, silently falling back to the slow kernel.Root Cause Analysis
tileops/ops/attention/gqa.py:138-152—_select_gqa_prefill_fwd_kernel_cls: FA3 gate ishopper and is_causal and dim==128 and dtype==torch.float16 and softcap==0.0 and abs(sm_scale - dim**-0.5) < 1e-9; anything else →GQAPrefillFwdKernel(slow fallback).tileops/ops/attention/gqa.py:227-251—GroupedQueryAttentionFwdOp(square) has nosm_scale/softcapparams and carriesseq_lenonly (# TODO: support s_q != s_kv).tileops/kernels/attention/gqa_prefill_fwd_ws.py:105-206and208-309— two ~101-line consumer warp-group bodies, near-identical exceptr0/my_bar/nxt_bar/Qsindex (dual-maintenance hazard).gqa_fwd_ws.pyhas the same_1/_2duplicated structure.gqa_prefill_fwd_ws.py:367-370— prefillforward()returns(out, None);gqa_fwd_ws.pydeclares anlseoutput.Related Files
tileops/ops/attention/gqa.py(forward Ops +_select_*gates)tileops/kernels/attention/gqa_fwd_ws.py,tileops/kernels/attention/gqa_prefill_fwd_ws.pytileops/kernels/attention/__init__.py(exports)Goal
Converge GQA forward onto a single FA3/warp-specialized kernel that handles
s_q <= s_kv(square =offset 0, bottom-right causal) and reaches FlashInfer capability parity: bf16, configurablesm_scale,softcap, and opt-in LSE output (off by default to preserve long-context performance, on when backward needs it — mirroring FlashInfer'ssingle_prefill_with_kv_cache_return_lse). This is the forward slice of the #1610 kernel-layer unification.Plan
s_q <= s_kv+ bottom-right causal, with the duplicated warp-group body factored into one parameterized macro (resolves the dual-WG duplication in both files).dtypesupport (incl. bf16, with fp32-promoted softmax/accumulators),sm_scale,softcap, and an opt-inreturn_lsepath.GroupedQueryAttentionFwdOp) and rectangular (GroupedQueryAttentionPrefillFwdOp) forward Ops through the unified kernel; widen the selection gate (coordinate with [Refactor] Extract GQA prefill kernel selection logic #1611) so non-default scale/softcap and bf16 no longer silently fall back.Constraints
tileops/manifest/attention.yamlsignature/kernel_map/source.kernelchange is a separate manifest-only PR with human review; this code issue must not edit the manifest to match code.Acceptance Criteria
s_q <= s_kv) GQA forward; the duplicated warp-group body exists in a single parameterized form.sm_scaleandsoftcapare honored on the fast path (no silent fallback).return_lse=Trueproduces correct LSE; default (False) preserves the long-context latency from PR [Feat][Attention] FA3 warp-specialized GQA prefill kernel (Hopper, fp16, causal) #1614.