Skip to content

[Perf][Linear-Attn] Add optimized Gated DeltaNet prefill op#1596

Open
superAngGao wants to merge 6 commits into
tile-ai:mainfrom
superAngGao:impl/gdn-prefill-op
Open

[Perf][Linear-Attn] Add optimized Gated DeltaNet prefill op#1596
superAngGao wants to merge 6 commits into
tile-ai:mainfrom
superAngGao:impl/gdn-prefill-op

Conversation

@superAngGao

@superAngGao superAngGao commented Jun 17, 2026

Copy link
Copy Markdown
Collaborator

Closes #1595.

Summary

This PR adds the serving-oriented GatedDeltaNetPrefillFwdOp implementation for zero-state inference prefill.

  • Adds optimized TileLang production kernels for GDN prefill, including BTHD prepare, recurrence/output, scan guard, recompute, and blocksolve-A.
  • Adds a production partitioned prefill fast path for the serving shape B=1, layout="bthd", fp16, DK=DV=128, chunk=64.
  • Makes BTHD the public prefill contract/default layout, matching FLA/Qwen inference prefill, while keeping bhtd/bhsd compatibility through the op API.
  • The partitioned prefill path returns only (o, final_state) and avoids materializing the previous global w/u/S/v_new intermediates on the hot serving path.
  • Dispatch is enabled by default for the supported shape; TILEOPS_GDN_PREFILL_PARTITIONED=0 remains as a kill switch.
  • Adds TileOps-owned partitioned prefill kernels, prefill-local preprocessing, initial-state correction, and TileLang 0.1.9 GEMM compatibility support under the single gdn_prefill/ layer.
  • Exposes the op path (q, k, v, g, beta) -> (o, final_state) without training/backward-only intermediates.
  • Adds focused unit coverage for numerical parity, layout handling, final-state behavior, invalid-input checks, and the structured transition reference.
  • Adds a GDN prefill benchmark entry that measures the serving BTHD path for both TileOps and FLA, with a torch reference fallback when FLA is unavailable.
  • Updates tileops/manifest/linear_attention.yaml from spec-only to implemented and records BTHD serving workload shapes.

Partitioned Prefill Dispatch

For B=1, layout="bthd", fp16, DK=DV=128, chunk=64, production dispatch now uses:

chunk_cumsum_bthd
-> blocksolve_A_bthd
-> intra_card_cp_preprocess
-> fused_gdr_fwd
-> o + final_state

The partitioning heuristic lives in gated_deltanet_prefill.py as a private serving fast-path helper. It uses the original auto split model for low/medium head counts and raises the H64 long-sequence local chunk floor to 256. This fixes the 64K/H64 case where the earlier heuristic disabled real sequence partitioning and fell back to one long fused-forward loop per head.

GDN-Relevant PR Shape

M tileops/manifest/linear_attention.yaml
A benchmarks/ops/bench_gated_deltanet_prefill.py
A tests/ops/test_gated_deltanet_prefill.py
M tileops/kernels/__init__.py
M tileops/kernels/gated_deltanet/__init__.py
A tileops/kernels/gated_deltanet/gdn_prefill/__init__.py
A tileops/kernels/gated_deltanet/gdn_prefill/cp_fwd.py
A tileops/kernels/gated_deltanet/gdn_prefill/fused_fwd.py
A tileops/kernels/gated_deltanet/gdn_prefill/prepare_h.py
A tileops/kernels/gated_deltanet/gdn_prefill/tilelang_compat.py
A tileops/kernels/gated_deltanet/gdn_prefill/utils.py
M tileops/kernels/gated_deltanet/gated_deltanet_fwd.py
A tileops/kernels/gated_deltanet/gated_deltanet_prefill.py
M tileops/ops/__init__.py
M tileops/ops/gated_deltanet.py
M tileops/perf/formulas.py
M workloads/gated_deltanet.py

Verification

Local checks retained from the original PR:

