Skip to content

Conversation

@bkryu
Copy link
Collaborator

@bkryu bkryu commented Oct 25, 2025

📌 Description

Current PR:

  • Introduces an auto backend to mm_fp4 that can be autotuned. It replaces cudnn as the default.
    • Implementation matches bmm_fp8's auto backend support.
  • Allows cudnn backend to be autotuned.
  • Added unit test test cases for backend=auto

Behavior of auto backend:

  • Examines CUDA version & cuDNN version and calls either cutlass or cudnn kernel backends. trtllm kernel is not considered due to a non-interchangeable interface with other backends.
    • auto backend therefore only supports inputs runnable by cutlass and/or `cudnn.
  • Non-autotuned behavior:
    • Constructs an ordered list of backends (cudnn, cutlass) or (cutlass, cudnn) where ordering is based on previous microbenchmark study results.
      • If CUDA 12 --> cutlass comes to front.
      • If CUDA 13 and cuDNN version < 9.15 --> cutlass comes front
      • If CUDA 13 and cuDNN version >= 9.15 --> cudnn comes front
    • If kernel is not available from a support check, it is removed from the list.
  • Autotune behavior:
    • If backend is explicitly provided --> Autotunes within the backend. Same as previous behavior, but now autotuning is supported for cudnn.
    • If backend='auto' --> Autotunes within and across backends (cudnn & cutlass) and chooses the best config of best backend. trtllm kernel is not considered
  • A lot of helper functions to mm_fp4 were refactored to enable cross-backend autotuning. Refactoring was done to match cross-backend autotune-enabled bmm_fp8 as a reference.

Pytest outputs

pytest tests/gemm/test_mm_fp4.py

  • SM100 (B200) CUDA 13 & cuDNN 9.15: 900 passed, 2532 skipped in 125.19s (0:02:05)
  • SM100 (B200) CUDA 12 & cuDNN 9.15: 900 passed, 2532 skipped in 125.67s (0:02:05)
  • SM120 (RTX 5090) CUDA 13 & cuDNN 9.15: 720 passed, 2712 skipped in 76.50s (0:01:16)

Example microbenchmark outputs:

On SM100 (B200) CUDA 13 & cuDNN 9.15

flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --use_nvfp4 --refcheck
[PERF] cudnn          :: median time 0.018 ms; std 0.000 ms; achieved tflops 3797.932 TFLOPs/sec; achieved tb_per_sec 1.884 TB/sec
[PERF] cutlass        :: median time 0.020 ms; std 0.000 ms; achieved tflops 3440.640 TFLOPs/sec; achieved tb_per_sec 1.707 TB/sec
[PERF] trtllm         :: median time 0.031 ms; std 0.000 ms; achieved tflops 2187.427 TFLOPs/sec; achieved tb_per_sec 1.085 TB/sec
[PERF] auto           :: median time 0.018 ms; std 0.000 ms; achieved tflops 3840.714 TFLOPs/sec; achieved tb_per_sec 1.905 TB/sec
/flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --refcheck
[INFO] cutlass backend does not support this configuration: ValueError: Only cudnn and auto FP4 GEMM supports mxfp4 quantization.
[INFO] trtllm backend does not support this configuration: ValueError: Only cudnn and auto FP4 GEMM supports mxfp4 quantization.
[PERF] cudnn          :: median time 0.021 ms; std 0.000 ms; achieved tflops 3238.249 TFLOPs/sec; achieved tb_per_sec 1.606 TB/sec
[PERF] auto           :: median time 0.021 ms; std 0.000 ms; achieved tflops 3237.753 TFLOPs/sec; achieved tb_per_sec 1.606 TB/sec median time 0.009 ms; std 0.000 ms; achieved tflops 938.356 TFLOPs/sec; achieved tb_per_sec 2.069 TB/sec

## Autotune
/flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --use_nvfp4 --refcheck --autotune
2025-11-11 23:43:23,715 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ...
2025-11-11 23:43:25,789 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends
2025-11-11 23:43:25,790 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ...
2025-11-11 23:43:26,251 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends
2025-11-11 23:43:26,251 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ...
2025-11-11 23:43:26,327 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends
2025-11-11 23:43:26,327 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ...
2025-11-11 23:43:26,335 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends
[PERF] cudnn_autotune :: median time 0.016 ms; std 0.000 ms; achieved tflops 4129.171 TFLOPs/sec; achieved tb_per_sec 2.048 TB/sec
[PERF] cutlass_autotun:: median time 0.019 ms; std 0.000 ms; achieved tflops 3513.845 TFLOPs/sec; achieved tb_per_sec 1.743 TB/sec
[PERF] trtllm_autotune:: median time 0.026 ms; std 0.000 ms; achieved tflops 2613.338 TFLOPs/sec; achieved tb_per_sec 1.296 TB/sec
[PERF] auto_autotune  :: median time 0.016 ms; std 0.000 ms; achieved tflops 4128.768 TFLOPs/sec; achieved tb_per_sec 2.048 TB/sec

/flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --refcheck --autotune
[INFO] cutlass backend does not support this configuration: ValueError: Only cudnn and auto FP4 GEMM supports mxfp4 quantization.
[INFO] trtllm backend does not support this configuration: ValueError: Only cudnn and auto FP4 GEMM supports mxfp4 quantization.
2025-11-11 23:43:37,942 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ...
2025-11-11 23:43:43,116 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends
2025-11-11 23:43:43,116 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ...
2025-11-11 23:43:43,124 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends
[PERF] cudnn_autotune :: median time 0.020 ms; std 0.000 ms; achieved tflops 3370.154 TFLOPs/sec; achieved tb_per_sec 1.672 TB/sec
[PERF] auto_autotune  :: median time 0.020 ms; std 0.000 ms; achieved tflops 3370.692 TFLOPs/sec; achieved tb_per_sec 1.672 TB/sec

On SM100 (B200) CUDA 12 & cuDNN 9.15

flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --use_nvfp4 --refcheck
[PERF] cudnn          :: median time 0.023 ms; std 0.001 ms; achieved tflops 2975.898 TFLOPs/sec; achieved tb_per_sec 1.476 TB/sec
[PERF] cutlass        :: median time 0.020 ms; std 0.000 ms; achieved tflops 3370.423 TFLOPs/sec; achieved tb_per_sec 1.672 TB/sec
[PERF] trtllm         :: median time 0.031 ms; std 0.000 ms; achieved tflops 2187.427 TFLOPs/sec; achieved tb_per_sec 1.085 TB/sec
[PERF] auto           :: median time 0.020 ms; std 0.000 ms; achieved tflops 3371.229 TFLOPs/sec; achieved tb_per_sec 1.672 TB/sec
(py312) root@84ef83abb1b5:/flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --refcheck
[INFO] cutlass backend does not support this configuration: ValueError: Only cudnn and auto FP4 GEMM supports mxfp4 quantization.
[INFO] trtllm backend does not support this configuration: ValueError: Only cudnn and auto FP4 GEMM supports mxfp4 quantization.
[PERF] cudnn          :: median time 0.021 ms; std 0.000 ms; achieved tflops 3238.249 TFLOPs/sec; achieved tb_per_sec 1.606 TB/sec
[PERF] auto           :: median time 0.021 ms; std 0.000 ms; achieved tflops 3238.249 TFLOPs/sec; achieved tb_per_sec 1.606 TB/sec

## Autotune
/flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --use_nvfp4 --refcheck --autotune
2025-11-11 23:42:43,378 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ...
2025-11-11 23:42:45,451 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends
2025-11-11 23:42:45,451 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ...
2025-11-11 23:42:45,910 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends
2025-11-11 23:42:45,910 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ...
2025-11-11 23:42:45,986 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends
2025-11-11 23:42:45,986 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ...
2025-11-11 23:42:45,993 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends
[PERF] cudnn_autotune :: median time 0.021 ms; std 0.000 ms; achieved tflops 3190.355 TFLOPs/sec; achieved tb_per_sec 1.583 TB/sec
[PERF] cutlass_autotun:: median time 0.019 ms; std 0.000 ms; achieved tflops 3551.330 TFLOPs/sec; achieved tb_per_sec 1.762 TB/sec
[PERF] trtllm_autotune:: median time 0.026 ms; std 0.000 ms; achieved tflops 2621.440 TFLOPs/sec; achieved tb_per_sec 1.300 TB/sec
[PERF] auto_autotune  :: median time 0.019 ms; std 0.000 ms; achieved tflops 3551.628 TFLOPs/sec; achieved tb_per_sec 1.762 TB/sec
flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --refcheck --autotune
[INFO] cutlass backend does not support this configuration: ValueError: Only cudnn and auto FP4 GEMM supports mxfp4 quantization.
[INFO] trtllm backend does not support this configuration: ValueError: Only cudnn and auto FP4 GEMM supports mxfp4 quantization.
2025-11-11 23:42:55,176 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ...
2025-11-11 23:42:58,600 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends
2025-11-11 23:42:58,601 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ...
2025-11-11 23:42:58,608 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends
[PERF] cudnn_autotune :: median time 0.021 ms; std 0.000 ms; achieved tflops 3238.249 TFLOPs/sec; achieved tb_per_sec 1.606 TB/sec
[PERF] auto_autotune  :: median time 0.021 ms; std 0.000 ms; achieved tflops 3238.249 TFLOPs/sec; achieved tb_per_sec 1.606 TB/sec

On SM120 (RTX 5090) CUDA 13 & cuDNN 9.15

/flashinfer/benchmarks$ python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --use_nvfp4 --refcheck
[INFO] trtllm backend does not support this configuration: BackendSupportedError: mm_fp4 does not support backend 'trtllm' with capability 120
[PERF] cudnn          :: median time 0.058 ms; std 0.000 ms; achieved tflops 1167.143 TFLOPs/sec; achieved tb_per_sec 0.579 TB/sec
[PERF] cutlass        :: median time 0.060 ms; std 0.000 ms; achieved tflops 1135.056 TFLOPs/sec; achieved tb_per_sec 0.563 TB/sec
[PERF] auto           :: median time 0.058 ms; std 0.000 ms; achieved tflops 1158.952 TFLOPs/sec; achieved tb_per_sec 0.575 TB/sec
/flashinfer/benchmarks$ python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --refcheck
[INFO] cutlass backend does not support this configuration: ValueError: Only cudnn and auto FP4 GEMM supports mxfp4 quantization.
[INFO] trtllm backend does not support this configuration: BackendSupportedError: mm_fp4 does not support backend 'trtllm' with capability 120
[PERF] cudnn          :: median time 0.054 ms; std 0.000 ms; achieved tflops 1241.735 TFLOPs/sec; achieved tb_per_sec 0.616 TB/sec
[PERF] auto           :: median time 0.054 ms; std 0.000 ms; achieved tflops 1241.735 TFLOPs/sec; achieved tb_per_sec 0.616 TB/sec

🔍 Related Issues

#1722

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • "auto" backend selection for FP4 matrix ops; command-line backend option now accepts "auto"
    • cuDNN, CUTLASS and TRTLLM offered as selectable FP4 GEMM backends; backend choice guided by CUDA/cuDNN heuristics
  • Improvements

    • Runtime capability checks replace static backend lists; unsupported backends are pruned dynamically
    • Heuristic-driven auto selection required for automatic mode; autotune/warmup expanded and validation tolerance relaxed
  • Tests

    • Tests extended to cover auto-backend scenarios and relaxed constraints

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 25, 2025

Walkthrough

Replaces static mm_fp4 backend listings with runtime support checks and an "auto" backend selector; adds cuDNN/CUTLASS FP4 runner factories, tactic-aware graph execution, runtime backend validation/pruning in benchmarks, CLI --backends "auto" choice, and test coverage for auto backend paths.

Changes

Cohort / File(s) Summary
FP4 GEMM core
flashinfer/gemm/gemm_base.py
Added CUDA-version use; introduced cuDNN FP4 graph/runner helpers (_get_cudnn_fp4_gemm_graph, _cudnn_gemm_fp4, _cudnn_gemm_fp4_runner); replaced direct CUTLASS op exposure with cutlass_fp4_gemm_runner and get_cutlass_fp4_gemm_module(sm_major); added _heuristic_func_mm_fp4; updated mm_fp4(...) signature to accept backend: Literal["cudnn","trtllm","cutlass","auto"]; added tactic-aware build/execute params and lazy runner initialization.
Backend utils
flashinfer/utils.py
Added heuristic_func: Optional[Callable] = None to backend_requirement; require/provide heuristic for auto selection and apply it in suitable_auto_backends.
Benchmark utilities
benchmarks/routines/flashinfer_benchmark_utils.py
Removed static mm_fp4 entry from routine_cc_to_supported_backends; added comment that mm_fp4 relies on runtime support checkers instead of static per-CC lists.
Benchmark GEMM routine
benchmarks/routines/gemm.py
Extended CLI --backends choices to include "auto"; removed static CC prefiltering; added runtime per-backend capability checks with try/except removal of unsupported backends; unified autotune/warmup across remaining backends; adjusted backend invocation and relaxed cross-backend cosine similarity threshold to 0.97.
FP4 tests
tests/gemm/test_mm_fp4.py
Extracted core test logic to _test_mm_fp4(...); made test_mm_fp4() a wrapper and added test_mm_fp4_backend_auto(...) for backend="auto"; relaxed some backend constraints and enabled separate auto-backend test path.
Decorator tests
tests/utils/test_decorators.py
Updated @backend_requirement usage to pass backend_checks={...} and a new heuristic_func that computes/prioritizes suitable auto backends based on inputs.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant MM as mm_fp4(auto)
    participant Heu as _heuristic_func_mm_fp4
    participant RunnerC as cudnn_runner
    participant RunnerU as cutlass_runner
    participant Bench as Benchmark/Test

    User->>MM: call mm_fp4(..., backend="auto")
    MM->>Heu: evaluate shapes/CUDA/cuDNN to rank candidates
    Heu-->>MM: ordered backend candidates
    loop try candidates (lazy init)
        MM->>RunnerC: init & capability trial
        RunnerC-->>MM: success / fail
        MM->>RunnerU: init & capability trial
        RunnerU-->>MM: success / fail
    end
    alt some backends failed
        MM->>MM: prune unsupported backends
    end
    MM->>MM: autotune & warmup across remaining backends
    Bench->>MM: run cross-backend validation (cosine >= 0.97)
    MM-->>User: return execution result from chosen backend
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Focus areas:

  • flashinfer/gemm/gemm_base.py: heuristic selection, lazy runner factories, cuDNN graph/tactic handling, public API signature changes.
  • benchmarks/routines/gemm.py: runtime validation/pruning, autotune/warmup flow, handling when all backends are dropped.
  • tests/gemm/test_mm_fp4.py and tests/utils/test_decorators.py: correctness of new wrappers, heuristic usage, and altered assertions/thresholds.

Possibly related PRs

Suggested reviewers

  • cyx-6
  • wenscarl
  • aleozlx
  • Anerudhan
  • yzh119

Poem

🐇 I hop through code with eager paws,
I scent the best backend through heuristics' laws.
cudnn or Cutlass, or "auto" flair,
I warm, I tune, then pick with care.
Benchmarks hum — the rabbit’s there!

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 20.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately and concisely describes the main changes: adding backend='auto' to mm_fp4 and enabling autotune for cudnn backend.
Description check ✅ Passed The description comprehensively covers what the PR does, includes detailed behavior explanations, test results, benchmark outputs, and completes all required checklist items.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@bkryu bkryu changed the title feat: Add backend='auto' to mm_fp4 feat: Add backend='auto' to mm_fp4 and enable autotune for backend='cudnn' Oct 25, 2025
@bkryu bkryu changed the title feat: Add backend='auto' to mm_fp4 and enable autotune for backend='cudnn' feat: [DRAFT] Add backend='auto' to mm_fp4 and enable autotune for backend='cudnn' Oct 25, 2025
@bkryu bkryu changed the title feat: [DRAFT] Add backend='auto' to mm_fp4 and enable autotune for backend='cudnn' [wip] feat: Add backend='auto' to mm_fp4 and enable autotune for backend='cudnn' Oct 25, 2025
@bkryu bkryu self-assigned this Oct 27, 2025
@bkryu bkryu force-pushed the mm_fp4_auto_backend branch 4 times, most recently from bc94c4c to 254827a Compare October 30, 2025 17:19
@bkryu bkryu marked this pull request as ready for review October 30, 2025 17:26
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
flashinfer/gemm.py (1)

2096-2134: Consider extracting auto-backend selection into a helper function.

The auto-backend selection logic (lines 2096-2134) is complex and involves:

  1. CUDA/cuDNN version inspection
  2. Backend ordering heuristics
  3. Problem size validation
  4. Exception handling for unsupported configurations

This logic could benefit from extraction into a dedicated helper function (e.g., _select_mm_fp4_backends) to improve readability and testability.

Additionally, the bare except Exception at lines 2131-2132 might hide unexpected errors. Consider either:

  1. Catching more specific exceptions (e.g., ValueError, RuntimeError)
  2. Adding logging to track which backends fail validation and why

Example refactoring:

def _select_mm_fp4_backends(
    cuda_major: int,
    cudnn_version: int,
    a: torch.Tensor,
    b: torch.Tensor,
    a_descale: torch.Tensor,
    b_descale: torch.Tensor,
    alpha: Optional[torch.Tensor],
    out_dtype: torch.dtype,
    out: torch.Tensor,
    block_size: int,
    use_8x4_sf_layout: bool,
    use_nvfp4: bool,
) -> List[str]:
    """Select supported backends for mm_fp4 based on device capabilities."""
    # Backend ordering heuristics
    if cuda_major >= 13 and cudnn_version >= 91400:
        candidate_backends = ("cudnn", "cutlass")
    else:
        candidate_backends = ("cutlass", "cudnn")
    
    # Filter by problem size support
    backends = []
    for candidate in candidate_backends:
        try:
            _check_mm_fp4_problem_size(
                a, b, a_descale, b_descale, alpha, out_dtype,
                out, block_size, use_8x4_sf_layout,
                cast(Literal["cudnn", "trtllm", "cutlass", "auto"], candidate),
                use_nvfp4,
            )
            backends.append(candidate)
        except (ValueError, RuntimeError):
            pass  # Backend not supported for this problem
    
    return backends
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b9287c9 and 254827a.

📒 Files selected for processing (4)
  • benchmarks/routines/flashinfer_benchmark_utils.py (1 hunks)
  • benchmarks/routines/gemm.py (5 hunks)
  • flashinfer/gemm.py (17 hunks)
  • tests/gemm/test_mm_fp4.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/gemm.py (4)
flashinfer/jit/cpp_ext.py (1)
  • get_cuda_version (64-83)
flashinfer/autotuner.py (9)
  • TunableRunner (194-247)
  • get_valid_tactics (196-214)
  • OptimizationProfile (168-183)
  • forward (220-244)
  • AutoTuner (335-784)
  • get (362-365)
  • TuningConfig (101-141)
  • choose_one (400-529)
  • get_opt_shapes (177-183)
flashinfer/trtllm_low_latency_gemm.py (2)
  • get_valid_tactics (52-77)
  • forward (79-109)
flashinfer/utils.py (4)
  • supported_compute_capability (772-852)
  • get_compute_capability (251-254)
  • is_compute_capability_supported (966-972)
  • backend_requirement (855-1028)
🪛 Ruff (0.14.2)
flashinfer/gemm.py

96-96: Unused function argument: device

(ARG001)


432-432: Unused method argument: inputs

(ARG002)


433-433: Unused method argument: profile

(ARG002)


441-441: Unused method argument: do_preparation

(ARG002)


442-442: Unused method argument: kwargs

(ARG002)


1722-1722: Unused method argument: profile

(ARG002)


1733-1733: Unpacked variable out is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


1736-1736: Unpacked variable workspace_buffer is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


1772-1772: Unused method argument: do_preparation

(ARG002)


1773-1773: Unused method argument: kwargs

(ARG002)


1855-1855: Avoid specifying long messages outside the exception class

(TRY003)


1876-1876: Unused function argument: backend

(ARG001)


1934-1934: Unused function argument: backend

(ARG001)


1956-1956: Unused function argument: backend

(ARG001)


1957-1957: Unused function argument: use_nvfp4

(ARG001)


1965-1965: Unused function argument: b

(ARG001)


1966-1966: Unused function argument: a_descale

(ARG001)


1967-1967: Unused function argument: b_descale

(ARG001)


1968-1968: Unused function argument: alpha

(ARG001)


1969-1969: Unused function argument: out_dtype

(ARG001)


1970-1970: Unused function argument: out

(ARG001)


1971-1971: Unused function argument: block_size

(ARG001)


1972-1972: Unused function argument: use_8x4_sf_layout

(ARG001)


1973-1973: Unused function argument: backend

(ARG001)


1974-1974: Unused function argument: use_nvfp4

(ARG001)


2099-2099: Unpacked variable cc_major is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


2099-2099: Unpacked variable cc_minor is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


2131-2132: try-except-pass detected, consider logging the exception

(S110)


2131-2131: Do not catch blind exception: Exception

(BLE001)


2163-2163: Avoid specifying long messages outside the exception class

(TRY003)


2509-2509: Unpacked variable a_descale is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


2510-2510: Unpacked variable b_descale is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


2511-2511: Unpacked variable alpha is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


2513-2513: Unpacked variable out is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


2516-2516: Unpacked variable workspace_buffer is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


2530-2530: Unused method argument: do_preparation

(ARG002)


2531-2531: Unused method argument: kwargs

(ARG002)

🔇 Additional comments (12)
benchmarks/routines/flashinfer_benchmark_utils.py (1)

241-243: LGTM! Auto backend addition is correct.

The addition of "auto" to the supported backends list for mm_fp4 at compute capabilities 10.0, 10.3, and 12.0 is consistent with the PR objectives and aligns with the auto-backend implementation in flashinfer/gemm.py.

benchmarks/routines/gemm.py (2)

134-134: LGTM! Backend choices updated correctly.

The addition of "auto" to the --backends argument choices is consistent with the auto-backend support introduced in this PR.


793-793: LGTM! Auto backend support properly integrated.

The changes correctly:

  1. Add "auto" to the list of autotune-supported backends for mm_fp4
  2. Implement backend filtering logic for "auto" that respects the use_128x4_sf_layout constraint
  3. Include "auto" in the run_backend execution path

The filtering logic at lines 836-842 appropriately mirrors the filtering done for other backends (cudnn, cutlass) and ensures "auto" is removed when layout constraints aren't met.

Also applies to: 836-842, 899-899

flashinfer/gemm.py (7)

425-465: LGTM! Runner refactoring improves consistency.

The refactoring of CUTLASS FP4 GEMM into cutlass_fp4_gemm_runner with the helper function _create_cutlass_fp4_gemm_module improves naming consistency and aligns with the pattern used for other runners (e.g., trtllm_fp4_gemm_runner).


1270-1294: LGTM! cuDNN tactic support enables fine-grained autotuning.

The addition of tactic parameter to build_plans_cudnn_fp4_gemm_graph and execute_cudnn_gemm_fp4_graph enables plan-specific execution for autotuning. The logic correctly:

  • Builds a specific plan when tactic != -1
  • Builds all plans when tactic == -1 (fallback)
  • Executes the selected plan or uses default execution

This aligns with the autotuning framework's expectations and follows the pattern established by other tunable runners.

Also applies to: 1306-1331


1665-1802: LGTM! cuDNN FP4 runner properly implements TunableRunner interface.

The new _cudnn_gemm_fp4 and _cudnn_gemm_fp4_runner functions correctly:

  1. Encapsulate cuDNN FP4 GEMM execution with tactic support
  2. Implement the TunableRunner interface with get_valid_tactics and forward methods
  3. Query available execution plans from the cuDNN graph
  4. Support tactic-specific execution for autotuning

The implementation follows the established pattern for tunable runners and integrates well with the autotuning framework.


1962-1997: LGTM! Auto backend requirement validation is well-implemented.

The _auto_gemm_fp4_requirement function correctly validates that the "auto" backend can be used by:

  1. Checking compute capability support for candidate backends (cudnn, cutlass)
  2. Explicitly excluding trtllm due to its different interface (as documented in the PR description)
  3. Returning True if at least one backend is supported

The implementation ensures that "auto" will only be accepted on devices where at least one compatible backend is available.


2136-2163: LGTM! Runner construction logic handles all backend cases correctly.

The runner construction for each backend (cudnn, trtllm, cutlass) correctly:

  1. Creates appropriate runner instances based on backend type
  2. Handles dtype conversions for cutlass backend (uint8 ↔ float8_e4m3fn)
  3. Dispatches to the correct module based on device architecture (SM120 vs SM100/SM103)
  4. Falls through to a clear error for unsupported backends

The logic is well-structured and handles all supported backend configurations.


2165-2217: LGTM! Autotuning integration is well-structured.

The autotuning setup correctly:

  1. Defines dynamic tensor specs for batch size variation (power-of-2 bucketing)
  2. Sets constraint specs to maintain shape relationships
  3. Prepares input tensors in the expected format
  4. Uses AutoTuner.choose_one to select the best (runner, tactic) combination
  5. Executes the chosen runner with the selected tactic

The integration follows the established autotuning framework patterns and enables cross-backend tuning when backend="auto".


2487-2563: LGTM! TRTLLM FP4 runner refactoring enables autotuning.

The refactoring of trtllm_fp4_gemm_runner to:

  1. Accept use_8x4_sf_layout as a parameter
  2. Implement the TunableRunner interface with tactic support
  3. Return a properly configured runner instance

This change aligns the TRTLLM backend with the autotuning framework and maintains consistency with other FP4 runners. The implementation correctly handles the use_8x4_sf_layout parameter throughout the runner lifecycle.

tests/gemm/test_mm_fp4.py (2)

15-95: LGTM! Test refactoring improves maintainability.

Extracting the test logic into _test_mm_fp4 is a good refactoring that:

  1. Eliminates code duplication between test functions
  2. Makes the test logic reusable and easier to maintain
  3. Consolidates backend support checks and skip conditions

The updated skip condition at lines 34-35 correctly limits mxfp4 support to cudnn and auto backends, which aligns with the implementation in flashinfer/gemm.py.


97-127: LGTM! Test split provides good coverage of auto backend.

The split between test_mm_fp4 (non-auto backends) and test_mm_fp4_backend_auto (auto backend) is well-designed:

  1. test_mm_fp4 maintains full parameter coverage for individual backends
  2. test_mm_fp4_backend_auto tests the auto backend with a reduced but representative parameter space
  3. The reduced parameter space (fewer m/n/k combinations, only use_128x4_sf_layout=True) is appropriate for auto backend testing and helps keep test execution time reasonable

This approach provides comprehensive coverage while avoiding combinatorial explosion of test cases.

"[INFO] cutlass backend does not support mxfp4 quantization (use_nvfp4=False)"
)
backends.remove("cutlass")
remove_cutlass = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another way to avoid these remove_backend_x bools is to call the related backend check (which should be annoated with the decorator), or have the decorator return a filtered list as I proposed. #2000 (comment)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regardless whether you stuff it into the decorator, this will be a pattern that will happen for all APIs, so we should think about encapsulating the "if backend and checks_dont_pass: filter_it_out".

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good idea. I have removed these hard coded checks entirely and have started using the checkers in the latest

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now this is addressed with the latest decorator update #2029

)
# Auto-select the best backend
if backend == "auto":
cuda_major, _ = get_cuda_version(a.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These checks should be part of the _auto_gemm_fp4_requirement check.

I think a cleaner way would be to move the generation of the list of candidate_backends in the @backend_requirement decorator, where "auto" backend is treated specially. It lists the required checks for each backend already. An alternative is that we create a separate decorator that composes and uses the backend checks of the backend_requirement

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The danger here is that we may be repeating some checks, but not all of them.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When writing the code path for this PR, I noted that the following questions had to be answered at different times by the auto backend logic:

  1. Is there at least one runnable backend for the given input params -- for early error raising
  2. What are the runnable backends for the given input params -- to consider which backends to choose from
  3. In the current GPU/CUDA/cuDNN environment, what is the preferred ordering of backends -- for heuristics

The current implementation in the PR answers 1 in @backend_requirement and 2 & 3 in the body of the mm_fp4 while you're suggesting putting 2 inside @backend_requirement. I agree that this helps us avoid repeating checks but this will involve--as you raised--a special treatment for the auto backend and a change to backend_requirement. We can discuss

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now this is addressed with the latest decorator update #2029

candidate_backends = ("cutlass", "cudnn")

# Filter to only supported backends for this compute capability
# Note: The requirement function already validated that at least one backend is supported
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is the dangerous part: at this point, we know 1 backend replied that its check is ok. But we are considering all backends. Maybe cudnn supports it but not trtllm or cutlass.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are correct here. In the latest commit, I now check whether the backend is supported generally + for the inputs

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now this is addressed with the latest decorator update #2029

for candidate in candidate_backends:
# mypy requires explicit type casting for the backend literal
backend_literal = cast(
Literal["cudnn", "trtllm", "cutlass", "auto"], candidate
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is auto added back?

Copy link
Collaborator Author

@bkryu bkryu Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Auto is actually not being added here since the cast() is telling pre-commit tests that backend_literal will be one of ["cudnn", "trtllm", "cutlass", "auto"] while candidate_backends will never contain auto.

However, there is no need for auto to be there and I can see it being confusing so I have removed in the latest commit

)
elif cur_backend == "cutlass":
if a.dtype == torch.uint8 and a_descale.dtype == torch.float8_e4m3fn:
a_descale = a_descale.view(torch.uint8)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like an implementation detail, and maybe needs to be moved to the cutlass runner itself, just like we do with the cudnn_runner.

Copy link
Collaborator Author

@bkryu bkryu Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree and this allows removal of the if-then-else structure above. Updated in latest commit

@bkryu bkryu force-pushed the mm_fp4_auto_backend branch from 254827a to c9f3d52 Compare October 31, 2025 18:09
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 254827a and c9f3d52.

📒 Files selected for processing (4)
  • benchmarks/routines/flashinfer_benchmark_utils.py (1 hunks)
  • benchmarks/routines/gemm.py (4 hunks)
  • flashinfer/gemm.py (17 hunks)
  • tests/gemm/test_mm_fp4.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
benchmarks/routines/gemm.py (2)
flashinfer/gemm.py (1)
  • _check_mm_fp4_problem_size (1812-1870)
flashinfer/autotuner.py (1)
  • autotune (251-262)
flashinfer/gemm.py (2)
flashinfer/jit/cpp_ext.py (1)
  • get_cuda_version (64-83)
flashinfer/utils.py (4)
  • supported_compute_capability (772-852)
  • get_compute_capability (251-254)
  • is_compute_capability_supported (966-972)
  • backend_requirement (855-1028)
🪛 Ruff (0.14.2)
benchmarks/routines/gemm.py

900-900: Do not catch blind exception: Exception

(BLE001)

flashinfer/gemm.py

96-96: Unused function argument: device

(ARG001)


436-436: Unused method argument: inputs

(ARG002)


437-437: Unused method argument: profile

(ARG002)


445-445: Unused method argument: do_preparation

(ARG002)


446-446: Unused method argument: kwargs

(ARG002)


1730-1730: Unused method argument: profile

(ARG002)


1741-1741: Unpacked variable out is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


1744-1744: Unpacked variable workspace_buffer is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


1780-1780: Unused method argument: do_preparation

(ARG002)


1781-1781: Unused method argument: kwargs

(ARG002)


1863-1863: Avoid specifying long messages outside the exception class

(TRY003)


1884-1884: Unused function argument: backend

(ARG001)


1942-1942: Unused function argument: backend

(ARG001)


1964-1964: Unused function argument: backend

(ARG001)


1965-1965: Unused function argument: use_nvfp4

(ARG001)


1973-1973: Unused function argument: b

(ARG001)


1974-1974: Unused function argument: a_descale

(ARG001)


1975-1975: Unused function argument: b_descale

(ARG001)


1976-1976: Unused function argument: alpha

(ARG001)


1977-1977: Unused function argument: out_dtype

(ARG001)


1978-1978: Unused function argument: out

(ARG001)


1979-1979: Unused function argument: block_size

(ARG001)


1980-1980: Unused function argument: use_8x4_sf_layout

(ARG001)


1981-1981: Unused function argument: backend

(ARG001)


1982-1982: Unused function argument: use_nvfp4

(ARG001)


2109-2109: Unpacked variable cc_major is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


2109-2109: Unpacked variable cc_minor is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


2153-2154: try-except-pass detected, consider logging the exception

(S110)


2153-2153: Do not catch blind exception: Exception

(BLE001)


2517-2517: Unpacked variable a_descale is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


2518-2518: Unpacked variable b_descale is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


2519-2519: Unpacked variable alpha is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


2521-2521: Unpacked variable out is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


2524-2524: Unpacked variable workspace_buffer is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


2538-2538: Unused method argument: do_preparation

(ARG002)


2539-2539: Unused method argument: kwargs

(ARG002)

@bkryu bkryu marked this pull request as draft October 31, 2025 22:19
@bkryu bkryu force-pushed the mm_fp4_auto_backend branch 2 times, most recently from fd4dfd6 to 1b9ffbf Compare November 11, 2025 22:14
@bkryu
Copy link
Collaborator Author

bkryu commented Nov 11, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !128 has been created, and the CI pipeline #38303944 is currently running. I'll report back once the pipeline job completes.

@bkryu bkryu marked this pull request as ready for review November 11, 2025 23:51
@bkryu bkryu requested a review from jiahanc as a code owner November 11, 2025 23:51
@bkryu bkryu changed the title [wip] feat: Add backend='auto' to mm_fp4 and enable autotune for backend='cudnn' feat: Add backend='auto' to mm_fp4 and enable autotune for backend='cudnn' Nov 11, 2025
@bkryu bkryu requested a review from nvmbreughe November 11, 2025 23:52
@bkryu
Copy link
Collaborator Author

bkryu commented Nov 12, 2025

/bot stop

@flashinfer-bot
Copy link
Collaborator

The GitLab CI pipeline #38303944 has been cancelled.

@flashinfer-bot
Copy link
Collaborator

GitLab MR !128 has been updated with latest changes, and the CI pipeline #38894324 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
tests/utils/test_decorators.py (1)

349-349: Remove unnecessary initialization.

The candidate_backends variable is initialized to None but immediately overwritten in both branches of the conditional. You can remove this line.

Apply this diff:

     def _heuristic_func(suitable_backends, x, backend):
-        candidate_backends = None
         if x.shape[0] > 5:
             candidate_backends = ["cudnn", "cutlass"]
         else:
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f2bf2ec and c19ffc4.

📒 Files selected for processing (1)
  • tests/utils/test_decorators.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/utils/test_decorators.py (1)
flashinfer/utils.py (1)
  • backend_requirement (897-1179)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (1)
tests/utils/test_decorators.py (1)

361-384: LGTM!

The updated decorator usage correctly demonstrates the new heuristic-driven auto backend selection. The test cases properly verify:

  • Backend filtering based on requirement checks and compute capability
  • Heuristic ordering of suitable backends
  • Error handling when no suitable backends are found

The logic correctly traces through both successful and failing scenarios.

@bkryu
Copy link
Collaborator Author

bkryu commented Nov 20, 2025

/bot stop

@flashinfer-bot
Copy link
Collaborator

The GitLab CI pipeline #38894324 has been cancelled.

@bkryu
Copy link
Collaborator Author

bkryu commented Nov 20, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !128 has been created, and the CI pipeline #38894904 is currently running. I'll report back once the pipeline job completes.

@yzh119 yzh119 enabled auto-merge (squash) November 20, 2025 23:32
@bkryu bkryu force-pushed the mm_fp4_auto_backend branch from c19ffc4 to fe2070b Compare November 21, 2025 01:38
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
tests/gemm/test_mm_fp4.py (1)

15-25: backend='auto' tests are always skipped due to is_backend_supported misuse.

mm_fp4.is_backend_supported("auto", cc) returns False because "auto" is not a real backend key in backend_checks. As a result, _test_mm_fp4 skips all test_mm_fp4_backend_auto parameterizations, so the new auto-backend tests never actually run.

You’re getting “coverage” numbers without exercising the backend="auto" path.

A minimal fix is to avoid using is_backend_supported for the synthetic "auto" backend and let the decorator logic drive error handling:

-    compute_capability = get_compute_capability(torch.device(device="cuda"))
-    compute_capability_number = compute_capability[0] * 10 + compute_capability[1]
-    if not mm_fp4.is_backend_supported(backend, compute_capability_number):
-        pytest.skip(
-            f"Skipping test for {backend} because it is not supported on compute capability {compute_capability_number}."
-        )
+    compute_capability = get_compute_capability(torch.device(device="cuda"))
+    compute_capability_number = compute_capability[0] * 10 + compute_capability[1]
+    # For concrete backends, pre-skip unsupported CCs. For backend='auto', rely on
+    # the decorator to raise `BackendSupportedError` if no candidate backend exists.
+    if backend != "auto" and not mm_fp4.is_backend_supported(
+        backend, compute_capability_number
+    ):
+        pytest.skip(
+            f"Skipping test for {backend} because it is not supported on compute capability {compute_capability_number}."
+        )

Optionally, if you want an early skip for auto as well, you can instead check mm_fp4.is_compute_capability_supported(compute_capability_number).

This will ensure test_mm_fp4_backend_auto actually exercises both the heuristic and autotune behavior.

Also applies to: 124-127

flashinfer/gemm/gemm_base.py (1)

1979-2015: Fix tensor index mismatch in TRTLLM FP4 autotuning.

The issue is confirmed. In TrtllmFp4GemmRunner.get_valid_tactics() (lines 2553–2620), using a_tensor_index = 1 and b_tensor_index = 2 incorrectly maps to b and a_descale respectively, since the mm_fp4 inputs list places a at index 0 and b at index 1.

Correct this by setting a_tensor_index = 0 and b_tensor_index = 1:

             def get_valid_tactics(
                 self,
                 inputs: List[torch.Tensor],
                 profile: OptimizationProfile,
             ) -> List[int]:
-                a_tensor_index = 1
-                b_tensor_index = 2
+                a_tensor_index = 0
+                b_tensor_index = 1
 
+                opt_shapes = profile.get_opt_shapes()
-                a = profile.get_opt_shapes()[a_tensor_index]
-                b = profile.get_opt_shapes()[b_tensor_index]
+                a = opt_shapes[a_tensor_index]
+                b = opt_shapes[b_tensor_index]
                 m = a[0]
                 n = b[0]
                 k = a[1] * 2

This ensures m, n, k passed to trtllm_gemm_tactics match the actual problem dimensions during autotuning.

🧹 Nitpick comments (3)
tests/utils/test_decorators.py (1)

347-359: Heuristic test helper is correct; consider minor cleanups to avoid shadowing.

The _heuristic_func implementation matches the new backend_requirement contract and exercises the auto-backend path correctly. One small readability nit is reusing backend as the loop variable, which shadows the function argument and slightly hurts clarity; you can also simplify candidate_backends/heuristic_backends construction.

For example:

-    def _heuristic_func(suitable_backends, x, backend):
-        candidate_backends = None
-        if x.shape[0] > 5:
-            candidate_backends = ["cudnn", "cutlass"]
-        else:
-            candidate_backends = ["cutlass", "cudnn"]
-
-        heuristic_backends = []
-        for backend in candidate_backends:
-            if backend in suitable_backends:
-                heuristic_backends.append(backend)
-        return heuristic_backends
+    def _heuristic_func(suitable_backends, x, backend):
+        if x.shape[0] > 5:
+            candidate_backends = ["cudnn", "cutlass"]
+        else:
+            candidate_backends = ["cutlass", "cudnn"]
+
+        return [b for b in candidate_backends if b in suitable_backends]

This keeps behavior identical while making the helper a bit clearer.

Also applies to: 361-367

flashinfer/utils.py (1)

924-930: heuristic_func semantics look good; prefer explicit error over assert for robustness.

The extended docstring clearly defines heuristic_func’s contract, and the change to always run it in suitable_auto_backends enforces the intended rule that any API exposing backend="auto" must supply a heuristic.

One concern: using

assert heuristic_func is not None, "Heuristic function must be provided"

means this safety net disappears under python -O, and callers would then silently fall back to unordered suitable_backends. It’s safer to raise an explicit error when backend="auto" is used without a heuristic, e.g.:

-            assert heuristic_func is not None, "Heuristic function must be provided"
-            suitable_backends = heuristic_func(suitable_backends, *args, **kwargs)
+            if heuristic_func is None:
+                raise RuntimeError(
+                    f"backend='auto' requires a heuristic_func for {func.__name__}"
+                )
+            suitable_backends = heuristic_func(suitable_backends, *args, **kwargs)

This keeps the behavior “loud” in all runtime modes while matching the intent discussed in earlier review threads.

Also applies to: 1086-1087

benchmarks/routines/gemm.py (1)

129-135: Runtime backend probing for mm_fp4 is sensible; narrow the catch to avoid hiding real bugs.

Letting testMmFp4 discover supported backends by actually calling flashinfer.gemm.mm_fp4 is a nice improvement over the old static CC map, and adding "auto" to --backends plus autotune_supported_backends = ["cudnn", "cutlass", "trtllm", "auto"] lines up with the new auto-backend behavior.

The main concern is this block:

for backend in backends:
    ...
    try:
        flashinfer.gemm.mm_fp4(...)
    except Exception as e:
        print(
            f"[INFO] {backend} backend does not support this configuration: {type(e).__name__}: {e}"
        )
        backends_to_remove.append(backend)

Catching bare Exception means any failure in the runner (logic bug, memory issue, etc.) is treated as “backend unsupported” and the benchmark keeps going, which can silently mask regressions.

Given the decorator and kernels raise well-typed errors on unsupported configs, you can be more precise here, e.g.:

-    try:
-        flashinfer.gemm.mm_fp4(...)
-    except Exception as e:
+    try:
+        flashinfer.gemm.mm_fp4(...)
+    except (LibraryError, BackendSupportedError, ValueError) as e:
         print(
             f"[INFO] {backend} backend does not support this configuration: {type(e).__name__}: {e}"
         )
         backends_to_remove.append(backend)
+    except Exception as e:
+        # Treat unexpected failures as real errors so they don't get hidden.
+        raise

This keeps the probing-based filtering but makes genuine bugs in mm_fp4 visible during benchmarking instead of being silently filtered away.

Also applies to: 793-799, 843-884, 907-917

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c19ffc4 and fe2070b.

📒 Files selected for processing (6)
  • benchmarks/routines/flashinfer_benchmark_utils.py (1 hunks)
  • benchmarks/routines/gemm.py (5 hunks)
  • flashinfer/gemm/gemm_base.py (18 hunks)
  • flashinfer/utils.py (2 hunks)
  • tests/gemm/test_mm_fp4.py (3 hunks)
  • tests/utils/test_decorators.py (1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • benchmarks/routines/gemm.py
  • flashinfer/gemm/gemm_base.py
🧬 Code graph analysis (3)
benchmarks/routines/gemm.py (2)
flashinfer/gemm/gemm_base.py (1)
  • mm_fp4 (2027-2186)
flashinfer/autotuner.py (1)
  • autotune (251-262)
tests/utils/test_decorators.py (1)
flashinfer/utils.py (1)
  • backend_requirement (897-1179)
flashinfer/gemm/gemm_base.py (2)
flashinfer/autotuner.py (7)
  • TunableRunner (194-247)
  • OptimizationProfile (168-183)
  • AutoTuner (335-786)
  • TuningConfig (101-141)
  • DynamicTensorSpec (41-82)
  • ConstraintSpec (86-97)
  • choose_one (400-529)
flashinfer/utils.py (3)
  • backend_requirement (897-1179)
  • suitable_auto_backends (1071-1091)
  • get_compute_capability (253-256)
🪛 Ruff (0.14.5)
benchmarks/routines/gemm.py

871-871: Do not catch blind exception: Exception

(BLE001)

flashinfer/gemm/gemm_base.py

417-417: Unused method argument: inputs

(ARG002)


418-418: Unused method argument: profile

(ARG002)


426-426: Unused method argument: do_preparation

(ARG002)


427-427: Unused method argument: kwargs

(ARG002)


483-483: Avoid specifying long messages outside the exception class

(TRY003)


1671-1671: Unused function argument: out

(ARG001)


1748-1748: Unused method argument: profile

(ARG002)


1761-1761: Unpacked variable workspace_buffer is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


1785-1785: Unused method argument: do_preparation

(ARG002)


1786-1786: Unused method argument: kwargs

(ARG002)


1824-1824: Unused function argument: out

(ARG001)


1826-1826: Unused function argument: use_8x4_sf_layout

(ARG001)


1827-1827: Unused function argument: backend

(ARG001)


1881-1881: Unused function argument: out

(ARG001)


1884-1884: Unused function argument: backend

(ARG001)


1888-1888: Avoid specifying long messages outside the exception class

(TRY003)


1936-1936: Unused function argument: a

(ARG001)


1937-1937: Unused function argument: b

(ARG001)


1938-1938: Unused function argument: a_descale

(ARG001)


1939-1939: Unused function argument: b_descale

(ARG001)


1940-1940: Unused function argument: alpha

(ARG001)


1942-1942: Unused function argument: out

(ARG001)


1943-1943: Unused function argument: block_size

(ARG001)


1944-1944: Unused function argument: use_8x4_sf_layout

(ARG001)


1945-1945: Unused function argument: backend

(ARG001)


1949-1949: Avoid specifying long messages outside the exception class

(TRY003)


1960-1960: Unused function argument: a

(ARG001)


1961-1961: Unused function argument: b

(ARG001)


1962-1962: Unused function argument: a_descale

(ARG001)


1963-1963: Unused function argument: b_descale

(ARG001)


1964-1964: Unused function argument: alpha

(ARG001)


1965-1965: Unused function argument: out_dtype

(ARG001)


1966-1966: Unused function argument: out

(ARG001)


1967-1967: Unused function argument: block_size

(ARG001)


1969-1969: Unused function argument: backend

(ARG001)


1973-1973: Avoid specifying long messages outside the exception class

(TRY003)


1975-1975: Avoid specifying long messages outside the exception class

(TRY003)


1990-1990: Unused function argument: backend

(ARG001)


1991-1991: Unused function argument: use_nvfp4

(ARG001)


2569-2569: Unpacked variable a_descale is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


2570-2570: Unpacked variable b_descale is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


2571-2571: Unpacked variable alpha is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


2573-2573: Unpacked variable out is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


2576-2576: Unpacked variable workspace_buffer is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


2590-2590: Unused method argument: do_preparation

(ARG002)


2591-2591: Unused method argument: kwargs

(ARG002)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (2)
benchmarks/routines/flashinfer_benchmark_utils.py (1)

238-239: Comment correctly reflects mm_fp4’s new dynamic support handling.

Not listing mm_fp4 in routine_cc_to_supported_backends and documenting that it relies on runtime support checkers aligns with the new decorator-based validation in mm_fp4. This avoids stale hard-coded backend lists in the benchmark helper.

flashinfer/gemm/gemm_base.py (1)

410-447: ****

The original review comment incorrectly states that a_descale.view(torch.uint8) will raise a TypeError at runtime. In current PyTorch, torch.Tensor.view does support a dtype overload that accepts a torch.dtype argument (e.g., x.view(torch.uint8)), which reinterprets the underlying data with the given dtype without copying. This is the correct idiomatic approach for the FP8-to-uint8 reinterpretation shown in CutlassFp4GemmRunner.forward.

The code is valid as written. No changes needed on this section.

Likely an incorrect or invalid review comment.

@yzh119 yzh119 merged commit 0aee7af into flashinfer-ai:main Nov 21, 2025
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants