[Perf][Linear-Attn] Add optimized Gated DeltaNet prefill op#1596
[Perf][Linear-Attn] Add optimized Gated DeltaNet prefill op#1596superAngGao wants to merge 6 commits into
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements the Gated DeltaNet inference prefill operator (GatedDeltaNetPrefillFwdOp) and its corresponding kernel, along with associated tests, benchmarks, and performance formulas. Feedback on the changes highlights two critical issues: first, input tensors should be made contiguous in the operator's forward pass to prevent memory access errors or crashes; second, the optimized scan path condition (use_scan_h) in the prefill kernel needs an additional divisibility check on the sequence length to avoid silent correctness bugs on non-divisible lengths.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
58a4ea4 to
ec23c79
Compare
ec23c79 to
b617d52
Compare
b617d52 to
2170e25
Compare
2170e25 to
17bfa8b
Compare
Gabbering
left a comment
There was a problem hiding this comment.
goose review — 17bfa8b9
honk. Both inline thread comments were addressed by the master. The goose has re-read the conversation and confirmed the following:
-
Thread 1 (
tileops/ops/gated_deltanet.py:303): The master added.contiguous()calls for all five input tensors before validation and kernel dispatch, and wrote a test (test_gated_deltanet_prefill_bthd_layout_matches_bhtd) that passes non-contiguous permuted inputs through the BTHD path. The goose's contiguity concern is resolved. -
Thread 2 (
tileops/kernels/gated_deltanet/gated_deltanet_prefill.py:852): The master definedgroup_chunks = 64before theuse_scan_hguard and addedand seq_len % (chunk_size * group_chunks) == 0to the condition. Non-divisible long sequences now fall back to the standard recurrence path. The goose's scan-path divisibility concern is resolved.
Both fixes are present in the delta 17bfa8b9 and the goose sees no new issues introduced by this commit.
SILENT
Ibuki-wind
left a comment
There was a problem hiding this comment.
Overall
The new manifest-driven benchmark does not record an independent baseline, so the benchmark contract is not satisfied yet.
17bfa8b to
d09c8f2
Compare
Ibuki-wind
left a comment
There was a problem hiding this comment.
Overall
The benchmark baseline blocker is fixed, but the public BTHD layout contract is not represented in the manifest and the approval-gate authoring hygiene still needs cleanup. Update the manifest shape rules for layout, make the PR body follow the template with final verification facts including pre-commit/benchmark/test-node-delta, and edit author replies to outcome-only one-liners.
## Summary - update the GatedDeltaNetPrefillFwdOp spec-only manifest contract to allow chunk_size=None - add the layout parameter used by the implementation path - add an H16/S2K manifest workload row for the target prefill shape family ## Notes - This is a manifest-only precursor for PR #1596. - It intentionally keeps status: spec-only and leaves source paths null; the implementation PR should flip status/source/roofline once the kernel code lands. ## Validation - python scripts/validate_manifest.py --check-op GatedDeltaNetPrefillFwdOp
## Summary - encode the GatedDeltaNetPrefillFwdOp BHTD/BHSD and BTHD layout shape contracts in the manifest - add fixed-rank shape metadata for the manifest tensors - fill the planned roofline/source metadata while keeping the entry spec-only ## Notes - This is a manifest-only prerequisite for PR #1596. - The implementation PR should only flip status and keep any remaining manifest edit within the status-flip carve-out. ## Validation - python scripts/validate_manifest.py
0828ab6 to
067edc7
Compare
|
For the performance comparison, please include FlashQLA as a baseline/reference as well: https://github.com/QwenLM/FlashQLA It would be useful to compare on the same target hardware and representative GDN prefill shapes/dtypes so the remaining gap vs. the current public optimized implementation is explicit. |
1c4337e to
9d21d88
Compare
9d21d88 to
9282b2b
Compare
|
Thanks for the pointer. I added FlashQLA as an explicit baseline in the PR body, measured on the same H200/GPU4 setup with the TileOps benchmark/CUPTI timer and representative Qwen-style GDN prefill rows. The refreshed numbers show the TileOps path is now faster than the public FlashQLA reference on these rows. The main reason is not just a different parallelization strategy or a block/thread schedule tweak. The important change is the overall prefill algorithmic path: the serving path now only materializes I also kept FlashQLA attribution explicit in the PR body: its CP-split fused replay schedule was the key production reference for the back half of this work; the TileOps contribution is the port, integration, A-producer path, dispatch/tuning, and validation against FLA/FlashQLA. |
…shqla-specialized # Conflicts: # tileops/kernels/gated_deltanet/gated_deltanet_fwd.py
Closes #1595.
Summary
This PR adds the serving-oriented
GatedDeltaNetPrefillFwdOpimplementation for zero-state inference prefill.B=1,layout="bthd",fp16,DK=DV=128,chunk=64.bhtd/bhsdcompatibility through the op API.(o, final_state)and avoids materializing the previous globalw/u/S/v_newintermediates on the hot serving path.TILEOPS_GDN_PREFILL_PARTITIONED=0remains as a kill switch.gdn_prefill/layer.(q, k, v, g, beta) -> (o, final_state)without training/backward-only intermediates.tileops/manifest/linear_attention.yamlfromspec-onlytoimplementedand records BTHD serving workload shapes.Partitioned Prefill Dispatch
For
B=1,layout="bthd",fp16,DK=DV=128,chunk=64, production dispatch now uses:The partitioning heuristic lives in
gated_deltanet_prefill.pyas a private serving fast-path helper. It uses the original auto split model for low/medium head counts and raises the H64 long-sequence local chunk floor to 256. This fixes the 64K/H64 case where the earlier heuristic disabled real sequence partitioning and fell back to one long fused-forward loop per head.GDN-Relevant PR Shape
Verification
Local checks retained from the original PR:
Benchmark plumbing smoke for the manifest BTHD prefill case, docker on H200/GPU4. This row is intentionally small (
S=512) and only checks that the manifest, BTHD benchmark path, TileOps op, and FLA baseline agree on layout/shape plumbing; it is not the long-sequence performance headline.Long-Sequence Performance
Fresh cross-backend run on H200/GPU4, CUPTI timer,
warmup=10,repeat=100,trials=7, BTHD,B=1,DK=DV=128,chunk=64,fp16, seed 0.9282b2b516c695d89ad71db1a37f088f22af3115(9282b2b5).flapackage version0.5.1installed in that same benchmark docker.baseline latency / TileOps latency), so higher is better for TileOps.32K/H160.369577 ms2.364970 ms0.583178 ms6.399x1.578x64K/H160.690859 ms5.265457 ms1.331969 ms7.622x1.928x128K/H161.227298 ms10.899666 ms2.623757 ms8.881x2.138x128K/H322.384966 ms16.687767 ms5.395038 ms6.997x2.262xArtifacts:
Hardware SOL Comparison
Estimated from the TileOps manifest roofline model for
GatedDeltaNetPrefillFwdOpand H200 peak assumptions (989 TFLOP/sdense FP16/BF16 tensor core,4.8 TB/sHBM). These are model-based speed-of-light numbers, not Nsight Compute measured utilization.32K/H160.069 TF0.674 GB0.140356 ms0.369577 ms38.0%64K/H160.137 TF1.347 GB0.280603 ms0.690859 ms40.6%128K/H160.275 TF2.693 GB0.561097 ms1.227298 ms45.7%128K/H320.550 TF5.387 GB1.122195 ms2.384966 ms47.1%Correctness
TileOps TL0.1.9 vs FLA 0.5.1,
atol=rtol=5e-2:o_max_absstate_max_abs32K/H160.00112915040.000284669964K/H160.00125122070.0006084554128K/H160.00128173830.0004399344128K/H320.00112915040.0005266964FlashQLA TL0.1.8 vs FLA, same tolerances:
o_max_absstate_max_abs32K/H160.00112915040.000286577364K/H160.00125122070.0006085336128K/H160.00128173830.0004400611128K/H320.00112915040.0005248971Additional H64 correctness and performance diagnostics remain in the result artifacts; the headline table above uses the Qwen-style four-line benchmark shape set.
Artifacts: