[Feat][Attention] FA3 warp-specialized GQA prefill kernel (Hopper, fp16, causal)#1614
Conversation
Add GQAPrefillFwdWsPersistentCausalKernel, a faithful port of FlashInfer's single_prefill_sm90 (2-warpgroup ping-pong, register-P, TMA, set_max_nreg, named-barrier WarpScheduler). The prefill op selects it on Hopper for the fp16 / causal / dim==128 / default-scale / no-softcap path and falls back to GQAPrefillFwdKernel everywhere else. Like FlashInfer's forward, the kernel emits no log-sum-exp; the op's (output, lse) contract is satisfied with lse=None. (Emitting lse from the epilogue perturbed the register-tuned schedule and cost ~35% on long context.) Bench (H200, fp16, locked clocks) vs flashinfer: B2 Sq512 Skv4096: 0.807 ms vs 0.815 ms (parity) B4 Sq512 Skv512: 0.228 ms vs flashinfer (faster) bench_gqa: generalize the flashinfer baseline helper to handle Sq != Skv and record it for the prefill bench.
There was a problem hiding this comment.
Code Review
This pull request introduces a new warp-specialized, 2-warpgroup ping-pong FlashInfer FA3 GQA-prefill kernel (GQAPrefillFwdWsPersistentCausalKernel) for causal, fp16, dim-128 attention on Hopper (sm90) architectures. It also updates the benchmarking scripts to profile this new FlashInfer baseline and integrates the kernel selection logic into the GQA operator. The review feedback highlights PEP 8 indentation issues within the newly added kernel file, specifically inside the with T.ws(0): and with T.ws(1): blocks where 8 spaces are used instead of the standard 4.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
bench_kernel drove torch.profiler.schedule with n_repeat calls crammed into each warmup/active "step". step() does not synchronize, so the queued warmup kernels were still executing when recording opened and leaked into the active trace — the window held ~2x n_repeat launches, and summed kernel time / n_repeat over-reported latency (worst for sub-100us kernels). Drop the schedule/on_trace_ready machinery: each trial now opens its own torch.profiler context around exactly n_repeat iterations and reads the trace after it closes, with a synchronize() before close so every kernel in the window is finished and recorded. _sum_kernel_time_us also returns the kernel count, and a sanity check warns if per-trial counts are not a consistent multiple of n_repeat (catches future leaks).
bench_kernel summed every CUDA kernel inside the profiler window, so the per-iteration cache.zero_() L2 flush was counted as kernel time. Wrap only the timed call in a record_function scope and sum the kernels Kineto attributes to that scope's device-timeline annotation window, excluding the flush structurally. Attribution is by scope, not kernel name, so a kernel under test that is itself a vectorized_elementwise fill is timed correctly. Add ProfilerActivity.CPU (required for the annotation to project onto the device timeline) and drop the inter-op sync before the call (same-stream ordering already guarantees the flush completes first).
Summary
Adds
GQAPrefillFwdWsPersistentCausalKernel— a faithful port of FlashInfer'ssingle_prefill_sm90— and wires it intoGroupedQueryAttentionPrefillFwdOp. On Hopper, the prefill op now dispatches to this kernel for the fp16 / causal / dim==128 / default sm_scale / no-softcap path, reaching FlashInfer-level performance; all other cases fall back to the existingGQAPrefillFwdKernel.Performance
Kernel-level parity with FlashInfer. Per-kernel CUDA time from
benchmarks/ops/attention/bench_gqa.py(H200, fp16, causal, L2 flushed per iteration;single_prefill_sm90is the kernel this ports):Within ~3–4% across the workloads: at parity / marginally faster on long context (Skv=4096, ~525 TFLOPS), marginally slower on the short square case (Skv=512). The bench now reports per-kernel CUPTI time directly — it wraps the timed call in a
record_functionscope and sums only the kernels Kineto attributes to it, so the per-iteration L2 flush is excluded from the measurement.Kernel
2-warpgroup ping-pong (1 TMA producer WG + 2 consumer WG, 384 threads); register-P (rs-wgmma) held in a fp16 fragment;
set_max_nregproducer→24 / consumer→240; named-barrier WarpScheduler ping-pong; delayed PV +wait_wgmma(1)so softmax(k) overlaps in-flight PV(k-1); pipelined rescale; smem-staged coalesced output. Causal uses bottom-right alignment (co = seq_len_kv - seq_len_q).Notes
(output, lse)contract is satisfied withlse=None. Emitting lse from the epilogue perturbed the register-tuned schedule and cost ~35% on long context; a backward pass that needs lse must compute it separately.bench_gqa: the flashinfer baseline helper is generalized to handleSq != Skvand is now recorded for the prefill bench.kernel_map/source.kernelare intentionally not touched here — adding a new-file kernel requires asource.kernelchange, which is outside the implementation-PR carve-out and will follow as a separate manifest PR.Verification
ruffclean.