Skip to content

[Bench][Linear-Attn] Add Qwen GDN prefill benchmark surface#1605

Draft
superAngGao wants to merge 3 commits into
tile-ai:mainfrom
superAngGao:issue-1597-gdn-prefill-bench
Draft

[Bench][Linear-Attn] Add Qwen GDN prefill benchmark surface#1605
superAngGao wants to merge 3 commits into
tile-ai:mainfrom
superAngGao:issue-1597-gdn-prefill-bench

Conversation

@superAngGao

Copy link
Copy Markdown
Collaborator

Depends on #1596. Addresses #1597.

This is opened against upstream as a draft. Until #1596 lands, the diff will include the GDN prefill implementation commits from that dependency; after #1596 merges this branch should be rebased/updated so the PR contains only the benchmark-surface commit.

Summary

  • Switch the Gated DeltaNet prefill benchmark surface to the FLA/Qwen BTHD contract.
  • Add Qwen3.5-aligned long-context rows for S=32K/64K/128K and H=16/32/48/64 across fp16/bf16.
  • Compare TileOps against FLA with output_final_state=True and surface speedup in profile_run.log.
  • Keep nightly JUnit properties compatible while allowing profile_run.log to show extra metrics and GPU clocks.

Validation

  • python -m py_compile benchmarks/ops/bench_gated_deltanet_prefill.py workloads/gated_deltanet.py benchmarks/benchmark_base.py tileops/perf/formulas.py
  • python scripts/validate_manifest.py
  • python -m pytest --collect-only benchmarks/ops/bench_gated_deltanet_prefill.py -q
  • TMPDIR=/home/ga/tmp python -m pytest tests/ops/test_gated_deltanet_prefill.py -m smoke -q

Notes

  • The local environment does not have FLA installed, so full benchmark execution was only checked for collection/skip behavior here.

@github-actions github-actions Bot added the bench Benchmark updates label Jun 23, 2026

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the Gated DeltaNet inference prefill operator (GatedDeltaNetPrefillFwdOp) along with its corresponding kernels, benchmarks, and tests, optimized for serving-oriented zero-state prefill with support for both bhtd and bthd layouts. The review feedback highlights several critical issues in the newly added prefill kernels, including multiple race conditions in the prepare and blocksolve kernels due to missing thread synchronization (T.sync_threads()) during shared memory operations. Additionally, the feedback points out redundant writes in the chunk-local cumsum kernel and a layout mismatch when profiling the FLA baseline in the benchmarking script.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +167 to +174
for _r in T.Serial(num_rounds):
T.clear(temp_frag)
T.gemm(P_shared, S_shared, temp_frag)
for i, j in T.Parallel(block_C, block_C):
S_shared[i, j] = S_shared[i, j] + temp_frag[i, j]
T.clear(temp_frag)
T.gemm(P_shared, P_shared, temp_frag)
T.copy(temp_frag, P_shared)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Inside the polynomial expansion loop, S_shared and P_shared are written to and then read in the next iteration without any thread synchronization. This introduces a critical race condition. Add a T.sync_threads() call at the end of the loop body.

Suggested change
for _r in T.Serial(num_rounds):
T.clear(temp_frag)
T.gemm(P_shared, S_shared, temp_frag)
for i, j in T.Parallel(block_C, block_C):
S_shared[i, j] = S_shared[i, j] + temp_frag[i, j]
T.clear(temp_frag)
T.gemm(P_shared, P_shared, temp_frag)
T.copy(temp_frag, P_shared)
for _r in T.Serial(num_rounds):
T.clear(temp_frag)
T.gemm(P_shared, S_shared, temp_frag)
for i, j in T.Parallel(block_C, block_C):
S_shared[i, j] = S_shared[i, j] + temp_frag[i, j]
T.clear(temp_frag)
T.gemm(P_shared, P_shared, temp_frag)
T.copy(temp_frag, P_shared)
T.sync_threads()

Comment on lines +835 to +842
for _r in T.Serial(num_rounds):
T.clear(temp_frag)
T.gemm(P_shared, S_shared, temp_frag)
for i, j in T.Parallel(block_C, block_C):
S_shared[i, j] = S_shared[i, j] + temp_frag[i, j]
T.clear(temp_frag)
T.gemm(P_shared, P_shared, temp_frag)
T.copy(temp_frag, P_shared)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Inside the polynomial expansion loop for the bthd prepare kernel, S_shared and P_shared are written to and then read in the next iteration without any thread synchronization. This introduces a critical race condition. Add a T.sync_threads() call at the end of the loop body.

