Skip to content

Conversation

@xuanzic
Copy link

@xuanzic xuanzic commented Nov 17, 2025

📌 Description

In flashinfer we already had fp8_gemm_kernel_swapAB kernel for optimizing Mixture of Experts (MOE) GEMM and Dense GEMM operations reference 1, reference 2, and reference 3.
This kernel improves performance in small batch scenarios by swapping the input order in matrix multiplication.

These kernels are currently used for:

  • MoE operations (exposed via fused_moe module)

  • Available in the codebase for Dense GEMM but not exposed for linear/dense layers


This PR aims to

  • Add Python binding to expose linear operations
    • Create dedicated binding for fp8_blockscale_gemm in csrc/fp8_blockscale_gemm_sm90_binding.cu
    • Add JIT module generation in flashinfer/jit/gemm/
    • Expose API in flashinfer/gemm/gemm_base.py
    • Add extensive test cases in flashinfer/tests/gemm/test_fp8_blockscale_gemm.py

TODO

  • Benchmark with real model and compare performance in vLLM comparing to Cutlass GEMM

🔍 Related Issues

vLLM 28427
vLLM 28316

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

Release Notes

  • New Features
    • Added FP8 Block-Scale GEMM support for SM90 (Hopper) GPUs
    • New API functions for efficient matrix multiplication with FP8 quantized weights and BF16 inputs
    • Supports per-token and per-block quantization modes
    • Includes workspace management and optional swapAB optimization for improved performance

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 17, 2025

Walkthrough

Introduces a complete FP8 block-scale GEMM implementation for SM90 GPUs. Adds a CUDA binding module exposing a TVM FFI interface, Python wrapper functions with input validation, JIT compilation specifications, and comprehensive tests covering various dtype combinations, shapes, and error scenarios.

Changes

Cohort / File(s) Summary
CUDA Binding Layer
csrc/fp8_blockscale_gemm_sm90_binding.cu
New TVM FFI module Fp8BlockScaleGemmRunner with gemm, get_workspace_size, and configure_workspace entry points. Handles BF16/FP8 dtype dispatch, kernel selection, workspace management, and runtime validation.
Python API Exports
flashinfer/gemm/__init__.py, flashinfer/jit/gemm/__init__.py
Exposes get_fp8_blockscale_gemm_runner_sm90 and fp8_blockscale_gemm_swapab from gemm_base module, and re-exports gen_fp8_blockscale_gemm_sm90_module from fp8_blockscale submodule.
Python Implementation
flashinfer/gemm/gemm_base.py
Adds get_fp8_blockscale_gemm_runner_sm90() to build the SM90 module and fp8_blockscale_gemm_swapab() for GEMM dispatch. Validates architecture (SM90/Hopper), shapes (K divisible by 128), dtypes (BF16/FP8 combinations), scales, and manages output allocation.
JIT Compilation
flashinfer/jit/gemm/fp8_blockscale.py
New gen_fp8_blockscale_gemm_sm90_module() function that assembles nvcc flags (HOPPER_TMA, FP8, FP8_BLOCK_SCALE), includes CUTLASS/TensorRT-LLM headers, and generates JIT spec for SM90 FP8 block-scale GEMM compilation.
Test Suite
tests/gemm/test_fp8_blockscale_gemm.py
Comprehensive test module with six test functions covering swapAB correctness, dtype combinations (BF16/BF16, BF16/FP8), per-block weight quantization, shape coverage, error handling, and pre-allocated output buffers. Includes JIT warmup fixture and parametrized test cases.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant PyAPI as Python API<br/>(fp8_blockscale_gemm_swapab)
    participant Runner as CUDA Module<br/>(Fp8BlockScaleGemmRunner)
    participant Kernel as CUDA Kernels<br/>(BF16×BF16, BF16×FP8)
    
    User->>PyAPI: fp8_blockscale_gemm_swapab(input, weight,<br/>input_scale?, weight_scale?)
    activate PyAPI
    PyAPI->>PyAPI: Validate architecture (SM90)<br/>Validate shapes (K%128==0)<br/>Validate dtypes & scales
    PyAPI->>Runner: get_fp8_blockscale_gemm_runner_sm90()
    Runner->>Runner: Load/Build JIT module
    Runner-->>PyAPI: Module instance
    PyAPI->>Runner: get_workspace_size()
    Runner-->>PyAPI: Workspace size
    PyAPI->>PyAPI: Allocate output buffer
    PyAPI->>Runner: configure_workspace(workspace_ptr)
    activate Runner
    Runner->>Kernel: Initialize for selected<br/>dtype combination
    deactivate Runner
    PyAPI->>Runner: gemm(input_ptr, weight_ptr,<br/>output_ptr, scales...)
    activate Runner
    alt BF16 × BF16 path
        Runner->>Kernel: Invoke BF16 kernel
    else BF16 × FP8 path
        Runner->>Kernel: Invoke FP8 kernel with scales
    end
    Kernel-->>Runner: Compute complete
    deactivate Runner
    Runner-->>PyAPI: Output tensor
    deactivate PyAPI
    PyAPI-->>User: Result tensor
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

  • Areas requiring extra attention:
    • csrc/fp8_blockscale_gemm_sm90_binding.cu: CUDA FFI binding logic, pointer/stride handling, dtype dispatch logic, and workspace management semantics
    • flashinfer/gemm/gemm_base.py: Input validation logic (especially scale shape/dtype checks), dtype combination handling, and runner lifecycle management
    • flashinfer/jit/gemm/fp8_blockscale.py: Nvcc flag configuration, CUDA version conditionals, and include path correctness
    • tests/gemm/test_fp8_blockscale_gemm.py: Cosine similarity thresholds (0.99 vs. ~0.967), fixture scope, and JIT warmup correctness

Suggested reviewers

  • djmmoss
  • yzh119
  • cyx-6
  • wenscarl
  • jiahanc
  • aleozlx
  • yongwww

Poem

🐰 A bunny hops through CUDA code,
FP8 blocks reduce the load,
SM90 kernels, swift and true,
Block-scale GEMM—a hop-through-blue,
From bindings bound to tests that ✓,
This feature makes computing flex!

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 45.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main change: exposing the DeepGEMM swapAB kernel for linear GEMM operations on SM90.
Description check ✅ Passed The description provides clear context about the existing kernel, its use cases, and the PR's objectives with relevant issue links, though pre-commit and test checklists remain unchecked.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (7)
csrc/fp8_blockscale_gemm_sm90_binding.cu (1)

125-134: Consider swallowing getWorkspaceSizeBase SMEM probe errors (SM90 config probing)

getWorkspaceSize directly calls getWorkspaceSizeBase on both runners and propagates any std::runtime_error. In other GEMM templates in this codebase, SMEM-limited configurations are probed and allowed to throw, with those exceptions explicitly swallowed so that other valid configs can still be considered.

It would be safer and more consistent to wrap each getWorkspaceSizeBase call in a try { ... } catch (const std::runtime_error&) { /* Swallow errors when SMEM exceeds maximum allowed */ }, using the same pattern as the other GEMM templates, so transient SMEM constraint failures don’t bubble up and break the Python path. Based on learnings

flashinfer/jit/gemm/fp8_blockscale.py (1)

12-19: Avoid empty-string flag and prefer iterable unpacking for nvcc_flags

The current expression:

nvcc_flags = sm90a_nvcc_flags + [
    "-DCOMPILE_HOPPER_TMA_GEMMS",
    "-DENABLE_BF16",
    "-DENABLE_FP8",
    "-DENABLE_FP8_BLOCK_SCALE" if is_cuda_version_at_least("12.8") else "",
]

injects an empty string into the flags list when CUDA < 12.8, and Ruff also flags this for style. You can avoid the empty flag and follow the unpacking pattern like:

nvcc_flags = [
    *sm90a_nvcc_flags,
    "-DCOMPILE_HOPPER_TMA_GEMMS",
    "-DENABLE_BF16",
    "-DENABLE_FP8",
    *(
        ["-DENABLE_FP8_BLOCK_SCALE"]
        if is_cuda_version_at_least("12.8")
        else []
    ),
]

This keeps the behavior identical while producing a cleaner flag list.

flashinfer/gemm/gemm_base.py (1)

