Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion benchmarks/benchmark_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,31 @@ def _get_env_metadata() -> list[str]:
driver = "N/A"
lines.append(f"- **Driver version**: {driver}")

try:
result = subprocess.run(
[
"nvidia-smi",
"--query-gpu=clocks.current.sm,clocks.current.memory,"
"clocks.applications.graphics,clocks.applications.memory",
"--format=csv,noheader,nounits",
],
capture_output=True,
text=True,
timeout=5,
)
if result.returncode == 0:
clock_values = [part.strip() for part in result.stdout.splitlines()[0].split(",")]
if len(clock_values) == 4:
sm_clock, mem_clock, app_sm_clock, app_mem_clock = clock_values
lines.append(
"- **GPU clocks**: "
f"SM current {sm_clock} MHz, memory current {mem_clock} MHz, "
f"application SM {app_sm_clock} MHz, "
f"application memory {app_mem_clock} MHz"
)
except (FileNotFoundError, IndexError, subprocess.TimeoutExpired):
pass

return lines


Expand Down Expand Up @@ -569,7 +594,7 @@ def dump(path: str) -> None:
lines.extend(_get_env_metadata())
lines.append("")

result_keys = ["latency_ms", "tflops", "bandwidth_tbs"]
default_result_keys = ["latency_ms", "tflops", "bandwidth_tbs"]

for name, entries in BenchmarkReport._records.items():
if not entries:
Expand All @@ -582,6 +607,11 @@ def dump(path: str) -> None:
tag_entries = {}
for entry in entries:
tag_entries.setdefault(entry["tag"], []).append(entry)
result_keys = list(default_result_keys)
for entry in entries:
for key in entry["result"]:
if key not in result_keys:
result_keys.append(key)

for tag, tag_group in tag_entries.items():
lines.append(f"### {tag}")
Expand Down
107 changes: 81 additions & 26 deletions benchmarks/ops/bench_gated_deltanet_prefill.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Benchmark: TileOPs Gated DeltaNet inference prefill.