Suggested change
for _r in T.Serial(num_rounds):
T.clear(temp_frag)
T.gemm(P_shared, S_shared, temp_frag)
for i, j in T.Parallel(block_C, block_C):
S_shared[i, j] = S_shared[i, j] + temp_frag[i, j]
T.clear(temp_frag)
T.gemm(P_shared, P_shared, temp_frag)
T.copy(temp_frag, P_shared)
for _r in T.Serial(num_rounds):
T.clear(temp_frag)
T.gemm(P_shared, S_shared, temp_frag)
for i, j in T.Parallel(block_C, block_C):
S_shared[i, j] = S_shared[i, j] + temp_frag[i, j]
T.clear(temp_frag)
T.gemm(P_shared, P_shared, temp_frag)
T.copy(temp_frag, P_shared)
T.sync_threads()

Comment on lines +528 to +547
for i, j in T.Parallel(block_c, block_c):
a_s[1, i, j] = -tmp[i, j]

T.clear(tmp)
T.gemm(i_s[2, :, :], a_s[4, :, :], tmp)
for i, j in T.Parallel(block_c, block_c):
work_s[0, i, j] = tmp[i, j]
T.sync_threads()
T.clear(tmp)
T.gemm(work_s[0, :, :], i_s[1, :, :], tmp)
for i, j in T.Parallel(block_c, block_c):
a_s[4, i, j] = -tmp[i, j]

T.clear(tmp)
T.gemm(a_s[3, :, :], i_s[0, :, :], tmp)
for i, j in T.Parallel(block_c, block_c):
work_s[0, i, j] = tmp[i, j]
T.clear(tmp)
T.gemm(a_s[4, :, :], a_s[1, :, :], tmp)
for i, j in T.Parallel(block_c, block_c):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

In _prefill_blocksolve_A_bthd_tl, shared memory arrays like a_s are written to (e.g., a_s[4, i, j] = -tmp[i, j]) and then read in subsequent T.gemm calls (e.g., T.gemm(a_s[4, :, :], a_s[1, :, :], tmp)) without a T.sync_threads() call in between. This introduces severe race conditions. Ensure that thread synchronization is performed between writing to shared memory and reading from it in subsequent block-level operations.

                for i, j in T.Parallel(block_c, block_c):
                    a_s[1, i, j] = -tmp[i, j]

                T.clear(tmp)
                T.gemm(i_s[2, :, :], a_s[4, :, :], tmp)
                for i, j in T.Parallel(block_c, block_c):
                    work_s[0, i, j] = tmp[i, j]
                T.sync_threads()
                T.clear(tmp)
                T.gemm(work_s[0, :, :], i_s[1, :, :], tmp)
                for i, j in T.Parallel(block_c, block_c):
                    a_s[4, i, j] = -tmp[i, j]
                T.sync_threads()

                T.clear(tmp)
                T.gemm(a_s[3, :, :], i_s[0, :, :], tmp)
                for i, j in T.Parallel(block_c, block_c):
                    work_s[0, i, j] = tmp[i, j]
                T.clear(tmp)
                T.gemm(a_s[4, :, :], a_s[1, :, :], tmp)

Comment on lines +191 to +193
result = bm.profile(op, *inputs)
result_fla = bm.profile(fla_fn, *inputs)
result["speedup_vs_fla"] = result_fla["latency_ms"] / result["latency_ms"]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

When layout == "bhtd", the generated inputs will be in bhtd layout. However, fla_fn (which wraps FLA's chunk_gated_delta_rule) expects inputs in bthd layout. Passing bhtd inputs directly to fla_fn will result in a shape mismatch or incorrect profiling results. Transpose the inputs to bthd before profiling the FLA baseline.

Suggested change
result = bm.profile(op, *inputs)
result_fla = bm.profile(fla_fn, *inputs)
result["speedup_vs_fla"] = result_fla["latency_ms"] / result["latency_ms"]
if layout == "bhtd":
fla_inputs = (
inputs[0].permute(0, 2, 1, 3).contiguous(),
inputs[1].permute(0, 2, 1, 3).contiguous(),
inputs[2].permute(0, 2, 1, 3).contiguous(),
inputs[3].permute(0, 2, 1).contiguous(),
inputs[4].permute(0, 2, 1).contiguous(),
)
else:
fla_inputs = inputs
result = bm.profile(op, *inputs)
result_fla = bm.profile(fla_fn, *fla_inputs)

Comment on lines +72 to +74
out_s[0] = T.cast(g_s[0], "float32")
for i in T.Serial(1, chunk_size):
out_s[i] = out_s[i - 1] + T.cast(g_s[i], "float32")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The serial loop for i in T.Serial(1, chunk_size): is executed by all 128 threads in the block, causing redundant writes and potential race conditions on the shared memory out_s. Wrap the serial accumulation in a single-thread parallel block to avoid redundant work.

Suggested change
out_s[0] = T.cast(g_s[0], "float32")
for i in T.Serial(1, chunk_size):
out_s[i] = out_s[i - 1] + T.cast(g_s[i], "float32")
for _ in T.Parallel(1):
out_s[0] = T.cast(g_s[0], "float32")
for i in T.Serial(1, chunk_size):
out_s[i] = out_s[i - 1] + T.cast(g_s[i], "float32")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bench Benchmark updates

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant