Skip to content

Commit 9b73baf

Browse files
committed
Faster int4 gemm
stack-info: PR: #751, branch: PaulZhang12/stack/11
1 parent 5a772d1 commit 9b73baf

File tree

3 files changed

+53
-52
lines changed

3 files changed

+53
-52
lines changed

examples/int4_gemm.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,14 @@ def matmul_bf16_int4(A: Tensor, B: Tensor) -> Tensor:
5252
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
5353

5454
for tile_k_packed in hl.tile(K // 2, block_size=block_size_k_packed):
55+
# Load corresponding tiles from A (need to load twice the packed tile size)
56+
# We need to map tile_k_packed to the corresponding range in A
57+
a_tile_begin = tile_k_packed.begin * 2
58+
a_tile_len = block_size_k_packed * 2
59+
a_tile = A[
60+
tile_m, a_tile_begin : (a_tile_begin + a_tile_len)
61+
].to(torch.float32) # [BLOCK_SIZE_M, BLOCK_SIZE_K]
62+
5563
# Load packed int8 data from B
5664
b_tile = B[tile_k_packed, tile_n] # [BLOCK_SIZE_K//2, BLOCK_SIZE_N]
5765

@@ -60,29 +68,19 @@ def matmul_bf16_int4(A: Tensor, B: Tensor) -> Tensor:
6068
b_lo = ((b_tile << 4) >> 4).to(torch.int8) # Sign-extend low 4 bits
6169
b_hi = (b_tile >> 4).to(torch.int8) # Sign-extend high 4 bits
6270

63-
# Convert to bfloat16
64-
b_lo_bf16 = b_lo.to(torch.bfloat16) # [BLOCK_SIZE_K//2, BLOCK_SIZE_N]
65-
b_hi_bf16 = b_hi.to(torch.bfloat16) # [BLOCK_SIZE_K//2, BLOCK_SIZE_N]
66-
6771
# Stack and reshape to interleave low and high bits
6872
# Stack along a new dimension to get [BLOCK_SIZE_K//2, 2, BLOCK_SIZE_N]
69-
b_stacked = torch.stack([b_lo_bf16, b_hi_bf16], dim=1)
73+
b_stacked = torch.stack([b_lo, b_hi], dim=1)
7074

7175
# Reshape to interleave: [BLOCK_SIZE_K//2, 2, BLOCK_SIZE_N] -> [BLOCK_SIZE_K, BLOCK_SIZE_N]
7276
# This will place elements in the order: b_lo[0], b_hi[0], b_lo[1], b_hi[1], ...
7377
b_unpacked = b_stacked.reshape(
7478
tile_k_packed.block_size * 2, tile_n.block_size
75-
)
76-
77-
# Load corresponding tiles from A (need to load twice the packed tile size)
78-
# We need to map tile_k_packed to the corresponding range in A
79-
a_tile_begin = tile_k_packed.begin * 2
80-
a_tile_len = tile_k_packed.block_size * 2
81-
a_tile = A[
82-
tile_m, a_tile_begin : (a_tile_begin + a_tile_len)
83-
] # [BLOCK_SIZE_M, BLOCK_SIZE_K]
79+
).to(torch.float32)
8480

85-
acc = acc + hl.dot(a_tile, b_unpacked) # [BLOCK_SIZE_M, BLOCK_SIZE_N]
81+
a_tile = a_tile.unsqueeze(2) # [BLOCK_SIZE_M, BLOCK_SIZE_K, 1]
82+
b_unpacked = b_unpacked.unsqueeze(0)
83+
acc = acc + (a_tile * b_unpacked).sum(dim=1) # [BLOCK_SIZE_M, BLOCK_SIZE_N]
8684

8785
C[tile_m, tile_n] = acc.to(torch.bfloat16)
8886

