Skip to content

Commit 07a35b6

Browse files
integrate torch._scaled_mm into Float8BlockwiseLinear and add bench script
stack-info: PR: #2785, branch: danielvegamyhre/stack/44
1 parent fbe08c3 commit 07a35b6

File tree

5 files changed

+282
-38
lines changed

5 files changed

+282
-38
lines changed

benchmarks/prototype/blockwise_fp8_training/bench_1x128_128x128_gemms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def get_configs() -> List[ExperimentConfig]:
5858
(16640, 5120, 8192),
5959
(16640, 8192, 5120),
6060
]
61-
out_dtypes = [torch.float32, torch.bfloat16]
61+
out_dtypes = [torch.bfloat16]
6262
configs = []
6363
for mnk, out_dtype in itertools.product(mnk_list, out_dtypes):
6464
m, n, k = mnk

benchmarks/prototype/blockwise_fp8_training/bench_1x128_128x1_gemms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def get_configs() -> List[ExperimentConfig]:
5858
(16640, 5120, 8192),
5959
(16640, 8192, 5120),
6060
]
61-
out_dtypes = [torch.float32, torch.bfloat16]
61+
out_dtypes = [torch.bfloat16]
6262
configs = []
6363
for mnk, out_dtype in itertools.product(mnk_list, out_dtypes):
6464
m, n, k = mnk
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py
7+
8+
import itertools
9+
from dataclasses import dataclass
10+
from typing import List
11+
12+
import torch
13+
from tabulate import tabulate
14+
from torch.nn import functional as F
15+
from tqdm import tqdm
16+
from triton.testing import do_bench
17+
18+
from torchao.prototype.blockwise_fp8_training.linear import Float8BlockwiseLinear
19+
20+
device = torch.device("cuda")
21+
22+
# This benchmark requires CUDA 12.9+
23+
assert torch.version.cuda is not None, "CUDA is not available"
24+
cuda_major, cuda_minor = map(int, torch.version.cuda.split("."))
25+
assert cuda_major >= 12 and cuda_minor >= 9, "CUDA 12.9+ is required"
26+
27+
# Needed since changing args to function causes recompiles
28+
torch._dynamo.config.cache_size_limit = 1000
29+
30+
31+
@dataclass(frozen=True)
32+
class ExperimentConfig:
33+
out_dtype: torch.dtype
34+
m: int
35+
n: int
36+
k: int
37+
38+
39+
@dataclass(frozen=True)
40+
class ExperimentResult:
41+
bf16_linear_us: float
42+
fp8_triton_linear_us: float
43+
fp8_scaled_mm_linear_us: float
44+
45+
46+
@dataclass(frozen=True)
47+
class Experiment:
48+
config: ExperimentConfig
49+
result: ExperimentResult
50+
51+
52+
def get_configs() -> List[ExperimentConfig]:
53+
mnk_list = [
54+
# Llama4 shapes
55+
(16640, 5120, 8192),
56+
(16640, 8192, 5120),
57+
]
58+
out_dtypes = [torch.bfloat16]
59+
configs = []
60+
for mnk, out_dtype in itertools.product(mnk_list, out_dtypes):
61+
m, n, k = mnk
62+
configs.append(
63+
ExperimentConfig(
64+
out_dtype=out_dtype,
65+
m=m,
66+
n=n,
67+
k=k,
68+
)
69+
)
70+
return configs
71+
72+
73+
def run_experiment(config: ExperimentConfig) -> ExperimentResult:
74+
M, N, K = config.m, config.n, config.k
75+
inputs = torch.randn(M, K, dtype=config.out_dtype, device="cuda")
76+
bf16_linear = torch.nn.Linear(K, N, dtype=config.out_dtype, device="cuda")
77+
fp8_triton_linear = Float8BlockwiseLinear(
78+
K, N, dtype=config.out_dtype, device="cuda", use_triton=True
79+
)
80+
fp8_scaled_mm_linear = Float8BlockwiseLinear(
81+
K, N, dtype=config.out_dtype, device="cuda", use_triton=False
82+
)
83+
84+
def warmup(func, *args, **kwargs):
85+
for _ in range(10):
86+
func(*args, **kwargs)
87+
88+
def fwd_bwd(func, inputs, labels, *args, **kwargs):
89+
out = func(inputs, *args, **kwargs)
90+
loss = F.mse_loss(out, labels)
91+
loss.backward()
92+
torch.cuda.synchronize()
93+
94+
# Warmup then run bf16 torch.mm
95+
labels = inputs.new_empty(M, N).fill_(1.0)
96+
warmup(fwd_bwd, bf16_linear, inputs, labels)
97+
98+
bf16_linear_us = benchmark_cuda_function_in_microseconds(
99+
fwd_bwd, bf16_linear, inputs, labels
100+
)
101+
102+
# Warm up then run triton bench
103+
warmup(
104+
fwd_bwd,
105+
fp8_triton_linear,
106+
inputs,
107+
labels,
108+
)
109+
110+
fp8_triton_linear_us = benchmark_cuda_function_in_microseconds(
111+
fwd_bwd,
112+
fp8_triton_linear,
113+
inputs,
114+
labels,
115+
)
116+
117+
warmup(
118+
fwd_bwd,
119+
fp8_scaled_mm_linear,
120+
inputs,
121+
labels,
122+
)
123+
124+
fp8_scaled_mm_linear_us = benchmark_cuda_function_in_microseconds(
125+
fwd_bwd,
126+
fp8_scaled_mm_linear,
127+
inputs,
128+
labels,
129+
)
130+
131+
return ExperimentResult(
132+
bf16_linear_us=bf16_linear_us,
133+
fp8_triton_linear_us=fp8_triton_linear_us,
134+
fp8_scaled_mm_linear_us=fp8_scaled_mm_linear_us,
135+
)
136+
137+
138+
def print_results(experiments: List[Experiment]):
139+
headers = [
140+
"M",
141+
"N",
142+
"K",
143+
"out_dtype",
144+
"bf16_mm_linear_us",
145+
"fp8_triton_linear_us",
146+
"fp8_scaled_mm_linear_us",
147+
]
148+
rows = []
149+
for experiment in experiments:
150+
m, n, k = experiment.config.m, experiment.config.n, experiment.config.k
151+
rows.append(
152+
[
153+
m,
154+
n,
155+
k,
156+
experiment.config.out_dtype,
157+
experiment.result.bf16_linear_us,
158+
experiment.result.fp8_triton_linear_us,
159+
experiment.result.fp8_scaled_mm_linear_us,
160+
]
161+
)
162+
print(tabulate(rows, headers=headers))
163+
164+
165+
def benchmark_cuda_function_in_microseconds(f, *args, **kwargs):
166+
return do_bench(lambda: f(*args, **kwargs), return_mode="median") * 1e3
167+
168+
169+
def main():
170+
torch.random.manual_seed(123)
171+
configs = get_configs()
172+
results = []
173+
for config in tqdm(configs):
174+
result = run_experiment(config)
175+
results.append(Experiment(config=config, result=result))
176+
177+
# Use Tabulate to print results
178+
print_results(results)
179+
180+
181+
if __name__ == "__main__":
182+
main()

torchao/prototype/blockwise_fp8_training/kernels.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,6 @@
2626
for num_stages in [2, 4]
2727
]
2828

29-
# For fast compile times during development.
30-
dev_fp8_gemm_configs = [
31-
triton.Config(
32-
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128}, num_warps=4, num_stages=3
33-
),
34-
]
35-
3629
EPS = 1e-12
3730

3831

@@ -115,9 +108,9 @@ def blockwise_fp8_gemm_1x128_128x128(
115108
"a must be row-major, b must be column-major"
116109
)
117110

118-
# a_scales must be row-major, b_scales must be column-major
119-
assert _is_row_major(a_s) and _is_column_major(b_s), (
120-
"a_s must be row-major, b_s must be column-major"
111+
# a_scales must be col-major, b_scales must be column-major
112+
assert _is_column_major(a_s) and _is_column_major(b_s), (
113+
"a_s must be col-major, b_s must be column-major"
121114
)
122115

123116
M = a.size(0)
@@ -229,7 +222,9 @@ def blockwise_fp8_gemm_1x128_128x1(
229222
):
230223
# 'a' must be in row-major layout, 'b' must be in column-major layout
231224
assert a.is_contiguous() and not b.is_contiguous()
232-
assert a_s.is_contiguous() and b_s.is_contiguous()
225+
226+
# a_scales must be col-major
227+
assert not a_s.is_contiguous() and b_s.is_contiguous()
233228
M = a.size(0)
234229
K = a.size(1)
235230
N = b.size(1)
@@ -260,6 +255,19 @@ def blockwise_fp8_gemm_1x128_128x1(
260255
return c
261256

262257

258+
# Quantization kernels autotuner configs
259+
quant_kernel_configs = [
260+
triton.Config(
261+
{},
262+
num_warps=warps,
263+
num_stages=stages,
264+
)
265+
for warps in [4, 8]
266+
for stages in [2, 4, 6]
267+
]
268+
269+
270+
@triton.autotune(configs=quant_kernel_configs, key=["K"])
263271
@triton.jit
264272
def fp8_blockwise_act_quant_lhs_kernel(
265273
x_ptr,
@@ -320,7 +328,11 @@ def fp8_blockwise_act_quant_lhs(
320328
], "dtype must be torch.float8_e4m3fn"
321329
M, K = x.size()
322330
y = torch.empty_like(x, dtype=dtype)
323-
s = x.new_empty(M, K // block_size, dtype=torch.float32)
331+
# Write scales to column-major format to align with torch._scaled_mm requirements.
332+
s = x.new_empty(M, K // block_size, dtype=torch.float32).as_strided(
333+
(M, K // block_size),
334+
(1, M),
335+
)
324336
grid = lambda meta: (M, triton.cdiv(K, meta["BLOCK_SIZE"]))
325337
fp8_blockwise_act_quant_lhs_kernel[grid](
326338
x,
@@ -340,6 +352,7 @@ def fp8_blockwise_act_quant_lhs(
340352
return y, s
341353

342354

355+
@triton.autotune(configs=quant_kernel_configs, key=["K"])
343356
@triton.jit
344357
def fp8_blockwise_act_quant_rhs_kernel(
345358
x_ptr,
@@ -424,6 +437,7 @@ def fp8_blockwise_act_quant_rhs(
424437
return y, s
425438

426439

440+
@triton.autotune(configs=quant_kernel_configs, key=["K"])
427441
@triton.jit
428442
def fp8_blockwise_act_quant_transposed_lhs_kernel(
429443
x_ptr,
@@ -497,7 +511,13 @@ def fp8_blockwise_act_quant_transposed_lhs(
497511
# Output should have transposed dims and be in row major format
498512
M, K = x.shape
499513
y = torch.empty(K, M, dtype=dtype, device=x.device)
500-
s = x.new_empty(K, triton.cdiv(M, block_size), dtype=torch.float32)
514+
M_blocks = triton.cdiv(M, block_size)
515+
516+
# Column major scales required for torch._scaled_mm
517+
s = x.new_empty(K, M_blocks, dtype=torch.float32).as_strided(
518+
(K, M_blocks), # shape
519+
(1, K), # stride
520+
)
501521
grid = lambda meta: (
502522
triton.cdiv(M, meta["SCALE_BLOCK_SIZE"]),
503523
triton.cdiv(K, meta["BLOCK_SIZE_K"]),
@@ -522,6 +542,7 @@ def fp8_blockwise_act_quant_transposed_lhs(
522542
return y, s
523543

524544

545+
@triton.autotune(configs=quant_kernel_configs, key=["M", "N"])
525546
@triton.jit
526547
def fp8_blockwise_weight_quant_rhs_kernel(
527548
x_ptr,
@@ -582,8 +603,10 @@ def fp8_blockwise_weight_quant_rhs(
582603
M, N = x.size()
583604
y = torch.empty_like(x, dtype=dtype)
584605
y = y.as_strided(y.size(), (1, y.size(0))) # Column major
585-
s = x.new_empty(
586-
triton.cdiv(M, block_size), triton.cdiv(N, block_size), dtype=torch.float32
606+
M_blocks, N_blocks = triton.cdiv(M, block_size), triton.cdiv(N, block_size)
607+
s = x.new_empty(M_blocks, N_blocks, dtype=torch.float32).as_strided(
608+
(M_blocks, N_blocks), # shape
609+
(1, M_blocks), # stride
587610
)
588611
grid = lambda meta: (
589612
triton.cdiv(M, meta["BLOCK_SIZE"]),
@@ -607,6 +630,7 @@ def fp8_blockwise_weight_quant_rhs(
607630
return y, s
608631

609632

633+
@triton.autotune(configs=quant_kernel_configs, key=["M", "N"])
610634
@triton.jit
611635
def fp8_blockwise_weight_quant_transposed_rhs_kernel(
612636
x_ptr,

0 commit comments

Comments
 (0)