Skip to content

Commit b530394

Browse files
knwngwdziurdz
authored andcommitted
[Bench][AMD]Support Padding and Unswizzling Scale on CDNA4 (#8803)
This PR supports - the `CDNA4MXScaleLayout.unswizzle_data` method used in GPT-OSS model - padding tensors with 0 when doing scale preshuffling Signed-off-by: Witold Dziurdz <[email protected]>
1 parent bc0c8de commit b530394

File tree

3 files changed

+44
-13
lines changed

3 files changed

+44
-13
lines changed

python/triton_kernels/tests/test_matmul.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -369,8 +369,6 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, has_y_gamm
369369
pytest.skip("Scale preshuffling on AMD GPU has not been emulated on non-CDNA4 arch yet.")
370370
if "mx" not in weight_dtype_str:
371371
pytest.skip("Non-scale swizzling not supported on CDNA4 yet")
372-
if n % 32 != 0 or k % (32 * 8) != 0:
373-
pytest.skip(f"Shape {m}x{n}x{k} is not supported for scale swizzling on AMD GPU")
374372
if is_cuda():
375373
if torch.cuda.get_device_capability()[0] < 9:
376374
pytest.skip("NYI. Ampere swizzling.")
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import pytest
2+
import torch
3+
from triton_kernels.tensor_details.layout import CDNA4MXScaleLayout
4+
5+
# ------------------------------------------------------------
6+
# Torch tests
7+
# ------------------------------------------------------------
8+
9+
10+
@pytest.mark.parametrize(
11+
"shape",
12+
[
13+
(3, 4096, 1024),
14+
(10, 254, 60),
15+
(1, 320, 160),
16+
(2, 16, 512),
17+
(3, 2, 36),
18+
],
19+
)
20+
def test_mxfp4_scale_roundtrip(shape):
21+
x = torch.randint(0, 256, shape, dtype=torch.uint8, device="cuda")
22+
layout = CDNA4MXScaleLayout(x.shape)
23+
res = layout.unswizzle_data(layout.swizzle_data(x))
24+
assert (res == x).all()

python/triton_kernels/triton_kernels/tensor_details/layout_details/cdna4_scale.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import math
2+
import torch
13
from dataclasses import dataclass
24
import triton
35
import triton.language as tl
@@ -12,24 +14,31 @@ class CDNA4MXScaleLayout(Layout):
1214

1315
def __init__(self, shape) -> None:
1416
super().__init__(shape)
17+
(
18+
*self.leading_shape,
19+
self.K_SCALE,
20+
self.N,
21+
) = shape
22+
self.B = math.prod(self.leading_shape)
23+
self.ALIGN_K_SCALE = 8
24+
self.ALIGN_N = 32
25+
self.K_SCALE_pad = math.ceil(self.K_SCALE / self.ALIGN_K_SCALE) * self.ALIGN_K_SCALE
26+
self.N_pad = math.ceil(self.N / self.ALIGN_N) * self.ALIGN_N
1527

1628
def swizzle_data(self, data):
17-
block_shape = data.shape
18-
SCALE_K = block_shape[-2]
19-
N = block_shape[-1]
29+
data = torch.nn.functional.pad(data, (0, self.N_pad - self.N, 0, self.K_SCALE_pad - self.K_SCALE))
2030
data = data.transpose(-1, -2)
21-
data = data.view(-1, N // NON_K_PRESHUFFLE_BLOCK_SIZE, 2, 16, SCALE_K // 8, 2, 4, 1)
31+
data = data.view(-1, self.N_pad // NON_K_PRESHUFFLE_BLOCK_SIZE, 2, 16, self.K_SCALE_pad // 8, 2, 4, 1)
2232
data = data.permute(0, 1, 4, 6, 3, 5, 2, 7).contiguous()
23-
if len(block_shape) == 3:
24-
E = block_shape[0]
25-
data = data.reshape(E, N // 32, SCALE_K * 32)
26-
else:
27-
assert len(block_shape) == 2
28-
data = data.reshape(N // 32, SCALE_K * 32)
33+
data = data.reshape(self.B, self.N_pad // 32, self.K_SCALE_pad * 32)
2934
return data.transpose(-1, -2)
3035

3136
def unswizzle_data(self, data):
32-
raise NotImplementedError()
37+
data = data.transpose(-1, -2)
38+
data = data.view(-1, self.N_pad // NON_K_PRESHUFFLE_BLOCK_SIZE, self.K_SCALE_pad // 8, 4, 16, 2, 2, 1)
39+
data = data.permute(0, 1, 6, 4, 2, 5, 3, 7)
40+
data = data.reshape(*self.leading_shape, self.N_pad, self.K_SCALE_pad)
41+
return data.transpose(-1, -2)[..., :self.K_SCALE, :self.N]
3342

3443
def swizzle_block_shape(self, block_shape):
3544
SCALE_K = block_shape[-2]

0 commit comments

Comments
 (0)