When FLA is installed, record it as the independent baseline. Otherwise fall
back to a pure-torch reference so every benchmark row has a non-TileOps entry.
Rows use the FLA/Qwen BTHD public contract directly:
q/k: [B, T, H, DK], v/o: [B, T, H, DV], g/beta: [B, T, H]
and compare TileOps against FLA with output_final_state=True.
"""

import inspect
Expand All @@ -13,7 +14,6 @@
from benchmarks.benchmark_base import BenchmarkReport, ManifestBenchmark
from benchmarks.ops.attention.manifest_params import manifest_params
from benchmarks.ops.bench_gated_deltanet import (
_to_fla_layout,
compute_w_u_torch,
kernel2_gated_deltanet_torch,
prepare_wy_repr_gated_torch,
Expand Down Expand Up @@ -42,6 +42,7 @@ def ref_program(
g: torch.Tensor,
beta: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
q, k, v, g, beta = _to_bhtd_layout(q, k, v, g, beta, self.layout)
batch, heads, seq_len, dim_k = k.shape
dim_v = v.shape[-1]
chunk_size = self.chunk_size
Expand All @@ -61,7 +62,7 @@ def ref_program(
final_state, o = kernel2_gated_deltanet_torch(
q, k, g_cum, w, u, initial_state, chunk_size
)
return o.to(self.dtype), final_state.to(self.dtype)
return _from_bhtd_layout(o.to(self.dtype), final_state.to(self.dtype), self.layout)


def _fla_prefill_fwd():
Expand All @@ -70,7 +71,8 @@ def _fla_prefill_fwd():
return None

signature = inspect.signature(chunk_gated_delta_rule)
supports_output_final_state = "output_final_state" in signature.parameters
if "output_final_state" not in signature.parameters:
return None

def baseline_fn(
q: torch.Tensor,
Expand All @@ -79,27 +81,79 @@ def baseline_fn(
g: torch.Tensor,
beta: torch.Tensor,
):
kwargs: dict[str, Any] = {"scale": 1.0}
if supports_output_final_state:
kwargs["output_final_state"] = True
return chunk_gated_delta_rule(q, k, v, g, beta, **kwargs)
return chunk_gated_delta_rule(
q,
k,
v,
g,
beta,
scale=1.0,
output_final_state=True,
)

return baseline_fn


def _gdn_prefill_args(workload: dict[str, Any]) -> tuple[int, int, int, int, int, int]:
batch, heads, seq_len, dim_k = workload["q_shape"]
_, _, v_seq_len, dim_v = workload["v_shape"]
if v_seq_len != seq_len:
raise ValueError("GDN prefill q_shape and v_shape must share seq_len")
return batch, heads, seq_len, dim_k, dim_v, workload.get("chunk_size", 64)
def _to_bhtd_layout(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
layout: str,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
layout = layout.lower()
if layout == "bthd":
return (
q.permute(0, 2, 1, 3).contiguous(),
k.permute(0, 2, 1, 3).contiguous(),
v.permute(0, 2, 1, 3).contiguous(),
g.permute(0, 2, 1).contiguous(),
beta.permute(0, 2, 1).contiguous(),
)
if layout in ("bhtd", "bhsd"):
return q, k, v, g, beta
raise ValueError(f"Unsupported layout: {layout}")


def _from_bhtd_layout(
o: torch.Tensor,
final_state: torch.Tensor,
layout: str,
) -> tuple[torch.Tensor, torch.Tensor]:
if layout.lower() == "bthd":
return o.permute(0, 2, 1, 3).contiguous(), final_state
return o, final_state


def _gdn_prefill_args(
workload: dict[str, Any],
) -> tuple[int, int, int, int, int, int, str]:
layout = workload.get("layout", "bthd").lower()
if layout == "bthd":
batch, seq_len, heads, dim_k = workload["q_shape"]
_, v_seq_len, v_heads, dim_v = workload["v_shape"]
else:
batch, heads, seq_len, dim_k = workload["q_shape"]
_, v_heads, v_seq_len, dim_v = workload["v_shape"]
if v_seq_len != seq_len or v_heads != heads:
raise ValueError("GDN prefill q_shape and v_shape must share seq_len and heads")
return (
batch,
heads,
seq_len,
dim_k,
dim_v,
workload.get("chunk_size", 64),
layout,
)


_BENCH_PARAMS = manifest_params(load_workloads(_OP_NAME), _gdn_prefill_args, tune=False)


@pytest.mark.parametrize(
"batch, heads, seq_len, dim_k, dim_v, chunk_size, dtype, tune",
"batch, heads, seq_len, dim_k, dim_v, chunk_size, layout, dtype, tune",
_BENCH_PARAMS,
)
def test_gated_deltanet_prefill_fwd_bench(
Expand All @@ -109,11 +163,16 @@ def test_gated_deltanet_prefill_fwd_bench(
dim_k: int,
dim_v: int,
chunk_size: int,
layout: str,
dtype: torch.dtype,
tune: bool,
) -> None:
fla_fn = _fla_prefill_fwd()
if fla_fn is None:
pytest.skip("FLA chunk_gated_delta_rule with output_final_state=True is required")

test = _GatedDeltaNetPrefillFwdTestBaseline(
batch, heads, seq_len, dim_k, dim_v, chunk_size, dtype
batch, heads, seq_len, dim_k, dim_v, chunk_size, dtype, layout=layout
)
inputs = test.gen_inputs()

Expand All @@ -126,19 +185,15 @@ def test_gated_deltanet_prefill_fwd_bench(
chunk_size,
dtype,
tune=tune,
layout=layout,
)
bm = ManifestBenchmark(_OP_NAME, op, test)
result = bm.profile(op, *inputs)
BenchmarkReport.record(op, locals(), result, tag="tileops")
result_fla = bm.profile(fla_fn, *inputs)
result["speedup_vs_fla"] = result_fla["latency_ms"] / result["latency_ms"]

fla_fn = _fla_prefill_fwd()
if fla_fn is not None:
fla_inputs = _to_fla_layout(*inputs)
result_fla = bm.profile(fla_fn, *fla_inputs)
BenchmarkReport.record(op, locals(), result_fla, tag="fla")
else:
result_ref = bm.profile(test.ref_program, *inputs)
BenchmarkReport.record(op, locals(), result_ref, tag="torch-ref")
BenchmarkReport.record(op, locals(), result, tag="tileops")
BenchmarkReport.record(op, locals(), result_fla, tag="fla")


if __name__ == "__main__":
Expand Down
21 changes: 15 additions & 6 deletions tileops/manifest/linear_attention.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,21 @@ GatedDeltaNetPrefillFwdOp:
- "chunk_size is None or S % chunk_size == 0"

workloads:
# Inference-oriented BHSD rows. Unit tests should use smaller
# branch-covering shapes in tests/ops, not these benchmark workloads.
- {q_shape: [1, 16, 512, 128], k_shape: [1, 16, 512, 128], v_shape: [1, 16, 512, 128], g_shape: [1, 16, 512], beta_shape: [1, 16, 512], chunk_size: 64, dtypes: [float16, bfloat16], label: "gdn-prefill-b1-s512-h16-d128"}
- {q_shape: [4, 16, 2048, 64], k_shape: [4, 16, 2048, 64], v_shape: [4, 16, 2048, 128], g_shape: [4, 16, 2048], beta_shape: [4, 16, 2048], chunk_size: 64, dtypes: [float16, bfloat16], label: "gdn-prefill-b4-s2k-h16-dk64-dv128"}
- {q_shape: [1, 16, 2048, 128], k_shape: [1, 16, 2048, 128], v_shape: [1, 16, 2048, 128], g_shape: [1, 16, 2048], beta_shape: [1, 16, 2048], chunk_size: 64, dtypes: [float16, bfloat16], label: "gdn-prefill-b1-s2k-h16-d128"}
- {q_shape: [1, 16, 8192, 128], k_shape: [1, 16, 8192, 128], v_shape: [1, 16, 8192, 128], g_shape: [1, 16, 8192], beta_shape: [1, 16, 8192], chunk_size: 64, dtypes: [float16, bfloat16], label: "gdn-prefill-b1-s8k-h16-d128"}
# Qwen3.5-aligned inference prefill rows using the public FLA BTHD
# contract. Tests use smaller branch-covering shapes in tests/ops; these
# rows are intentionally long-context benchmark coverage.
- {q_shape: [1, 32768, 16, 128], k_shape: [1, 32768, 16, 128], v_shape: [1, 32768, 16, 128], g_shape: [1, 32768, 16], beta_shape: [1, 32768, 16], chunk_size: 64, layout: bthd, dtypes: [float16, bfloat16], label: "qwen35-gdn-prefill-b1-s32k-h16-d128-bthd"}
- {q_shape: [1, 65536, 16, 128], k_shape: [1, 65536, 16, 128], v_shape: [1, 65536, 16, 128], g_shape: [1, 65536, 16], beta_shape: [1, 65536, 16], chunk_size: 64, layout: bthd, dtypes: [float16, bfloat16], label: "qwen35-gdn-prefill-b1-s64k-h16-d128-bthd"}
- {q_shape: [1, 131072, 16, 128], k_shape: [1, 131072, 16, 128], v_shape: [1, 131072, 16, 128], g_shape: [1, 131072, 16], beta_shape: [1, 131072, 16], chunk_size: 64, layout: bthd, dtypes: [float16, bfloat16], label: "qwen35-gdn-prefill-b1-s128k-h16-d128-bthd"}
- {q_shape: [1, 32768, 32, 128], k_shape: [1, 32768, 32, 128], v_shape: [1, 32768, 32, 128], g_shape: [1, 32768, 32], beta_shape: [1, 32768, 32], chunk_size: 64, layout: bthd, dtypes: [float16, bfloat16], label: "qwen35-gdn-prefill-b1-s32k-h32-d128-bthd"}
- {q_shape: [1, 65536, 32, 128], k_shape: [1, 65536, 32, 128], v_shape: [1, 65536, 32, 128], g_shape: [1, 65536, 32], beta_shape: [1, 65536, 32], chunk_size: 64, layout: bthd, dtypes: [float16, bfloat16], label: "qwen35-gdn-prefill-b1-s64k-h32-d128-bthd"}
- {q_shape: [1, 131072, 32, 128], k_shape: [1, 131072, 32, 128], v_shape: [1, 131072, 32, 128], g_shape: [1, 131072, 32], beta_shape: [1, 131072, 32], chunk_size: 64, layout: bthd, dtypes: [float16, bfloat16], label: "qwen35-gdn-prefill-b1-s128k-h32-d128-bthd"}
- {q_shape: [1, 32768, 48, 128], k_shape: [1, 32768, 48, 128], v_shape: [1, 32768, 48, 128], g_shape: [1, 32768, 48], beta_shape: [1, 32768, 48], chunk_size: 64, layout: bthd, dtypes: [float16, bfloat16], label: "qwen35-gdn-prefill-b1-s32k-h48-d128-bthd"}
- {q_shape: [1, 65536, 48, 128], k_shape: [1, 65536, 48, 128], v_shape: [1, 65536, 48, 128], g_shape: [1, 65536, 48], beta_shape: [1, 65536, 48], chunk_size: 64, layout: bthd, dtypes: [float16, bfloat16], label: "qwen35-gdn-prefill-b1-s64k-h48-d128-bthd"}
- {q_shape: [1, 131072, 48, 128], k_shape: [1, 131072, 48, 128], v_shape: [1, 131072, 48, 128], g_shape: [1, 131072, 48], beta_shape: [1, 131072, 48], chunk_size: 64, layout: bthd, dtypes: [float16, bfloat16], label: "qwen35-gdn-prefill-b1-s128k-h48-d128-bthd"}
- {q_shape: [1, 32768, 64, 128], k_shape: [1, 32768, 64, 128], v_shape: [1, 32768, 64, 128], g_shape: [1, 32768, 64], beta_shape: [1, 32768, 64], chunk_size: 64, layout: bthd, dtypes: [float16, bfloat16], label: "qwen35-gdn-prefill-b1-s32k-h64-d128-bthd"}
- {q_shape: [1, 65536, 64, 128], k_shape: [1, 65536, 64, 128], v_shape: [1, 65536, 64, 128], g_shape: [1, 65536, 64], beta_shape: [1, 65536, 64], chunk_size: 64, layout: bthd, dtypes: [float16, bfloat16], label: "qwen35-gdn-prefill-b1-s64k-h64-d128-bthd"}
- {q_shape: [1, 131072, 64, 128], k_shape: [1, 131072, 64, 128], v_shape: [1, 131072, 64, 128], g_shape: [1, 131072, 64], beta_shape: [1, 131072, 64], chunk_size: 64, layout: bthd, dtypes: [float16, bfloat16], label: "qwen35-gdn-prefill-b1-s128k-h64-d128-bthd"}

roofline:
func: "tileops.perf.formulas.gated_deltanet_prefill_fwd_roofline"
Expand Down
9 changes: 7 additions & 2 deletions tileops/perf/formulas.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,13 @@ def gated_deltanet_prefill_fwd_roofline(op: Any | None = None, **kwargs: Any) ->
"""
data = _shape_or_attrs(op, kwargs)
if "q_shape" in data:
batch, heads, seq_len, dim_k = data["q_shape"]
_, _, _, dim_v = data["v_shape"]
layout = str(data.get("layout", "bhtd")).lower()
if layout == "bthd":
batch, seq_len, heads, dim_k = data["q_shape"]
_, _, _, dim_v = data["v_shape"]
else:
batch, heads, seq_len, dim_k = data["q_shape"]
_, _, _, dim_v = data["v_shape"]
chunk_size = data.get("chunk_size", 64)
else:
batch, heads, seq_len, dim_k, dim_v, chunk_size = (
Expand Down
28 changes: 27 additions & 1 deletion workloads/gated_deltanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,35 @@ def __init__(
dim_v: int,
chunk_size: int,
dtype: torch.dtype,
layout: str = "bhtd",
) -> None:
super().__init__(batch, heads, seq_len, dim_k, dim_v, chunk_size, dtype)
self.shape = (batch, heads, seq_len, dim_k)
self.layout = self._normalize_layout(layout)
if self.layout == "bthd":
self.shape = (batch, seq_len, heads, dim_k)
else:
self.shape = (batch, heads, seq_len, dim_k)

@staticmethod
def _normalize_layout(layout: str) -> str:
layout = layout.lower()
if layout == "bhsd":
return "bhtd"
if layout in ("bhtd", "bthd"):
return layout
raise ValueError(f"Unsupported layout: {layout}")

def gen_inputs(self) -> tuple[torch.Tensor, ...]:
if self.layout != "bthd":
return super().gen_inputs()

B, H, S, DK, DV = self.batch, self.heads, self.seq_len, self.dim_k, self.dim_v
q = torch.randn(B, S, H, DK, device="cuda", dtype=self.dtype) * 0.1
k = torch.randn(B, S, H, DK, device="cuda", dtype=self.dtype) * 0.1
v = torch.randn(B, S, H, DV, device="cuda", dtype=self.dtype) * 0.1
g = -torch.rand(B, S, H, device="cuda", dtype=self.dtype)
beta = torch.rand(B, S, H, device="cuda", dtype=self.dtype) * 0.5
return q, k, v, g, beta


class GatedDeltaNetDecodeTest(WorkloadBase):
Expand Down
Loading