Skip to content

Commit ea069f0

Browse files
committed
[Benchmark] bf16 x int16 helion kernel
stack-info: PR: #794, branch: karthickai/stack/5
1 parent 05fb47d commit ea069f0

File tree

4 files changed

+301
-0
lines changed

4 files changed

+301
-0
lines changed

benchmarks/run.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,11 @@ class RunResult:
285285
"examples.low_mem_dropout",
286286
"low_mem_dropout_tritonbench",
287287
),
288+
"bf16xint16_gemm": (
289+
"tritonbench.operators.bf16xint16_gemm.bf16xint16_gemm",
290+
"examples.bf16xint16_gemm",
291+
"bf16xint16_gemm_tritonbench",
292+
),
288293
}
289294

290295

@@ -551,6 +556,15 @@ class RunResult:
551556
"helion_low_mem_dropout_tritonbench-accuracy": "helion_accuracy",
552557
"helion_low_mem_dropout_tritonbench-speedup": "helion_speedup",
553558
},
559+
"bf16xint16_gemm": {
560+
"bf16xbf16": "baseline",
561+
"bf16xint16-speedup": "triton_speedup",
562+
"bf16xint16-accuracy": "triton_accuracy",
563+
"torch_compile_bf16xbf16-speedup": "torch_compile_speedup",
564+
"torch_compile_bf16xbf16-accuracy": "torch_compile_accuracy",
565+
"helion_bf16xint16_gemm_tritonbench-speedup": "helion_speedup",
566+
"helion_bf16xint16_gemm_tritonbench-accuracy": "helion_accuracy",
567+
},
554568
}
555569

556570

examples/bf16xint16_gemm.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
"""
2+
BF16 x INT16 GEMM with Helion
3+
============================================================
4+
The kernel performs matrix multiplication where one matrix is in bfloat16 format and the other is in int16 format.
5+
The int16 values are converted to bfloat16 before performing the matrix multiplication.
6+
"""
7+
8+
# %%
9+
from __future__ import annotations
10+
11+
from typing import Callable
12+
13+
import torch
14+
from torch import Tensor
15+
16+
import helion
17+
import helion.language as hl
18+
19+
20+
# %%
21+
@helion.kernel(static_shapes=True)
22+
def _bf16xint16_gemm(x: Tensor, w: Tensor) -> Tensor:
23+
"""
24+
x is bf16, w is int16.
25+
"""
26+
M, K = x.shape
27+
K2, N = w.shape
28+
assert K == K2, f"size mismatch {K} != {K2}"
29+
30+
out = torch.empty([M, N], dtype=torch.bfloat16, device=x.device)
31+
32+
for tile_m, tile_n in hl.tile([M, N]):
33+
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
34+
for tile_k in hl.tile(K):
35+
x_tile = x[tile_m, tile_k]
36+
w_tile = w[tile_k, tile_n].to(torch.bfloat16)
37+
acc = hl.dot(x_tile, w_tile, acc=acc)
38+
out[tile_m, tile_n] = acc.to(torch.bfloat16)
39+
40+
return out
41+
42+
43+
# %%
44+
@helion.kernel(static_shapes=True)
45+
def _int16xbf16_gemm(x: Tensor, w: Tensor) -> Tensor:
46+
"""
47+
x is int16, w is bf16.
48+
"""
49+
M, K = x.shape
50+
K2, N = w.shape
51+
assert K == K2, f"size mismatch {K} != {K2}"
52+
53+
out = torch.empty([M, N], dtype=torch.bfloat16, device=x.device)
54+
55+
for tile_m, tile_n in hl.tile([M, N]):
56+
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
57+
for tile_k in hl.tile(K):
58+
x_tile = x[tile_m, tile_k].to(torch.bfloat16)
59+
w_tile = w[tile_k, tile_n]
60+
acc = hl.dot(x_tile, w_tile, acc=acc)
61+
out[tile_m, tile_n] = acc.to(torch.bfloat16)
62+
63+
return out
64+
65+
66+
# %%
67+
def bf16xint16_gemm(x: Tensor, w: Tensor, transpose: bool = False) -> Tensor:
68+
"""
69+
This function dispatches to the appropriate kernel based on the transpose flag.
70+
71+
Args:
72+
x (Tensor): Input tensor.
73+
w (Tensor): Weight tensor.
74+
transpose (bool): If True, assumes x is int16 and w is bf16. Default: False.
75+
76+
Returns:
77+
Tensor: Output tensor in bfloat16 format.
78+
"""
79+
if transpose:
80+
return _int16xbf16_gemm(x, w)
81+
return _bf16xint16_gemm(x, w)
82+
83+
84+
# %%
85+
def bf16xint16_gemm_tritonbench(
86+
tb_op: object, x: torch.Tensor, w: torch.Tensor
87+
) -> Callable[[], torch.Tensor]:
88+
"""
89+
Wrapper for TritonBench compatibility.
90+
91+
Args:
92+
tb_op: TritonBench operator instance
93+
x (torch.Tensor): Input tensor in bfloat16 format.
94+
w (torch.Tensor): Weight tensor in int16 format.
95+
96+
Returns:
97+
Callable that returns output tensor in bfloat16 format.
98+
"""
99+
# Check if transpose mode based on tritonbench operator
100+
transpose = getattr(tb_op, "transpose", False)
101+
102+
def run_kernel() -> torch.Tensor:
103+
return bf16xint16_gemm(x, w, transpose=transpose)
104+
105+
return run_kernel
106+
107+
108+
# %%
109+
def reference_bf16xint16_pytorch(
110+
x: torch.Tensor, w: torch.Tensor, transpose: bool = False
111+
) -> torch.Tensor:
112+
"""
113+
Reference implementation using PyTorch operations.
114+
115+
Args:
116+
x (torch.Tensor): Input tensor.
117+
w (torch.Tensor): Weight tensor.
118+
transpose (bool): Transpose mode flag.
119+
120+
Returns:
121+
torch.Tensor: Output tensor in bfloat16 format.
122+
"""
123+
if transpose:
124+
x_bf16 = x.to(torch.bfloat16)
125+
return torch.matmul(x_bf16, w)
126+
w_bf16 = w.to(torch.bfloat16)
127+
return torch.matmul(x, w_bf16)
128+
129+
130+
# %%
131+
def check(m: int, k: int, n: int) -> None:
132+
"""
133+
Test the bf16 x int16 GEMM implementation against the PyTorch reference.
134+
135+
Args:
136+
m (int): Number of rows.
137+
k (int): Shared dimension.
138+
n (int): Number of cols.
139+
"""
140+
x = torch.randn([m, k], device="cuda", dtype=torch.bfloat16)
141+
w = torch.randint(-(2**15), 2**15 - 1, (k, n), device="cuda", dtype=torch.int16)
142+
143+
result = bf16xint16_gemm(x, w, transpose=False)
144+
expected = reference_bf16xint16_pytorch(x, w, transpose=False)
145+
torch.testing.assert_close(result, expected, rtol=1e-2, atol=1e-2)
146+
147+
x_int16 = torch.randint(
148+
-(2**15), 2**15 - 1, (m, k), device="cuda", dtype=torch.int16
149+
)
150+
w_bf16 = torch.randn([k, n], device="cuda", dtype=torch.bfloat16)
151+
152+
result = bf16xint16_gemm(x_int16, w_bf16, transpose=True)
153+
expected = reference_bf16xint16_pytorch(x_int16, w_bf16, transpose=True)
154+
torch.testing.assert_close(result, expected, rtol=1e-2, atol=1e-2)
155+
156+
157+
# %%
158+
def main() -> None:
159+
"""
160+
Main entry point that runs the bf16xint16 kernel verification with different tensor sizes.
161+
"""
162+
check(256, 256, 256)
163+
check(512, 512, 512)
164+
check(65536, 1024, 1280)
165+
166+
167+
# %%
168+
if __name__ == "__main__":
169+
main()

