-
Notifications
You must be signed in to change notification settings - Fork 573
feat: Add backend='auto' to mm_fp4 and enable autotune for backend='cudnn' #1979
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WalkthroughReplaces static mm_fp4 backend listings with runtime support checks and an "auto" backend selector; adds cuDNN/CUTLASS FP4 runner factories, tactic-aware graph execution, runtime backend validation/pruning in benchmarks, CLI Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant MM as mm_fp4(auto)
participant Heu as _heuristic_func_mm_fp4
participant RunnerC as cudnn_runner
participant RunnerU as cutlass_runner
participant Bench as Benchmark/Test
User->>MM: call mm_fp4(..., backend="auto")
MM->>Heu: evaluate shapes/CUDA/cuDNN to rank candidates
Heu-->>MM: ordered backend candidates
loop try candidates (lazy init)
MM->>RunnerC: init & capability trial
RunnerC-->>MM: success / fail
MM->>RunnerU: init & capability trial
RunnerU-->>MM: success / fail
end
alt some backends failed
MM->>MM: prune unsupported backends
end
MM->>MM: autotune & warmup across remaining backends
Bench->>MM: run cross-backend validation (cosine >= 0.97)
MM-->>User: return execution result from chosen backend
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Focus areas:
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
bc94c4c to
254827a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
flashinfer/gemm.py (1)
2096-2134: Consider extracting auto-backend selection into a helper function.The auto-backend selection logic (lines 2096-2134) is complex and involves:
- CUDA/cuDNN version inspection
- Backend ordering heuristics
- Problem size validation
- Exception handling for unsupported configurations
This logic could benefit from extraction into a dedicated helper function (e.g.,
_select_mm_fp4_backends) to improve readability and testability.Additionally, the bare
except Exceptionat lines 2131-2132 might hide unexpected errors. Consider either:
- Catching more specific exceptions (e.g.,
ValueError,RuntimeError)- Adding logging to track which backends fail validation and why
Example refactoring:
def _select_mm_fp4_backends( cuda_major: int, cudnn_version: int, a: torch.Tensor, b: torch.Tensor, a_descale: torch.Tensor, b_descale: torch.Tensor, alpha: Optional[torch.Tensor], out_dtype: torch.dtype, out: torch.Tensor, block_size: int, use_8x4_sf_layout: bool, use_nvfp4: bool, ) -> List[str]: """Select supported backends for mm_fp4 based on device capabilities.""" # Backend ordering heuristics if cuda_major >= 13 and cudnn_version >= 91400: candidate_backends = ("cudnn", "cutlass") else: candidate_backends = ("cutlass", "cudnn") # Filter by problem size support backends = [] for candidate in candidate_backends: try: _check_mm_fp4_problem_size( a, b, a_descale, b_descale, alpha, out_dtype, out, block_size, use_8x4_sf_layout, cast(Literal["cudnn", "trtllm", "cutlass", "auto"], candidate), use_nvfp4, ) backends.append(candidate) except (ValueError, RuntimeError): pass # Backend not supported for this problem return backends
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
benchmarks/routines/flashinfer_benchmark_utils.py(1 hunks)benchmarks/routines/gemm.py(5 hunks)flashinfer/gemm.py(17 hunks)tests/gemm/test_mm_fp4.py(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/gemm.py (4)
flashinfer/jit/cpp_ext.py (1)
get_cuda_version(64-83)flashinfer/autotuner.py (9)
TunableRunner(194-247)get_valid_tactics(196-214)OptimizationProfile(168-183)forward(220-244)AutoTuner(335-784)get(362-365)TuningConfig(101-141)choose_one(400-529)get_opt_shapes(177-183)flashinfer/trtllm_low_latency_gemm.py (2)
get_valid_tactics(52-77)forward(79-109)flashinfer/utils.py (4)
supported_compute_capability(772-852)get_compute_capability(251-254)is_compute_capability_supported(966-972)backend_requirement(855-1028)
🪛 Ruff (0.14.2)
flashinfer/gemm.py
96-96: Unused function argument: device
(ARG001)
432-432: Unused method argument: inputs
(ARG002)
433-433: Unused method argument: profile
(ARG002)
441-441: Unused method argument: do_preparation
(ARG002)
442-442: Unused method argument: kwargs
(ARG002)
1722-1722: Unused method argument: profile
(ARG002)
1733-1733: Unpacked variable out is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
1736-1736: Unpacked variable workspace_buffer is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
1772-1772: Unused method argument: do_preparation
(ARG002)
1773-1773: Unused method argument: kwargs
(ARG002)
1855-1855: Avoid specifying long messages outside the exception class
(TRY003)
1876-1876: Unused function argument: backend
(ARG001)
1934-1934: Unused function argument: backend
(ARG001)
1956-1956: Unused function argument: backend
(ARG001)
1957-1957: Unused function argument: use_nvfp4
(ARG001)
1965-1965: Unused function argument: b
(ARG001)
1966-1966: Unused function argument: a_descale
(ARG001)
1967-1967: Unused function argument: b_descale
(ARG001)
1968-1968: Unused function argument: alpha
(ARG001)
1969-1969: Unused function argument: out_dtype
(ARG001)
1970-1970: Unused function argument: out
(ARG001)
1971-1971: Unused function argument: block_size
(ARG001)
1972-1972: Unused function argument: use_8x4_sf_layout
(ARG001)
1973-1973: Unused function argument: backend
(ARG001)
1974-1974: Unused function argument: use_nvfp4
(ARG001)
2099-2099: Unpacked variable cc_major is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2099-2099: Unpacked variable cc_minor is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2131-2132: try-except-pass detected, consider logging the exception
(S110)
2131-2131: Do not catch blind exception: Exception
(BLE001)
2163-2163: Avoid specifying long messages outside the exception class
(TRY003)
2509-2509: Unpacked variable a_descale is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2510-2510: Unpacked variable b_descale is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2511-2511: Unpacked variable alpha is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2513-2513: Unpacked variable out is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2516-2516: Unpacked variable workspace_buffer is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2530-2530: Unused method argument: do_preparation
(ARG002)
2531-2531: Unused method argument: kwargs
(ARG002)
🔇 Additional comments (12)
benchmarks/routines/flashinfer_benchmark_utils.py (1)
241-243: LGTM! Auto backend addition is correct.The addition of "auto" to the supported backends list for
mm_fp4at compute capabilities 10.0, 10.3, and 12.0 is consistent with the PR objectives and aligns with the auto-backend implementation inflashinfer/gemm.py.benchmarks/routines/gemm.py (2)
134-134: LGTM! Backend choices updated correctly.The addition of "auto" to the --backends argument choices is consistent with the auto-backend support introduced in this PR.
793-793: LGTM! Auto backend support properly integrated.The changes correctly:
- Add "auto" to the list of autotune-supported backends for
mm_fp4- Implement backend filtering logic for "auto" that respects the
use_128x4_sf_layoutconstraint- Include "auto" in the
run_backendexecution pathThe filtering logic at lines 836-842 appropriately mirrors the filtering done for other backends (cudnn, cutlass) and ensures "auto" is removed when layout constraints aren't met.
Also applies to: 836-842, 899-899
flashinfer/gemm.py (7)
425-465: LGTM! Runner refactoring improves consistency.The refactoring of CUTLASS FP4 GEMM into
cutlass_fp4_gemm_runnerwith the helper function_create_cutlass_fp4_gemm_moduleimproves naming consistency and aligns with the pattern used for other runners (e.g.,trtllm_fp4_gemm_runner).
1270-1294: LGTM! cuDNN tactic support enables fine-grained autotuning.The addition of
tacticparameter tobuild_plans_cudnn_fp4_gemm_graphandexecute_cudnn_gemm_fp4_graphenables plan-specific execution for autotuning. The logic correctly:
- Builds a specific plan when
tactic != -1- Builds all plans when
tactic == -1(fallback)- Executes the selected plan or uses default execution
This aligns with the autotuning framework's expectations and follows the pattern established by other tunable runners.
Also applies to: 1306-1331
1665-1802: LGTM! cuDNN FP4 runner properly implements TunableRunner interface.The new
_cudnn_gemm_fp4and_cudnn_gemm_fp4_runnerfunctions correctly:
- Encapsulate cuDNN FP4 GEMM execution with tactic support
- Implement the
TunableRunnerinterface withget_valid_tacticsandforwardmethods- Query available execution plans from the cuDNN graph
- Support tactic-specific execution for autotuning
The implementation follows the established pattern for tunable runners and integrates well with the autotuning framework.
1962-1997: LGTM! Auto backend requirement validation is well-implemented.The
_auto_gemm_fp4_requirementfunction correctly validates that the "auto" backend can be used by:
- Checking compute capability support for candidate backends (cudnn, cutlass)
- Explicitly excluding trtllm due to its different interface (as documented in the PR description)
- Returning True if at least one backend is supported
The implementation ensures that "auto" will only be accepted on devices where at least one compatible backend is available.
2136-2163: LGTM! Runner construction logic handles all backend cases correctly.The runner construction for each backend (cudnn, trtllm, cutlass) correctly:
- Creates appropriate runner instances based on backend type
- Handles dtype conversions for cutlass backend (uint8 ↔ float8_e4m3fn)
- Dispatches to the correct module based on device architecture (SM120 vs SM100/SM103)
- Falls through to a clear error for unsupported backends
The logic is well-structured and handles all supported backend configurations.
2165-2217: LGTM! Autotuning integration is well-structured.The autotuning setup correctly:
- Defines dynamic tensor specs for batch size variation (power-of-2 bucketing)
- Sets constraint specs to maintain shape relationships
- Prepares input tensors in the expected format
- Uses
AutoTuner.choose_oneto select the best (runner, tactic) combination- Executes the chosen runner with the selected tactic
The integration follows the established autotuning framework patterns and enables cross-backend tuning when
backend="auto".
2487-2563: LGTM! TRTLLM FP4 runner refactoring enables autotuning.The refactoring of
trtllm_fp4_gemm_runnerto:
- Accept
use_8x4_sf_layoutas a parameter- Implement the
TunableRunnerinterface with tactic support- Return a properly configured runner instance
This change aligns the TRTLLM backend with the autotuning framework and maintains consistency with other FP4 runners. The implementation correctly handles the
use_8x4_sf_layoutparameter throughout the runner lifecycle.tests/gemm/test_mm_fp4.py (2)
15-95: LGTM! Test refactoring improves maintainability.Extracting the test logic into
_test_mm_fp4is a good refactoring that:
- Eliminates code duplication between test functions
- Makes the test logic reusable and easier to maintain
- Consolidates backend support checks and skip conditions
The updated skip condition at lines 34-35 correctly limits mxfp4 support to cudnn and auto backends, which aligns with the implementation in
flashinfer/gemm.py.
97-127: LGTM! Test split provides good coverage of auto backend.The split between
test_mm_fp4(non-auto backends) andtest_mm_fp4_backend_auto(auto backend) is well-designed:
test_mm_fp4maintains full parameter coverage for individual backendstest_mm_fp4_backend_autotests the auto backend with a reduced but representative parameter space- The reduced parameter space (fewer m/n/k combinations, only
use_128x4_sf_layout=True) is appropriate for auto backend testing and helps keep test execution time reasonableThis approach provides comprehensive coverage while avoiding combinatorial explosion of test cases.
benchmarks/routines/gemm.py
Outdated
| "[INFO] cutlass backend does not support mxfp4 quantization (use_nvfp4=False)" | ||
| ) | ||
| backends.remove("cutlass") | ||
| remove_cutlass = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another way to avoid these remove_backend_x bools is to call the related backend check (which should be annoated with the decorator), or have the decorator return a filtered list as I proposed. #2000 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regardless whether you stuff it into the decorator, this will be a pattern that will happen for all APIs, so we should think about encapsulating the "if backend and checks_dont_pass: filter_it_out".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good idea. I have removed these hard coded checks entirely and have started using the checkers in the latest
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now this is addressed with the latest decorator update #2029
flashinfer/gemm/gemm_base.py
Outdated
| ) | ||
| # Auto-select the best backend | ||
| if backend == "auto": | ||
| cuda_major, _ = get_cuda_version(a.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These checks should be part of the _auto_gemm_fp4_requirement check.
I think a cleaner way would be to move the generation of the list of candidate_backends in the @backend_requirement decorator, where "auto" backend is treated specially. It lists the required checks for each backend already. An alternative is that we create a separate decorator that composes and uses the backend checks of the backend_requirement
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The danger here is that we may be repeating some checks, but not all of them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When writing the code path for this PR, I noted that the following questions had to be answered at different times by the auto backend logic:
- Is there at least one runnable backend for the given input params -- for early error raising
- What are the runnable backends for the given input params -- to consider which backends to choose from
- In the current GPU/CUDA/cuDNN environment, what is the preferred ordering of backends -- for heuristics
The current implementation in the PR answers 1 in @backend_requirement and 2 & 3 in the body of the mm_fp4 while you're suggesting putting 2 inside @backend_requirement. I agree that this helps us avoid repeating checks but this will involve--as you raised--a special treatment for the auto backend and a change to backend_requirement. We can discuss
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now this is addressed with the latest decorator update #2029
flashinfer/gemm/gemm_base.py
Outdated
| candidate_backends = ("cutlass", "cudnn") | ||
|
|
||
| # Filter to only supported backends for this compute capability | ||
| # Note: The requirement function already validated that at least one backend is supported |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this is the dangerous part: at this point, we know 1 backend replied that its check is ok. But we are considering all backends. Maybe cudnn supports it but not trtllm or cutlass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are correct here. In the latest commit, I now check whether the backend is supported generally + for the inputs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now this is addressed with the latest decorator update #2029
flashinfer/gemm.py
Outdated
| for candidate in candidate_backends: | ||
| # mypy requires explicit type casting for the backend literal | ||
| backend_literal = cast( | ||
| Literal["cudnn", "trtllm", "cutlass", "auto"], candidate |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is auto added back?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Auto is actually not being added here since the cast() is telling pre-commit tests that backend_literal will be one of ["cudnn", "trtllm", "cutlass", "auto"] while candidate_backends will never contain auto.
However, there is no need for auto to be there and I can see it being confusing so I have removed in the latest commit
flashinfer/gemm.py
Outdated
| ) | ||
| elif cur_backend == "cutlass": | ||
| if a.dtype == torch.uint8 and a_descale.dtype == torch.float8_e4m3fn: | ||
| a_descale = a_descale.view(torch.uint8) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like an implementation detail, and maybe needs to be moved to the cutlass runner itself, just like we do with the cudnn_runner.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree and this allows removal of the if-then-else structure above. Updated in latest commit
254827a to
c9f3d52
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
benchmarks/routines/flashinfer_benchmark_utils.py(1 hunks)benchmarks/routines/gemm.py(4 hunks)flashinfer/gemm.py(17 hunks)tests/gemm/test_mm_fp4.py(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
benchmarks/routines/gemm.py (2)
flashinfer/gemm.py (1)
_check_mm_fp4_problem_size(1812-1870)flashinfer/autotuner.py (1)
autotune(251-262)
flashinfer/gemm.py (2)
flashinfer/jit/cpp_ext.py (1)
get_cuda_version(64-83)flashinfer/utils.py (4)
supported_compute_capability(772-852)get_compute_capability(251-254)is_compute_capability_supported(966-972)backend_requirement(855-1028)
🪛 Ruff (0.14.2)
benchmarks/routines/gemm.py
900-900: Do not catch blind exception: Exception
(BLE001)
flashinfer/gemm.py
96-96: Unused function argument: device
(ARG001)
436-436: Unused method argument: inputs
(ARG002)
437-437: Unused method argument: profile
(ARG002)
445-445: Unused method argument: do_preparation
(ARG002)
446-446: Unused method argument: kwargs
(ARG002)
1730-1730: Unused method argument: profile
(ARG002)
1741-1741: Unpacked variable out is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
1744-1744: Unpacked variable workspace_buffer is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
1780-1780: Unused method argument: do_preparation
(ARG002)
1781-1781: Unused method argument: kwargs
(ARG002)
1863-1863: Avoid specifying long messages outside the exception class
(TRY003)
1884-1884: Unused function argument: backend
(ARG001)
1942-1942: Unused function argument: backend
(ARG001)
1964-1964: Unused function argument: backend
(ARG001)
1965-1965: Unused function argument: use_nvfp4
(ARG001)
1973-1973: Unused function argument: b
(ARG001)
1974-1974: Unused function argument: a_descale
(ARG001)
1975-1975: Unused function argument: b_descale
(ARG001)
1976-1976: Unused function argument: alpha
(ARG001)
1977-1977: Unused function argument: out_dtype
(ARG001)
1978-1978: Unused function argument: out
(ARG001)
1979-1979: Unused function argument: block_size
(ARG001)
1980-1980: Unused function argument: use_8x4_sf_layout
(ARG001)
1981-1981: Unused function argument: backend
(ARG001)
1982-1982: Unused function argument: use_nvfp4
(ARG001)
2109-2109: Unpacked variable cc_major is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2109-2109: Unpacked variable cc_minor is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2153-2154: try-except-pass detected, consider logging the exception
(S110)
2153-2153: Do not catch blind exception: Exception
(BLE001)
2517-2517: Unpacked variable a_descale is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2518-2518: Unpacked variable b_descale is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2519-2519: Unpacked variable alpha is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2521-2521: Unpacked variable out is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2524-2524: Unpacked variable workspace_buffer is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2538-2538: Unused method argument: do_preparation
(ARG002)
2539-2539: Unused method argument: kwargs
(ARG002)
fd4dfd6 to
1b9ffbf
Compare
|
/bot run |
|
/bot stop |
|
The GitLab CI pipeline #38303944 has been cancelled. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tests/utils/test_decorators.py (1)
349-349: Remove unnecessary initialization.The
candidate_backendsvariable is initialized toNonebut immediately overwritten in both branches of the conditional. You can remove this line.Apply this diff:
def _heuristic_func(suitable_backends, x, backend): - candidate_backends = None if x.shape[0] > 5: candidate_backends = ["cudnn", "cutlass"] else:
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tests/utils/test_decorators.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/utils/test_decorators.py (1)
flashinfer/utils.py (1)
backend_requirement(897-1179)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (1)
tests/utils/test_decorators.py (1)
361-384: LGTM!The updated decorator usage correctly demonstrates the new heuristic-driven auto backend selection. The test cases properly verify:
- Backend filtering based on requirement checks and compute capability
- Heuristic ordering of suitable backends
- Error handling when no suitable backends are found
The logic correctly traces through both successful and failing scenarios.
|
/bot stop |
|
The GitLab CI pipeline #38894324 has been cancelled. |
|
/bot run |
…uto, but no cudnn autotune
c19ffc4 to
fe2070b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tests/gemm/test_mm_fp4.py (1)
15-25:backend='auto'tests are always skipped due tois_backend_supportedmisuse.
mm_fp4.is_backend_supported("auto", cc)returnsFalsebecause"auto"is not a real backend key inbackend_checks. As a result,_test_mm_fp4skips alltest_mm_fp4_backend_autoparameterizations, so the new auto-backend tests never actually run.You’re getting “coverage” numbers without exercising the
backend="auto"path.A minimal fix is to avoid using
is_backend_supportedfor the synthetic"auto"backend and let the decorator logic drive error handling:- compute_capability = get_compute_capability(torch.device(device="cuda")) - compute_capability_number = compute_capability[0] * 10 + compute_capability[1] - if not mm_fp4.is_backend_supported(backend, compute_capability_number): - pytest.skip( - f"Skipping test for {backend} because it is not supported on compute capability {compute_capability_number}." - ) + compute_capability = get_compute_capability(torch.device(device="cuda")) + compute_capability_number = compute_capability[0] * 10 + compute_capability[1] + # For concrete backends, pre-skip unsupported CCs. For backend='auto', rely on + # the decorator to raise `BackendSupportedError` if no candidate backend exists. + if backend != "auto" and not mm_fp4.is_backend_supported( + backend, compute_capability_number + ): + pytest.skip( + f"Skipping test for {backend} because it is not supported on compute capability {compute_capability_number}." + )Optionally, if you want an early skip for auto as well, you can instead check
mm_fp4.is_compute_capability_supported(compute_capability_number).This will ensure
test_mm_fp4_backend_autoactually exercises both the heuristic and autotune behavior.Also applies to: 124-127
flashinfer/gemm/gemm_base.py (1)
1979-2015: Fix tensor index mismatch in TRTLLM FP4 autotuning.The issue is confirmed. In
TrtllmFp4GemmRunner.get_valid_tactics()(lines 2553–2620), usinga_tensor_index = 1andb_tensor_index = 2incorrectly maps tobanda_descalerespectively, since themm_fp4inputs list placesaat index 0 andbat index 1.Correct this by setting
a_tensor_index = 0andb_tensor_index = 1:def get_valid_tactics( self, inputs: List[torch.Tensor], profile: OptimizationProfile, ) -> List[int]: - a_tensor_index = 1 - b_tensor_index = 2 + a_tensor_index = 0 + b_tensor_index = 1 + opt_shapes = profile.get_opt_shapes() - a = profile.get_opt_shapes()[a_tensor_index] - b = profile.get_opt_shapes()[b_tensor_index] + a = opt_shapes[a_tensor_index] + b = opt_shapes[b_tensor_index] m = a[0] n = b[0] k = a[1] * 2This ensures
m, n, kpassed totrtllm_gemm_tacticsmatch the actual problem dimensions during autotuning.
🧹 Nitpick comments (3)
tests/utils/test_decorators.py (1)
347-359: Heuristic test helper is correct; consider minor cleanups to avoid shadowing.The
_heuristic_funcimplementation matches the newbackend_requirementcontract and exercises the auto-backend path correctly. One small readability nit is reusingbackendas the loop variable, which shadows the function argument and slightly hurts clarity; you can also simplifycandidate_backends/heuristic_backendsconstruction.For example:
- def _heuristic_func(suitable_backends, x, backend): - candidate_backends = None - if x.shape[0] > 5: - candidate_backends = ["cudnn", "cutlass"] - else: - candidate_backends = ["cutlass", "cudnn"] - - heuristic_backends = [] - for backend in candidate_backends: - if backend in suitable_backends: - heuristic_backends.append(backend) - return heuristic_backends + def _heuristic_func(suitable_backends, x, backend): + if x.shape[0] > 5: + candidate_backends = ["cudnn", "cutlass"] + else: + candidate_backends = ["cutlass", "cudnn"] + + return [b for b in candidate_backends if b in suitable_backends]This keeps behavior identical while making the helper a bit clearer.
Also applies to: 361-367
flashinfer/utils.py (1)
924-930:heuristic_funcsemantics look good; prefer explicit error overassertfor robustness.The extended docstring clearly defines
heuristic_func’s contract, and the change to always run it insuitable_auto_backendsenforces the intended rule that any API exposingbackend="auto"must supply a heuristic.One concern: using
assert heuristic_func is not None, "Heuristic function must be provided"means this safety net disappears under
python -O, and callers would then silently fall back to unorderedsuitable_backends. It’s safer to raise an explicit error whenbackend="auto"is used without a heuristic, e.g.:- assert heuristic_func is not None, "Heuristic function must be provided" - suitable_backends = heuristic_func(suitable_backends, *args, **kwargs) + if heuristic_func is None: + raise RuntimeError( + f"backend='auto' requires a heuristic_func for {func.__name__}" + ) + suitable_backends = heuristic_func(suitable_backends, *args, **kwargs)This keeps the behavior “loud” in all runtime modes while matching the intent discussed in earlier review threads.
Also applies to: 1086-1087
benchmarks/routines/gemm.py (1)
129-135: Runtime backend probing for mm_fp4 is sensible; narrow the catch to avoid hiding real bugs.Letting
testMmFp4discover supported backends by actually callingflashinfer.gemm.mm_fp4is a nice improvement over the old static CC map, and adding"auto"to--backendsplusautotune_supported_backends = ["cudnn", "cutlass", "trtllm", "auto"]lines up with the new auto-backend behavior.The main concern is this block:
for backend in backends: ... try: flashinfer.gemm.mm_fp4(...) except Exception as e: print( f"[INFO] {backend} backend does not support this configuration: {type(e).__name__}: {e}" ) backends_to_remove.append(backend)Catching bare
Exceptionmeans any failure in the runner (logic bug, memory issue, etc.) is treated as “backend unsupported” and the benchmark keeps going, which can silently mask regressions.Given the decorator and kernels raise well-typed errors on unsupported configs, you can be more precise here, e.g.:
- try: - flashinfer.gemm.mm_fp4(...) - except Exception as e: + try: + flashinfer.gemm.mm_fp4(...) + except (LibraryError, BackendSupportedError, ValueError) as e: print( f"[INFO] {backend} backend does not support this configuration: {type(e).__name__}: {e}" ) backends_to_remove.append(backend) + except Exception as e: + # Treat unexpected failures as real errors so they don't get hidden. + raiseThis keeps the probing-based filtering but makes genuine bugs in mm_fp4 visible during benchmarking instead of being silently filtered away.
Also applies to: 793-799, 843-884, 907-917
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
benchmarks/routines/flashinfer_benchmark_utils.py(1 hunks)benchmarks/routines/gemm.py(5 hunks)flashinfer/gemm/gemm_base.py(18 hunks)flashinfer/utils.py(2 hunks)tests/gemm/test_mm_fp4.py(3 hunks)tests/utils/test_decorators.py(1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
benchmarks/routines/gemm.pyflashinfer/gemm/gemm_base.py
🧬 Code graph analysis (3)
benchmarks/routines/gemm.py (2)
flashinfer/gemm/gemm_base.py (1)
mm_fp4(2027-2186)flashinfer/autotuner.py (1)
autotune(251-262)
tests/utils/test_decorators.py (1)
flashinfer/utils.py (1)
backend_requirement(897-1179)
flashinfer/gemm/gemm_base.py (2)
flashinfer/autotuner.py (7)
TunableRunner(194-247)OptimizationProfile(168-183)AutoTuner(335-786)TuningConfig(101-141)DynamicTensorSpec(41-82)ConstraintSpec(86-97)choose_one(400-529)flashinfer/utils.py (3)
backend_requirement(897-1179)suitable_auto_backends(1071-1091)get_compute_capability(253-256)
🪛 Ruff (0.14.5)
benchmarks/routines/gemm.py
871-871: Do not catch blind exception: Exception
(BLE001)
flashinfer/gemm/gemm_base.py
417-417: Unused method argument: inputs
(ARG002)
418-418: Unused method argument: profile
(ARG002)
426-426: Unused method argument: do_preparation
(ARG002)
427-427: Unused method argument: kwargs
(ARG002)
483-483: Avoid specifying long messages outside the exception class
(TRY003)
1671-1671: Unused function argument: out
(ARG001)
1748-1748: Unused method argument: profile
(ARG002)
1761-1761: Unpacked variable workspace_buffer is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
1785-1785: Unused method argument: do_preparation
(ARG002)
1786-1786: Unused method argument: kwargs
(ARG002)
1824-1824: Unused function argument: out
(ARG001)
1826-1826: Unused function argument: use_8x4_sf_layout
(ARG001)
1827-1827: Unused function argument: backend
(ARG001)
1881-1881: Unused function argument: out
(ARG001)
1884-1884: Unused function argument: backend
(ARG001)
1888-1888: Avoid specifying long messages outside the exception class
(TRY003)
1936-1936: Unused function argument: a
(ARG001)
1937-1937: Unused function argument: b
(ARG001)
1938-1938: Unused function argument: a_descale
(ARG001)
1939-1939: Unused function argument: b_descale
(ARG001)
1940-1940: Unused function argument: alpha
(ARG001)
1942-1942: Unused function argument: out
(ARG001)
1943-1943: Unused function argument: block_size
(ARG001)
1944-1944: Unused function argument: use_8x4_sf_layout
(ARG001)
1945-1945: Unused function argument: backend
(ARG001)
1949-1949: Avoid specifying long messages outside the exception class
(TRY003)
1960-1960: Unused function argument: a
(ARG001)
1961-1961: Unused function argument: b
(ARG001)
1962-1962: Unused function argument: a_descale
(ARG001)
1963-1963: Unused function argument: b_descale
(ARG001)
1964-1964: Unused function argument: alpha
(ARG001)
1965-1965: Unused function argument: out_dtype
(ARG001)
1966-1966: Unused function argument: out
(ARG001)
1967-1967: Unused function argument: block_size
(ARG001)
1969-1969: Unused function argument: backend
(ARG001)
1973-1973: Avoid specifying long messages outside the exception class
(TRY003)
1975-1975: Avoid specifying long messages outside the exception class
(TRY003)
1990-1990: Unused function argument: backend
(ARG001)
1991-1991: Unused function argument: use_nvfp4
(ARG001)
2569-2569: Unpacked variable a_descale is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2570-2570: Unpacked variable b_descale is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2571-2571: Unpacked variable alpha is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2573-2573: Unpacked variable out is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2576-2576: Unpacked variable workspace_buffer is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2590-2590: Unused method argument: do_preparation
(ARG002)
2591-2591: Unused method argument: kwargs
(ARG002)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (2)
benchmarks/routines/flashinfer_benchmark_utils.py (1)
238-239: Comment correctly reflects mm_fp4’s new dynamic support handling.Not listing
mm_fp4inroutine_cc_to_supported_backendsand documenting that it relies on runtime support checkers aligns with the new decorator-based validation inmm_fp4. This avoids stale hard-coded backend lists in the benchmark helper.flashinfer/gemm/gemm_base.py (1)
410-447: ****The original review comment incorrectly states that
a_descale.view(torch.uint8)will raise aTypeErrorat runtime. In current PyTorch,torch.Tensor.viewdoes support a dtype overload that accepts atorch.dtypeargument (e.g.,x.view(torch.uint8)), which reinterprets the underlying data with the given dtype without copying. This is the correct idiomatic approach for the FP8-to-uint8 reinterpretation shown inCutlassFp4GemmRunner.forward.The code is valid as written. No changes needed on this section.
Likely an incorrect or invalid review comment.
📌 Description
Current PR:
autobackend tomm_fp4that can be autotuned. It replacescudnnas the default.bmm_fp8's auto backend support.cudnnbackend to be autotuned.Behavior of
autobackend:cutlassorcudnnkernel backends.trtllmkernel is not considered due to a non-interchangeable interface with other backends.autobackend therefore only supports inputs runnable bycutlassand/or `cudnn.backend='auto'--> Autotunes within and across backends (cudnn & cutlass) and chooses the best config of best backend.trtllmkernel is not consideredmm_fp4were refactored to enable cross-backend autotuning. Refactoring was done to match cross-backend autotune-enabledbmm_fp8as a reference.Pytest outputs
pytest tests/gemm/test_mm_fp4.py900 passed, 2532 skipped in 125.19s (0:02:05)900 passed, 2532 skipped in 125.67s (0:02:05)720 passed, 2712 skipped in 76.50s (0:01:16)Example microbenchmark outputs:
On SM100 (B200) CUDA 13 & cuDNN 9.15
On SM100 (B200) CUDA 12 & cuDNN 9.15
On SM120 (RTX 5090) CUDA 13 & cuDNN 9.15
🔍 Related Issues
#1722
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Improvements
Tests
✏️ Tip: You can customize this high-level summary in your review settings.