TMPDIR=/home/ga/tmp/tvm-test TILELANG_CLEANUP_TEMP_FILES=1 PYTHONPATH=/home/ga/TileOPs-pr1596 /home/ga/anaconda3/bin/python -m pytest tests/ops/test_gated_deltanet_prefill.py -q
8 passed, 2 warnings

PYTHONPATH=/home/ga/TileOPs-pr1596 /home/ga/anaconda3/bin/python -m pytest -q benchmarks/tests
8 passed

python -m py_compile tileops/ops/gated_deltanet.py benchmarks/ops/bench_gated_deltanet_prefill.py tests/ops/test_gated_deltanet_prefill.py
git diff --check
No output

Benchmark plumbing smoke for the manifest BTHD prefill case, docker on H200/GPU4. This row is intentionally small (S=512) and only checks that the manifest, BTHD benchmark path, TileOps op, and FLA baseline agree on layout/shape plumbing; it is not the long-sequence performance headline.

CUDA_VISIBLE_DEVICES=4 TILELANG_CACHE_KEY=bench_prefill_bthd_manifest_default \
/home/ga/Documents/gdn_kernel_bench_2026-06-18/run_tileops_pr1596_docker.sh \
python -m pytest /workspace/TileOPs/benchmarks/ops/bench_gated_deltanet_prefill.py \
-q -k 'gdn-prefill-b1-s512-h16-d128-float16'

1 passed, 7 deselected
TileOps 0.0529 ms, FLA 0.0530 ms

Long-Sequence Performance

Fresh cross-backend run on H200/GPU4, CUPTI timer, warmup=10, repeat=100, trials=7, BTHD, B=1, DK=DV=128, chunk=64, fp16, seed 0.

  • The GitHub PR head and the local benchmark worktree are the same commit: 9282b2b516c695d89ad71db1a37f088f22af3115 (9282b2b5).
  • TileOps was measured in the PR TileLang 0.1.9 docker; the FLA baseline is the fla package version 0.5.1 installed in that same benchmark docker.
  • FlashQLA was measured in the FlashQLA TileLang 0.1.8 docker, which is the reliable public FlashQLA lowering environment for this comparison.
  • Ratio columns report throughput speedup (baseline latency / TileOps latency), so higher is better for TileOps.
Case TileOps TL0.1.9 FLA 0.5.1 FlashQLA TL0.1.8 TileOps/FLA TileOps/FlashQLA
32K/H16 0.369577 ms 2.364970 ms 0.583178 ms 6.399x 1.578x
64K/H16 0.690859 ms 5.265457 ms 1.331969 ms 7.622x 1.928x
128K/H16 1.227298 ms 10.899666 ms 2.623757 ms 8.881x 2.138x
128K/H32 2.384966 ms 16.687767 ms 5.395038 ms 6.997x 2.262x

Artifacts:

results/flashqla_specialized/qwen_fourline_tileops_flashqla_20260625.jsonl
results/flashqla_specialized/qwen_fourline_fla_20260625.jsonl

Hardware SOL Comparison

Estimated from the TileOps manifest roofline model for GatedDeltaNetPrefillFwdOp and H200 peak assumptions (989 TFLOP/s dense FP16/BF16 tensor core, 4.8 TB/s HBM). These are model-based speed-of-light numbers, not Nsight Compute measured utilization.

Case Roofline work Roofline bytes H200 SOL time TileOps CUPTI SOL efficiency
32K/H16 0.069 TF 0.674 GB 0.140356 ms 0.369577 ms 38.0%
64K/H16 0.137 TF 1.347 GB 0.280603 ms 0.690859 ms 40.6%
128K/H16 0.275 TF 2.693 GB 0.561097 ms 1.227298 ms 45.7%
128K/H32 0.550 TF 5.387 GB 1.122195 ms 2.384966 ms 47.1%

Correctness

TileOps TL0.1.9 vs FLA 0.5.1, atol=rtol=5e-2:

