@@ -1343,7 +1343,7 @@ from torch._inductor.runtime import triton_helpers
1343
1343
from helion.runtime import default_launcher as _default_launcher
1344
1344
1345
1345
@triton.jit
1346
- def _helion_matmul_bf16_int4(B, A , C, A_stride_0, A_stride_1, B_stride_0, B_stride_1, C_stride_0, C_stride_1, M, N, K, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, mul : tl.constexpr):
1346
+ def _helion_matmul_bf16_int4(A, B , C, A_stride_0, A_stride_1, B_stride_0, B_stride_1, C_stride_0, C_stride_1, M, N, K, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, mul_1 : tl.constexpr):
1347
1347
num_blocks_0 = tl.cdiv(M, _BLOCK_SIZE_1)
1348
1348
pid_0 = tl.program_id(0) % num_blocks_0
1349
1349
pid_1 = tl.program_id(0) // num_blocks_0
@@ -1355,37 +1355,40 @@ def _helion_matmul_bf16_int4(B, A, C, A_stride_0, A_stride_1, B_stride_0, B_stri
1355
1355
mask_2 = indices_2 < N
1356
1356
acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_2], 0.0, tl.float32)
1357
1357
floordiv = triton_helpers.div_floor_integer(K, 2)
1358
- for offset_0 in tl.range(0, floordiv.to(tl.int32), _BLOCK_SIZE_0):
1359
- indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
1360
- mask_0 = indices_0 < floordiv
1358
+ for offset_3 in tl.range(0, floordiv.to(tl.int32), _BLOCK_SIZE_0):
1359
+ indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
1360
+ mask_0 = indices_3 < floordiv
1361
1361
acc_copy = acc
1362
1362
acc_copy_0 = acc_copy
1363
- b_tile = tl.load(B + (indices_0[:, None] * B_stride_0 + indices_2[None, :] * B_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
1364
- v_0 = tl.full([], 4, tl.int8)
1365
- v_1 = b_tile << v_0
1366
- v_2 = tl.full([], 4, tl.int8)
1367
- v_3 = v_1 >> v_2
1368
- v_4 = tl.full([], 4, tl.int8)
1369
- v_5 = b_tile >> v_4
1370
- v_6 = tl.cast(v_3, tl.bfloat16)
1371
- v_7 = tl.cast(v_5, tl.bfloat16)
1363
+ mul = 2 * offset_3
1364
+ iota = mul + tl.arange(0, mul_1)
1365
+ load = tl.load(A + (indices_1[:, None] * A_stride_0 + iota[None, :] * A_stride_1), mask_1[:, None], other=0)
1366
+ v_0 = tl.cast(load, tl.float32)
1367
+ b_tile = tl.load(B + (indices_3[:, None] * B_stride_0 + indices_2[None, :] * B_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
1368
+ v_1 = tl.full([], 4, tl.int8)
1369
+ v_2 = b_tile << v_1
1370
+ v_3 = tl.full([], 4, tl.int8)
1371
+ v_4 = v_2 >> v_3
1372
+ v_5 = tl.full([], 4, tl.int8)
1373
+ v_6 = b_tile >> v_5
1372
1374
stack_idx = tl.arange(0, 2)
1373
1375
broadcast_idx = stack_idx[None, :, None]
1374
- expanded_0 = tl.expand_dims(v_6 , 1)
1375
- expanded_1 = tl.expand_dims(v_7 , 1)
1376
+ expanded_0 = tl.expand_dims(v_4 , 1)
1377
+ expanded_1 = tl.expand_dims(v_6 , 1)
1376
1378
stacked_result = tl.zeros_like(expanded_0)
1377
- mask_3 = broadcast_idx == 0
1378
- stacked_result = tl.where(mask_3, expanded_0, stacked_result)
1379
- mask_4 = broadcast_idx == 1
1380
- stacked_result = tl.where(mask_4, expanded_1, stacked_result)
1381
- b_unpacked = tl.reshape(stacked_result, [2 * _BLOCK_SIZE_0, _BLOCK_SIZE_2])
1382
- mul_5 = 2 * offset_0
1383
- iota = mul_5 + tl.arange(0, mul)
1384
- a_tile = tl.load(A + (indices_1[:, None] * A_stride_0 + iota[None, :] * A_stride_1), mask_1[:, None], other=0)
1385
- dot = tl.dot(tl.cast(a_tile, tl.bfloat16), tl.cast(b_unpacked, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32)
1386
- acc = acc_copy_0 + dot
1387
- v_9 = tl.cast(acc, tl.bfloat16)
1388
- tl.store(C + (indices_1[:, None] * C_stride_0 + indices_2[None, :] * C_stride_1), v_9, mask_1[:, None] & mask_2[None, :])
1379
+ mask_4 = broadcast_idx == 0
1380
+ stacked_result = tl.where(mask_4, expanded_0, stacked_result)
1381
+ mask_5 = broadcast_idx == 1
1382
+ stacked_result = tl.where(mask_5, expanded_1, stacked_result)
1383
+ view = tl.reshape(stacked_result, [2 * _BLOCK_SIZE_0, _BLOCK_SIZE_2])
1384
+ v_7 = tl.cast(view, tl.float32)
1385
+ a_tile_1 = v_0[:, :, None]
1386
+ b_unpacked_1 = v_7[None, :, :]
1387
+ v_8 = a_tile_1 * b_unpacked_1
1388
+ sum_1 = tl.cast(tl.sum(v_8, 1), tl.float32)
1389
+ acc = acc_copy_0 + sum_1
1390
+ v_10 = tl.cast(acc, tl.bfloat16)
1391
+ tl.store(C + (indices_1[:, None] * C_stride_0 + indices_2[None, :] * C_stride_1), v_10, mask_1[:, None] & mask_2[None, :])
1389
1392
1390
1393
def matmul_bf16_int4(A: Tensor, B: Tensor, *, _launcher=_default_launcher):
1391
1394
"""
@@ -1409,7 +1412,8 @@ def matmul_bf16_int4(A: Tensor, B: Tensor, *, _launcher=_default_launcher):
1409
1412
_BLOCK_SIZE_1 = 64
1410
1413
_BLOCK_SIZE_2 = 32
1411
1414
_BLOCK_SIZE_0 = 64
1412
- _launcher(_helion_matmul_bf16_int4, (triton.cdiv(M, _BLOCK_SIZE_1) * triton.cdiv(N, _BLOCK_SIZE_2),), B, A, C, A.stride(0), A.stride(1), B.stride(0), B.stride(1), C.stride(0), C.stride(1), M, N, K, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_0, 2 * _BLOCK_SIZE_0, num_warps=4, num_stages=3)
1415
+ _RDIM_SIZE_3 = triton.next_power_of_2(2 * _BLOCK_SIZE_0)
1416
+ _launcher(_helion_matmul_bf16_int4, (triton.cdiv(M, _BLOCK_SIZE_1) * triton.cdiv(N, _BLOCK_SIZE_2),), A, B, C, A.stride(0), A.stride(1), B.stride(0), B.stride(1), C.stride(0), C.stride(1), M, N, K, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_0, 2 * _BLOCK_SIZE_0, num_warps=4, num_stages=3)
1413
1417
return C
1414
1418
1415
1419
--- assertExpectedJournal(TestExamples.test_jagged_dense_add)
0 commit comments