Skip to content
11 changes: 3 additions & 8 deletions benchmarks/float8/bench_grouped_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def run(

# Run bf16 torch._grouped_mm baseline.
A = torch.randn(M, K, device=device, dtype=dtype)
B = torch.randn(E, K, N, device=device, dtype=dtype)
B = torch.randn(E, N, K, device=device, dtype=dtype)
offs = generate_jagged_offs(E, M)
print(f"offs: {offs}")
ref_time_sec, ref_tops_sec, ref_pct_top_peak = do_benchmarks(
Expand All @@ -73,7 +73,7 @@ def run(
use_gpu_kernel_time,
torch._grouped_mm,
A,
B,
B.transpose(-2, -1),
offs,
)
print(
Expand All @@ -84,12 +84,7 @@ def run(

# Run scaled_grouped_mm.
A_hp = torch.randn(M, K, device=device)
B_hp_t = (
torch.randn(E, K, N, device=device)
.transpose(-2, -1)
.contiguous()
.transpose(-2, -1)
)
B_hp_t = torch.randn(E, N, K, device=device).transpose(-2, -1)

if recipe == "rowwise":
# TODO: add e5m2
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/float8/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def get_name_to_moe_shapes_iter(
N: Optional[int] = None,
E: Optional[int] = None,
):
M = 8192 if M is None else M
M = 16640 if M is None else M
if shape_gen_name == "llama4_17bx16e":
# num_experts=16, dim=5120
names_to_shapes = {
Expand All @@ -232,8 +232,8 @@ def get_name_to_moe_shapes_iter(
# num_experts=128, dim=5120
names_to_shapes = {
# M, K, N, E
"moe.experts.w1": (M, 5120, 8192, 128),
"moe.experts.w2": (M, 8192, 5120, 128),
"moe.experts.w1": (M, 5120, 4 * 5120, 128),
"moe.experts.w2": (M, 4 * 5120, 5120, 128),
}
return names_to_shapes.items()
elif shape_gen_name == "custom":
Expand Down
190 changes: 0 additions & 190 deletions benchmarks/prototype/moe_training/benchmark_kernels.py

This file was deleted.

Loading
Loading