Case status o_max_abs state_max_abs
32K/H16 ok 0.0011291504 0.0002846699
64K/H16 ok 0.0012512207 0.0006084554
128K/H16 ok 0.0012817383 0.0004399344
128K/H32 ok 0.0011291504 0.0005266964

FlashQLA TL0.1.8 vs FLA, same tolerances:

Case status o_max_abs state_max_abs
32K/H16 ok 0.0011291504 0.0002865773
64K/H16 ok 0.0012512207 0.0006085336
128K/H16 ok 0.0012817383 0.0004400611
128K/H32 ok 0.0011291504 0.0005248971

Additional H64 correctness and performance diagnostics remain in the result artifacts; the headline table above uses the Qwen-style four-line benchmark shape set.

Artifacts:

results/flashqla_specialized/pr1596_tri_compare_correctness_20260624.jsonl
results/flashqla_specialized/pr1596_flashqla_tl018_correctness_20260624.jsonl
results/flashqla_specialized/pr1596_tileops_correctness_128k_h32_20260625.jsonl
results/flashqla_specialized/pr1596_flashqla_tl018_correctness_128k_h32_20260625.jsonl

@Gabbering Gabbering left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

goose goose skimmed 58a4ea4 — nothing to honk about.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements the Gated DeltaNet inference prefill operator (GatedDeltaNetPrefillFwdOp) and its corresponding kernel, along with associated tests, benchmarks, and performance formulas. Feedback on the changes highlights two critical issues: first, input tensors should be made contiguous in the operator's forward pass to prevent memory access errors or crashes; second, the optimized scan path condition (use_scan_h) in the prefill kernel needs an additional divisibility check on the sequence length to avoid silent correctness bugs on non-divisible lengths.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread tileops/ops/gated_deltanet.py
Comment thread tileops/kernels/gated_deltanet/gated_deltanet_prefill.py
@superAngGao superAngGao force-pushed the impl/gdn-prefill-op branch from 58a4ea4 to ec23c79 Compare June 17, 2026 09:24

@Gabbering Gabbering left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

goose goose skimmed ec23c79 — nothing to honk about.

@superAngGao superAngGao force-pushed the impl/gdn-prefill-op branch from ec23c79 to b617d52 Compare June 17, 2026 09:37
@superAngGao superAngGao changed the title [Linear-Attn] Add optimized Gated DeltaNet prefill op [Perf][Linear-Attn] Add optimized Gated DeltaNet prefill op Jun 17, 2026

@Gabbering Gabbering left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

goose goose skimmed b617d52 — nothing to honk about.

@superAngGao superAngGao force-pushed the impl/gdn-prefill-op branch from b617d52 to 2170e25 Compare June 17, 2026 09:39

@Gabbering Gabbering left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

goose goose skimmed 2170e25 — nothing to honk about.

@Gabbering Gabbering left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

goose goose review — 17bfa8b9

honk. Both inline thread comments were addressed by the master. The goose has re-read the conversation and confirmed the following:

  1. Thread 1 (tileops/ops/gated_deltanet.py:303): The master added .contiguous() calls for all five input tensors before validation and kernel dispatch, and wrote a test (test_gated_deltanet_prefill_bthd_layout_matches_bhtd) that passes non-contiguous permuted inputs through the BTHD path. The goose's contiguity concern is resolved.

  2. Thread 2 (tileops/kernels/gated_deltanet/gated_deltanet_prefill.py:852): The master defined group_chunks = 64 before the use_scan_h guard and added and seq_len % (chunk_size * group_chunks) == 0 to the condition. Non-divisible long sequences now fall back to the standard recurrence path. The goose's scan-path divisibility concern is resolved.

Both fixes are present in the delta 17bfa8b9 and the goose sees no new issues introduced by this commit.

SILENT

@superAngGao superAngGao requested a review from Ibuki-wind June 17, 2026 11:29

@Ibuki-wind Ibuki-wind left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall

