Skip to content

Commit 038e46a

Browse files
committed
Address comments
1 parent a01b4b9 commit 038e46a

File tree

3 files changed

+123
-114
lines changed

3 files changed

+123
-114
lines changed

benchmarks/routines/flashinfer_benchmark_utils.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -236,17 +236,7 @@ def dtype_str_to_torch_dtype(dtype_str):
236236
"10.3": ["cudnn", "cublas", "cutlass"],
237237
"12.0": ["cudnn", "cublas"],
238238
},
239-
"mm_fp4": {
240-
"7.5": [],
241-
"8.0": [],
242-
"8.6": [],
243-
"8.9": [],
244-
"9.0": [],
245-
"10.0": ["cudnn", "trtllm", "cutlass", "auto"],
246-
"10.3": ["cudnn", "trtllm", "cutlass", "auto"],
247-
"12.0": ["cudnn", "cutlass", "auto"],
248-
"12.1": ["cudnn", "cutlass", "auto"],
249-
},
239+
# Note: mm_fp4 uses support checkers to filter backends, so it is not listed here
250240
# MOE
251241
"trtllm_fp4_block_scale_moe": {
252242
"7.5": [],

benchmarks/routines/gemm.py

Lines changed: 76 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -793,65 +793,11 @@ def testMmFp4(args):
793793
autotune_supported_backends = ["cudnn", "cutlass", "trtllm", "auto"]
794794
res = []
795795

796-
backends = filter_backends_by_compute_capability(backends, args.routine, device)
797-
798796
res_dtype = dtype_str_to_torch_dtype(args.out_dtype)
799797
if res_dtype not in [torch.bfloat16, torch.float16]:
800798
raise ValueError(
801799
f"Unsupported res dtype: {res_dtype}. Supported dtypes are bfloat16 and float16."
802800
)
803-
## Done parsing input arguments
804-
805-
if "trtllm" in backends:
806-
remove_trtllm = False
807-
if res_dtype == torch.float16:
808-
print("[INFO] trtllm backend does not support float16 output")
809-
remove_trtllm = True
810-
if remove_trtllm:
811-
backends.remove("trtllm")
812-
if not use_nvfp4:
813-
print(
814-
"[INFO] trtllm backend does not support mxfp4 quantization (use_nvfp4=False)"
815-
)
816-
backends.remove("trtllm")
817-
if "cutlass" in backends:
818-
remove_cutlass = False
819-
if not use_128x4_sf_layout:
820-
print("[INFO] cutlass backend does not support use_128x4_sf_layout=False")
821-
remove_cutlass = True
822-
if not use_nvfp4:
823-
print(
824-
"[INFO] cutlass backend does not support mxfp4 quantization (use_nvfp4=False)"
825-
)
826-
remove_cutlass = True
827-
if remove_cutlass:
828-
backends.remove("cutlass")
829-
if "cudnn" in backends:
830-
remove_cudnn = False
831-
if not use_128x4_sf_layout:
832-
print("[INFO] cudnn backend does not support use_128x4_sf_layout=False")
833-
remove_cudnn = True
834-
if remove_cudnn:
835-
backends.remove("cudnn")
836-
if "auto" in backends:
837-
remove_auto = False
838-
if not use_128x4_sf_layout:
839-
print("[INFO] auto backend does not support use_128x4_sf_layout=False")
840-
remove_auto = True
841-
if remove_auto:
842-
backends.remove("auto")
843-
if getattr(args, "autotune", False):
844-
backends_to_remove = []
845-
for cur_backend in backends:
846-
if cur_backend not in autotune_supported_backends:
847-
print(f"[INFO] {cur_backend} backend does not support autotune")
848-
backends_to_remove.append(cur_backend)
849-
for cur_backend in backends_to_remove:
850-
backends.remove(cur_backend)
851-
852-
if len(backends) == 0:
853-
print("[ERROR] No backends to test. Exiting.")
854-
return
855801

856802
input = torch.randn([m, k], device=device, dtype=torch.bfloat16)
857803
mat2 = torch.randn([n, k], device=device, dtype=torch.bfloat16)
@@ -893,7 +839,77 @@ def testMmFp4(args):
893839
print(f"[VVERBOSE] {mat2_fp4.dtype = }")
894840

895841
alpha = 1.0 / (global_sf_input * global_sf_mat2) if use_nvfp4 else None
896-
# res = torch.empty([m, n], device="cuda", dtype=res_dtype)
842+
# Completed preparing inputs. Now programmatically filter backends
843+
block_size = 16 if use_nvfp4 else 32
844+
backends_to_remove = []
845+
846+
for backend in backends:
847+
# Skip autotune check for now (handled separately below)
848+
if (
849+
getattr(args, "autotune", False)
850+
and backend not in autotune_supported_backends
851+
):
852+
print(f"[INFO] {backend} backend does not support autotune")
853+
backends_to_remove.append(backend)
854+
continue
855+
856+
try:
857+
from flashinfer.gemm import (
858+
_mm_fp4_backend_checkers,
859+
_check_mm_fp4_problem_size,
860+
)
861+
862+
# Choose correct tensors for this backend
863+
if backend == "trtllm":
864+
b_tensor = mat2_fp4_trtllm.T
865+
b_descale = mat2_inv_s_trtllm.T
866+
else:
867+
b_tensor = mat2_fp4.T
868+
b_descale = mat2_inv_s.T
869+
870+
# Validate common requirements
871+
_check_mm_fp4_problem_size(
872+
input_fp4,
873+
b_tensor,
874+
input_inv_s,
875+
b_descale,
876+
alpha,
877+
res_dtype,
878+
None, # out
879+
block_size,
880+
not use_128x4_sf_layout, # use_8x4_sf_layout
881+
backend,
882+
use_nvfp4,
883+
)
884+
885+
# Validate backend-specific requirements
886+
if backend in _mm_fp4_backend_checkers:
887+
_mm_fp4_backend_checkers[backend](
888+
input_fp4,
889+
b_tensor,
890+
input_inv_s,
891+
b_descale,
892+
alpha,
893+
res_dtype,
894+
None, # out
895+
block_size,
896+
not use_128x4_sf_layout,
897+
backend,
898+
use_nvfp4,
899+
)
900+
except Exception as e:
901+
print(
902+
f"[INFO] {backend} backend does not support this configuration: {type(e).__name__}: {e}"
903+
)
904+
backends_to_remove.append(backend)
905+
906+
# Remove unsupported backends
907+
for backend in backends_to_remove:
908+
backends.remove(backend)
909+
910+
if len(backends) == 0:
911+
print("[ERROR] No backends passed validation. Exiting.")
912+
return
897913

898914
def run_backend(backend):
899915
if backend in ["cudnn", "trtllm", "cutlass", "auto"]:
@@ -924,12 +940,11 @@ def run_backend(backend):
924940
args.dry_run_iters if args.dry_run_iters and args.dry_run_iters > 0 else 10
925941
)
926942
for cur_backend in backends:
927-
if cur_backend in autotune_supported_backends:
928-
if args.verbose >= 1:
929-
print(f"[INFO] Autotune warmup for mm_fp4: {warmup_iters} iters")
930-
with autotune(True):
931-
for _ in range(warmup_iters):
932-
run_backend(cur_backend)
943+
if args.verbose >= 1:
944+
print(f"[INFO] Autotune warmup for mm_fp4: {warmup_iters} iters")
945+
with autotune(True):
946+
for _ in range(warmup_iters):
947+
run_backend(cur_backend)
933948