3335-3337: Minor doc mismatch: swapAB threshold and backend naming

The docstring notes:

  • “SwapAB kernel is automatically used when M < 32 (threshold)”
  • “The function uses DeepGEMM backend with JIT compilation”

In this wrapper, there is no explicit M < 32 branch; you always call the same runner.gemm(...). If the M‑based swapAB decision is implemented entirely inside the underlying runner, it might be better to soften the wording (e.g. “the underlying kernel may select a swapAB variant for small M”) to avoid over‑specifying the threshold in Python.

Also, this path is wired through the new Fp8BlockScaleGemmRunner JIT module (Tensorrt LLM / Cutlass blockscale implementation), not the existing DeepGEMM helpers in this file, so calling it a “DeepGEMM backend” may confuse users browsing the code.

This is non‑blocking, but tightening the docstring will make it easier to maintain.

tests/gemm/test_fp8_blockscale_gemm.py (4)

163-168: Address the TODO comment about threshold verification.

The 0.967 threshold for BF16+FP8 combination appears multiple times in the test suite (also line 215) and should be validated against real model performance.

Would you like me to open an issue to track threshold validation with real model benchmarks?


217-217: Consider removing debug print statement.

Print statements in test code can clutter CI output. Consider using pytest's capture or removing if not needed for debugging.

-    print(f"✓ Per-block weight scales: cosine similarity = {cos_sim:.4f}")

256-307: Excellent error handling coverage.

All major validation scenarios are properly tested. However, consider using raw strings for regex patterns to avoid potential escape sequence issues.

Apply this diff to fix the regex pattern issues flagged by static analysis:

-    with pytest.raises(ValueError, match="FP8.*or BF16"):
+    with pytest.raises(ValueError, match=r"FP8.*or BF16"):
         fp8_blockscale_gemm_swapab(input, weight)
-    with pytest.raises(ValueError, match="FP8 input.*BF16 weight.*not supported"):
+    with pytest.raises(ValueError, match=r"FP8 input.*BF16 weight.*not supported"):
         fp8_blockscale_gemm_swapab(input_fp8, weight, input_scale, None)

55-60: Consider extracting repeated capability checks to a fixture.

The compute capability and SM90a checks are duplicated across all test functions. This could be simplified with a shared fixture, though the current approach is acceptable.

Example fixture at module level:

@pytest.fixture(scope="module")
def check_sm90a_support():
    """Check and skip if SM90a is not supported."""
    compute_capability = get_compute_capability(torch.device("cuda"))
    if compute_capability[0] < 9:
        pytest.skip("FP8 block-scale GEMM requires SM90 (Hopper) or later")
    if not is_sm90a_supported(torch.device("cuda")):
        pytest.skip("FP8 block-scale GEMM requires SM90a (Hopper) support")

Then use check_sm90a_support as a fixture parameter in each test.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0aee7af and 6f3449d.

