Skip to content

Commit bc0c8de

Browse files
aeng-openaiwdziurdz
authored andcommitted
[KERNELS] fix persistent matmul heuristics (#8791)
any mxfp where natively supported requires using the persistent matmul kernel. in these cases, do not use heuristics to resolve `is_persistent` Signed-off-by: Witold Dziurdz <[email protected]>
1 parent afa8452 commit bc0c8de

File tree

1 file changed

+6
-1
lines changed
  • python/triton_kernels/triton_kernels/matmul_ogs_details

1 file changed

+6
-1
lines changed

python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
from dataclasses import dataclass
44

55
import triton
6+
from triton_kernels import target_info
67
from triton_kernels.target_info import get_cdna_version
78
from triton_kernels.tensor import FP4
89
import torch
910
from .opt_flags_details import opt_flags_amd, opt_flags_nvidia, opt_flags_intel
10-
from triton_kernels.tensor import bitwidth
11+
from triton_kernels.tensor import bitwidth, get_layout
1112

1213

1314
@dataclass
@@ -297,8 +298,12 @@ def make_default_opt_flags_nvidia(
297298
n_sms = torch.cuda.get_device_properties(0).multi_processor_count
298299
tiles_per_sm = grid_size_tma / n_sms
299300
supports_persistent = can_use_persistent_tma and (arch is None or int(arch[2:-1]) >= 9)
301+
requires_persistent = (get_layout(precision_config.act_scale) is not None or get_layout(precision_config.weight_scale) is not None) and target_info.has_native_mxfp()
300302
if constraints.get("is_persistent", None) is not None:
301303
is_persistent = constraints["is_persistent"]
304+
elif requires_persistent:
305+
assert supports_persistent, "persistent kernel required but not supported"
306+
is_persistent = True
302307
else:
303308
has_simple_epilogue = precision_config.max_num_imprecise_acc is None
304309
is_persistent = supports_persistent and has_simple_epilogue and (tiles_per_sm >= 2.0 or lhs_dtype.itemsize <= 1) and out_dtype.itemsize < 4

0 commit comments

Comments
 (0)