934949
# Storage for timing results and outputs
935950
backend_times = {backend: [] for backend in backends}

flashinfer/gemm/gemm_base.py

Lines changed: 46 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,10 @@ def forward(
441441
_,
442442
workspace_buffer,
443443
) = inputs
444+
if a.dtype == torch.uint8 and a_descale.dtype == torch.float8_e4m3fn:
445+
a_descale = a_descale.view(torch.uint8)
446+
if b.dtype == torch.uint8 and b_descale.dtype == torch.float8_e4m3fn:
447+
b_descale = b_descale.view(torch.uint8)
444448
module.fp4_gemm(
445449
a, b.T, a_descale, b_descale.T, alpha, out, workspace_buffer, tactic
446450
)
@@ -1947,7 +1951,7 @@ def _cutlass_gemm_fp4_requirement(
19471951
return True
19481952

19491953

1950-
@supported_compute_capability([100, 103, 110, 120])
1954+
@supported_compute_capability([100, 103, 110, 120, 121])
19511955
def _auto_gemm_fp4_requirement(
19521956
a: torch.Tensor,
19531957
b: torch.Tensor,
@@ -1985,14 +1989,16 @@ def _auto_gemm_fp4_requirement(
19851989
return False
19861990

19871991

1992+
_mm_fp4_backend_checkers = {
1993+
"cudnn": _cudnn_gemm_fp4_requirement,
1994+
"trtllm": _trtllm_gemm_fp4_requirement,
1995+
"cutlass": _cutlass_gemm_fp4_requirement,
1996+
"auto": _auto_gemm_fp4_requirement,
1997+
}
1998+
1999+
19882000
@backend_requirement(
1989-
{
1990-
"cudnn": _cudnn_gemm_fp4_requirement, # Each backend has its own requirement function
1991-
"trtllm": _trtllm_gemm_fp4_requirement,
1992-
"cutlass": _cutlass_gemm_fp4_requirement,
1993-
"auto": _auto_gemm_fp4_requirement, # Auto backend requires at least one backend to be supported on the current device
1994-
},
1995-
common_check=_check_mm_fp4_problem_size, # Shape checks common to all backends
2001+
backend_checks=_mm_fp4_backend_checkers, common_check=_check_mm_fp4_problem_size
19962002
)
19972003
def mm_fp4(
19982004
a: torch.Tensor,
@@ -2087,7 +2093,7 @@ def mm_fp4(
20872093
cc_major, cc_minor = get_compute_capability(a.device)
20882094
# If cuda version is 13 or greater:
20892095
# cudnn is more performant if cudnn version is 9.14 or greater.
2090-
if cuda_major >= 13 and cudnn.backend_version() >= 91400:
2096+
if CUDNN_AVAILABLE and cuda_major >= 13 and cudnn.backend_version() >= 91400:
20912097
candidate_backends = ("cudnn", "cutlass")
20922098
# Otherwise, prioritize cutlass
20932099
else:
@@ -2098,11 +2104,11 @@ def mm_fp4(
20982104
backends = []
20992105
for candidate in candidate_backends:
21002106
# mypy requires explicit type casting for the backend literal
2101-
backend_literal = cast(
2102-
Literal["cudnn", "trtllm", "cutlass", "auto"], candidate
2103-
)
2107+
backend_literal = cast(Literal["cudnn", "trtllm", "cutlass"], candidate)
21042108
try:
2105-
_check_mm_fp4_problem_size(
2109+
# Check both common constraints and backend-specific requirements
2110+
# to find all compatible backends for this problem instance
2111+
if _check_mm_fp4_problem_size(
21062112
a,
21072113
b,
21082114
a_descale,
@@ -2114,41 +2120,39 @@ def mm_fp4(
21142120
use_8x4_sf_layout,
21152121
backend_literal,
21162122
use_nvfp4,
2117-
)
2118-
backends.append(candidate)
2123+
) and _mm_fp4_backend_checkers[candidate](
2124+
a,
2125+
b,
2126+
a_descale,
2127+
b_descale,
2128+
alpha,
2129+
out_dtype,
2130+
out,
2131+
block_size,
2132+
use_8x4_sf_layout,
2133+
backend_literal,
2134+
use_nvfp4,
2135+
):
2136+
backends.append(candidate)
21192137
except Exception:
21202138
pass
21212139
else:
21222140
backends = [backend]
21232141

21242142
# At this point, backends contains a supported backend if specified, or all supported backends if backend='auto'.
2125-
runners = []
2126-
for cur_backend in backends:
2127-
if cur_backend == "cudnn":
2128-
runners.append(_cudnn_gemm_fp4_runner())
2129-
elif cur_backend == "trtllm":
2130-
runners.append(
2131-
get_trtllm_fp4_gemm_module().trtllm_fp4_gemm_runner(use_8x4_sf_layout)
2132-
)
2133-
elif cur_backend == "cutlass":
2134-
if a.dtype == torch.uint8 and a_descale.dtype == torch.float8_e4m3fn:
2135-
a_descale = a_descale.view(torch.uint8)
2136-
if b.dtype == torch.uint8 and b_descale.dtype == torch.float8_e4m3fn:
2137-
b_descale = b_descale.view(torch.uint8)
2138-
2139-
# Dispatch to the correct module based on device architecture
2140-
major, _ = get_compute_capability(a.device)
2141-
if major == 12:
2142-
runners.append(
2143-
get_gemm_sm120_module_cutlass_fp4().cutlass_fp4_gemm_runner()
2144-
)
2145-
else:
2146-
runners.append(
2147-
get_gemm_sm100_module_cutlass_fp4().cutlass_fp4_gemm_runner()
2148-
)
2149-
else:
2150-
# Should not reach this
2151-
raise ValueError(f"Unsupported backend: {cur_backend}")
2143+
# Lazy initialization of runners to avoid overhead of creating a new runner that will not be used
2144+
major, _ = get_compute_capability(a.device)
2145+
2146+
backend_to_runner_factory = {
2147+
"cudnn": lambda: _cudnn_gemm_fp4_runner(),
2148+
"trtllm": lambda: get_trtllm_fp4_gemm_module().trtllm_fp4_gemm_runner(
2149+
use_8x4_sf_layout
2150+
),
2151+
"cutlass": lambda: get_gemm_sm120_module_cutlass_fp4().cutlass_fp4_gemm_runner()
2152+
if major == 12
2153+
else get_gemm_sm100_module_cutlass_fp4().cutlass_fp4_gemm_runner(),
2154+
}
2155+
runners = [backend_to_runner_factory[cur_backend]() for cur_backend in backends]
21522156

21532157
# Now we have a list of runners for desired & supported backends.
21542158
tuner = AutoTuner.get()

0 commit comments

Comments
 (0)