📒 Files selected for processing (6)
  • csrc/fp8_blockscale_gemm_sm90_binding.cu (1 hunks)
  • flashinfer/gemm/__init__.py (2 hunks)
  • flashinfer/gemm/gemm_base.py (2 hunks)
  • flashinfer/jit/gemm/__init__.py (2 hunks)
  • flashinfer/jit/gemm/fp8_blockscale.py (1 hunks)
  • tests/gemm/test_fp8_blockscale_gemm.py (1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • csrc/fp8_blockscale_gemm_sm90_binding.cu
🧬 Code graph analysis (6)
flashinfer/gemm/__init__.py (1)
flashinfer/gemm/gemm_base.py (2)
  • get_fp8_blockscale_gemm_runner_sm90 (3249-3251)
  • fp8_blockscale_gemm_swapab (3254-3477)
flashinfer/jit/gemm/fp8_blockscale.py (2)
flashinfer/jit/core.py (2)
  • JitSpec (213-312)
  • gen_jit_spec (315-381)
flashinfer/jit/cpp_ext.py (1)
  • is_cuda_version_at_least (86-87)
flashinfer/jit/gemm/__init__.py (1)
flashinfer/jit/gemm/fp8_blockscale.py (1)
  • gen_fp8_blockscale_gemm_sm90_module (12-56)
flashinfer/gemm/gemm_base.py (3)
flashinfer/jit/gemm/fp8_blockscale.py (1)
  • gen_fp8_blockscale_gemm_sm90_module (12-56)
flashinfer/jit/core.py (1)
  • build_and_load (300-312)
csrc/fp8_blockscale_gemm_sm90_binding.cu (8)
  • init (152-155)
  • init (152-152)
  • input (87-123)
  • input (87-88)
  • input_is_fp8 (75-85)
  • input_is_fp8 (75-76)
  • workspace (136-142)
  • workspace (136-136)
tests/gemm/test_fp8_blockscale_gemm.py (7)
flashinfer/gemm/gemm_base.py (1)
  • fp8_blockscale_gemm_swapab (3254-3477)
flashinfer/testing/utils.py (1)
  • per_token_cast_to_fp8 (39-46)
flashinfer/utils.py (2)
  • get_compute_capability (253-256)
  • is_sm90a_supported (526-528)
flashinfer/jit/env.py (1)
  • has_flashinfer_jit_cache (27-36)
flashinfer/jit/gemm/fp8_blockscale.py (1)
  • gen_fp8_blockscale_gemm_sm90_module (12-56)
flashinfer/jit/core.py (1)
  • build_jit_specs (392-414)
csrc/fp8_blockscale_gemm_sm90_binding.cu (2)
  • input (87-123)
  • input (87-88)
csrc/fp8_blockscale_gemm_sm90_binding.cu (1)
csrc/tvm_ffi_utils.h (2)
  • encode_dlpack_dtype (29-31)
  • get_stream (272-274)
🪛 Ruff (0.14.5)
flashinfer/jit/gemm/fp8_blockscale.py

1-1: The file is executable but no shebang is present

(EXE002)


14-19: Consider iterable unpacking instead of concatenation

Replace with iterable unpacking

(RUF005)

flashinfer/gemm/gemm_base.py

3337-3339: Avoid specifying long messages outside the exception class

(TRY003)


3343-3343: Avoid specifying long messages outside the exception class

(TRY003)


3345-3345: Avoid specifying long messages outside the exception class

(TRY003)


3351-3353: Avoid specifying long messages outside the exception class

(TRY003)


3358-3360: Avoid specifying long messages outside the exception class

(TRY003)


3370-3372: Avoid specifying long messages outside the exception class

(TRY003)


3377-3379: Avoid specifying long messages outside the exception class

(TRY003)


3382-3385: Avoid specifying long messages outside the exception class

(TRY003)


3387-3389: Avoid specifying long messages outside the exception class

(TRY003)


3391-3394: Avoid specifying long messages outside the exception class

(TRY003)


3397-3400: Avoid specifying long messages outside the exception class

(TRY003)


3402-3405: Avoid specifying long messages outside the exception class

(TRY003)


3409-3411: Avoid specifying long messages outside the exception class

(TRY003)


3418-3422: Avoid specifying long messages outside the exception class

(TRY003)


3424-3426: Avoid specifying long messages outside the exception class

(TRY003)


3429-3432: Avoid specifying long messages outside the exception class

(TRY003)


3434-3437: Avoid specifying long messages outside the exception class

(TRY003)


3442-3444: Avoid specifying long messages outside the exception class

(TRY003)


3446-3448: Avoid specifying long messages outside the exception class

(TRY003)


3450-3452: Avoid specifying long messages outside the exception class

(TRY003)


3458-3460: Avoid specifying long messages outside the exception class

(TRY003)

tests/gemm/test_fp8_blockscale_gemm.py

1-1: The file is executable but no shebang is present

(EXE002)


277-277: Pattern passed to match= contains metacharacters but is neither escaped nor raw

(RUF043)


306-306: Pattern passed to match= contains metacharacters but is neither escaped nor raw

(RUF043)

🔇 Additional comments (10)
flashinfer/jit/gemm/__init__.py (1)

30-45: Export of gen_fp8_blockscale_gemm_sm90_module looks consistent

Import and __all__ exposure match existing JIT GEMM module patterns; no issues here.

flashinfer/gemm/__init__.py (1)

16-41: Public export of SM90 FP8 blockscale GEMM entry points is correct

Re-exporting get_fp8_blockscale_gemm_runner_sm90 and fp8_blockscale_gemm_swapab via this module is consistent with how other GEMM helpers are surfaced; no additional wiring concerns.

flashinfer/gemm/gemm_base.py (1)

3248-3252: JIT runner construction for SM90 FP8 blockscale GEMM looks good

Using functools.cache around gen_fp8_blockscale_gemm_sm90_module().build_and_load().init() matches the TVM FFI factory pattern used elsewhere and avoids repeated JIT builds. No issues here.

tests/gemm/test_fp8_blockscale_gemm.py (7)

1-29: LGTM - Clean imports and setup.

The imports are well-organized and all appear necessary for the test suite.


32-41: LGTM - Efficient JIT warmup strategy.

The conditional autouse with module scope ensures JIT compilation happens once only when needed, optimizing test execution time.


44-89: LGTM - Comprehensive basic functionality test.

Well-parametrized test covering a good range of shapes with proper device capability checks and correctness verification.


199-201: LGTM - Good scale validation.

Proper verification of scale shape and the reciprocal format (positive values).


220-253: LGTM - Good coverage of LLM inference shapes.

The parametrized shapes represent typical decoder layer dimensions, providing realistic test coverage.


310-342: LGTM - Proper output buffer reuse test.

The test correctly verifies both buffer identity (line 333) and correctness, ensuring the pre-allocated buffer path works as expected.


345-346: LGTM - Standard test entry point.

Comment on lines +75 to +85
kernels::CutlassFp8BlockScaleGemmRunnerInterface* selectRunner(
bool input_is_fp8, bool weight_is_fp8) {

if (!input_is_fp8 && !weight_is_fp8) {
return runner_bf16_bf16_.get();
} else if (!input_is_fp8 && weight_is_fp8) {
return runner_bf16_fp8_.get();
} else {
return nullptr;
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

C++ runner only supports BF16 input; any FP8 input combination is rejected

selectRunner deliberately only returns a runner for (!input_is_fp8 && !weight_is_fp8) and (!input_is_fp8 && weight_is_fp8). Any case with input_is_fp8 == true will yield nullptr and trip the "Unsupported dtype combination" check in runGemm. This is fine as long as the Python API never calls this path with FP8 inputs. The current Python wrapper (fp8_blockscale_gemm_swapab in gemm_base.py) allows FP8 inputs and advertises them in the docstring, which does not match what this binding actually supports.

I’d recommend either:

  • Extending this binding and kernels to handle FP8 inputs, or
  • Tightening the Python wrapper to reject FP8 inputs up front and updating the docstring accordingly, so users don’t hit a runtime "Unsupported dtype combination" from this C++ layer.

Comment on lines +3262 to +3438
"""
Perform FP8 block-scaled GEMM with automatic swapAB optimization.
This function automatically selects between normal and swapAB kernel based on
the M dimension. For small M (< 32), it uses the swapAB kernel for
better performance.
Supported Dtype Combinations
-----------------------------
- **BF16 + BF16 → BF16**: Both inputs BF16, internal quantization (no scales needed)
- **BF16 + FP8 → BF16**: BF16 input, FP8 weight
Parameters
----------
input : torch.Tensor
Input activation tensor of shape (M, K).
- BF16 (torch.bfloat16) with internal quantization
weight : torch.Tensor
Weight tensor of shape (N, K). Can be:
- FP8 (torch.float8_e4m3fn) with weight_scale required
- BF16 (torch.bfloat16) for internal quantization
input_scale : torch.Tensor, optional
weight_scale : torch.Tensor, optional
Scaling factors for weight. Required if weight is FP8.
out : torch.Tensor, optional
Output tensor of shape (M, N). If None, will be allocated.
out_dtype : torch.dtype, optional
Output data type. Default is torch.bfloat16.
Returns
-------
torch.Tensor
Output tensor of shape (M, N) with dtype `out_dtype`.
Examples
--------
>>> import torch
>>> from flashinfer.gemm import fp8_blockscale_gemm_swapab
>>>
>>> M, N, K = 16, 4096, 4096
>>> device = "cuda"
>>>
>>> # BF16 inputs
>>> input_bf16 = torch.randn(M, K, device=device, dtype=torch.bfloat16)
>>> weight_bf16 = torch.randn(N, K, device=device, dtype=torch.bfloat16)
>>> output = fp8_blockscale_gemm_swapab(input_bf16, weight_bf16)
>>> print(output.shape) # torch.Size([16, 4096])
>>>
>>> # Mixed: BF16 input + FP8 weight
>>> from flashinfer.testing.utils import per_token_cast_to_fp8
>>> input_bf16 = torch.randn(M, K, device=device, dtype=torch.bfloat16)
>>> weight_bf16 = torch.randn(N, K, device=device, dtype=torch.bfloat16)
>>> weight_fp8, weight_scale = per_token_cast_to_fp8(weight_bf16)
>>> output = fp8_blockscale_gemm_swapab(input_bf16, weight_fp8, None, weight_scale)
>>> print(output.shape) # torch.Size([16, 4096])
>>>
>>> # FP8 weight with 128x128 block scales
>>> from flashinfer.testing.utils import per_block_cast_to_fp8
>>> weight_bf16 = torch.randn(N, K, device=device, dtype=torch.bfloat16)
>>> weight_fp8, weight_scale = per_block_cast_to_fp8(weight_bf16)
>>> # weight_scale has shape (N // 128, K // 128)
>>> input_bf16 = torch.randn(M, K, device=device, dtype=torch.bfloat16)
>>> output = fp8_blockscale_gemm_swapab(input_bf16, weight_fp8, None, weight_scale)
>>> print(output.shape) # torch.Size([16, 4096])
Notes
-----
- This function requires NVIDIA Hopper (SM90) architecture and CUDA 12.8+
- SwapAB kernel is automatically used when M < 32 (threshold)
- For FP8 inputs, scaling factors must be provided
- For BF16 inputs, quantization and scaling happen internally
- Weight scales support two granularities:
* Per-token (1x128 blocks): (N, K//128)
* Per-block (128x128 blocks): (N//128, K//128)
- Input scales only support per-token format: (M, K//128)
- The function uses DeepGEMM backend with JIT compilation
"""
# Validate architecture support
if not _match_sm_version(input.device, ["90", "90a"]):
raise ValueError(
"fp8_blockscale_gemm_swapab is only supported on SM90 (Hopper) architecture."
)

# Validate tensor dimensions
if input.ndim != 2:
raise ValueError(f"Input must be 2D (M, K), got shape {input.shape}")
if weight.ndim != 2:
raise ValueError(f"Weight must be 2D (N, K), got shape {weight.shape}")

M, K = input.shape
N, K_weight = weight.shape

if K != K_weight:
raise ValueError(
f"K dimension mismatch: input has K={K}, weight has K={K_weight}"
)

# Validate K is divisible by block size (128)
BLOCK_SIZE = 128
if K % BLOCK_SIZE != 0:
raise ValueError(
f"K dimension must be divisible by block size ({BLOCK_SIZE}), got K={K}"
)

# Validate dtype combinations
input_is_fp8 = input.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]
weight_is_fp8 = weight.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]
input_is_bf16 = input.dtype == torch.bfloat16
weight_is_bf16 = weight.dtype == torch.bfloat16

# Explicitly reject FP8 input + BF16 weight (missing kernel implementation)
if input_is_fp8 and weight_is_bf16:
raise ValueError(
"FP8 input + BF16 weight is not supported (missing kernel implementation). "
)

# Validate scale requirements for FP8 inputs
if input_is_fp8:
if input_scale is None:
raise ValueError(
"input_scale is required when input is FP8. "
)
expected_scale_shape = (M, K // BLOCK_SIZE)
if input_scale.shape != expected_scale_shape:
raise ValueError(
f"input_scale shape mismatch. Expected {expected_scale_shape}, "
f"got {input_scale.shape}"
)
if input_scale.dtype != torch.float32:
raise ValueError(
f"input_scale must be float32, got {input_scale.dtype}"
)
if input_scale.device != input.device:
raise ValueError(
f"input_scale device mismatch. Expected {input.device}, "
f"got {input_scale.device}"
)
else:
if not input_is_bf16:
raise ValueError(
f"Input must be either FP8 (torch.float8_e4m3fn) or BF16 (torch.bfloat16), "
f"got {input.dtype}"
)
if input_scale is not None:
raise ValueError(
"input_scale should not be provided for BF16 inputs. "
"Use FP8 inputs if you want to provide external scales."
)

if weight_is_fp8:
if weight_scale is None:
raise ValueError(
"weight_scale is required when weight is FP8. "
)
expected_per_token_shape = (N, K // BLOCK_SIZE)
expected_per_block_shape = ((N + BLOCK_SIZE - 1) // BLOCK_SIZE, K // BLOCK_SIZE)
is_per_token = weight_scale.shape == expected_per_token_shape
is_per_block = weight_scale.shape == expected_per_block_shape

if not (is_per_token or is_per_block):
raise ValueError(
f"weight_scale shape mismatch. Expected either {expected_per_token_shape} "
f"(per-token, 1x128 blocks) or {expected_per_block_shape} "
f"(per-block, 128x128 blocks), got {weight_scale.shape}"
)
if weight_scale.dtype != torch.float32:
raise ValueError(
f"weight_scale must be float32, got {weight_scale.dtype}"
)
else:
if not weight_is_bf16:
raise ValueError(
f"Weight must be either FP8 (torch.float8_e4m3fn) or BF16 (torch.bfloat16), "
f"got {weight.dtype}"
)
if weight_scale is not None:
raise ValueError(
"weight_scale should not be provided for BF16 weights. "
"Use FP8 weights if you want to provide external scales."
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Dtype support mismatch: FP8 inputs and float8_e5m2 weights are not actually supported by the C++ runner

Within fp8_blockscale_gemm_swapab, the dtype logic is:

input_is_fp8 = input.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]
weight_is_fp8 = weight.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]
input_is_bf16 = input.dtype == torch.bfloat16
weight_is_bf16 = weight.dtype == torch.bfloat16

and you only explicitly reject FP8 input + BF16 weight. The docstring also describes FP8 inputs as supported with external scales.

However, the C++ binding (Fp8BlockScaleGemmRunner) only instantiates kernels for:

  • BF16 input + BF16 weight, and
  • BF16 input + FP8 e4m3 weight,

and its dispatch:

bool input_is_fp8 = is_fp8_e4m3fn(input.dtype());
bool weight_is_fp8 = is_fp8_e4m3fn(weight.dtype());
auto* runner = selectRunner(input_is_fp8, weight_is_fp8);

only returns a runner when input_is_fp8 == false. Any FP8 input combination results in runner == nullptr and a runtime "Unsupported dtype combination" error, so FP8 inputs are not actually usable here despite being allowed and documented in Python.

More critically, for weight.dtype == torch.float8_e5m2:

  • Python treats this as weight_is_fp8 == True and enforces FP8-scale semantics,
  • but the C++ side’s is_fp8_e4m3fn returns false, so the BF16–BF16 runner is used, which interprets the weight buffer as __nv_bfloat16 instead of 1‑byte FP8. That is undefined behavior and will produce corrupt results (and potentially out-of-bounds accesses).

To make this safe and predictable, I’d strongly recommend:

  1. Restrict supported dtypes to what the C++ runner actually implements:

    • Require input.dtype == torch.bfloat16 (no FP8 inputs for now).
    • Require weight.dtype to be either torch.bfloat16 or torch.float8_e4m3fn specifically.
    • Explicitly reject torch.float8_e5m2 for both input and weight with a clear error message.
  2. Align the docstring with the real support:

    • Remove or qualify references to FP8 inputs until a corresponding C++ path exists.
    • Clarify that only FP8 e4m3fn weights are supported for now.
  3. Optionally, if FP8 inputs and e5m2 support are desired, extend the C++ binding to add the necessary runner instantiations and dispatch branches instead of allowing them through the Python-side checks.

Because the current behavior can silently misinterpret float8_e5m2 weights, this should be addressed before relying on this API in user code.

🧰 Tools
🪛 Ruff (0.14.5)

3337-3339: Avoid specifying long messages outside the exception class

(TRY003)


3343-3343: Avoid specifying long messages outside the exception class

(TRY003)


3345-3345: Avoid specifying long messages outside the exception class

(TRY003)


3351-3353: Avoid specifying long messages outside the exception class

(TRY003)


3358-3360: Avoid specifying long messages outside the exception class

(TRY003)


3370-3372: Avoid specifying long messages outside the exception class

(TRY003)


3377-3379: Avoid specifying long messages outside the exception class

(TRY003)


3382-3385: Avoid specifying long messages outside the exception class

(TRY003)


3387-3389: Avoid specifying long messages outside the exception class

(TRY003)


3391-3394: Avoid specifying long messages outside the exception class

(TRY003)


3397-3400: Avoid specifying long messages outside the exception class

(TRY003)


3402-3405: Avoid specifying long messages outside the exception class

(TRY003)


3409-3411: Avoid specifying long messages outside the exception class

(TRY003)


3418-3422: Avoid specifying long messages outside the exception class

(TRY003)


3424-3426: Avoid specifying long messages outside the exception class

(TRY003)


3429-3432: Avoid specifying long messages outside the exception class

(TRY003)


3434-3437: Avoid specifying long messages outside the exception class

(TRY003)

Comment on lines +3407 to +3427
if weight_is_fp8:
if weight_scale is None:
raise ValueError(
"weight_scale is required when weight is FP8. "
)
expected_per_token_shape = (N, K // BLOCK_SIZE)
expected_per_block_shape = ((N + BLOCK_SIZE - 1) // BLOCK_SIZE, K // BLOCK_SIZE)
is_per_token = weight_scale.shape == expected_per_token_shape
is_per_block = weight_scale.shape == expected_per_block_shape

if not (is_per_token or is_per_block):
raise ValueError(
f"weight_scale shape mismatch. Expected either {expected_per_token_shape} "
f"(per-token, 1x128 blocks) or {expected_per_block_shape} "
f"(per-block, 128x128 blocks), got {weight_scale.shape}"
)
if weight_scale.dtype != torch.float32:
raise ValueError(
f"weight_scale must be float32, got {weight_scale.dtype}"
)
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Missing device check for weight_scale can hand a host pointer to a CUDA kernel

In the FP8‑weight branch you validate weight_scale’s shape and dtype, but not its device:

if weight_is_fp8:
    ...
    if weight_scale.dtype != torch.float32:
        raise ValueError(...)

There’s no check that weight_scale.device == weight.device == input.device. If a caller accidentally passes a CPU weight_scale while input and weight are CUDA tensors, the C++ runner will still receive scales_b_ptr and try to use it on the GPU, leading to undefined behavior.

Recommend adding a device check symmetric to the one you already have for input_scale:

if weight_scale.device != weight.device:
    raise ValueError(
        f"weight_scale device mismatch. Expected {weight.device}, "
        f"got {weight_scale.device}"
    )

and optionally also enforce weight.device == input.device early for clarity.

🧰 Tools
🪛 Ruff (0.14.5)

3409-3411: Avoid specifying long messages outside the exception class

(TRY003)


3418-3422: Avoid specifying long messages outside the exception class

(TRY003)


3424-3426: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In flashinfer/gemm/gemm_base.py around lines 3407 to 3427, the FP8-weight branch
validates weight_scale shape and dtype but omits a device check; add a guard
that weight_scale.device matches weight.device (and raise a ValueError with a
clear message if it does not), and also ensure earlier that weight.device ==
input.device (or check weight_scale.device == input.device) so no host (CPU)
tensor is passed when weight/input are on CUDA.

@yzh119
Copy link
Collaborator

yzh119 commented Nov 22, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !158 has been created, and the CI pipeline #38984962 is currently running. I'll report back once the pipeline job completes.

return tvm::ffi::Module(ptr);
}

TVM_FFI_DLL_EXPORT_TYPED_FUNC(init, init); No newline at end of file
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please fix the pre-commit issues, by running pre-commit run --all-files

@flashinfer-bot
Copy link
Collaborator

[SUCCESS] Pipeline #38984962: 15/18 passed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants