Skip to content

integrate torch._scaled_mm into Float8BlockwiseLinear and add bench script #2785

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
from triton.testing import do_bench

from torchao.prototype.blockwise_fp8_training.kernels import (
blockwise_fp8_gemm_1x128_128x128,
fp8_blockwise_act_quant_lhs,
fp8_blockwise_weight_quant_transposed_rhs,
triton_fp8_gemm_1x128_128x128,
)

device = torch.device("cuda")
Expand Down Expand Up @@ -58,7 +58,7 @@ def get_configs() -> List[ExperimentConfig]:
(16640, 5120, 8192),
(16640, 8192, 5120),
]
out_dtypes = [torch.float32, torch.bfloat16]
out_dtypes = [torch.bfloat16]
configs = []
for mnk, out_dtype in itertools.product(mnk_list, out_dtypes):
m, n, k = mnk
Expand Down Expand Up @@ -94,19 +94,21 @@ def warmup(func, *args, **kwargs):

# Warm up then run triton bench
warmup(
blockwise_fp8_gemm_1x128_128x128,
triton_fp8_gemm_1x128_128x128,
A_q,
1.0 / A_s,
B_t_q,
1.0 / A_s,
1.0 / B_t_s,
out_dtype=config.out_dtype,
)

fp8_triton_us = benchmark_cuda_function_in_microseconds(
blockwise_fp8_gemm_1x128_128x128,
triton_fp8_gemm_1x128_128x128,
A_q,
1.0 / A_s,
B_t_q,
1.0 / A_s,
1.0 / B_t_s,
out_dtype=config.out_dtype,
)

# Warm up then run torch bench
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
from triton.testing import do_bench

from torchao.prototype.blockwise_fp8_training.kernels import (
blockwise_fp8_gemm_1x128_128x1,
fp8_blockwise_act_quant_rhs,
fp8_blockwise_act_quant_transposed_lhs,
triton_fp8_gemm_1x128_128x1,
)

device = torch.device("cuda")
Expand Down Expand Up @@ -58,7 +58,7 @@ def get_configs() -> List[ExperimentConfig]:
(16640, 5120, 8192),
(16640, 8192, 5120),
]
out_dtypes = [torch.float32, torch.bfloat16]
out_dtypes = [torch.bfloat16]
configs = []
for mnk, out_dtype in itertools.product(mnk_list, out_dtypes):
m, n, k = mnk
Expand Down Expand Up @@ -92,24 +92,23 @@ def warmup(func, *args, **kwargs):

# Warm up then run triton bench
warmup(
blockwise_fp8_gemm_1x128_128x1,
triton_fp8_gemm_1x128_128x1,
A_t_q,
1.0 / A_t_s,
B_q,
1.0 / A_t_s,
1.0 / B_s,
out_dtype=config.out_dtype,
)

fp8_triton_us = benchmark_cuda_function_in_microseconds(
blockwise_fp8_gemm_1x128_128x1,
triton_fp8_gemm_1x128_128x1,
A_t_q,
1.0 / A_t_s,
B_q,
1.0 / A_t_s,
1.0 / B_s,
out_dtype=config.out_dtype,
)

# torch._scaled_mm requires A_s and B_t_s be in column-major format
A_t_s = A_t_s.t().contiguous().t()

# Warm up then run torch bench
warmup(
torch._scaled_mm,
Expand Down
181 changes: 181 additions & 0 deletions benchmarks/prototype/blockwise_fp8_training/bench_linear_fwd_bwd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py

import itertools
from dataclasses import dataclass
from typing import List

import torch
from tabulate import tabulate
from torch.nn import functional as F
from tqdm import tqdm
from triton.testing import do_bench

from benchmarks.utils import bench_fwd_bwd_microseconds
from torchao.prototype.blockwise_fp8_training.linear import Float8BlockwiseLinear

device = torch.device("cuda")

# This benchmark requires CUDA 12.9+
assert torch.version.cuda is not None, "CUDA is not available"
cuda_major, cuda_minor = map(int, torch.version.cuda.split("."))
assert cuda_major >= 12 and cuda_minor >= 9, "CUDA 12.9+ is required"

# Needed since changing args to function causes recompiles
torch._dynamo.config.cache_size_limit = 1000


@dataclass(frozen=True)
class ExperimentConfig:
out_dtype: torch.dtype
m: int
n: int
k: int


@dataclass(frozen=True)
class ExperimentResult:
bf16_linear_us: float
fp8_triton_linear_us: float
fp8_scaled_mm_linear_us: float


@dataclass(frozen=True)
class Experiment:
config: ExperimentConfig
result: ExperimentResult


def get_configs() -> List[ExperimentConfig]:
mnk_list = [
# Llama4 shapes
(16640, 5120, 8192),
(16640, 8192, 5120),
]
out_dtypes = [torch.bfloat16]
configs = []
for mnk, out_dtype in itertools.product(mnk_list, out_dtypes):
m, n, k = mnk
configs.append(
ExperimentConfig(
out_dtype=out_dtype,
m=m,
n=n,
k=k,
)
)
return configs


def run_experiment(config: ExperimentConfig) -> ExperimentResult:
M, N, K = config.m, config.n, config.k
inputs = torch.randn(M, K, dtype=config.out_dtype, device="cuda")
bf16_linear = torch.nn.Linear(K, N, dtype=config.out_dtype, device="cuda")
fp8_triton_linear = Float8BlockwiseLinear(
K, N, dtype=config.out_dtype, device="cuda", use_triton=True
)
fp8_scaled_mm_linear = Float8BlockwiseLinear(
K, N, dtype=config.out_dtype, device="cuda", use_triton=False
)

def warmup(func, *args, **kwargs):
for _ in range(10):
func(*args, **kwargs)

def fwd_bwd(func, inputs, labels, *args, **kwargs):
out = func(inputs, *args, **kwargs)
loss = F.mse_loss(out, labels)
loss.backward()
torch.cuda.synchronize()

# Warmup then run bf16 torch.mm
labels = inputs.new_empty(M, N).fill_(1.0)
warmup(fwd_bwd, bf16_linear, inputs, labels)

bf16_linear_us = benchmark_cuda_function_in_microseconds(
fwd_bwd, bf16_linear, inputs, labels
)

# Warm up then run triton bench
warmup(
fwd_bwd,
fp8_triton_linear,
inputs,
labels,
)

fp8_triton_linear_us = bench_fwd_bwd_microseconds(
fp8_triton_linear,
inputs,
labels=labels,
)

warmup(
fwd_bwd,
fp8_scaled_mm_linear,
inputs,
labels,
)

fp8_scaled_mm_linear_us = bench_fwd_bwd_microseconds(
fp8_scaled_mm_linear,
inputs,
labels=labels,
)

return ExperimentResult(
bf16_linear_us=bf16_linear_us,
fp8_triton_linear_us=fp8_triton_linear_us,
fp8_scaled_mm_linear_us=fp8_scaled_mm_linear_us,
)


def print_results(experiments: List[Experiment]):
headers = [
"M",
"N",
"K",
"out_dtype",
"bf16_mm_linear_us",
"fp8_triton_linear_us",
"fp8_scaled_mm_linear_us",
]
rows = []
for experiment in experiments:
m, n, k = experiment.config.m, experiment.config.n, experiment.config.k
rows.append(
[
m,
n,
k,
experiment.config.out_dtype,
experiment.result.bf16_linear_us,
experiment.result.fp8_triton_linear_us,
experiment.result.fp8_scaled_mm_linear_us,
]
)
print(tabulate(rows, headers=headers))


def benchmark_cuda_function_in_microseconds(f, *args, **kwargs):
return do_bench(lambda: f(*args, **kwargs), return_mode="median") * 1e3


def main():
torch.random.manual_seed(123)
configs = get_configs()
results = []
for config in tqdm(configs):
result = run_experiment(config)
results.append(Experiment(config=config, result=result))

# Use Tabulate to print results
print_results(results)


if __name__ == "__main__":
main()
5 changes: 1 addition & 4 deletions benchmarks/prototype/moe_training/benchmark_moe_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,7 @@
from torch.distributed._composable.fsdp import fully_shard
from torch.nn import functional as F

from benchmarks.prototype.moe_training.utils import (
bench_fwd_bwd_microseconds,
profile_fwd_bwd,
)
from benchmarks.utils import bench_fwd_bwd_microseconds, profile_fwd_bwd

# this feature requires CUDA and SM89+
if not torch.cuda.is_available() or torch.cuda.get_device_capability() < (8, 9):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
import torch
from tabulate import tabulate
from tqdm import tqdm
from utils import bench_fwd_bwd_microseconds, profile_fwd_bwd

from benchmarks.utils import bench_fwd_bwd_microseconds, profile_fwd_bwd
from torchao.prototype.moe_training import _scaled_grouped_mm
from torchao.prototype.moe_training.conversion_utils import MoEScalingType
from torchao.prototype.moe_training.utils import generate_jagged_offs
Expand Down
File renamed without changes.
Loading
Loading