-
Notifications
You must be signed in to change notification settings - Fork 582
make DeepGEMM swapAB available for linear gemm SM90 #2101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
WalkthroughIntroduces a complete FP8 block-scale GEMM implementation for SM90 GPUs. Adds a CUDA binding module exposing a TVM FFI interface, Python wrapper functions with input validation, JIT compilation specifications, and comprehensive tests covering various dtype combinations, shapes, and error scenarios. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant PyAPI as Python API<br/>(fp8_blockscale_gemm_swapab)
participant Runner as CUDA Module<br/>(Fp8BlockScaleGemmRunner)
participant Kernel as CUDA Kernels<br/>(BF16×BF16, BF16×FP8)
User->>PyAPI: fp8_blockscale_gemm_swapab(input, weight,<br/>input_scale?, weight_scale?)
activate PyAPI
PyAPI->>PyAPI: Validate architecture (SM90)<br/>Validate shapes (K%128==0)<br/>Validate dtypes & scales
PyAPI->>Runner: get_fp8_blockscale_gemm_runner_sm90()
Runner->>Runner: Load/Build JIT module
Runner-->>PyAPI: Module instance
PyAPI->>Runner: get_workspace_size()
Runner-->>PyAPI: Workspace size
PyAPI->>PyAPI: Allocate output buffer
PyAPI->>Runner: configure_workspace(workspace_ptr)
activate Runner
Runner->>Kernel: Initialize for selected<br/>dtype combination
deactivate Runner
PyAPI->>Runner: gemm(input_ptr, weight_ptr,<br/>output_ptr, scales...)
activate Runner
alt BF16 × BF16 path
Runner->>Kernel: Invoke BF16 kernel
else BF16 × FP8 path
Runner->>Kernel: Invoke FP8 kernel with scales
end
Kernel-->>Runner: Compute complete
deactivate Runner
Runner-->>PyAPI: Output tensor
deactivate PyAPI
PyAPI-->>User: Result tensor
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (7)
csrc/fp8_blockscale_gemm_sm90_binding.cu (1)
125-134: Consider swallowinggetWorkspaceSizeBaseSMEM probe errors (SM90 config probing)
getWorkspaceSizedirectly callsgetWorkspaceSizeBaseon both runners and propagates anystd::runtime_error. In other GEMM templates in this codebase, SMEM-limited configurations are probed and allowed to throw, with those exceptions explicitly swallowed so that other valid configs can still be considered.It would be safer and more consistent to wrap each
getWorkspaceSizeBasecall in atry { ... } catch (const std::runtime_error&) { /* Swallow errors when SMEM exceeds maximum allowed */ }, using the same pattern as the other GEMM templates, so transient SMEM constraint failures don’t bubble up and break the Python path. Based on learningsflashinfer/jit/gemm/fp8_blockscale.py (1)
12-19: Avoid empty-string flag and prefer iterable unpacking fornvcc_flagsThe current expression:
nvcc_flags = sm90a_nvcc_flags + [ "-DCOMPILE_HOPPER_TMA_GEMMS", "-DENABLE_BF16", "-DENABLE_FP8", "-DENABLE_FP8_BLOCK_SCALE" if is_cuda_version_at_least("12.8") else "", ]injects an empty string into the flags list when CUDA < 12.8, and Ruff also flags this for style. You can avoid the empty flag and follow the unpacking pattern like:
nvcc_flags = [ *sm90a_nvcc_flags, "-DCOMPILE_HOPPER_TMA_GEMMS", "-DENABLE_BF16", "-DENABLE_FP8", *( ["-DENABLE_FP8_BLOCK_SCALE"] if is_cuda_version_at_least("12.8") else [] ), ]This keeps the behavior identical while producing a cleaner flag list.
flashinfer/gemm/gemm_base.py (1)
3335-3337: Minor doc mismatch: swapAB threshold and backend namingThe docstring notes:
- “SwapAB kernel is automatically used when M < 32 (threshold)”
- “The function uses DeepGEMM backend with JIT compilation”
In this wrapper, there is no explicit
M < 32branch; you always call the samerunner.gemm(...). If the M‑based swapAB decision is implemented entirely inside the underlying runner, it might be better to soften the wording (e.g. “the underlying kernel may select a swapAB variant for small M”) to avoid over‑specifying the threshold in Python.Also, this path is wired through the new
Fp8BlockScaleGemmRunnerJIT module (Tensorrt LLM / Cutlass blockscale implementation), not the existing DeepGEMM helpers in this file, so calling it a “DeepGEMM backend” may confuse users browsing the code.This is non‑blocking, but tightening the docstring will make it easier to maintain.
tests/gemm/test_fp8_blockscale_gemm.py (4)
163-168: Address the TODO comment about threshold verification.The 0.967 threshold for BF16+FP8 combination appears multiple times in the test suite (also line 215) and should be validated against real model performance.
Would you like me to open an issue to track threshold validation with real model benchmarks?
217-217: Consider removing debug print statement.Print statements in test code can clutter CI output. Consider using pytest's capture or removing if not needed for debugging.
- print(f"✓ Per-block weight scales: cosine similarity = {cos_sim:.4f}")
256-307: Excellent error handling coverage.All major validation scenarios are properly tested. However, consider using raw strings for regex patterns to avoid potential escape sequence issues.
Apply this diff to fix the regex pattern issues flagged by static analysis:
- with pytest.raises(ValueError, match="FP8.*or BF16"): + with pytest.raises(ValueError, match=r"FP8.*or BF16"): fp8_blockscale_gemm_swapab(input, weight)- with pytest.raises(ValueError, match="FP8 input.*BF16 weight.*not supported"): + with pytest.raises(ValueError, match=r"FP8 input.*BF16 weight.*not supported"): fp8_blockscale_gemm_swapab(input_fp8, weight, input_scale, None)
55-60: Consider extracting repeated capability checks to a fixture.The compute capability and SM90a checks are duplicated across all test functions. This could be simplified with a shared fixture, though the current approach is acceptable.
Example fixture at module level:
@pytest.fixture(scope="module") def check_sm90a_support(): """Check and skip if SM90a is not supported.""" compute_capability = get_compute_capability(torch.device("cuda")) if compute_capability[0] < 9: pytest.skip("FP8 block-scale GEMM requires SM90 (Hopper) or later") if not is_sm90a_supported(torch.device("cuda")): pytest.skip("FP8 block-scale GEMM requires SM90a (Hopper) support")Then use
check_sm90a_supportas a fixture parameter in each test.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
csrc/fp8_blockscale_gemm_sm90_binding.cu(1 hunks)flashinfer/gemm/__init__.py(2 hunks)flashinfer/gemm/gemm_base.py(2 hunks)flashinfer/jit/gemm/__init__.py(2 hunks)flashinfer/jit/gemm/fp8_blockscale.py(1 hunks)tests/gemm/test_fp8_blockscale_gemm.py(1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
csrc/fp8_blockscale_gemm_sm90_binding.cu
🧬 Code graph analysis (6)
flashinfer/gemm/__init__.py (1)
flashinfer/gemm/gemm_base.py (2)
get_fp8_blockscale_gemm_runner_sm90(3249-3251)fp8_blockscale_gemm_swapab(3254-3477)
flashinfer/jit/gemm/fp8_blockscale.py (2)
flashinfer/jit/core.py (2)
JitSpec(213-312)gen_jit_spec(315-381)flashinfer/jit/cpp_ext.py (1)
is_cuda_version_at_least(86-87)
flashinfer/jit/gemm/__init__.py (1)
flashinfer/jit/gemm/fp8_blockscale.py (1)
gen_fp8_blockscale_gemm_sm90_module(12-56)
flashinfer/gemm/gemm_base.py (3)
flashinfer/jit/gemm/fp8_blockscale.py (1)
gen_fp8_blockscale_gemm_sm90_module(12-56)flashinfer/jit/core.py (1)
build_and_load(300-312)csrc/fp8_blockscale_gemm_sm90_binding.cu (8)
init(152-155)init(152-152)input(87-123)input(87-88)input_is_fp8(75-85)input_is_fp8(75-76)workspace(136-142)workspace(136-136)
tests/gemm/test_fp8_blockscale_gemm.py (7)
flashinfer/gemm/gemm_base.py (1)
fp8_blockscale_gemm_swapab(3254-3477)flashinfer/testing/utils.py (1)
per_token_cast_to_fp8(39-46)flashinfer/utils.py (2)
get_compute_capability(253-256)is_sm90a_supported(526-528)flashinfer/jit/env.py (1)
has_flashinfer_jit_cache(27-36)flashinfer/jit/gemm/fp8_blockscale.py (1)
gen_fp8_blockscale_gemm_sm90_module(12-56)flashinfer/jit/core.py (1)
build_jit_specs(392-414)csrc/fp8_blockscale_gemm_sm90_binding.cu (2)
input(87-123)input(87-88)
csrc/fp8_blockscale_gemm_sm90_binding.cu (1)
csrc/tvm_ffi_utils.h (2)
encode_dlpack_dtype(29-31)get_stream(272-274)
🪛 Ruff (0.14.5)
flashinfer/jit/gemm/fp8_blockscale.py
1-1: The file is executable but no shebang is present
(EXE002)
14-19: Consider iterable unpacking instead of concatenation
Replace with iterable unpacking
(RUF005)
flashinfer/gemm/gemm_base.py
3337-3339: Avoid specifying long messages outside the exception class
(TRY003)
3343-3343: Avoid specifying long messages outside the exception class
(TRY003)
3345-3345: Avoid specifying long messages outside the exception class
(TRY003)
3351-3353: Avoid specifying long messages outside the exception class
(TRY003)
3358-3360: Avoid specifying long messages outside the exception class
(TRY003)
3370-3372: Avoid specifying long messages outside the exception class
(TRY003)
3377-3379: Avoid specifying long messages outside the exception class
(TRY003)
3382-3385: Avoid specifying long messages outside the exception class
(TRY003)
3387-3389: Avoid specifying long messages outside the exception class
(TRY003)
3391-3394: Avoid specifying long messages outside the exception class
(TRY003)
3397-3400: Avoid specifying long messages outside the exception class
(TRY003)
3402-3405: Avoid specifying long messages outside the exception class
(TRY003)
3409-3411: Avoid specifying long messages outside the exception class
(TRY003)
3418-3422: Avoid specifying long messages outside the exception class
(TRY003)
3424-3426: Avoid specifying long messages outside the exception class
(TRY003)
3429-3432: Avoid specifying long messages outside the exception class
(TRY003)
3434-3437: Avoid specifying long messages outside the exception class
(TRY003)
3442-3444: Avoid specifying long messages outside the exception class
(TRY003)
3446-3448: Avoid specifying long messages outside the exception class
(TRY003)
3450-3452: Avoid specifying long messages outside the exception class
(TRY003)
3458-3460: Avoid specifying long messages outside the exception class
(TRY003)
tests/gemm/test_fp8_blockscale_gemm.py
1-1: The file is executable but no shebang is present
(EXE002)
277-277: Pattern passed to match= contains metacharacters but is neither escaped nor raw
(RUF043)
306-306: Pattern passed to match= contains metacharacters but is neither escaped nor raw
(RUF043)
🔇 Additional comments (10)
flashinfer/jit/gemm/__init__.py (1)
30-45: Export ofgen_fp8_blockscale_gemm_sm90_modulelooks consistentImport and
__all__exposure match existing JIT GEMM module patterns; no issues here.flashinfer/gemm/__init__.py (1)
16-41: Public export of SM90 FP8 blockscale GEMM entry points is correctRe-exporting
get_fp8_blockscale_gemm_runner_sm90andfp8_blockscale_gemm_swapabvia this module is consistent with how other GEMM helpers are surfaced; no additional wiring concerns.flashinfer/gemm/gemm_base.py (1)
3248-3252: JIT runner construction for SM90 FP8 blockscale GEMM looks goodUsing
functools.cachearoundgen_fp8_blockscale_gemm_sm90_module().build_and_load().init()matches the TVM FFI factory pattern used elsewhere and avoids repeated JIT builds. No issues here.tests/gemm/test_fp8_blockscale_gemm.py (7)
1-29: LGTM - Clean imports and setup.The imports are well-organized and all appear necessary for the test suite.
32-41: LGTM - Efficient JIT warmup strategy.The conditional autouse with module scope ensures JIT compilation happens once only when needed, optimizing test execution time.
44-89: LGTM - Comprehensive basic functionality test.Well-parametrized test covering a good range of shapes with proper device capability checks and correctness verification.
199-201: LGTM - Good scale validation.Proper verification of scale shape and the reciprocal format (positive values).
220-253: LGTM - Good coverage of LLM inference shapes.The parametrized shapes represent typical decoder layer dimensions, providing realistic test coverage.
310-342: LGTM - Proper output buffer reuse test.The test correctly verifies both buffer identity (line 333) and correctness, ensuring the pre-allocated buffer path works as expected.
345-346: LGTM - Standard test entry point.
| kernels::CutlassFp8BlockScaleGemmRunnerInterface* selectRunner( | ||
| bool input_is_fp8, bool weight_is_fp8) { | ||
|
|
||
| if (!input_is_fp8 && !weight_is_fp8) { | ||
| return runner_bf16_bf16_.get(); | ||
| } else if (!input_is_fp8 && weight_is_fp8) { | ||
| return runner_bf16_fp8_.get(); | ||
| } else { | ||
| return nullptr; | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
C++ runner only supports BF16 input; any FP8 input combination is rejected
selectRunner deliberately only returns a runner for (!input_is_fp8 && !weight_is_fp8) and (!input_is_fp8 && weight_is_fp8). Any case with input_is_fp8 == true will yield nullptr and trip the "Unsupported dtype combination" check in runGemm. This is fine as long as the Python API never calls this path with FP8 inputs. The current Python wrapper (fp8_blockscale_gemm_swapab in gemm_base.py) allows FP8 inputs and advertises them in the docstring, which does not match what this binding actually supports.
I’d recommend either:
- Extending this binding and kernels to handle FP8 inputs, or
- Tightening the Python wrapper to reject FP8 inputs up front and updating the docstring accordingly, so users don’t hit a runtime
"Unsupported dtype combination"from this C++ layer.
| """ | ||
| Perform FP8 block-scaled GEMM with automatic swapAB optimization. | ||
| This function automatically selects between normal and swapAB kernel based on | ||
| the M dimension. For small M (< 32), it uses the swapAB kernel for | ||
| better performance. | ||
| Supported Dtype Combinations | ||
| ----------------------------- | ||
| - **BF16 + BF16 → BF16**: Both inputs BF16, internal quantization (no scales needed) | ||
| - **BF16 + FP8 → BF16**: BF16 input, FP8 weight | ||
| Parameters | ||
| ---------- | ||
| input : torch.Tensor | ||
| Input activation tensor of shape (M, K). | ||
| - BF16 (torch.bfloat16) with internal quantization | ||
| weight : torch.Tensor | ||
| Weight tensor of shape (N, K). Can be: | ||
| - FP8 (torch.float8_e4m3fn) with weight_scale required | ||
| - BF16 (torch.bfloat16) for internal quantization | ||
| input_scale : torch.Tensor, optional | ||
| weight_scale : torch.Tensor, optional | ||
| Scaling factors for weight. Required if weight is FP8. | ||
| out : torch.Tensor, optional | ||
| Output tensor of shape (M, N). If None, will be allocated. | ||
| out_dtype : torch.dtype, optional | ||
| Output data type. Default is torch.bfloat16. | ||
| Returns | ||
| ------- | ||
| torch.Tensor | ||
| Output tensor of shape (M, N) with dtype `out_dtype`. | ||
| Examples | ||
| -------- | ||
| >>> import torch | ||
| >>> from flashinfer.gemm import fp8_blockscale_gemm_swapab | ||
| >>> | ||
| >>> M, N, K = 16, 4096, 4096 | ||
| >>> device = "cuda" | ||
| >>> | ||
| >>> # BF16 inputs | ||
| >>> input_bf16 = torch.randn(M, K, device=device, dtype=torch.bfloat16) | ||
| >>> weight_bf16 = torch.randn(N, K, device=device, dtype=torch.bfloat16) | ||
| >>> output = fp8_blockscale_gemm_swapab(input_bf16, weight_bf16) | ||
| >>> print(output.shape) # torch.Size([16, 4096]) | ||
| >>> | ||
| >>> # Mixed: BF16 input + FP8 weight | ||
| >>> from flashinfer.testing.utils import per_token_cast_to_fp8 | ||
| >>> input_bf16 = torch.randn(M, K, device=device, dtype=torch.bfloat16) | ||
| >>> weight_bf16 = torch.randn(N, K, device=device, dtype=torch.bfloat16) | ||
| >>> weight_fp8, weight_scale = per_token_cast_to_fp8(weight_bf16) | ||
| >>> output = fp8_blockscale_gemm_swapab(input_bf16, weight_fp8, None, weight_scale) | ||
| >>> print(output.shape) # torch.Size([16, 4096]) | ||
| >>> | ||
| >>> # FP8 weight with 128x128 block scales | ||
| >>> from flashinfer.testing.utils import per_block_cast_to_fp8 | ||
| >>> weight_bf16 = torch.randn(N, K, device=device, dtype=torch.bfloat16) | ||
| >>> weight_fp8, weight_scale = per_block_cast_to_fp8(weight_bf16) | ||
| >>> # weight_scale has shape (N // 128, K // 128) | ||
| >>> input_bf16 = torch.randn(M, K, device=device, dtype=torch.bfloat16) | ||
| >>> output = fp8_blockscale_gemm_swapab(input_bf16, weight_fp8, None, weight_scale) | ||
| >>> print(output.shape) # torch.Size([16, 4096]) | ||
| Notes | ||
| ----- | ||
| - This function requires NVIDIA Hopper (SM90) architecture and CUDA 12.8+ | ||
| - SwapAB kernel is automatically used when M < 32 (threshold) | ||
| - For FP8 inputs, scaling factors must be provided | ||
| - For BF16 inputs, quantization and scaling happen internally | ||
| - Weight scales support two granularities: | ||
| * Per-token (1x128 blocks): (N, K//128) | ||
| * Per-block (128x128 blocks): (N//128, K//128) | ||
| - Input scales only support per-token format: (M, K//128) | ||
| - The function uses DeepGEMM backend with JIT compilation | ||
| """ | ||
| # Validate architecture support | ||
| if not _match_sm_version(input.device, ["90", "90a"]): | ||
| raise ValueError( | ||
| "fp8_blockscale_gemm_swapab is only supported on SM90 (Hopper) architecture." | ||
| ) | ||
|
|
||
| # Validate tensor dimensions | ||
| if input.ndim != 2: | ||
| raise ValueError(f"Input must be 2D (M, K), got shape {input.shape}") | ||
| if weight.ndim != 2: | ||
| raise ValueError(f"Weight must be 2D (N, K), got shape {weight.shape}") | ||
|
|
||
| M, K = input.shape | ||
| N, K_weight = weight.shape | ||
|
|
||
| if K != K_weight: | ||
| raise ValueError( | ||
| f"K dimension mismatch: input has K={K}, weight has K={K_weight}" | ||
| ) | ||
|
|
||
| # Validate K is divisible by block size (128) | ||
| BLOCK_SIZE = 128 | ||
| if K % BLOCK_SIZE != 0: | ||
| raise ValueError( | ||
| f"K dimension must be divisible by block size ({BLOCK_SIZE}), got K={K}" | ||
| ) | ||
|
|
||
| # Validate dtype combinations | ||
| input_is_fp8 = input.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] | ||
| weight_is_fp8 = weight.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] | ||
| input_is_bf16 = input.dtype == torch.bfloat16 | ||
| weight_is_bf16 = weight.dtype == torch.bfloat16 | ||
|
|
||
| # Explicitly reject FP8 input + BF16 weight (missing kernel implementation) | ||
| if input_is_fp8 and weight_is_bf16: | ||
| raise ValueError( | ||
| "FP8 input + BF16 weight is not supported (missing kernel implementation). " | ||
| ) | ||
|
|
||
| # Validate scale requirements for FP8 inputs | ||
| if input_is_fp8: | ||
| if input_scale is None: | ||
| raise ValueError( | ||
| "input_scale is required when input is FP8. " | ||
| ) | ||
| expected_scale_shape = (M, K // BLOCK_SIZE) | ||
| if input_scale.shape != expected_scale_shape: | ||
| raise ValueError( | ||
| f"input_scale shape mismatch. Expected {expected_scale_shape}, " | ||
| f"got {input_scale.shape}" | ||
| ) | ||
| if input_scale.dtype != torch.float32: | ||
| raise ValueError( | ||
| f"input_scale must be float32, got {input_scale.dtype}" | ||
| ) | ||
| if input_scale.device != input.device: | ||
| raise ValueError( | ||
| f"input_scale device mismatch. Expected {input.device}, " | ||
| f"got {input_scale.device}" | ||
| ) | ||
| else: | ||
| if not input_is_bf16: | ||
| raise ValueError( | ||
| f"Input must be either FP8 (torch.float8_e4m3fn) or BF16 (torch.bfloat16), " | ||
| f"got {input.dtype}" | ||
| ) | ||
| if input_scale is not None: | ||
| raise ValueError( | ||
| "input_scale should not be provided for BF16 inputs. " | ||
| "Use FP8 inputs if you want to provide external scales." | ||
| ) | ||
|
|
||
| if weight_is_fp8: | ||
| if weight_scale is None: | ||
| raise ValueError( | ||
| "weight_scale is required when weight is FP8. " | ||
| ) | ||
| expected_per_token_shape = (N, K // BLOCK_SIZE) | ||
| expected_per_block_shape = ((N + BLOCK_SIZE - 1) // BLOCK_SIZE, K // BLOCK_SIZE) | ||
| is_per_token = weight_scale.shape == expected_per_token_shape | ||
| is_per_block = weight_scale.shape == expected_per_block_shape | ||
|
|
||
| if not (is_per_token or is_per_block): | ||
| raise ValueError( | ||
| f"weight_scale shape mismatch. Expected either {expected_per_token_shape} " | ||
| f"(per-token, 1x128 blocks) or {expected_per_block_shape} " | ||
| f"(per-block, 128x128 blocks), got {weight_scale.shape}" | ||
| ) | ||
| if weight_scale.dtype != torch.float32: | ||
| raise ValueError( | ||
| f"weight_scale must be float32, got {weight_scale.dtype}" | ||
| ) | ||
| else: | ||
| if not weight_is_bf16: | ||
| raise ValueError( | ||
| f"Weight must be either FP8 (torch.float8_e4m3fn) or BF16 (torch.bfloat16), " | ||
| f"got {weight.dtype}" | ||
| ) | ||
| if weight_scale is not None: | ||
| raise ValueError( | ||
| "weight_scale should not be provided for BF16 weights. " | ||
| "Use FP8 weights if you want to provide external scales." | ||
| ) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dtype support mismatch: FP8 inputs and float8_e5m2 weights are not actually supported by the C++ runner
Within fp8_blockscale_gemm_swapab, the dtype logic is:
input_is_fp8 = input.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]
weight_is_fp8 = weight.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]
input_is_bf16 = input.dtype == torch.bfloat16
weight_is_bf16 = weight.dtype == torch.bfloat16and you only explicitly reject FP8 input + BF16 weight. The docstring also describes FP8 inputs as supported with external scales.
However, the C++ binding (Fp8BlockScaleGemmRunner) only instantiates kernels for:
- BF16 input + BF16 weight, and
- BF16 input + FP8 e4m3 weight,
and its dispatch:
bool input_is_fp8 = is_fp8_e4m3fn(input.dtype());
bool weight_is_fp8 = is_fp8_e4m3fn(weight.dtype());
auto* runner = selectRunner(input_is_fp8, weight_is_fp8);only returns a runner when input_is_fp8 == false. Any FP8 input combination results in runner == nullptr and a runtime "Unsupported dtype combination" error, so FP8 inputs are not actually usable here despite being allowed and documented in Python.
More critically, for weight.dtype == torch.float8_e5m2:
- Python treats this as
weight_is_fp8 == Trueand enforces FP8-scale semantics, - but the C++ side’s
is_fp8_e4m3fnreturnsfalse, so the BF16–BF16 runner is used, which interprets theweightbuffer as__nv_bfloat16instead of 1‑byte FP8. That is undefined behavior and will produce corrupt results (and potentially out-of-bounds accesses).
To make this safe and predictable, I’d strongly recommend:
-
Restrict supported dtypes to what the C++ runner actually implements:
- Require
input.dtype == torch.bfloat16(no FP8 inputs for now). - Require
weight.dtypeto be eithertorch.bfloat16ortorch.float8_e4m3fnspecifically. - Explicitly reject
torch.float8_e5m2for both input and weight with a clear error message.
- Require
-
Align the docstring with the real support:
- Remove or qualify references to FP8 inputs until a corresponding C++ path exists.
- Clarify that only FP8 e4m3fn weights are supported for now.
-
Optionally, if FP8 inputs and e5m2 support are desired, extend the C++ binding to add the necessary runner instantiations and dispatch branches instead of allowing them through the Python-side checks.
Because the current behavior can silently misinterpret float8_e5m2 weights, this should be addressed before relying on this API in user code.
🧰 Tools
🪛 Ruff (0.14.5)
3337-3339: Avoid specifying long messages outside the exception class
(TRY003)
3343-3343: Avoid specifying long messages outside the exception class
(TRY003)
3345-3345: Avoid specifying long messages outside the exception class
(TRY003)
3351-3353: Avoid specifying long messages outside the exception class
(TRY003)
3358-3360: Avoid specifying long messages outside the exception class
(TRY003)
3370-3372: Avoid specifying long messages outside the exception class
(TRY003)
3377-3379: Avoid specifying long messages outside the exception class
(TRY003)
3382-3385: Avoid specifying long messages outside the exception class
(TRY003)
3387-3389: Avoid specifying long messages outside the exception class
(TRY003)
3391-3394: Avoid specifying long messages outside the exception class
(TRY003)
3397-3400: Avoid specifying long messages outside the exception class
(TRY003)
3402-3405: Avoid specifying long messages outside the exception class
(TRY003)
3409-3411: Avoid specifying long messages outside the exception class
(TRY003)
3418-3422: Avoid specifying long messages outside the exception class
(TRY003)
3424-3426: Avoid specifying long messages outside the exception class
(TRY003)
3429-3432: Avoid specifying long messages outside the exception class
(TRY003)
3434-3437: Avoid specifying long messages outside the exception class
(TRY003)
| if weight_is_fp8: | ||
| if weight_scale is None: | ||
| raise ValueError( | ||
| "weight_scale is required when weight is FP8. " | ||
| ) | ||
| expected_per_token_shape = (N, K // BLOCK_SIZE) | ||
| expected_per_block_shape = ((N + BLOCK_SIZE - 1) // BLOCK_SIZE, K // BLOCK_SIZE) | ||
| is_per_token = weight_scale.shape == expected_per_token_shape | ||
| is_per_block = weight_scale.shape == expected_per_block_shape | ||
|
|
||
| if not (is_per_token or is_per_block): | ||
| raise ValueError( | ||
| f"weight_scale shape mismatch. Expected either {expected_per_token_shape} " | ||
| f"(per-token, 1x128 blocks) or {expected_per_block_shape} " | ||
| f"(per-block, 128x128 blocks), got {weight_scale.shape}" | ||
| ) | ||
| if weight_scale.dtype != torch.float32: | ||
| raise ValueError( | ||
| f"weight_scale must be float32, got {weight_scale.dtype}" | ||
| ) | ||
| else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing device check for weight_scale can hand a host pointer to a CUDA kernel
In the FP8‑weight branch you validate weight_scale’s shape and dtype, but not its device:
if weight_is_fp8:
...
if weight_scale.dtype != torch.float32:
raise ValueError(...)There’s no check that weight_scale.device == weight.device == input.device. If a caller accidentally passes a CPU weight_scale while input and weight are CUDA tensors, the C++ runner will still receive scales_b_ptr and try to use it on the GPU, leading to undefined behavior.
Recommend adding a device check symmetric to the one you already have for input_scale:
if weight_scale.device != weight.device:
raise ValueError(
f"weight_scale device mismatch. Expected {weight.device}, "
f"got {weight_scale.device}"
)and optionally also enforce weight.device == input.device early for clarity.
🧰 Tools
🪛 Ruff (0.14.5)
3409-3411: Avoid specifying long messages outside the exception class
(TRY003)
3418-3422: Avoid specifying long messages outside the exception class
(TRY003)
3424-3426: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In flashinfer/gemm/gemm_base.py around lines 3407 to 3427, the FP8-weight branch
validates weight_scale shape and dtype but omits a device check; add a guard
that weight_scale.device matches weight.device (and raise a ValueError with a
clear message if it does not), and also ensure earlier that weight.device ==
input.device (or check weight_scale.device == input.device) so no host (CPU)
tensor is passed when weight/input are on CUDA.
|
/bot run |
| return tvm::ffi::Module(ptr); | ||
| } | ||
|
|
||
| TVM_FFI_DLL_EXPORT_TYPED_FUNC(init, init); No newline at end of file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please fix the pre-commit issues, by running pre-commit run --all-files
|
[SUCCESS] Pipeline #38984962: 15/18 passed |
📌 Description
In flashinfer we already had fp8_gemm_kernel_swapAB kernel for optimizing Mixture of Experts (MOE) GEMM and Dense GEMM operations reference 1, reference 2, and reference 3.
This kernel improves performance in small batch scenarios by swapping the input order in matrix multiplication.
These kernels are currently used for:
MoE operations (exposed via fused_moe module)
Available in the codebase for Dense GEMM but not exposed for linear/dense layers
This PR aims to
csrc/fp8_blockscale_gemm_sm90_binding.cuflashinfer/jit/gemm/flashinfer/gemm/gemm_base.pyflashinfer/tests/gemm/test_fp8_blockscale_gemm.pyTODO
🔍 Related Issues
vLLM 28427
vLLM 28316
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Release Notes
✏️ Tip: You can customize this high-level summary in your review settings.