1717import functools
1818from enum import Enum
1919from types import SimpleNamespace
20- from typing import List , Literal , Optional , Tuple , cast
20+ from typing import List , Literal , Optional , Tuple
2121
2222from flashinfer .trtllm_low_latency_gemm import trtllm_low_latency_gemm
2323import torch
@@ -1989,16 +1989,48 @@ def _auto_gemm_fp4_requirement(
19891989 return False
19901990
19911991
1992- _mm_fp4_backend_checkers = {
1993- "cudnn" : _cudnn_gemm_fp4_requirement ,
1994- "trtllm" : _trtllm_gemm_fp4_requirement ,
1995- "cutlass" : _cutlass_gemm_fp4_requirement ,
1996- "auto" : _auto_gemm_fp4_requirement ,
1997- }
1992+ def _heuristic_func_mm_fp4 (
1993+ suitable_backends : List [str ],
1994+ a : torch .Tensor ,
1995+ b : torch .Tensor ,
1996+ a_descale : torch .Tensor ,
1997+ b_descale : torch .Tensor ,
1998+ alpha : Optional [torch .Tensor ] = None ,
1999+ out_dtype : torch .dtype = torch .bfloat16 ,
2000+ out : Optional [torch .Tensor ] = None ,
2001+ block_size : int = 16 ,
2002+ use_8x4_sf_layout : bool = False ,
2003+ backend : Literal ["cudnn" , "trtllm" , "cutlass" , "auto" ] = "cudnn" ,
2004+ use_nvfp4 : bool = True ,
2005+ ):
2006+ cuda_major , _ = get_cuda_version (a .device )
2007+ cc_major , cc_minor = get_compute_capability (a .device )
2008+ # If cuda version is 13 or greater:
2009+ # cudnn is more performant if cudnn version is 9.14 or greater.
2010+ if CUDNN_AVAILABLE and cuda_major >= 13 and cudnn .backend_version () >= 91400 :
2011+ candidate_backends = ("cudnn" , "cutlass" )
2012+ # Otherwise, prioritize cutlass
2013+ else :
2014+ candidate_backends = ("cutlass" , "cudnn" )
2015+
2016+ # Filter to only supported backends for this compute capability
2017+ # Note: The requirement function already validated that at least one backend is supported
2018+ heuristic_backends = []
2019+ for candidate in candidate_backends :
2020+ # mypy requires explicit type casting for the backend literal
2021+ if candidate in suitable_backends :
2022+ heuristic_backends .append (candidate )
2023+ return heuristic_backends
19982024
19992025
20002026@backend_requirement (
2001- backend_checks = _mm_fp4_backend_checkers , common_check = _check_mm_fp4_problem_size
2027+ {
2028+ "cudnn" : _cudnn_gemm_fp4_requirement ,
2029+ "trtllm" : _trtllm_gemm_fp4_requirement ,
2030+ "cutlass" : _cutlass_gemm_fp4_requirement ,
2031+ },
2032+ common_check = _check_mm_fp4_problem_size ,
2033+ heuristic_func = _heuristic_func_mm_fp4 ,
20022034)
20032035def mm_fp4 (
20042036 a : torch .Tensor ,
@@ -2010,7 +2042,7 @@ def mm_fp4(
20102042 out : Optional [torch .Tensor ] = None ,
20112043 block_size : int = 16 ,
20122044 use_8x4_sf_layout : bool = False ,
2013- backend : Literal ["cudnn" , "trtllm" , "cutlass" , "auto" ] = "auto " ,
2045+ backend : Literal ["cudnn" , "trtllm" , "cutlass" , "auto" ] = "cudnn " ,
20142046 use_nvfp4 : bool = True ,
20152047) -> torch .Tensor :
20162048 r"""MM FP4
@@ -2089,53 +2121,7 @@ def mm_fp4(
20892121
20902122 # Auto-select the best backend
20912123 if backend == "auto" :
2092- cuda_major , _ = get_cuda_version (a .device )
2093- cc_major , cc_minor = get_compute_capability (a .device )
2094- # If cuda version is 13 or greater:
2095- # cudnn is more performant if cudnn version is 9.14 or greater.
2096- if CUDNN_AVAILABLE and cuda_major >= 13 and cudnn .backend_version () >= 91400 :
2097- candidate_backends = ("cudnn" , "cutlass" )
2098- # Otherwise, prioritize cutlass
2099- else :
2100- candidate_backends = ("cutlass" , "cudnn" )
2101-
2102- # Filter to only supported backends for this compute capability
2103- # Note: The requirement function already validated that at least one backend is supported
2104- backends = []
2105- for candidate in candidate_backends :
2106- # mypy requires explicit type casting for the backend literal
2107- backend_literal = cast (Literal ["cudnn" , "trtllm" , "cutlass" ], candidate )
2108- try :
2109- # Check both common constraints and backend-specific requirements
2110- # to find all compatible backends for this problem instance
2111- if _check_mm_fp4_problem_size (
2112- a ,
2113- b ,
2114- a_descale ,
2115- b_descale ,
2116- alpha ,
2117- out_dtype ,
2118- out ,
2119- block_size ,
2120- use_8x4_sf_layout ,
2121- backend_literal ,
2122- use_nvfp4 ,
2123- ) and _mm_fp4_backend_checkers [candidate ](
2124- a ,
2125- b ,
2126- a_descale ,
2127- b_descale ,
2128- alpha ,
2129- out_dtype ,
2130- out ,
2131- block_size ,
2132- use_8x4_sf_layout ,
2133- backend_literal ,
2134- use_nvfp4 ,
2135- ):
2136- backends .append (candidate )
2137- except Exception :
2138- pass
2124+ backends = mm_fp4 .suitable_auto_backends
21392125 else :
21402126 backends = [backend ]
21412127
0 commit comments