Skip to content

Commit 0aee7af

Browse files
authored
feat: Add backend='auto' to mm_fp4 and enable autotune for backend='cudnn' (#1979)
<!-- .github/pull_request_template.md --> ## 📌 Description Current PR: * Introduces an `auto` backend to `mm_fp4` that can be autotuned. **It replaces `cudnn` as the default.** * Implementation matches `bmm_fp8`'s auto backend support. * Allows `cudnn` backend to be autotuned. * Added unit test test cases for backend=auto Behavior of `auto` backend: * Examines CUDA version & cuDNN version and calls either `cutlass` or `cudnn` kernel backends. `trtllm` kernel is not considered due to a non-interchangeable interface with other backends. * `auto` backend therefore only supports inputs runnable by `cutlass` and/or `cudnn. * Non-autotuned behavior: * Constructs an ordered list of backends (cudnn, cutlass) or (cutlass, cudnn) where ordering is based on previous microbenchmark study results. * If CUDA 12 --> cutlass comes to front. * If CUDA 13 and cuDNN version < 9.15 --> cutlass comes front * If CUDA 13 and cuDNN version >= 9.15 --> cudnn comes front * If kernel is not available from a support check, it is removed from the list. * Autotune behavior: * If backend is explicitly provided --> Autotunes within the backend. Same as previous behavior, but now autotuning is supported for cudnn. * If `backend='auto'` --> Autotunes within and across backends (cudnn & cutlass) and chooses the best config of best backend. `trtllm` kernel is not considered * A lot of helper functions to `mm_fp4` were refactored to enable cross-backend autotuning. Refactoring was done to match cross-backend autotune-enabled `bmm_fp8` as a reference. ### Pytest outputs `pytest tests/gemm/test_mm_fp4.py` * SM100 (B200) CUDA 13 & cuDNN 9.15: `900 passed, 2532 skipped in 125.19s (0:02:05)` * SM100 (B200) CUDA 12 & cuDNN 9.15: `900 passed, 2532 skipped in 125.67s (0:02:05)` * SM120 (RTX 5090) CUDA 13 & cuDNN 9.15: `720 passed, 2712 skipped in 76.50s (0:01:16)` ### Example microbenchmark outputs: On SM100 (B200) CUDA 13 & cuDNN 9.15 ``` flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --use_nvfp4 --refcheck [PERF] cudnn :: median time 0.018 ms; std 0.000 ms; achieved tflops 3797.932 TFLOPs/sec; achieved tb_per_sec 1.884 TB/sec [PERF] cutlass :: median time 0.020 ms; std 0.000 ms; achieved tflops 3440.640 TFLOPs/sec; achieved tb_per_sec 1.707 TB/sec [PERF] trtllm :: median time 0.031 ms; std 0.000 ms; achieved tflops 2187.427 TFLOPs/sec; achieved tb_per_sec 1.085 TB/sec [PERF] auto :: median time 0.018 ms; std 0.000 ms; achieved tflops 3840.714 TFLOPs/sec; achieved tb_per_sec 1.905 TB/sec /flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --refcheck [INFO] cutlass backend does not support this configuration: ValueError: Only cudnn and auto FP4 GEMM supports mxfp4 quantization. [INFO] trtllm backend does not support this configuration: ValueError: Only cudnn and auto FP4 GEMM supports mxfp4 quantization. [PERF] cudnn :: median time 0.021 ms; std 0.000 ms; achieved tflops 3238.249 TFLOPs/sec; achieved tb_per_sec 1.606 TB/sec [PERF] auto :: median time 0.021 ms; std 0.000 ms; achieved tflops 3237.753 TFLOPs/sec; achieved tb_per_sec 1.606 TB/sec median time 0.009 ms; std 0.000 ms; achieved tflops 938.356 TFLOPs/sec; achieved tb_per_sec 2.069 TB/sec ## Autotune /flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --use_nvfp4 --refcheck --autotune 2025-11-11 23:43:23,715 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... 2025-11-11 23:43:25,789 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends 2025-11-11 23:43:25,790 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... 2025-11-11 23:43:26,251 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends 2025-11-11 23:43:26,251 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... 2025-11-11 23:43:26,327 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends 2025-11-11 23:43:26,327 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... 2025-11-11 23:43:26,335 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends [PERF] cudnn_autotune :: median time 0.016 ms; std 0.000 ms; achieved tflops 4129.171 TFLOPs/sec; achieved tb_per_sec 2.048 TB/sec [PERF] cutlass_autotun:: median time 0.019 ms; std 0.000 ms; achieved tflops 3513.845 TFLOPs/sec; achieved tb_per_sec 1.743 TB/sec [PERF] trtllm_autotune:: median time 0.026 ms; std 0.000 ms; achieved tflops 2613.338 TFLOPs/sec; achieved tb_per_sec 1.296 TB/sec [PERF] auto_autotune :: median time 0.016 ms; std 0.000 ms; achieved tflops 4128.768 TFLOPs/sec; achieved tb_per_sec 2.048 TB/sec /flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --refcheck --autotune [INFO] cutlass backend does not support this configuration: ValueError: Only cudnn and auto FP4 GEMM supports mxfp4 quantization. [INFO] trtllm backend does not support this configuration: ValueError: Only cudnn and auto FP4 GEMM supports mxfp4 quantization. 2025-11-11 23:43:37,942 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... 2025-11-11 23:43:43,116 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends 2025-11-11 23:43:43,116 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... 2025-11-11 23:43:43,124 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends [PERF] cudnn_autotune :: median time 0.020 ms; std 0.000 ms; achieved tflops 3370.154 TFLOPs/sec; achieved tb_per_sec 1.672 TB/sec [PERF] auto_autotune :: median time 0.020 ms; std 0.000 ms; achieved tflops 3370.692 TFLOPs/sec; achieved tb_per_sec 1.672 TB/sec ``` On SM100 (B200) CUDA 12 & cuDNN 9.15 ``` flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --use_nvfp4 --refcheck [PERF] cudnn :: median time 0.023 ms; std 0.001 ms; achieved tflops 2975.898 TFLOPs/sec; achieved tb_per_sec 1.476 TB/sec [PERF] cutlass :: median time 0.020 ms; std 0.000 ms; achieved tflops 3370.423 TFLOPs/sec; achieved tb_per_sec 1.672 TB/sec [PERF] trtllm :: median time 0.031 ms; std 0.000 ms; achieved tflops 2187.427 TFLOPs/sec; achieved tb_per_sec 1.085 TB/sec [PERF] auto :: median time 0.020 ms; std 0.000 ms; achieved tflops 3371.229 TFLOPs/sec; achieved tb_per_sec 1.672 TB/sec (py312) root@84ef83abb1b5:/flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --refcheck [INFO] cutlass backend does not support this configuration: ValueError: Only cudnn and auto FP4 GEMM supports mxfp4 quantization. [INFO] trtllm backend does not support this configuration: ValueError: Only cudnn and auto FP4 GEMM supports mxfp4 quantization. [PERF] cudnn :: median time 0.021 ms; std 0.000 ms; achieved tflops 3238.249 TFLOPs/sec; achieved tb_per_sec 1.606 TB/sec [PERF] auto :: median time 0.021 ms; std 0.000 ms; achieved tflops 3238.249 TFLOPs/sec; achieved tb_per_sec 1.606 TB/sec ## Autotune /flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --use_nvfp4 --refcheck --autotune 2025-11-11 23:42:43,378 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... 2025-11-11 23:42:45,451 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends 2025-11-11 23:42:45,451 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... 2025-11-11 23:42:45,910 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends 2025-11-11 23:42:45,910 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... 2025-11-11 23:42:45,986 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends 2025-11-11 23:42:45,986 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... 2025-11-11 23:42:45,993 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends [PERF] cudnn_autotune :: median time 0.021 ms; std 0.000 ms; achieved tflops 3190.355 TFLOPs/sec; achieved tb_per_sec 1.583 TB/sec [PERF] cutlass_autotun:: median time 0.019 ms; std 0.000 ms; achieved tflops 3551.330 TFLOPs/sec; achieved tb_per_sec 1.762 TB/sec [PERF] trtllm_autotune:: median time 0.026 ms; std 0.000 ms; achieved tflops 2621.440 TFLOPs/sec; achieved tb_per_sec 1.300 TB/sec [PERF] auto_autotune :: median time 0.019 ms; std 0.000 ms; achieved tflops 3551.628 TFLOPs/sec; achieved tb_per_sec 1.762 TB/sec flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --refcheck --autotune [INFO] cutlass backend does not support this configuration: ValueError: Only cudnn and auto FP4 GEMM supports mxfp4 quantization. [INFO] trtllm backend does not support this configuration: ValueError: Only cudnn and auto FP4 GEMM supports mxfp4 quantization. 2025-11-11 23:42:55,176 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... 2025-11-11 23:42:58,600 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends 2025-11-11 23:42:58,601 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... 2025-11-11 23:42:58,608 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends [PERF] cudnn_autotune :: median time 0.021 ms; std 0.000 ms; achieved tflops 3238.249 TFLOPs/sec; achieved tb_per_sec 1.606 TB/sec [PERF] auto_autotune :: median time 0.021 ms; std 0.000 ms; achieved tflops 3238.249 TFLOPs/sec; achieved tb_per_sec 1.606 TB/sec ``` On SM120 (RTX 5090) CUDA 13 & cuDNN 9.15 ``` /flashinfer/benchmarks$ python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --use_nvfp4 --refcheck [INFO] trtllm backend does not support this configuration: BackendSupportedError: mm_fp4 does not support backend 'trtllm' with capability 120 [PERF] cudnn :: median time 0.058 ms; std 0.000 ms; achieved tflops 1167.143 TFLOPs/sec; achieved tb_per_sec 0.579 TB/sec [PERF] cutlass :: median time 0.060 ms; std 0.000 ms; achieved tflops 1135.056 TFLOPs/sec; achieved tb_per_sec 0.563 TB/sec [PERF] auto :: median time 0.058 ms; std 0.000 ms; achieved tflops 1158.952 TFLOPs/sec; achieved tb_per_sec 0.575 TB/sec /flashinfer/benchmarks$ python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --refcheck [INFO] cutlass backend does not support this configuration: ValueError: Only cudnn and auto FP4 GEMM supports mxfp4 quantization. [INFO] trtllm backend does not support this configuration: BackendSupportedError: mm_fp4 does not support backend 'trtllm' with capability 120 [PERF] cudnn :: median time 0.054 ms; std 0.000 ms; achieved tflops 1241.735 TFLOPs/sec; achieved tb_per_sec 0.616 TB/sec [PERF] auto :: median time 0.054 ms; std 0.000 ms; achieved tflops 1241.735 TFLOPs/sec; achieved tb_per_sec 0.616 TB/sec ``` <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues #1722 <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * "auto" backend selection for FP4 ops to choose backend at runtime * cuDNN, CUTLASS and TRTLLM selectable as FP4 GEMM backends * CUDA/cuDNN version awareness to guide auto-backend heuristics * **Improvements** * Runtime capability checks replace static backend lists; unsupported backends are removed dynamically * Heuristic-driven auto-backend selection required for automatic mode * Expanded autotuning/warmup across backends and relaxed FP4 validation tolerance * **Tests** * Tests updated and added to exercise auto-backend scenarios and relaxed constraints <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 2628beb commit 0aee7af

File tree

6 files changed

+543
-381
lines changed

6 files changed

+543
-381
lines changed

benchmarks/routines/flashinfer_benchmark_utils.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -235,17 +235,7 @@ def dtype_str_to_torch_dtype(dtype_str):
235235
"10.3": ["cudnn", "cublas", "cutlass"],
236236
"12.0": ["cudnn", "cublas"],
237237
},
238-
"mm_fp4": {
239-
"7.5": [],
240-
"8.0": [],
241-
"8.6": [],
242-
"8.9": [],
243-
"9.0": [],
244-
"10.0": ["cudnn", "trtllm", "cutlass"],
245-
"10.3": ["cudnn", "trtllm", "cutlass"],
246-
"12.0": ["cudnn", "cutlass"],
247-
"12.1": ["cudnn", "cutlass"],
248-
},
238+
# Note: mm_fp4 uses support checkers to filter backends, so it is not listed here
249239
# MOE
250240
"trtllm_fp4_block_scale_moe": {
251241
"7.5": [],

benchmarks/routines/gemm.py

Lines changed: 50 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def parse_gemm_args(line, parser):
131131
required=False,
132132
nargs="+",
133133
default=["cudnn"],
134-
choices=["cudnn", "cublas", "trtllm", "cutlass"],
134+
choices=["cudnn", "cublas", "trtllm", "cutlass", "auto"],
135135
help="Kernel backends to test. Default: cudnn",
136136
)
137137
parser.add_argument(
@@ -790,61 +790,14 @@ def testMmFp4(args):
790790
run_refcheck = args.refcheck
791791
use_128x4_sf_layout = args.use_128x4_sf_layout
792792
use_nvfp4 = args.use_nvfp4
793-
autotune_supported_backends = ["cutlass", "trtllm"]
793+
autotune_supported_backends = ["cudnn", "cutlass", "trtllm", "auto"]
794794
res = []
795795

796-
backends = filter_backends_by_compute_capability(backends, args.routine, device)
797-
798796
res_dtype = dtype_str_to_torch_dtype(args.out_dtype)
799797
if res_dtype not in [torch.bfloat16, torch.float16]:
800798
raise ValueError(
801799
f"Unsupported res dtype: {res_dtype}. Supported dtypes are bfloat16 and float16."
802800
)
803-
## Done parsing input arguments
804-
805-
if "trtllm" in backends:
806-
remove_trtllm = False
807-
if res_dtype == torch.float16:
808-
print("[INFO] trtllm backend does not support float16 output")
809-
remove_trtllm = True
810-
if remove_trtllm:
811-
backends.remove("trtllm")
812-
if not use_nvfp4:
813-
print(
814-
"[INFO] trtllm backend does not support mxfp4 quantization (use_nvfp4=False)"
815-
)
816-
backends.remove("trtllm")
817-
if "cutlass" in backends:
818-
remove_cutlass = False
819-
if not use_128x4_sf_layout:
820-
print("[INFO] cutlass backend does not support use_128x4_sf_layout=False")
821-
remove_cutlass = True
822-
if not use_nvfp4:
823-
print(
824-
"[INFO] cutlass backend does not support mxfp4 quantization (use_nvfp4=False)"
825-
)
826-
backends.remove("cutlass")
827-
if remove_cutlass:
828-
backends.remove("cutlass")
829-
if "cudnn" in backends:
830-
remove_cudnn = False
831-
if not use_128x4_sf_layout:
832-
print("[INFO] cudnn backend does not support use_128x4_sf_layout=False")
833-
remove_cudnn = True
834-
if remove_cudnn:
835-
backends.remove("cudnn")
836-
if getattr(args, "autotune", False):
837-
backends_to_remove = []
838-
for cur_backend in backends:
839-
if cur_backend not in autotune_supported_backends:
840-
print(f"[INFO] {cur_backend} backend does not support autotune")
841-
backends_to_remove.append(cur_backend)
842-
for cur_backend in backends_to_remove:
843-
backends.remove(cur_backend)
844-
845-
if len(backends) == 0:
846-
print("[ERROR] No backends to test. Exiting.")
847-
return
848801

849802
input = torch.randn([m, k], device=device, dtype=torch.bfloat16)
850803
mat2 = torch.randn([n, k], device=device, dtype=torch.bfloat16)
@@ -886,11 +839,22 @@ def testMmFp4(args):
886839
print(f"[VVERBOSE] {mat2_fp4.dtype = }")
887840

888841
alpha = 1.0 / (global_sf_input * global_sf_mat2) if use_nvfp4 else None
889-
# res = torch.empty([m, n], device="cuda", dtype=res_dtype)
842+
# Completed preparing inputs. Now programmatically filter backends
843+
block_size = 16 if use_nvfp4 else 32
844+
backends_to_remove = []
890845

891-
def run_backend(backend):
892-
if backend in ["cudnn", "trtllm", "cutlass"]:
893-
return flashinfer.gemm.mm_fp4(
846+
for backend in backends:
847+
# Skip autotune check for now (handled separately below)
848+
if (
849+
getattr(args, "autotune", False)
850+
and backend not in autotune_supported_backends
851+
):
852+
print(f"[INFO] {backend} backend does not support autotune")
853+
backends_to_remove.append(backend)
854+
continue
855+
856+
try:
857+
flashinfer.gemm.mm_fp4(
894858
a=input_fp4,
895859
b=mat2_fp4.T if backend != "trtllm" else mat2_fp4_trtllm.T,
896860
a_descale=input_inv_s,
@@ -904,6 +868,34 @@ def run_backend(backend):
904868
backend=backend,
905869
use_nvfp4=use_nvfp4,
906870
)
871+
except Exception as e:
872+
print(
873+
f"[INFO] {backend} backend does not support this configuration: {type(e).__name__}: {e}"
874+
)
875+
backends_to_remove.append(backend)
876+
877+
# Remove unsupported backends
878+
for backend in backends_to_remove:
879+
backends.remove(backend)
880+
881+
if len(backends) == 0:
882+
print("[ERROR] No backends passed validation. Exiting.")
883+
return
884+
885+
def run_backend(backend):
886+
if backend in ["cudnn", "trtllm", "cutlass", "auto"]:
887+
return flashinfer.gemm.mm_fp4(
888+
a=input_fp4,
889+
b=mat2_fp4.T if backend != "trtllm" else mat2_fp4_trtllm.T,
890+
a_descale=input_inv_s,
891+
b_descale=mat2_inv_s.T if backend != "trtllm" else mat2_inv_s_trtllm.T,
892+
alpha=alpha,
893+
out_dtype=res_dtype,
894+
block_size=block_size,
895+
use_8x4_sf_layout=not use_128x4_sf_layout,
896+
backend=backend,
897+
use_nvfp4=use_nvfp4,
898+
)
907899
else:
908900
raise ValueError(f"Unsupported backend: {backend}")
909901

@@ -917,12 +909,11 @@ def run_backend(backend):
917909
args.dry_run_iters if args.dry_run_iters and args.dry_run_iters > 0 else 10
918910
)
919911
for cur_backend in backends:
920-
if cur_backend in autotune_supported_backends:
921-
if args.verbose >= 1:
922-
print(f"[INFO] Autotune warmup for mm_fp4: {warmup_iters} iters")
923-
with autotune(True):
924-
for _ in range(warmup_iters):
925-
run_backend(cur_backend)
912+
if args.verbose >= 1:
913+
print(f"[INFO] Autotune warmup for mm_fp4: {warmup_iters} iters")
914+
with autotune(True):
915+
for _ in range(warmup_iters):
916+
run_backend(cur_backend)
926917

927918
# Storage for timing results and outputs
928919
backend_times = {backend: [] for backend in backends}

0 commit comments

Comments
 (0)