@@ -106,14 +104,13 @@ def int4_gemm_tritonbench(tb_op: object, x: torch.Tensor, w: torch.Tensor) -> Ca
106104
Callable: A function that performs the int4 gemm.
107105
"""
108106

109-
def run_kernel() -> torch.Tensor:
110-
x_2d = x.reshape(-1, x.size(-1))
111-
112-
# Pack w to int4 format (two 4-bit values per int8 byte)
113-
w_int8 = w.to(torch.int8)
114-
w_reshaped = w_int8.reshape(w.shape[0] // 2, 2, w.shape[1]).permute(1, 0, 2)
115-
w_packed = ((w_reshaped[0] & 0xF) | (w_reshaped[1] << 4)).to(torch.int8)
107+
# Pack w to int4 format (two 4-bit values per int8 byte)
108+
x_2d = x.reshape(-1, x.size(-1))
109+
w_int8 = w.to(torch.int8)
110+
w_reshaped = w_int8.reshape(w.shape[0] // 2, 2, w.shape[1]).permute(1, 0, 2)
111+
w_packed = ((w_reshaped[0] & 0xF) | (w_reshaped[1] << 4)).to(torch.int8)
116112

113+
def run_kernel() -> torch.Tensor:
117114
return matmul_bf16_int4(x_2d, w_packed)
118115

119116
return run_kernel

test/test_examples.expected

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1343,7 +1343,7 @@ from torch._inductor.runtime import triton_helpers
13431343
from helion.runtime import default_launcher as _default_launcher
13441344

13451345
@triton.jit
1346-
def _helion_matmul_bf16_int4(B, A, C, A_stride_0, A_stride_1, B_stride_0, B_stride_1, C_stride_0, C_stride_1, M, N, K, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, mul: tl.constexpr):
1346+
def _helion_matmul_bf16_int4(A, B, C, A_stride_0, A_stride_1, B_stride_0, B_stride_1, C_stride_0, C_stride_1, M, N, K, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, mul_1: tl.constexpr):
13471347
num_blocks_0 = tl.cdiv(M, _BLOCK_SIZE_1)
13481348
pid_0 = tl.program_id(0) % num_blocks_0
13491349
pid_1 = tl.program_id(0) // num_blocks_0
@@ -1355,37 +1355,40 @@ def _helion_matmul_bf16_int4(B, A, C, A_stride_0, A_stride_1, B_stride_0, B_stri
13551355
mask_2 = indices_2 < N
13561356
acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_2], 0.0, tl.float32)
13571357
floordiv = triton_helpers.div_floor_integer(K, 2)
1358-
for offset_0 in tl.range(0, floordiv.to(tl.int32), _BLOCK_SIZE_0):
1359-
indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
1360-
mask_0 = indices_0 < floordiv
1358+
for offset_3 in tl.range(0, floordiv.to(tl.int32), _BLOCK_SIZE_0):
1359+
indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
1360+
mask_0 = indices_3 < floordiv
13611361
acc_copy = acc
13621362
acc_copy_0 = acc_copy
1363-
b_tile = tl.load(B + (indices_0[:, None] * B_stride_0 + indices_2[None, :] * B_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
1364-
v_0 = tl.full([], 4, tl.int8)
1365-
v_1 = b_tile << v_0
1366-
v_2 = tl.full([], 4, tl.int8)
1367-
v_3 = v_1 >> v_2
1368-
v_4 = tl.full([], 4, tl.int8)
1369-
v_5 = b_tile >> v_4
1370-
v_6 = tl.cast(v_3, tl.bfloat16)
1371-
v_7 = tl.cast(v_5, tl.bfloat16)
1363+
mul = 2 * offset_3
1364+
iota = mul + tl.arange(0, mul_1)
1365+
load = tl.load(A + (indices_1[:, None] * A_stride_0 + iota[None, :] * A_stride_1), mask_1[:, None], other=0)
1366+
v_0 = tl.cast(load, tl.float32)
1367+
b_tile = tl.load(B + (indices_3[:, None] * B_stride_0 + indices_2[None, :] * B_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
1368+
v_1 = tl.full([], 4, tl.int8)
1369+
v_2 = b_tile << v_1
1370+
v_3 = tl.full([], 4, tl.int8)
1371+
v_4 = v_2 >> v_3
1372+
v_5 = tl.full([], 4, tl.int8)
1373+
v_6 = b_tile >> v_5
13721374
stack_idx = tl.arange(0, 2)
13731375
broadcast_idx = stack_idx[None, :, None]
1374-
expanded_0 = tl.expand_dims(v_6, 1)
1375-
expanded_1 = tl.expand_dims(v_7, 1)
1376+
expanded_0 = tl.expand_dims(v_4, 1)
1377+
expanded_1 = tl.expand_dims(v_6, 1)
13761378
stacked_result = tl.zeros_like(expanded_0)
1377-
mask_3 = broadcast_idx == 0
1378-
stacked_result = tl.where(mask_3, expanded_0, stacked_result)
1379-
mask_4 = broadcast_idx == 1
1380-
stacked_result = tl.where(mask_4, expanded_1, stacked_result)
1381-
b_unpacked = tl.reshape(stacked_result, [2 * _BLOCK_SIZE_0, _BLOCK_SIZE_2])
1382-
mul_5 = 2 * offset_0
1383-
iota = mul_5 + tl.arange(0, mul)
1384-
a_tile = tl.load(A + (indices_1[:, None] * A_stride_0 + iota[None, :] * A_stride_1), mask_1[:, None], other=0)
1385-
dot = tl.dot(tl.cast(a_tile, tl.bfloat16), tl.cast(b_unpacked, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32)
1386-
acc = acc_copy_0 + dot
1387-
v_9 = tl.cast(acc, tl.bfloat16)
1388-
tl.store(C + (indices_1[:, None] * C_stride_0 + indices_2[None, :] * C_stride_1), v_9, mask_1[:, None] & mask_2[None, :])
1379+
mask_4 = broadcast_idx == 0
1380+
stacked_result = tl.where(mask_4, expanded_0, stacked_result)
1381+
mask_5 = broadcast_idx == 1
1382+
stacked_result = tl.where(mask_5, expanded_1, stacked_result)
1383+
view = tl.reshape(stacked_result, [2 * _BLOCK_SIZE_0, _BLOCK_SIZE_2])
1384+
v_7 = tl.cast(view, tl.float32)
1385+
a_tile_1 = v_0[:, :, None]
1386+
b_unpacked_1 = v_7[None, :, :]
1387+
v_8 = a_tile_1 * b_unpacked_1
1388+
sum_1 = tl.cast(tl.sum(v_8, 1), tl.float32)
1389+
acc = acc_copy_0 + sum_1
1390+
v_10 = tl.cast(acc, tl.bfloat16)
1391+
tl.store(C + (indices_1[:, None] * C_stride_0 + indices_2[None, :] * C_stride_1), v_10, mask_1[:, None] & mask_2[None, :])
13891392

13901393
def matmul_bf16_int4(A: Tensor, B: Tensor, *, _launcher=_default_launcher):
13911394
"""
@@ -1409,7 +1412,8 @@ def matmul_bf16_int4(A: Tensor, B: Tensor, *, _launcher=_default_launcher):
14091412
_BLOCK_SIZE_1 = 64
14101413
_BLOCK_SIZE_2 = 32
14111414
_BLOCK_SIZE_0 = 64
1412-
_launcher(_helion_matmul_bf16_int4, (triton.cdiv(M, _BLOCK_SIZE_1) * triton.cdiv(N, _BLOCK_SIZE_2),), B, A, C, A.stride(0), A.stride(1), B.stride(0), B.stride(1), C.stride(0), C.stride(1), M, N, K, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_0, 2 * _BLOCK_SIZE_0, num_warps=4, num_stages=3)
1415+
_RDIM_SIZE_3 = triton.next_power_of_2(2 * _BLOCK_SIZE_0)
1416+
_launcher(_helion_matmul_bf16_int4, (triton.cdiv(M, _BLOCK_SIZE_1) * triton.cdiv(N, _BLOCK_SIZE_2),), A, B, C, A.stride(0), A.stride(1), B.stride(0), B.stride(1), C.stride(0), C.stride(1), M, N, K, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_0, 2 * _BLOCK_SIZE_0, num_warps=4, num_stages=3)
14131417
return C
14141418

14151419
--- assertExpectedJournal(TestExamples.test_jagged_dense_add)

test/test_examples.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
from helion._testing import skipIfRocm
1717
from helion._testing import skipIfXPU
1818

19-
torch.backends.cuda.matmul.fp32_precision = "tf32"
20-
torch.backends.cudnn.conv.fp32_precision = "tf32"
19+
# torch.backends.cuda.matmul.fp32_precision = "tf32"
20+
# torch.backends.cudnn.conv.fp32_precision = "tf32"
2121

2222

2323
class TestExamples(RefEagerTestBase, TestCase):

0 commit comments

Comments
 (0)