Skip to content

Commit 1b9ffbf

Browse files
committed
Rebase main
1 parent 038e46a commit 1b9ffbf

File tree

2 files changed

+55
-98
lines changed

2 files changed

+55
-98
lines changed

benchmarks/routines/gemm.py

Lines changed: 13 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -854,49 +854,20 @@ def testMmFp4(args):
854854
continue
855855

856856
try:
857-
from flashinfer.gemm import (
858-
_mm_fp4_backend_checkers,
859-
_check_mm_fp4_problem_size,
860-
)
861-
862-
# Choose correct tensors for this backend
863-
if backend == "trtllm":
864-
b_tensor = mat2_fp4_trtllm.T
865-
b_descale = mat2_inv_s_trtllm.T
866-
else:
867-
b_tensor = mat2_fp4.T
868-
b_descale = mat2_inv_s.T
869-
870-
# Validate common requirements
871-
_check_mm_fp4_problem_size(
872-
input_fp4,
873-
b_tensor,
874-
input_inv_s,
875-
b_descale,
876-
alpha,
877-
res_dtype,
878-
None, # out
879-
block_size,
880-
not use_128x4_sf_layout, # use_8x4_sf_layout
881-
backend,
882-
use_nvfp4,
857+
flashinfer.gemm.mm_fp4(
858+
a=input_fp4,
859+
b=mat2_fp4.T if backend != "trtllm" else mat2_fp4_trtllm.T,
860+
a_descale=input_inv_s,
861+
b_descale=mat2_inv_s.T if backend != "trtllm" else mat2_inv_s_trtllm.T,
862+
alpha=alpha,
863+
out_dtype=res_dtype,
864+
block_size=16
865+
if use_nvfp4
866+
else 32, # nvfp4 only supports 16; mxfp4 only supports 32.
867+
use_8x4_sf_layout=not use_128x4_sf_layout,
868+
backend=backend,
869+
use_nvfp4=use_nvfp4,
883870
)
884-
885-
# Validate backend-specific requirements
886-
if backend in _mm_fp4_backend_checkers:
887-
_mm_fp4_backend_checkers[backend](
888-
input_fp4,
889-
b_tensor,
890-
input_inv_s,
891-
b_descale,
892-
alpha,
893-
res_dtype,
894-
None, # out
895-
block_size,
896-
not use_128x4_sf_layout,
897-
backend,
898-
use_nvfp4,
899-
)
900871
except Exception as e:
901872
print(
902873
f"[INFO] {backend} backend does not support this configuration: {type(e).__name__}: {e}"

flashinfer/gemm/gemm_base.py

Lines changed: 42 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import functools
1818
from enum import Enum
1919
from types import SimpleNamespace
20-
from typing import List, Literal, Optional, Tuple, cast
20+
from typing import List, Literal, Optional, Tuple
2121

2222
from flashinfer.trtllm_low_latency_gemm import trtllm_low_latency_gemm
2323
import torch
@@ -1989,16 +1989,48 @@ def _auto_gemm_fp4_requirement(
19891989
return False
19901990

19911991

1992-
_mm_fp4_backend_checkers = {
1993-
"cudnn": _cudnn_gemm_fp4_requirement,
1994-
"trtllm": _trtllm_gemm_fp4_requirement,
1995-
"cutlass": _cutlass_gemm_fp4_requirement,
1996-
"auto": _auto_gemm_fp4_requirement,
1997-
}
1992+
def _heuristic_func_mm_fp4(
1993+
suitable_backends: List[str],
1994+
a: torch.Tensor,
1995+
b: torch.Tensor,
1996+
a_descale: torch.Tensor,
1997+
b_descale: torch.Tensor,
1998+
alpha: Optional[torch.Tensor] = None,
1999+
out_dtype: torch.dtype = torch.bfloat16,
2000+
out: Optional[torch.Tensor] = None,
2001+
block_size: int = 16,
2002+
use_8x4_sf_layout: bool = False,
2003+
backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "cudnn",
2004+
use_nvfp4: bool = True,
2005+
):
2006+
cuda_major, _ = get_cuda_version(a.device)
2007+
cc_major, cc_minor = get_compute_capability(a.device)
2008+
# If cuda version is 13 or greater:
2009+
# cudnn is more performant if cudnn version is 9.14 or greater.
2010+
if CUDNN_AVAILABLE and cuda_major >= 13 and cudnn.backend_version() >= 91400:
2011+
candidate_backends = ("cudnn", "cutlass")
2012+
# Otherwise, prioritize cutlass
2013+
else:
2014+
candidate_backends = ("cutlass", "cudnn")
2015+
2016+
# Filter to only supported backends for this compute capability
2017+
# Note: The requirement function already validated that at least one backend is supported
2018+
heuristic_backends = []
2019+
for candidate in candidate_backends:
2020+
# mypy requires explicit type casting for the backend literal
2021+
if candidate in suitable_backends:
2022+
heuristic_backends.append(candidate)
2023+
return heuristic_backends
19982024

19992025

20002026
@backend_requirement(
2001-
backend_checks=_mm_fp4_backend_checkers, common_check=_check_mm_fp4_problem_size
2027+
{
2028+
"cudnn": _cudnn_gemm_fp4_requirement,
2029+
"trtllm": _trtllm_gemm_fp4_requirement,
2030+
"cutlass": _cutlass_gemm_fp4_requirement,
2031+
},
2032+
common_check=_check_mm_fp4_problem_size,
2033+
heuristic_func=_heuristic_func_mm_fp4,
20022034
)
20032035
def mm_fp4(
20042036
a: torch.Tensor,
@@ -2010,7 +2042,7 @@ def mm_fp4(
20102042
out: Optional[torch.Tensor] = None,
20112043
block_size: int = 16,
20122044
use_8x4_sf_layout: bool = False,
2013-
backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto",
2045+
backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "cudnn",
20142046
use_nvfp4: bool = True,
20152047
) -> torch.Tensor:
20162048
r"""MM FP4
@@ -2089,53 +2121,7 @@ def mm_fp4(
20892121

20902122
# Auto-select the best backend
20912123
if backend == "auto":
2092-
cuda_major, _ = get_cuda_version(a.device)
2093-
cc_major, cc_minor = get_compute_capability(a.device)
2094-
# If cuda version is 13 or greater:
2095-
# cudnn is more performant if cudnn version is 9.14 or greater.
2096-
if CUDNN_AVAILABLE and cuda_major >= 13 and cudnn.backend_version() >= 91400:
2097-
candidate_backends = ("cudnn", "cutlass")
2098-
# Otherwise, prioritize cutlass
2099-
else:
2100-
candidate_backends = ("cutlass", "cudnn")
2101-
2102-
# Filter to only supported backends for this compute capability
2103-
# Note: The requirement function already validated that at least one backend is supported
2104-
backends = []
2105-
for candidate in candidate_backends:
2106-
# mypy requires explicit type casting for the backend literal
2107-
backend_literal = cast(Literal["cudnn", "trtllm", "cutlass"], candidate)
2108-
try:
2109-
# Check both common constraints and backend-specific requirements
2110-
# to find all compatible backends for this problem instance
2111-
if _check_mm_fp4_problem_size(
2112-
a,
2113-
b,
2114-
a_descale,
2115-
b_descale,
2116-
alpha,
2117-
out_dtype,
2118-
out,
2119-
block_size,
2120-
use_8x4_sf_layout,
2121-
backend_literal,
2122-
use_nvfp4,
2123-
) and _mm_fp4_backend_checkers[candidate](
2124-
a,
2125-
b,
2126-
a_descale,
2127-
b_descale,
2128-
alpha,
2129-
out_dtype,
2130-
out,
2131-
block_size,
2132-
use_8x4_sf_layout,
2133-
backend_literal,
2134-
use_nvfp4,
2135-
):
2136-
backends.append(candidate)
2137-
except Exception:
2138-
pass
2124+
backends = mm_fp4.suitable_auto_backends
21392125
else:
21402126
backends = [backend]
21412127

0 commit comments

Comments
 (0)