Skip to content

Commit 69bc643

Browse files
committed
Add third draft of mm_fp4 backend -- no audotune
1 parent 2797e2b commit 69bc643

File tree

2 files changed

+4
-8
lines changed

2 files changed

+4
-8
lines changed

benchmarks/routines/gemm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -790,7 +790,7 @@ def testMmFp4(args):
790790
run_refcheck = args.refcheck
791791
use_128x4_sf_layout = args.use_128x4_sf_layout
792792
use_nvfp4 = args.use_nvfp4
793-
autotune_supported_backends = ["cutlass", "trtllm"]
793+
autotune_supported_backends = ["cutlass", "trtllm", "auto"]
794794
res = []
795795

796796
backends = filter_backends_by_compute_capability(backends, args.routine, device)

flashinfer/gemm/gemm_base.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1871,7 +1871,6 @@ def _auto_gemm_fp4_requirement(
18711871
checker, "is_compute_capability_supported"
18721872
) and checker.is_compute_capability_supported(cc_arch):
18731873
# At least one backend is supported
1874-
print(f"Backend {candidate} is supported on this device.")
18751874
return True
18761875

18771876
# No backend is supported on this device
@@ -1978,8 +1977,9 @@ def mm_fp4(
19781977
if backend == "auto":
19791978
cuda_major, _ = get_cuda_version(a.device)
19801979
cc_major, cc_minor = get_compute_capability(a.device)
1981-
# If cuda version is 13 or greater AND cudnn version is 9.X or greater, prioritize cudnn.
1982-
if cuda_major >= 13: # to-do add cudnn version threshold
1980+
# If cuda version is 13 or greater:
1981+
# cudnn is more performant if cudnn version is 9.14 or greater.
1982+
if cuda_major >= 13 and cudnn.backend_version() >= 91400:
19831983
candidate_backends = ("cudnn", "cutlass")
19841984
# Otherwise, prioritize cutlass
19851985
else:
@@ -2010,11 +2010,7 @@ def mm_fp4(
20102010
supported_backends.append(candidate)
20112011
except Exception:
20122012
pass
2013-
print(f"Supported backends: {supported_backends}")
20142013
selected_backend = supported_backends[0]
2015-
print(
2016-
f"Selected backend: {selected_backend} for cuda version {cuda_major} and compute capability {cc_major}{cc_minor}"
2017-
)
20182014
else:
20192015
selected_backend = backend
20202016
if selected_backend == "cudnn":

0 commit comments

Comments
 (0)