[Bench][Linear-Attn] Add Qwen GDN prefill benchmark surface#1605
[Bench][Linear-Attn] Add Qwen GDN prefill benchmark surface#1605superAngGao wants to merge 3 commits into
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces the Gated DeltaNet inference prefill operator (GatedDeltaNetPrefillFwdOp) along with its corresponding kernels, benchmarks, and tests, optimized for serving-oriented zero-state prefill with support for both bhtd and bthd layouts. The review feedback highlights several critical issues in the newly added prefill kernels, including multiple race conditions in the prepare and blocksolve kernels due to missing thread synchronization (T.sync_threads()) during shared memory operations. Additionally, the feedback points out redundant writes in the chunk-local cumsum kernel and a layout mismatch when profiling the FLA baseline in the benchmarking script.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| for _r in T.Serial(num_rounds): | ||
| T.clear(temp_frag) | ||
| T.gemm(P_shared, S_shared, temp_frag) | ||
| for i, j in T.Parallel(block_C, block_C): | ||
| S_shared[i, j] = S_shared[i, j] + temp_frag[i, j] | ||
| T.clear(temp_frag) | ||
| T.gemm(P_shared, P_shared, temp_frag) | ||
| T.copy(temp_frag, P_shared) |
There was a problem hiding this comment.
Inside the polynomial expansion loop, S_shared and P_shared are written to and then read in the next iteration without any thread synchronization. This introduces a critical race condition. Add a T.sync_threads() call at the end of the loop body.
| for _r in T.Serial(num_rounds): | |
| T.clear(temp_frag) | |
| T.gemm(P_shared, S_shared, temp_frag) | |
| for i, j in T.Parallel(block_C, block_C): | |
| S_shared[i, j] = S_shared[i, j] + temp_frag[i, j] | |
| T.clear(temp_frag) | |
| T.gemm(P_shared, P_shared, temp_frag) | |
| T.copy(temp_frag, P_shared) | |
| for _r in T.Serial(num_rounds): | |
| T.clear(temp_frag) | |
| T.gemm(P_shared, S_shared, temp_frag) | |
| for i, j in T.Parallel(block_C, block_C): | |
| S_shared[i, j] = S_shared[i, j] + temp_frag[i, j] | |
| T.clear(temp_frag) | |
| T.gemm(P_shared, P_shared, temp_frag) | |
| T.copy(temp_frag, P_shared) | |
| T.sync_threads() |
| for _r in T.Serial(num_rounds): | ||
| T.clear(temp_frag) | ||
| T.gemm(P_shared, S_shared, temp_frag) | ||
| for i, j in T.Parallel(block_C, block_C): | ||
| S_shared[i, j] = S_shared[i, j] + temp_frag[i, j] | ||
| T.clear(temp_frag) | ||
| T.gemm(P_shared, P_shared, temp_frag) | ||
| T.copy(temp_frag, P_shared) |
There was a problem hiding this comment.
Inside the polynomial expansion loop for the bthd prepare kernel, S_shared and P_shared are written to and then read in the next iteration without any thread synchronization. This introduces a critical race condition. Add a T.sync_threads() call at the end of the loop body.
| for _r in T.Serial(num_rounds): | |
| T.clear(temp_frag) | |
| T.gemm(P_shared, S_shared, temp_frag) | |
| for i, j in T.Parallel(block_C, block_C): | |
| S_shared[i, j] = S_shared[i, j] + temp_frag[i, j] | |
| T.clear(temp_frag) | |
| T.gemm(P_shared, P_shared, temp_frag) | |
| T.copy(temp_frag, P_shared) | |
| for _r in T.Serial(num_rounds): | |
| T.clear(temp_frag) | |
| T.gemm(P_shared, S_shared, temp_frag) | |
| for i, j in T.Parallel(block_C, block_C): | |
| S_shared[i, j] = S_shared[i, j] + temp_frag[i, j] | |
| T.clear(temp_frag) | |
| T.gemm(P_shared, P_shared, temp_frag) | |
| T.copy(temp_frag, P_shared) | |
| T.sync_threads() |
| for i, j in T.Parallel(block_c, block_c): | ||
| a_s[1, i, j] = -tmp[i, j] | ||
|
|
||
| T.clear(tmp) | ||
| T.gemm(i_s[2, :, :], a_s[4, :, :], tmp) | ||
| for i, j in T.Parallel(block_c, block_c): | ||
| work_s[0, i, j] = tmp[i, j] | ||
| T.sync_threads() | ||
| T.clear(tmp) | ||
| T.gemm(work_s[0, :, :], i_s[1, :, :], tmp) | ||
| for i, j in T.Parallel(block_c, block_c): | ||
| a_s[4, i, j] = -tmp[i, j] | ||
|
|
||
| T.clear(tmp) | ||
| T.gemm(a_s[3, :, :], i_s[0, :, :], tmp) | ||
| for i, j in T.Parallel(block_c, block_c): | ||
| work_s[0, i, j] = tmp[i, j] | ||
| T.clear(tmp) | ||
| T.gemm(a_s[4, :, :], a_s[1, :, :], tmp) | ||
| for i, j in T.Parallel(block_c, block_c): |
There was a problem hiding this comment.
In _prefill_blocksolve_A_bthd_tl, shared memory arrays like a_s are written to (e.g., a_s[4, i, j] = -tmp[i, j]) and then read in subsequent T.gemm calls (e.g., T.gemm(a_s[4, :, :], a_s[1, :, :], tmp)) without a T.sync_threads() call in between. This introduces severe race conditions. Ensure that thread synchronization is performed between writing to shared memory and reading from it in subsequent block-level operations.
for i, j in T.Parallel(block_c, block_c):
a_s[1, i, j] = -tmp[i, j]
T.clear(tmp)
T.gemm(i_s[2, :, :], a_s[4, :, :], tmp)
for i, j in T.Parallel(block_c, block_c):
work_s[0, i, j] = tmp[i, j]
T.sync_threads()
T.clear(tmp)
T.gemm(work_s[0, :, :], i_s[1, :, :], tmp)
for i, j in T.Parallel(block_c, block_c):
a_s[4, i, j] = -tmp[i, j]
T.sync_threads()
T.clear(tmp)
T.gemm(a_s[3, :, :], i_s[0, :, :], tmp)
for i, j in T.Parallel(block_c, block_c):
work_s[0, i, j] = tmp[i, j]
T.clear(tmp)
T.gemm(a_s[4, :, :], a_s[1, :, :], tmp)| result = bm.profile(op, *inputs) | ||
| result_fla = bm.profile(fla_fn, *inputs) | ||
| result["speedup_vs_fla"] = result_fla["latency_ms"] / result["latency_ms"] |
There was a problem hiding this comment.
When layout == "bhtd", the generated inputs will be in bhtd layout. However, fla_fn (which wraps FLA's chunk_gated_delta_rule) expects inputs in bthd layout. Passing bhtd inputs directly to fla_fn will result in a shape mismatch or incorrect profiling results. Transpose the inputs to bthd before profiling the FLA baseline.
| result = bm.profile(op, *inputs) | |
| result_fla = bm.profile(fla_fn, *inputs) | |
| result["speedup_vs_fla"] = result_fla["latency_ms"] / result["latency_ms"] | |
| if layout == "bhtd": | |
| fla_inputs = ( | |
| inputs[0].permute(0, 2, 1, 3).contiguous(), | |
| inputs[1].permute(0, 2, 1, 3).contiguous(), | |
| inputs[2].permute(0, 2, 1, 3).contiguous(), | |
| inputs[3].permute(0, 2, 1).contiguous(), | |
| inputs[4].permute(0, 2, 1).contiguous(), | |
| ) | |
| else: | |
| fla_inputs = inputs | |
| result = bm.profile(op, *inputs) | |
| result_fla = bm.profile(fla_fn, *fla_inputs) |
| out_s[0] = T.cast(g_s[0], "float32") | ||
| for i in T.Serial(1, chunk_size): | ||
| out_s[i] = out_s[i - 1] + T.cast(g_s[i], "float32") |
There was a problem hiding this comment.
The serial loop for i in T.Serial(1, chunk_size): is executed by all 128 threads in the block, causing redundant writes and potential race conditions on the shared memory out_s. Wrap the serial accumulation in a single-thread parallel block to avoid redundant work.
| out_s[0] = T.cast(g_s[0], "float32") | |
| for i in T.Serial(1, chunk_size): | |
| out_s[i] = out_s[i - 1] + T.cast(g_s[i], "float32") | |
| for _ in T.Parallel(1): | |
| out_s[0] = T.cast(g_s[0], "float32") | |
| for i in T.Serial(1, chunk_size): | |
| out_s[i] = out_s[i - 1] + T.cast(g_s[i], "float32") |
Depends on #1596. Addresses #1597.
This is opened against upstream as a draft. Until #1596 lands, the diff will include the GDN prefill implementation commits from that dependency; after #1596 merges this branch should be rebased/updated so the PR contains only the benchmark-surface commit.
Summary
Validation
Notes