The new manifest-driven benchmark does not record an independent baseline, so the benchmark contract is not satisfied yet.

Comment thread benchmarks/ops/bench_gated_deltanet_prefill.py
@superAngGao superAngGao force-pushed the impl/gdn-prefill-op branch from 17bfa8b to d09c8f2 Compare June 18, 2026 03:00

@Gabbering Gabbering left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

goose goose skimmed d09c8f2 — nothing to honk about.

@superAngGao superAngGao requested a review from Ibuki-wind June 18, 2026 04:48

@Ibuki-wind Ibuki-wind left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall

The benchmark baseline blocker is fixed, but the public BTHD layout contract is not represented in the manifest and the approval-gate authoring hygiene still needs cleanup. Update the manifest shape rules for layout, make the PR body follow the template with final verification facts including pre-commit/benchmark/test-node-delta, and edit author replies to outcome-only one-liners.

Comment thread tileops/manifest/linear_attention.yaml Outdated
lcy-seso pushed a commit that referenced this pull request Jun 21, 2026
## Summary
- update the GatedDeltaNetPrefillFwdOp spec-only manifest contract to
allow chunk_size=None
- add the layout parameter used by the implementation path
- add an H16/S2K manifest workload row for the target prefill shape
family

## Notes
- This is a manifest-only precursor for PR #1596.
- It intentionally keeps status: spec-only and leaves source paths null;
the implementation PR should flip status/source/roofline once the kernel
code lands.

## Validation
- python scripts/validate_manifest.py --check-op
GatedDeltaNetPrefillFwdOp
lcy-seso pushed a commit that referenced this pull request Jun 21, 2026
## Summary
- encode the GatedDeltaNetPrefillFwdOp BHTD/BHSD and BTHD layout shape
contracts in the manifest
- add fixed-rank shape metadata for the manifest tensors
- fill the planned roofline/source metadata while keeping the entry
spec-only

## Notes
- This is a manifest-only prerequisite for PR #1596.
- The implementation PR should only flip status and keep any remaining
manifest edit within the status-flip carve-out.

## Validation
- python scripts/validate_manifest.py
@superAngGao superAngGao force-pushed the impl/gdn-prefill-op branch from 0828ab6 to 067edc7 Compare June 22, 2026 02:16
@superAngGao superAngGao requested a review from Ibuki-wind June 22, 2026 03:39
@zhen8838

Copy link
Copy Markdown
Collaborator

For the performance comparison, please include FlashQLA as a baseline/reference as well: https://github.com/QwenLM/FlashQLA

It would be useful to compare on the same target hardware and representative GDN prefill shapes/dtypes so the remaining gap vs. the current public optimized implementation is explicit.

@superAngGao superAngGao force-pushed the impl/gdn-prefill-op branch from 9d21d88 to 9282b2b Compare June 24, 2026 10:51
@superAngGao

Copy link
Copy Markdown
Collaborator Author

Thanks for the pointer. I added FlashQLA as an explicit baseline in the PR body, measured on the same H200/GPU4 setup with the TileOps benchmark/CUPTI timer and representative Qwen-style GDN prefill rows.

The refreshed numbers show the TileOps path is now faster than the public FlashQLA reference on these rows. The main reason is not just a different parallelization strategy or a block/thread schedule tweak. The important change is the overall prefill algorithmic path: the serving path now only materializes o + final_state, avoids the old global w/u/S/v_new intermediates, uses CP-split replay with corrected segment initial states, and combines that with the TileOps-owned A producer and dispatch. Parallel schedule tuning still matters, but the larger win comes from this algorithmic decomposition of the inference prefill path.

I also kept FlashQLA attribution explicit in the PR body: its CP-split fused replay schedule was the key production reference for the back half of this work; the TileOps contribution is the port, integration, A-producer path, dispatch/tuning, and validation against FLA/FlashQLA.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Perf][Linear-Attn] Add FLA-compatible Gated DeltaNet prefill op

4 participants