From 66ae6c5fc873c68caf0f7c40e90c1f7dfd94d11c Mon Sep 17 00:00:00 2001 From: karthickai Date: Thu, 2 Oct 2025 21:43:49 -0700 Subject: [PATCH] [Benchmark] bf16 x int16 helion kernel stack-info: PR: https://github.com/pytorch/helion/pull/794, branch: karthickai/stack/5 --- benchmarks/run.py | 14 +++ examples/bf16xint16_gemm.py | 169 ++++++++++++++++++++++++++++++++++++ test/test_examples.expected | 86 ++++++++++++++++++ test/test_examples.py | 32 +++++++ 4 files changed, 301 insertions(+) create mode 100644 examples/bf16xint16_gemm.py diff --git a/benchmarks/run.py b/benchmarks/run.py index 11af98119..bcbe0461f 100644 --- a/benchmarks/run.py +++ b/benchmarks/run.py @@ -285,6 +285,11 @@ class RunResult: "examples.low_mem_dropout", "low_mem_dropout_tritonbench", ), + "bf16xint16_gemm": ( + "tritonbench.operators.bf16xint16_gemm.bf16xint16_gemm", + "examples.bf16xint16_gemm", + "bf16xint16_gemm_tritonbench", + ), } @@ -551,6 +556,15 @@ class RunResult: "helion_low_mem_dropout_tritonbench-accuracy": "helion_accuracy", "helion_low_mem_dropout_tritonbench-speedup": "helion_speedup", }, + "bf16xint16_gemm": { + "bf16xbf16": "baseline", + "bf16xint16-speedup": "triton_speedup", + "bf16xint16-accuracy": "triton_accuracy", + "torch_compile_bf16xbf16-speedup": "torch_compile_speedup", + "torch_compile_bf16xbf16-accuracy": "torch_compile_accuracy", + "helion_bf16xint16_gemm_tritonbench-speedup": "helion_speedup", + "helion_bf16xint16_gemm_tritonbench-accuracy": "helion_accuracy", + }, } diff --git a/examples/bf16xint16_gemm.py b/examples/bf16xint16_gemm.py new file mode 100644 index 000000000..afaeaced7 --- /dev/null +++ b/examples/bf16xint16_gemm.py @@ -0,0 +1,169 @@ +""" +BF16 x INT16 GEMM with Helion +============================================================ +The kernel performs matrix multiplication where one matrix is in bfloat16 format and the other is in int16 format. +The int16 values are converted to bfloat16 before performing the matrix multiplication. +""" + +# %% +from __future__ import annotations + +from typing import Callable + +import torch +from torch import Tensor + +import helion +import helion.language as hl + + +# %% +@helion.kernel(static_shapes=True) +def _bf16xint16_gemm(x: Tensor, w: Tensor) -> Tensor: + """ + x is bf16, w is int16. + """ + M, K = x.shape + K2, N = w.shape + assert K == K2, f"size mismatch {K} != {K2}" + + out = torch.empty([M, N], dtype=torch.bfloat16, device=x.device) + + for tile_m, tile_n in hl.tile([M, N]): + acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + for tile_k in hl.tile(K): + x_tile = x[tile_m, tile_k] + w_tile = w[tile_k, tile_n].to(torch.bfloat16) + acc = hl.dot(x_tile, w_tile, acc=acc) + out[tile_m, tile_n] = acc.to(torch.bfloat16) + + return out + + +# %% +@helion.kernel(static_shapes=True) +def _int16xbf16_gemm(x: Tensor, w: Tensor) -> Tensor: + """ + x is int16, w is bf16. + """ + M, K = x.shape + K2, N = w.shape + assert K == K2, f"size mismatch {K} != {K2}" + + out = torch.empty([M, N], dtype=torch.bfloat16, device=x.device) + + for tile_m, tile_n in hl.tile([M, N]): + acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + for tile_k in hl.tile(K): + x_tile = x[tile_m, tile_k].to(torch.bfloat16) + w_tile = w[tile_k, tile_n] + acc = hl.dot(x_tile, w_tile, acc=acc) + out[tile_m, tile_n] = acc.to(torch.bfloat16) + + return out + + +# %% +def bf16xint16_gemm(x: Tensor, w: Tensor, transpose: bool = False) -> Tensor: + """ + This function dispatches to the appropriate kernel based on the transpose flag. + + Args: + x (Tensor): Input tensor. + w (Tensor): Weight tensor. + transpose (bool): If True, assumes x is int16 and w is bf16. Default: False. + + Returns: + Tensor: Output tensor in bfloat16 format. + """ + if transpose: + return _int16xbf16_gemm(x, w) + return _bf16xint16_gemm(x, w) + + +# %% +def bf16xint16_gemm_tritonbench( + tb_op: object, x: torch.Tensor, w: torch.Tensor +) -> Callable[[], torch.Tensor]: + """ + Wrapper for TritonBench compatibility. + + Args: + tb_op: TritonBench operator instance + x (torch.Tensor): Input tensor in bfloat16 format. + w (torch.Tensor): Weight tensor in int16 format. + + Returns: + Callable that returns output tensor in bfloat16 format. + """ + # Check if transpose mode based on tritonbench operator + transpose = getattr(tb_op, "transpose", False) + + def run_kernel() -> torch.Tensor: + return bf16xint16_gemm(x, w, transpose=transpose) + + return run_kernel + + +# %% +def reference_bf16xint16_pytorch( + x: torch.Tensor, w: torch.Tensor, transpose: bool = False +) -> torch.Tensor: + """ + Reference implementation using PyTorch operations. + + Args: + x (torch.Tensor): Input tensor. + w (torch.Tensor): Weight tensor. + transpose (bool): Transpose mode flag. + + Returns: + torch.Tensor: Output tensor in bfloat16 format. + """ + if transpose: + x_bf16 = x.to(torch.bfloat16) + return torch.matmul(x_bf16, w) + w_bf16 = w.to(torch.bfloat16) + return torch.matmul(x, w_bf16) + + +# %% +def check(m: int, k: int, n: int) -> None: + """ + Test the bf16 x int16 GEMM implementation against the PyTorch reference. + + Args: + m (int): Number of rows. + k (int): Shared dimension. + n (int): Number of cols. + """ + x = torch.randn([m, k], device="cuda", dtype=torch.bfloat16) + w = torch.randint(-(2**15), 2**15 - 1, (k, n), device="cuda", dtype=torch.int16) + + result = bf16xint16_gemm(x, w, transpose=False) + expected = reference_bf16xint16_pytorch(x, w, transpose=False) + torch.testing.assert_close(result, expected, rtol=1e-2, atol=1e-2) + + x_int16 = torch.randint( + -(2**15), 2**15 - 1, (m, k), device="cuda", dtype=torch.int16 + ) + w_bf16 = torch.randn([k, n], device="cuda", dtype=torch.bfloat16) + + result = bf16xint16_gemm(x_int16, w_bf16, transpose=True) + expected = reference_bf16xint16_pytorch(x_int16, w_bf16, transpose=True) + torch.testing.assert_close(result, expected, rtol=1e-2, atol=1e-2) + + +# %% +def main() -> None: + """ + Main entry point that runs the bf16xint16 kernel verification with different tensor sizes. + """ + check(256, 256, 256) + check(512, 512, 512) + check(65536, 1024, 1280) + + +# %% +if __name__ == "__main__": + main() diff --git a/test/test_examples.expected b/test/test_examples.expected index 5e56dbe53..fa652db03 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -404,6 +404,92 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor, *, _la _launcher(_helion_attention, (32 * triton.cdiv(512, _BLOCK_SIZE_1),), q_view, k_view, v_view, out, _BLOCK_SIZE_1, _RDIM_SIZE_2, 1, _BLOCK_SIZE_3, num_warps=4, num_stages=2) return out.view(q_in.size()) +--- assertExpectedJournal(TestExamples.test_bf16xint16) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion__bf16xint16_gemm(x, w, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + num_blocks_0 = tl.cdiv(65536, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + for offset_2 in tl.range(0, 1024, _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + acc_copy = acc + acc_copy_0 = acc_copy + x_tile = tl.load(x + (indices_0[:, None] * 1024 + indices_2[None, :] * 1), None) + load_1 = tl.load(w + (indices_2[:, None] * 1280 + indices_1[None, :] * 1), None) + v_0 = tl.cast(load_1, tl.bfloat16) + acc = tl.dot(tl.cast(x_tile, tl.bfloat16), tl.cast(v_0, tl.bfloat16), acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32) + v_1 = tl.cast(acc, tl.bfloat16) + tl.store(out + (indices_0[:, None] * 1280 + indices_1[None, :] * 1), v_1, None) + +def _bf16xint16_gemm(x: Tensor, w: Tensor, *, _launcher=_default_launcher): + """ + x is bf16, w is int16. + """ + M, K = x.shape + K2, N = w.shape + assert K == K2, f'size mismatch {K} != {K2}' + out = torch.empty([M, N], dtype=torch.bfloat16, device=x.device) + _BLOCK_SIZE_0 = 16 + _BLOCK_SIZE_1 = 16 + _BLOCK_SIZE_2 = 16 + _launcher(_helion__bf16xint16_gemm, (triton.cdiv(65536, _BLOCK_SIZE_0) * triton.cdiv(1280, _BLOCK_SIZE_1),), x, w, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + return out + +--- assertExpectedJournal(TestExamples.test_bf16xint16) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion__int16xbf16_gemm(x, w, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + num_blocks_0 = tl.cdiv(65536, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + for offset_2 in tl.range(0, 1024, _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + acc_copy = acc + acc_copy_0 = acc_copy + load = tl.load(x + (indices_0[:, None] * 1024 + indices_2[None, :] * 1), None) + v_0 = tl.cast(load, tl.bfloat16) + w_tile = tl.load(w + (indices_2[:, None] * 1280 + indices_1[None, :] * 1), None) + acc = tl.dot(tl.cast(v_0, tl.bfloat16), tl.cast(w_tile, tl.bfloat16), acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32) + v_1 = tl.cast(acc, tl.bfloat16) + tl.store(out + (indices_0[:, None] * 1280 + indices_1[None, :] * 1), v_1, None) + +def _int16xbf16_gemm(x: Tensor, w: Tensor, *, _launcher=_default_launcher): + """ + x is int16, w is bf16. + """ + M, K = x.shape + K2, N = w.shape + assert K == K2, f'size mismatch {K} != {K2}' + out = torch.empty([M, N], dtype=torch.bfloat16, device=x.device) + _BLOCK_SIZE_0 = 16 + _BLOCK_SIZE_1 = 16 + _BLOCK_SIZE_2 = 16 + _launcher(_helion__int16xbf16_gemm, (triton.cdiv(65536, _BLOCK_SIZE_0) * triton.cdiv(1280, _BLOCK_SIZE_1),), x, w, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) + return out + --- assertExpectedJournal(TestExamples.test_bmm) from __future__ import annotations diff --git a/test/test_examples.py b/test/test_examples.py index 45dfdcb79..24e28fd44 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -355,6 +355,38 @@ def test_low_mem_dropout(self): check_example("low_mem_dropout", (p, grad_y, seed), grad_x), ) + @skipIfRocm("precision differences with bf16xint16 operations on rocm") + def test_bf16xint16(self): + from examples.bf16xint16_gemm import reference_bf16xint16_pytorch + + m, k, n = 65536, 1024, 1280 + + x = torch.randn([m, k], device=DEVICE, dtype=torch.bfloat16) + w = torch.randint(-(2**15), 2**15 - 1, (k, n), device=DEVICE, dtype=torch.int16) + + self.assertExpectedJournal( + check_example( + "bf16xint16_gemm", + (x, w), + reference_bf16xint16_pytorch(x, w, False), + fn_name="_bf16xint16_gemm", + ) + ) + + x_int16 = torch.randint( + -(2**15), 2**15 - 1, (m, k), device=DEVICE, dtype=torch.int16 + ) + w_bf16 = torch.randn([k, n], device=DEVICE, dtype=torch.bfloat16) + + self.assertExpectedJournal( + check_example( + "bf16xint16_gemm", + (x_int16, w_bf16), + reference_bf16xint16_pytorch(x_int16, w_bf16, True), + fn_name="_int16xbf16_gemm", + ) + ) + def test_rms_norm_fwd(self): args = ( torch.randn([128, 256], device=DEVICE, dtype=torch.float16),