Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 1 addition & 11 deletions benchmarks/routines/flashinfer_benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,17 +235,7 @@ def dtype_str_to_torch_dtype(dtype_str):
"10.3": ["cudnn", "cublas", "cutlass"],
"12.0": ["cudnn", "cublas"],
},
"mm_fp4": {
"7.5": [],
"8.0": [],
"8.6": [],
"8.9": [],
"9.0": [],
"10.0": ["cudnn", "trtllm", "cutlass"],
"10.3": ["cudnn", "trtllm", "cutlass"],
"12.0": ["cudnn", "cutlass"],
"12.1": ["cudnn", "cutlass"],
},
# Note: mm_fp4 uses support checkers to filter backends, so it is not listed here
# MOE
"trtllm_fp4_block_scale_moe": {
"7.5": [],
Expand Down
109 changes: 50 additions & 59 deletions benchmarks/routines/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def parse_gemm_args(line, parser):
required=False,
nargs="+",
default=["cudnn"],
choices=["cudnn", "cublas", "trtllm", "cutlass"],
choices=["cudnn", "cublas", "trtllm", "cutlass", "auto"],
help="Kernel backends to test. Default: cudnn",
)
parser.add_argument(
Expand Down Expand Up @@ -790,61 +790,14 @@ def testMmFp4(args):
run_refcheck = args.refcheck
use_128x4_sf_layout = args.use_128x4_sf_layout
use_nvfp4 = args.use_nvfp4
autotune_supported_backends = ["cutlass", "trtllm"]
autotune_supported_backends = ["cudnn", "cutlass", "trtllm", "auto"]
res = []

backends = filter_backends_by_compute_capability(backends, args.routine, device)

res_dtype = dtype_str_to_torch_dtype(args.out_dtype)
if res_dtype not in [torch.bfloat16, torch.float16]:
raise ValueError(
f"Unsupported res dtype: {res_dtype}. Supported dtypes are bfloat16 and float16."
)
## Done parsing input arguments

if "trtllm" in backends:
remove_trtllm = False
if res_dtype == torch.float16:
print("[INFO] trtllm backend does not support float16 output")
remove_trtllm = True
if remove_trtllm:
backends.remove("trtllm")
if not use_nvfp4:
print(
"[INFO] trtllm backend does not support mxfp4 quantization (use_nvfp4=False)"
)
backends.remove("trtllm")
if "cutlass" in backends:
remove_cutlass = False
if not use_128x4_sf_layout:
print("[INFO] cutlass backend does not support use_128x4_sf_layout=False")
remove_cutlass = True
if not use_nvfp4:
print(
"[INFO] cutlass backend does not support mxfp4 quantization (use_nvfp4=False)"
)
backends.remove("cutlass")
if remove_cutlass:
backends.remove("cutlass")
if "cudnn" in backends:
remove_cudnn = False
if not use_128x4_sf_layout:
print("[INFO] cudnn backend does not support use_128x4_sf_layout=False")
remove_cudnn = True
if remove_cudnn:
backends.remove("cudnn")
if getattr(args, "autotune", False):
backends_to_remove = []
for cur_backend in backends:
if cur_backend not in autotune_supported_backends:
print(f"[INFO] {cur_backend} backend does not support autotune")
backends_to_remove.append(cur_backend)
for cur_backend in backends_to_remove:
backends.remove(cur_backend)

if len(backends) == 0:
print("[ERROR] No backends to test. Exiting.")
return

input = torch.randn([m, k], device=device, dtype=torch.bfloat16)
mat2 = torch.randn([n, k], device=device, dtype=torch.bfloat16)
Expand Down Expand Up @@ -886,11 +839,22 @@ def testMmFp4(args):
print(f"[VVERBOSE] {mat2_fp4.dtype = }")

alpha = 1.0 / (global_sf_input * global_sf_mat2) if use_nvfp4 else None
# res = torch.empty([m, n], device="cuda", dtype=res_dtype)
# Completed preparing inputs. Now programmatically filter backends
block_size = 16 if use_nvfp4 else 32
backends_to_remove = []

def run_backend(backend):
if backend in ["cudnn", "trtllm", "cutlass"]:
return flashinfer.gemm.mm_fp4(
for backend in backends:
# Skip autotune check for now (handled separately below)
if (
getattr(args, "autotune", False)
and backend not in autotune_supported_backends
):
print(f"[INFO] {backend} backend does not support autotune")
backends_to_remove.append(backend)
continue

try:
flashinfer.gemm.mm_fp4(
a=input_fp4,
b=mat2_fp4.T if backend != "trtllm" else mat2_fp4_trtllm.T,
a_descale=input_inv_s,
Expand All @@ -904,6 +868,34 @@ def run_backend(backend):
backend=backend,
use_nvfp4=use_nvfp4,
)
except Exception as e:
print(
f"[INFO] {backend} backend does not support this configuration: {type(e).__name__}: {e}"
)
backends_to_remove.append(backend)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this try-except block trying to verify whether each backend supports the input configuration? I'm thinking whether it would be a good idea to expose the _is_problem_size_supported API in the decorator so that we can avoid doing something like this, and make a call directly to query this. E.g flashinfer.gemm.mm_fp4.is_problem_size_supported(args...)

Copy link
Collaborator Author

@bkryu bkryu Nov 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an interesting suggestion, but I also wonder whether it is a common use case. Essentially what we are doing here is reverse engineering what backends are supported for my inputs.

For a framework running end-to-end inference, this would not be a use case so I'd say we probably don't need to worry about it. If anything, we should advocate using the "auto" backend and deliver good backend selection accuracy, which could be done through enabling autotuning.

For a benchmarking script like the current one, I'd see the use case of is_problem_size_supported, but I'd claim that the current flow is too bad because the exception is raised very early before the kernels are launched.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, when I designed the decorator I had exactly in mind what Jimmy is nudging at: callling is_problem_size_supported before we even run the operation.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see the initial appeal of this idea and how being able to call a support check seems nice at first glance, but I am still not able to see the use case.

If we hypothetically support this, how do we expect frameworks to use it? Would they have a preferred ordering of backends in mind and try to check one by one, and pick the first one that succeeds? and then if no backend is supported maybe fallback to a non-FlashInfer choice?

My stance is:

  • Frameworks should be encouraged to use the auto option for optimal backend selection, not checking them one by one
  • If there is a support gap, we want the frameworks to come to us so we can fix it instead of falling back to some other kernel provider.

However, if I am missing something and there is a need for this, I would advocate for a followup PR, just to avoid the scope creep of this PR. The size of this PR has grown quite a bit since its inception and is becoming a bit daunting πŸ˜“

Copy link
Contributor

@nvmbreughe nvmbreughe Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, if I am missing something and there is a need for this, I would advocate for a followup PR, just to avoid the scope creep of this PR. The size of this PR has grown quite a bit since its inception and is becoming a bit daunting πŸ˜“

Yes that is not a problem.

Just to provide context of the use case -- I think is_problem_size_supported itself is not as big of a use case. But the complete picture: backend + problem size is:

A DLFW may not want to use auto, and avoid any checks. They could have determined themselves which backend is most performant (perhaps using auto initially) and bake in which backends they want. They can then call the api with skip_check=True, to avoid overhead of any checks. CC+ @elfiegg


# Remove unsupported backends
for backend in backends_to_remove:
backends.remove(backend)

if len(backends) == 0:
print("[ERROR] No backends passed validation. Exiting.")
return

def run_backend(backend):
if backend in ["cudnn", "trtllm", "cutlass", "auto"]:
return flashinfer.gemm.mm_fp4(
a=input_fp4,
b=mat2_fp4.T if backend != "trtllm" else mat2_fp4_trtllm.T,
a_descale=input_inv_s,
b_descale=mat2_inv_s.T if backend != "trtllm" else mat2_inv_s_trtllm.T,
alpha=alpha,
out_dtype=res_dtype,
block_size=block_size,
use_8x4_sf_layout=not use_128x4_sf_layout,
backend=backend,
use_nvfp4=use_nvfp4,
)
else:
raise ValueError(f"Unsupported backend: {backend}")

Expand All @@ -917,12 +909,11 @@ def run_backend(backend):
args.dry_run_iters if args.dry_run_iters and args.dry_run_iters > 0 else 10
)
for cur_backend in backends:
if cur_backend in autotune_supported_backends:
if args.verbose >= 1:
print(f"[INFO] Autotune warmup for mm_fp4: {warmup_iters} iters")
with autotune(True):
for _ in range(warmup_iters):
run_backend(cur_backend)
if args.verbose >= 1:
print(f"[INFO] Autotune warmup for mm_fp4: {warmup_iters} iters")
with autotune(True):
for _ in range(warmup_iters):
run_backend(cur_backend)

# Storage for timing results and outputs
backend_times = {backend: [] for backend in backends}
Expand Down
Loading