From ffcf8559f899ee265faea4df230b801d954fd525 Mon Sep 17 00:00:00 2001 From: SuperAngGao Date: Tue, 23 Jun 2026 14:31:27 +0800 Subject: [PATCH] [Bench][Linear-Attn] Add Qwen GDN prefill benchmark surface --- benchmarks/benchmark_base.py | 32 +++++- .../ops/bench_gated_deltanet_prefill.py | 107 +++++++++++++----- tileops/manifest/linear_attention.yaml | 21 +++- tileops/perf/formulas.py | 9 +- workloads/gated_deltanet.py | 28 ++++- 5 files changed, 161 insertions(+), 36 deletions(-) diff --git a/benchmarks/benchmark_base.py b/benchmarks/benchmark_base.py index 28758b058..29e6b0ff0 100644 --- a/benchmarks/benchmark_base.py +++ b/benchmarks/benchmark_base.py @@ -278,6 +278,31 @@ def _get_env_metadata() -> list[str]: driver = "N/A" lines.append(f"- **Driver version**: {driver}") + try: + result = subprocess.run( + [ + "nvidia-smi", + "--query-gpu=clocks.current.sm,clocks.current.memory," + "clocks.applications.graphics,clocks.applications.memory", + "--format=csv,noheader,nounits", + ], + capture_output=True, + text=True, + timeout=5, + ) + if result.returncode == 0: + clock_values = [part.strip() for part in result.stdout.splitlines()[0].split(",")] + if len(clock_values) == 4: + sm_clock, mem_clock, app_sm_clock, app_mem_clock = clock_values + lines.append( + "- **GPU clocks**: " + f"SM current {sm_clock} MHz, memory current {mem_clock} MHz, " + f"application SM {app_sm_clock} MHz, " + f"application memory {app_mem_clock} MHz" + ) + except (FileNotFoundError, IndexError, subprocess.TimeoutExpired): + pass + return lines @@ -569,7 +594,7 @@ def dump(path: str) -> None: lines.extend(_get_env_metadata()) lines.append("") - result_keys = ["latency_ms", "tflops", "bandwidth_tbs"] + default_result_keys = ["latency_ms", "tflops", "bandwidth_tbs"] for name, entries in BenchmarkReport._records.items(): if not entries: @@ -582,6 +607,11 @@ def dump(path: str) -> None: tag_entries = {} for entry in entries: tag_entries.setdefault(entry["tag"], []).append(entry) + result_keys = list(default_result_keys) + for entry in entries: + for key in entry["result"]: + if key not in result_keys: + result_keys.append(key) for tag, tag_group in tag_entries.items(): lines.append(f"### {tag}") diff --git a/benchmarks/ops/bench_gated_deltanet_prefill.py b/benchmarks/ops/bench_gated_deltanet_prefill.py index 6d87e1f5c..9aea62c87 100644 --- a/benchmarks/ops/bench_gated_deltanet_prefill.py +++ b/benchmarks/ops/bench_gated_deltanet_prefill.py @@ -1,7 +1,8 @@ """Benchmark: TileOPs Gated DeltaNet inference prefill. -When FLA is installed, record it as the independent baseline. Otherwise fall -back to a pure-torch reference so every benchmark row has a non-TileOps entry. +Rows use the FLA/Qwen BTHD public contract directly: + q/k: [B, T, H, DK], v/o: [B, T, H, DV], g/beta: [B, T, H] +and compare TileOps against FLA with output_final_state=True. """ import inspect @@ -13,7 +14,6 @@ from benchmarks.benchmark_base import BenchmarkReport, ManifestBenchmark from benchmarks.ops.attention.manifest_params import manifest_params from benchmarks.ops.bench_gated_deltanet import ( - _to_fla_layout, compute_w_u_torch, kernel2_gated_deltanet_torch, prepare_wy_repr_gated_torch, @@ -42,6 +42,7 @@ def ref_program( g: torch.Tensor, beta: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: + q, k, v, g, beta = _to_bhtd_layout(q, k, v, g, beta, self.layout) batch, heads, seq_len, dim_k = k.shape dim_v = v.shape[-1] chunk_size = self.chunk_size @@ -61,7 +62,7 @@ def ref_program( final_state, o = kernel2_gated_deltanet_torch( q, k, g_cum, w, u, initial_state, chunk_size ) - return o.to(self.dtype), final_state.to(self.dtype) + return _from_bhtd_layout(o.to(self.dtype), final_state.to(self.dtype), self.layout) def _fla_prefill_fwd(): @@ -70,7 +71,8 @@ def _fla_prefill_fwd(): return None signature = inspect.signature(chunk_gated_delta_rule) - supports_output_final_state = "output_final_state" in signature.parameters + if "output_final_state" not in signature.parameters: + return None def baseline_fn( q: torch.Tensor, @@ -79,27 +81,79 @@ def baseline_fn( g: torch.Tensor, beta: torch.Tensor, ): - kwargs: dict[str, Any] = {"scale": 1.0} - if supports_output_final_state: - kwargs["output_final_state"] = True - return chunk_gated_delta_rule(q, k, v, g, beta, **kwargs) + return chunk_gated_delta_rule( + q, + k, + v, + g, + beta, + scale=1.0, + output_final_state=True, + ) return baseline_fn -def _gdn_prefill_args(workload: dict[str, Any]) -> tuple[int, int, int, int, int, int]: - batch, heads, seq_len, dim_k = workload["q_shape"] - _, _, v_seq_len, dim_v = workload["v_shape"] - if v_seq_len != seq_len: - raise ValueError("GDN prefill q_shape and v_shape must share seq_len") - return batch, heads, seq_len, dim_k, dim_v, workload.get("chunk_size", 64) +def _to_bhtd_layout( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + layout: str, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + layout = layout.lower() + if layout == "bthd": + return ( + q.permute(0, 2, 1, 3).contiguous(), + k.permute(0, 2, 1, 3).contiguous(), + v.permute(0, 2, 1, 3).contiguous(), + g.permute(0, 2, 1).contiguous(), + beta.permute(0, 2, 1).contiguous(), + ) + if layout in ("bhtd", "bhsd"): + return q, k, v, g, beta + raise ValueError(f"Unsupported layout: {layout}") + + +def _from_bhtd_layout( + o: torch.Tensor, + final_state: torch.Tensor, + layout: str, +) -> tuple[torch.Tensor, torch.Tensor]: + if layout.lower() == "bthd": + return o.permute(0, 2, 1, 3).contiguous(), final_state + return o, final_state + + +def _gdn_prefill_args( + workload: dict[str, Any], +) -> tuple[int, int, int, int, int, int, str]: + layout = workload.get("layout", "bthd").lower() + if layout == "bthd": + batch, seq_len, heads, dim_k = workload["q_shape"] + _, v_seq_len, v_heads, dim_v = workload["v_shape"] + else: + batch, heads, seq_len, dim_k = workload["q_shape"] + _, v_heads, v_seq_len, dim_v = workload["v_shape"] + if v_seq_len != seq_len or v_heads != heads: + raise ValueError("GDN prefill q_shape and v_shape must share seq_len and heads") + return ( + batch, + heads, + seq_len, + dim_k, + dim_v, + workload.get("chunk_size", 64), + layout, + ) _BENCH_PARAMS = manifest_params(load_workloads(_OP_NAME), _gdn_prefill_args, tune=False) @pytest.mark.parametrize( - "batch, heads, seq_len, dim_k, dim_v, chunk_size, dtype, tune", + "batch, heads, seq_len, dim_k, dim_v, chunk_size, layout, dtype, tune", _BENCH_PARAMS, ) def test_gated_deltanet_prefill_fwd_bench( @@ -109,11 +163,16 @@ def test_gated_deltanet_prefill_fwd_bench( dim_k: int, dim_v: int, chunk_size: int, + layout: str, dtype: torch.dtype, tune: bool, ) -> None: + fla_fn = _fla_prefill_fwd() + if fla_fn is None: + pytest.skip("FLA chunk_gated_delta_rule with output_final_state=True is required") + test = _GatedDeltaNetPrefillFwdTestBaseline( - batch, heads, seq_len, dim_k, dim_v, chunk_size, dtype + batch, heads, seq_len, dim_k, dim_v, chunk_size, dtype, layout=layout ) inputs = test.gen_inputs() @@ -126,19 +185,15 @@ def test_gated_deltanet_prefill_fwd_bench( chunk_size, dtype, tune=tune, + layout=layout, ) bm = ManifestBenchmark(_OP_NAME, op, test) result = bm.profile(op, *inputs) - BenchmarkReport.record(op, locals(), result, tag="tileops") + result_fla = bm.profile(fla_fn, *inputs) + result["speedup_vs_fla"] = result_fla["latency_ms"] / result["latency_ms"] - fla_fn = _fla_prefill_fwd() - if fla_fn is not None: - fla_inputs = _to_fla_layout(*inputs) - result_fla = bm.profile(fla_fn, *fla_inputs) - BenchmarkReport.record(op, locals(), result_fla, tag="fla") - else: - result_ref = bm.profile(test.ref_program, *inputs) - BenchmarkReport.record(op, locals(), result_ref, tag="torch-ref") + BenchmarkReport.record(op, locals(), result, tag="tileops") + BenchmarkReport.record(op, locals(), result_fla, tag="fla") if __name__ == "__main__": diff --git a/tileops/manifest/linear_attention.yaml b/tileops/manifest/linear_attention.yaml index d00233ef1..55f5cd7e3 100644 --- a/tileops/manifest/linear_attention.yaml +++ b/tileops/manifest/linear_attention.yaml @@ -43,12 +43,21 @@ GatedDeltaNetPrefillFwdOp: - "chunk_size is None or S % chunk_size == 0" workloads: - # Inference-oriented BHSD rows. Unit tests should use smaller - # branch-covering shapes in tests/ops, not these benchmark workloads. - - {q_shape: [1, 16, 512, 128], k_shape: [1, 16, 512, 128], v_shape: [1, 16, 512, 128], g_shape: [1, 16, 512], beta_shape: [1, 16, 512], chunk_size: 64, dtypes: [float16, bfloat16], label: "gdn-prefill-b1-s512-h16-d128"} - - {q_shape: [4, 16, 2048, 64], k_shape: [4, 16, 2048, 64], v_shape: [4, 16, 2048, 128], g_shape: [4, 16, 2048], beta_shape: [4, 16, 2048], chunk_size: 64, dtypes: [float16, bfloat16], label: "gdn-prefill-b4-s2k-h16-dk64-dv128"} - - {q_shape: [1, 16, 2048, 128], k_shape: [1, 16, 2048, 128], v_shape: [1, 16, 2048, 128], g_shape: [1, 16, 2048], beta_shape: [1, 16, 2048], chunk_size: 64, dtypes: [float16, bfloat16], label: "gdn-prefill-b1-s2k-h16-d128"} - - {q_shape: [1, 16, 8192, 128], k_shape: [1, 16, 8192, 128], v_shape: [1, 16, 8192, 128], g_shape: [1, 16, 8192], beta_shape: [1, 16, 8192], chunk_size: 64, dtypes: [float16, bfloat16], label: "gdn-prefill-b1-s8k-h16-d128"} + # Qwen3.5-aligned inference prefill rows using the public FLA BTHD + # contract. Tests use smaller branch-covering shapes in tests/ops; these + # rows are intentionally long-context benchmark coverage. + - {q_shape: [1, 32768, 16, 128], k_shape: [1, 32768, 16, 128], v_shape: [1, 32768, 16, 128], g_shape: [1, 32768, 16], beta_shape: [1, 32768, 16], chunk_size: 64, layout: bthd, dtypes: [float16, bfloat16], label: "qwen35-gdn-prefill-b1-s32k-h16-d128-bthd"} + - {q_shape: [1, 65536, 16, 128], k_shape: [1, 65536, 16, 128], v_shape: [1, 65536, 16, 128], g_shape: [1, 65536, 16], beta_shape: [1, 65536, 16], chunk_size: 64, layout: bthd, dtypes: [float16, bfloat16], label: "qwen35-gdn-prefill-b1-s64k-h16-d128-bthd"} + - {q_shape: [1, 131072, 16, 128], k_shape: [1, 131072, 16, 128], v_shape: [1, 131072, 16, 128], g_shape: [1, 131072, 16], beta_shape: [1, 131072, 16], chunk_size: 64, layout: bthd, dtypes: [float16, bfloat16], label: "qwen35-gdn-prefill-b1-s128k-h16-d128-bthd"} + - {q_shape: [1, 32768, 32, 128], k_shape: [1, 32768, 32, 128], v_shape: [1, 32768, 32, 128], g_shape: [1, 32768, 32], beta_shape: [1, 32768, 32], chunk_size: 64, layout: bthd, dtypes: [float16, bfloat16], label: "qwen35-gdn-prefill-b1-s32k-h32-d128-bthd"} + - {q_shape: [1, 65536, 32, 128], k_shape: [1, 65536, 32, 128], v_shape: [1, 65536, 32, 128], g_shape: [1, 65536, 32], beta_shape: [1, 65536, 32], chunk_size: 64, layout: bthd, dtypes: [float16, bfloat16], label: "qwen35-gdn-prefill-b1-s64k-h32-d128-bthd"} + - {q_shape: [1, 131072, 32, 128], k_shape: [1, 131072, 32, 128], v_shape: [1, 131072, 32, 128], g_shape: [1, 131072, 32], beta_shape: [1, 131072, 32], chunk_size: 64, layout: bthd, dtypes: [float16, bfloat16], label: "qwen35-gdn-prefill-b1-s128k-h32-d128-bthd"} + - {q_shape: [1, 32768, 48, 128], k_shape: [1, 32768, 48, 128], v_shape: [1, 32768, 48, 128], g_shape: [1, 32768, 48], beta_shape: [1, 32768, 48], chunk_size: 64, layout: bthd, dtypes: [float16, bfloat16], label: "qwen35-gdn-prefill-b1-s32k-h48-d128-bthd"} + - {q_shape: [1, 65536, 48, 128], k_shape: [1, 65536, 48, 128], v_shape: [1, 65536, 48, 128], g_shape: [1, 65536, 48], beta_shape: [1, 65536, 48], chunk_size: 64, layout: bthd, dtypes: [float16, bfloat16], label: "qwen35-gdn-prefill-b1-s64k-h48-d128-bthd"} + - {q_shape: [1, 131072, 48, 128], k_shape: [1, 131072, 48, 128], v_shape: [1, 131072, 48, 128], g_shape: [1, 131072, 48], beta_shape: [1, 131072, 48], chunk_size: 64, layout: bthd, dtypes: [float16, bfloat16], label: "qwen35-gdn-prefill-b1-s128k-h48-d128-bthd"} + - {q_shape: [1, 32768, 64, 128], k_shape: [1, 32768, 64, 128], v_shape: [1, 32768, 64, 128], g_shape: [1, 32768, 64], beta_shape: [1, 32768, 64], chunk_size: 64, layout: bthd, dtypes: [float16, bfloat16], label: "qwen35-gdn-prefill-b1-s32k-h64-d128-bthd"} + - {q_shape: [1, 65536, 64, 128], k_shape: [1, 65536, 64, 128], v_shape: [1, 65536, 64, 128], g_shape: [1, 65536, 64], beta_shape: [1, 65536, 64], chunk_size: 64, layout: bthd, dtypes: [float16, bfloat16], label: "qwen35-gdn-prefill-b1-s64k-h64-d128-bthd"} + - {q_shape: [1, 131072, 64, 128], k_shape: [1, 131072, 64, 128], v_shape: [1, 131072, 64, 128], g_shape: [1, 131072, 64], beta_shape: [1, 131072, 64], chunk_size: 64, layout: bthd, dtypes: [float16, bfloat16], label: "qwen35-gdn-prefill-b1-s128k-h64-d128-bthd"} roofline: func: "tileops.perf.formulas.gated_deltanet_prefill_fwd_roofline" diff --git a/tileops/perf/formulas.py b/tileops/perf/formulas.py index 80a9f758c..4b235619f 100644 --- a/tileops/perf/formulas.py +++ b/tileops/perf/formulas.py @@ -222,8 +222,13 @@ def gated_deltanet_prefill_fwd_roofline(op: Any | None = None, **kwargs: Any) -> """ data = _shape_or_attrs(op, kwargs) if "q_shape" in data: - batch, heads, seq_len, dim_k = data["q_shape"] - _, _, _, dim_v = data["v_shape"] + layout = str(data.get("layout", "bhtd")).lower() + if layout == "bthd": + batch, seq_len, heads, dim_k = data["q_shape"] + _, _, _, dim_v = data["v_shape"] + else: + batch, heads, seq_len, dim_k = data["q_shape"] + _, _, _, dim_v = data["v_shape"] chunk_size = data.get("chunk_size", 64) else: batch, heads, seq_len, dim_k, dim_v, chunk_size = ( diff --git a/workloads/gated_deltanet.py b/workloads/gated_deltanet.py index 1200802bc..05f7731e3 100644 --- a/workloads/gated_deltanet.py +++ b/workloads/gated_deltanet.py @@ -45,9 +45,35 @@ def __init__( dim_v: int, chunk_size: int, dtype: torch.dtype, + layout: str = "bhtd", ) -> None: super().__init__(batch, heads, seq_len, dim_k, dim_v, chunk_size, dtype) - self.shape = (batch, heads, seq_len, dim_k) + self.layout = self._normalize_layout(layout) + if self.layout == "bthd": + self.shape = (batch, seq_len, heads, dim_k) + else: + self.shape = (batch, heads, seq_len, dim_k) + + @staticmethod + def _normalize_layout(layout: str) -> str: + layout = layout.lower() + if layout == "bhsd": + return "bhtd" + if layout in ("bhtd", "bthd"): + return layout + raise ValueError(f"Unsupported layout: {layout}") + + def gen_inputs(self) -> tuple[torch.Tensor, ...]: + if self.layout != "bthd": + return super().gen_inputs() + + B, H, S, DK, DV = self.batch, self.heads, self.seq_len, self.dim_k, self.dim_v + q = torch.randn(B, S, H, DK, device="cuda", dtype=self.dtype) * 0.1 + k = torch.randn(B, S, H, DK, device="cuda", dtype=self.dtype) * 0.1 + v = torch.randn(B, S, H, DV, device="cuda", dtype=self.dtype) * 0.1 + g = -torch.rand(B, S, H, device="cuda", dtype=self.dtype) + beta = torch.rand(B, S, H, device="cuda", dtype=self.dtype) * 0.5 + return q, k, v, g, beta class GatedDeltaNetDecodeTest(WorkloadBase):