test/test_examples.expected

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,92 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor, *, _la
404404
_launcher(_helion_attention, (32 * triton.cdiv(512, _BLOCK_SIZE_1),), q_view, k_view, v_view, out, _BLOCK_SIZE_1, _RDIM_SIZE_2, 1, _BLOCK_SIZE_3, num_warps=4, num_stages=2)
405405
return out.view(q_in.size())
406406

407+
--- assertExpectedJournal(TestExamples.test_bf16xint16)
408+
from __future__ import annotations
409+
410+
import torch
411+
import triton
412+
import triton.language as tl
413+
from helion.runtime import default_launcher as _default_launcher
414+
415+
@triton.jit
416+
def _helion__bf16xint16_gemm(x, w, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
417+
num_blocks_0 = tl.cdiv(65536, _BLOCK_SIZE_0)
418+
pid_0 = tl.program_id(0) % num_blocks_0
419+
pid_1 = tl.program_id(0) // num_blocks_0
420+
offset_0 = pid_0 * _BLOCK_SIZE_0
421+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
422+
offset_1 = pid_1 * _BLOCK_SIZE_1
423+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
424+
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
425+
for offset_2 in tl.range(0, 1024, _BLOCK_SIZE_2):
426+
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
427+
acc_copy = acc
428+
acc_copy_0 = acc_copy
429+
x_tile = tl.load(x + (indices_0[:, None] * 1024 + indices_2[None, :] * 1), None)
430+
load_1 = tl.load(w + (indices_2[:, None] * 1280 + indices_1[None, :] * 1), None)
431+
v_0 = tl.cast(load_1, tl.bfloat16)
432+
acc = tl.dot(tl.cast(x_tile, tl.bfloat16), tl.cast(v_0, tl.bfloat16), acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32)
433+
v_1 = tl.cast(acc, tl.bfloat16)
434+
tl.store(out + (indices_0[:, None] * 1280 + indices_1[None, :] * 1), v_1, None)
435+
436+
def _bf16xint16_gemm(x: Tensor, w: Tensor, *, _launcher=_default_launcher):
437+
"""
438+
x is bf16, w is int16.
439+
"""
440+
M, K = x.shape
441+
K2, N = w.shape
442+
assert K == K2, f'size mismatch {K} != {K2}'
443+
out = torch.empty([M, N], dtype=torch.bfloat16, device=x.device)
444+
_BLOCK_SIZE_0 = 16
445+
_BLOCK_SIZE_1 = 16
446+
_BLOCK_SIZE_2 = 16
447+
_launcher(_helion__bf16xint16_gemm, (triton.cdiv(65536, _BLOCK_SIZE_0) * triton.cdiv(1280, _BLOCK_SIZE_1),), x, w, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
448+
return out
449+
450+
--- assertExpectedJournal(TestExamples.test_bf16xint16)
451+
from __future__ import annotations
452+
453+
import torch
454+
import triton
455+
import triton.language as tl
456+
from helion.runtime import default_launcher as _default_launcher
457+
458+
@triton.jit
459+
def _helion__int16xbf16_gemm(x, w, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
460+
num_blocks_0 = tl.cdiv(65536, _BLOCK_SIZE_0)
461+
pid_0 = tl.program_id(0) % num_blocks_0
462+
pid_1 = tl.program_id(0) // num_blocks_0
463+
offset_0 = pid_0 * _BLOCK_SIZE_0
464+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
465+
offset_1 = pid_1 * _BLOCK_SIZE_1
466+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
467+
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
468+
for offset_2 in tl.range(0, 1024, _BLOCK_SIZE_2):
469+
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
470+
acc_copy = acc
471+
acc_copy_0 = acc_copy
472+
load = tl.load(x + (indices_0[:, None] * 1024 + indices_2[None, :] * 1), None)
473+
v_0 = tl.cast(load, tl.bfloat16)
474+
w_tile = tl.load(w + (indices_2[:, None] * 1280 + indices_1[None, :] * 1), None)
475+
acc = tl.dot(tl.cast(v_0, tl.bfloat16), tl.cast(w_tile, tl.bfloat16), acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32)
476+
v_1 = tl.cast(acc, tl.bfloat16)
477+
tl.store(out + (indices_0[:, None] * 1280 + indices_1[None, :] * 1), v_1, None)
478+
479+
def _int16xbf16_gemm(x: Tensor, w: Tensor, *, _launcher=_default_launcher):
480+
"""
481+
x is int16, w is bf16.
482+
"""
483+
M, K = x.shape
484+
K2, N = w.shape
485+
assert K == K2, f'size mismatch {K} != {K2}'
486+
out = torch.empty([M, N], dtype=torch.bfloat16, device=x.device)
487+
_BLOCK_SIZE_0 = 16
488+
_BLOCK_SIZE_1 = 16
489+
_BLOCK_SIZE_2 = 16
490+
_launcher(_helion__int16xbf16_gemm, (triton.cdiv(65536, _BLOCK_SIZE_0) * triton.cdiv(1280, _BLOCK_SIZE_1),), x, w, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
491+
return out
492+
407493
--- assertExpectedJournal(TestExamples.test_bmm)
408494
from __future__ import annotations
409495

test/test_examples.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,38 @@ def test_low_mem_dropout(self):
355355
check_example("low_mem_dropout", (p, grad_y, seed), grad_x),
356356
)
357357

358+
@skipIfRocm("precision differences with bf16xint16 operations on rocm")
359+
def test_bf16xint16(self):
360+
from examples.bf16xint16_gemm import reference_bf16xint16_pytorch
361+
362+
m, k, n = 65536, 1024, 1280
363+
364+
x = torch.randn([m, k], device=DEVICE, dtype=torch.bfloat16)
365+
w = torch.randint(-(2**15), 2**15 - 1, (k, n), device=DEVICE, dtype=torch.int16)
366+
367+
self.assertExpectedJournal(
368+
check_example(
369+
"bf16xint16_gemm",
370+
(x, w),
371+
reference_bf16xint16_pytorch(x, w, False),
372+
fn_name="_bf16xint16_gemm",
373+
)
374+
)
375+
376+
x_int16 = torch.randint(
377+
-(2**15), 2**15 - 1, (m, k), device=DEVICE, dtype=torch.int16
378+
)
379+
w_bf16 = torch.randn([k, n], device=DEVICE, dtype=torch.bfloat16)
380+
381+
self.assertExpectedJournal(
382+
check_example(
383+
"bf16xint16_gemm",
384+
(x_int16, w_bf16),
385+
reference_bf16xint16_pytorch(x_int16, w_bf16, True),
386+
fn_name="_int16xbf16_gemm",
387+
)
388+
)
389+
358390
def test_rms_norm_fwd(self):
359391
args = (
360392
torch.randn([128, 256], device=DEVICE, dtype=torch.float16),

0 commit comments

Comments
 (0)