-
Notifications
You must be signed in to change notification settings - Fork 539
silu_and_mul nvfp4 quanization fusion rework #1927
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…atched quantize (flashinfer-ai#1835)" This reverts commit f765a2a. Port kernels from sglang
912e796
to
b979ee2
Compare
/bot run |
/gemini review |
[FAILED] Pipeline #36839981: 0/17 passed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request reworks the silu_and_mul
nvfp4 quantization fusion by reverting previous changes and porting kernels from SGLANG. The changes are extensive, affecting CUDA kernels, C++ wrappers, Python bindings, and documentation. My review focuses on improving code quality, maintainability, and correctness. I've identified several areas for improvement, including code duplication in new Python functions that could be refactored, a new CUDA header file that appears to be unused and contains duplicated code, and minor issues in CUDA kernels such as commented-out code, unused parameters, and hardcoded values. Additionally, the documentation for FP4 quantization functions seems to be out of sync with the code, and a test case for the new functionality could be improved to better cover masking logic.
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughThis PR removes Changes
Sequence Diagram(s)sequenceDiagram
participant Py as Python caller
participant PyAPI as flashinfer API
participant Cpp as C++ binding
participant NVInv as NV launcher
participant Kernel as CUDA kernel
participant GPU as Device memory
Note left of Py: high-level call
Py->>PyAPI: scaled_fp4_grouped_quantize(...) / silu_and_mul_scaled_nvfp4_experts_quantize(...)
PyAPI->>Cpp: forward tensors & params
Cpp->>NVInv: prepare opaque pointers, m_topk/k/n_experts, stream
NVInv->>Kernel: launch per-expert kernel (cvt_fp16_to_fp4_expert / quantize)
Kernel->>GPU: read input, compute SF, write SFout, write quantized out
Kernel-->>NVInv: kernel completes
NVInv-->>Cpp: return
Cpp-->>PyAPI: return outputs
PyAPI-->>Py: (output, output_scale)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (2)
🧰 Additional context used🧬 Code graph analysis (2)csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp (3)
csrc/nv_internal/cpp/kernels/quantization.cu (1)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp (2)
133-140
: Enforce unsupported combo explicitly.The batched path enforces sfVecSize==16 but still accepts sfUseUE8M0. Add a guard to reject UE8M0 here to avoid confusing, unsupported inputs.
Apply:CHECK_INPUT_TYPE(globalScale, fp32_dtype); TVM_FFI_ICHECK_EQ(sfVecSize, 16) << "sfVecSize can only be 16"; + TVM_FFI_ICHECK(!sfUseUE8M0) << "UE8M0 (sfVecSize=32) is not supported in batched path.";
158-163
: Critical: stream passed into enable_pdl; kernels run on default stream.get_stream(self->device) is currently bound to enable_pdl (bool), and stream defaults to 0. This flips enable_pdl and launches on default stream, breaking CUDA Graph/PDL semantics and likely causing the CI failures.
Fix by passing both enable_pdl and stream (batched path can keep PDL disabled):
- tensorrt_llm::kernels::invokeFP4Quantization<T, SF_VEC_SIZE>( \ - b, m, k, reinterpret_cast<T*>(self->data), static_cast<float*>(globalScale->data), \ - reinterpret_cast<int64_t*>(valueE2M1->data), reinterpret_cast<int32_t*>(scaleFP8SF->data), \ - sfUseUE8M0, layout, mMultiProcessorCount, get_stream(self->device)); + tensorrt_llm::kernels::invokeFP4Quantization<T, SF_VEC_SIZE>( \ + b, m, k, reinterpret_cast<T*>(self->data), static_cast<float*>(globalScale->data), \ + reinterpret_cast<int64_t*>(valueE2M1->data), reinterpret_cast<int32_t*>(scaleFP8SF->data), \ + sfUseUE8M0, layout, mMultiProcessorCount, /*enable_pdl=*/false, get_stream(self->device));I can scan the repo and patch other sites similarly.
flashinfer/fp4_quantization.py (1)
371-438
: Fix dtype reinterpretation for scale factors insilu_and_mul_scaled_nvfp4_experts_quantize_sm100
and its fake op.Lines 433 and 460 incorrectly view
int32
tensors directly totorch.float8_e4m3fn
. Replace both occurrences with proper byte-level reinterpretation:- output_scales = output_scales.view(torch.float8_e4m3fn).view( - l, padded_m // 128, padded_k // 4, 32, 4, 4 - ) + output_scales = output_scales.view(torch.uint8).view( + l, padded_m // 128, padded_k // 4, 32, 4, 4 + ).view(torch.float8_e4m3fn)Apply at:
- Line 433 in
silu_and_mul_scaled_nvfp4_experts_quantize_sm100
- Line 460 in
_fake_silu_and_mul_scaled_nvfp4_experts_quantize_sm100
♻️ Duplicate comments (1)
tests/utils/test_fp4_quantize.py (1)
362-396
: Good: randomized mask exercises masking logic.Using randint(1, m+1) ensures some rows are masked; this improves coverage.
🧹 Nitpick comments (9)
csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.h (1)
36-38
: API clarity: function name vs boolean flag.The name includes “silu_and_mul”, yet the API exposes use_silu_and_mul. Consider either removing the flag (always apply silu*mul) or renaming to reflect the optionality to avoid confusion at call sites.
csrc/nv_internal/tensorrt_llm/kernels/quantization.h (1)
65-69
: Guard against size narrowing (m_topk/k/n_experts).The NV path takes int; typical shapes come as int64_t. Add explicit range checks (<= INT_MAX) at FFI boundaries to prevent UB on large tensors.
flashinfer/activation.py (1)
193-218
: Fix docstring and argument docs to match actual behavior (expects 2K input, mask before scale).Current doc says input is [B, M, K] and lists a_global_sf before mask; code consumes [B, M, 2K] when fusing SiLU and passes (a, mask, a_global_sf). Update docs to avoid misuse.
def silu_and_mul_scaled_nvfp4_experts_quantize( a, mask, a_global_sf, ): - """ - Silu and multiply and quantize batched input tensor to NVFP4 format with mask. + """ + Fused SiLU-and-Mul, then NVFP4 quantize with per-expert mask. Parameters: - a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16. - a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32. - mask (torch.Tensor): Mask tensor to apply before quantization. - sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. + a (torch.Tensor): Input tensor of shape [B, M, 2*K] (SiLU-and-Mul fused) with dtype fp16/bf16. + mask (torch.Tensor): Per-expert valid rows, shape [B], dtype int32. + a_global_sf (torch.Tensor): Per-expert global scale, shape [B], dtype float32. Returns: Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2 - Scale factors tensor with shape determined by layout and sf_vec_size """flashinfer/fp4_quantization.py (2)
439-465
: Fake op returns int32 scales; should mirror real path’s float8 reshaped layout.Keep dtype/shape parity to prevent downstream assumptions from breaking.
- output_scales = torch.empty( + output_scales = torch.empty( l, padded_m, padded_k_int32, device=device, dtype=torch.int32 ) - - output_scales = output_scales.view(torch.float8_e4m3fn).view( - l, padded_m // 128, padded_k // 4, 32, 4, 4 - ) - output_scales = output_scales.permute(3, 4, 1, 5, 2, 0) - return (output, output_scales) + scales_f8 = output_scales.view(torch.uint8).view( + l, padded_m // 128, padded_k // 4, 32, 4, 4 + ).view(torch.float8_e4m3fn) + scales_f8 = scales_f8.permute(3, 4, 1, 5, 2, 0) + return (output, scales_f8)
466-527
: Scaled grouped quantization reuses silu kernel with use_silu_and_mul=False; OK. Two tweaks:
- Name locals descriptively (replace ambiguous l) to improve readability.
- Apply the same dtype reinterpret for scales as above.
- device = input_tensor.device - l, m, k = input_tensor.shape + device = input_tensor.device + num_groups, m, k = input_tensor.shape @@ - output = torch.empty(l, m, k // 2, device=device, dtype=torch.uint8) + output = torch.empty(num_groups, m, k // 2, device=device, dtype=torch.uint8) @@ - module.silu_and_mul_scaled_nvfp4_experts_quantize( - output.view(l * m, k // 2), - output_scales.view(l * padded_m, padded_k_int32), - input_tensor.view(l * m, k), + module.silu_and_mul_scaled_nvfp4_experts_quantize( + output.view(num_groups * m, k // 2), + output_scales.view(num_groups * padded_m, padded_k_int32), + input_tensor.view(num_groups * m, k), input_global_scale, mask, False, ) @@ - output = output.permute(1, 2, 0) + output = output.permute(1, 2, 0) @@ - output_scales = output_scales.view(torch.float8_e4m3fn).view( - l, padded_m // 128, padded_k // 4, 32, 4, 4 - ) - output_scales = output_scales.permute(3, 4, 1, 5, 2, 0) + output_scales = output_scales.view(torch.uint8).view( + num_groups, padded_m // 128, padded_k // 4, 32, 4, 4 + ).view(torch.float8_e4m3fn) + output_scales = output_scales.permute(3, 4, 1, 5, 2, 0)csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh (3)
741-838
: Kernel template compiles, but a few correctness and maintainability nits.
- Add restrict to in/out pointers to help codegen.
- Rename local alias to avoid shadowing the template (
PackedVecT
vsPackedVec
).- Comment “numColThreadsForSf” derivation and enforce it’s multiple of warp size when using shuffles.
-template <BlockScaleQuantizationType quantization_type, class Type, int SF_VEC_SIZE, bool UE8M0_SF> +template <BlockScaleQuantizationType quantization_type, class Type, int SF_VEC_SIZE, bool UE8M0_SF> __global__ void @@ - using PackedVec = PackedVec<Type>; + using PackedVecT = ::tensorrt_llm::kernels::PackedVec<Type>; @@ - PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset]; + PackedVecT in_vec = reinterpret_cast<PackedVecT const*>(in)[inOffset];Verify adding
__restrict__
doesn’t regress occupancy; check SASS for reduced ld/st hazards.
840-884
: cvt_quant_to_fp4_get_sf_out_offset: constant uses CVT_FP4_SF_VEC_SIZE; good. Minor: document padding math.Add a brief comment explaining why numKTiles rounds by (SF_VEC_SIZE*4).
886-910
: silu_and_mul: OK. Consider makingsilu
constexpr device and using fast_math for expf.Micro-optimization; optional.
csrc/nv_internal/cpp/kernels/quantization.cu (1)
331-350
: Explicit instantiations look correct. Consider adding BF16/FP8 guards to match ENABLE_ macros.*Keeps ODR and compile time in check.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (10)
csrc/nv_internal/cpp/kernels/quantization.cu
(5 hunks)csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh
(3 hunks)csrc/nv_internal/tensorrt_llm/kernels/quantization.h
(1 hunks)csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp
(4 hunks)csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.h
(1 hunks)docs/api/fp4_quantization.rst
(1 hunks)flashinfer/__init__.py
(2 hunks)flashinfer/activation.py
(1 hunks)flashinfer/fp4_quantization.py
(4 hunks)tests/utils/test_fp4_quantize.py
(5 hunks)
🧰 Additional context used
🧬 Code graph analysis (8)
flashinfer/activation.py (3)
csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp (2)
silu_and_mul_scaled_nvfp4_experts_quantize
(188-244)silu_and_mul_scaled_nvfp4_experts_quantize
(188-190)flashinfer/utils.py (1)
get_compute_capability
(245-248)flashinfer/fp4_quantization.py (1)
silu_and_mul_scaled_nvfp4_experts_quantize_sm100
(375-437)
csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.h (1)
csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp (4)
fp4_batched_quantize
(133-186)fp4_batched_quantize
(133-134)silu_and_mul_scaled_nvfp4_experts_quantize
(188-244)silu_and_mul_scaled_nvfp4_experts_quantize
(188-190)
flashinfer/fp4_quantization.py (3)
flashinfer/utils.py (5)
register_custom_op
(266-275)register_custom_op
(285-304)register_fake_op
(277-281)register_fake_op
(306-311)get_compute_capability
(245-248)csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp (2)
silu_and_mul_scaled_nvfp4_experts_quantize
(188-244)silu_and_mul_scaled_nvfp4_experts_quantize
(188-190)flashinfer/activation.py (1)
silu_and_mul_scaled_nvfp4_experts_quantize
(193-219)
flashinfer/__init__.py (3)
csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp (2)
silu_and_mul_scaled_nvfp4_experts_quantize
(188-244)silu_and_mul_scaled_nvfp4_experts_quantize
(188-190)flashinfer/activation.py (1)
silu_and_mul_scaled_nvfp4_experts_quantize
(193-219)flashinfer/fp4_quantization.py (1)
scaled_fp4_grouped_quantize
(961-986)
csrc/nv_internal/tensorrt_llm/kernels/quantization.h (1)
csrc/nv_internal/cpp/kernels/quantization.cu (6)
void
(225-249)void
(251-269)invokeSiluAndMulNVFP4Quantization
(298-329)invokeSiluAndMulNVFP4Quantization
(298-300)invokeSiluAndMulNVFP4Quantization
(346-349)invokeSiluAndMulNVFP4Quantization
(365-367)
csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp (4)
csrc/tvm_ffi_utils.h (1)
get_stream
(272-274)csrc/nv_internal/cpp/kernels/quantization.cu (4)
invokeSiluAndMulNVFP4Quantization
(298-329)invokeSiluAndMulNVFP4Quantization
(298-300)invokeSiluAndMulNVFP4Quantization
(346-349)invokeSiluAndMulNVFP4Quantization
(365-367)flashinfer/fp4_quantization.py (1)
fp4_quantize
(624-686)flashinfer/activation.py (1)
silu_and_mul_scaled_nvfp4_experts_quantize
(193-219)
tests/utils/test_fp4_quantize.py (4)
flashinfer/fp4_quantization.py (3)
scaled_fp4_grouped_quantize
(961-986)nvfp4_batched_quantize
(932-958)fp4_quantize
(624-686)csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp (4)
silu_and_mul_scaled_nvfp4_experts_quantize
(188-244)silu_and_mul_scaled_nvfp4_experts_quantize
(188-190)fp4_quantize
(35-122)fp4_quantize
(35-37)flashinfer/activation.py (2)
silu_and_mul_scaled_nvfp4_experts_quantize
(193-219)silu_and_mul
(69-110)flashinfer/utils.py (1)
is_sm100a_supported
(483-485)
csrc/nv_internal/cpp/kernels/quantization.cu (2)
csrc/trtllm_gemm_runner.cu (8)
m
(111-126)m
(111-111)m
(128-179)m
(128-130)m
(181-236)m
(181-181)m
(238-250)m
(238-238)include/flashinfer/trtllm/common.h (1)
device
(83-90)
🪛 Ruff (0.14.0)
flashinfer/fp4_quantization.py
361-361: Unused function argument: global_scale
(ARG001)
363-363: Unused function argument: sf_use_ue8m0
(ARG001)
410-410: Ambiguous variable name: l
(E741)
442-442: Unused function argument: mask
(ARG001)
443-443: Unused function argument: global_scale
(ARG001)
446-446: Ambiguous variable name: l
(E741)
495-495: Ambiguous variable name: l
(E741)
531-531: Unused function argument: input_global_scale
(ARG001)
532-532: Unused function argument: mask
(ARG001)
535-535: Ambiguous variable name: l
(E741)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (13)
csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.h (1)
33-34
: Signature update looks consistent with cpp.Header matches the new fp4_batched_quantize(Tensor, ...) definition and mask removal. No issues spotted here.
csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp (1)
75-75
: Correct param ordering (enable_pdl, stream).fp4_quantize launches with enable_pdl followed by stream; this is correct post‑signature change.
flashinfer/__init__.py (2)
28-29
: Export rename aligns with backend.silu_and_mul_scaled_nvfp4_experts_quantize re-export looks correct and consistent with activation.py.
74-75
: Good to expose grouped quantize publicly.scaled_fp4_grouped_quantize export is consistent with fp4_quantization.py.
flashinfer/activation.py (1)
68-110
: silu_and_mul API looks fine; PDL defaulting and shape checks are consistent.No action needed.
flashinfer/fp4_quantization.py (2)
961-986
: Public wrapper: doc and arg order match grouped path; good. Minor: keep wrapper names consistent across activation and quant modules.Consider aligning function names and docstrings with activation.py to reduce API confusion.
Ensure init.py re-exports this symbol and old alias (if any) to avoid downstream breakage.
623-686
: fp4_quantize path looks good; watch for is_column_major reshape when is_sf_swizzled_layout=True.No blocking issues here.
csrc/nv_internal/cpp/kernels/quantization.cu (3)
105-106
: SFOuput type cast: consistent with kernel signature. LGTM.Casts to uint32_t* match cvt_quant_get_sf_out_offset expectation.
168-223
: invokeFP4Quantization: good use of cudaLaunchKernelEx + PDL.No issues.
271-295
: Correct lines 237-238 and 256-257: std::optional in device kernels will fail with NVCC 12.8.The review correctly identifies a critical issue, but targets the wrong lines. The host wrapper functions at lines 271-295 are fine. The actual problem is in the
__global__
kernels they call:
block_scale_interleave_kernel
(lines 225–249): Lines 237–238 usestd::optional<int>
block_scale_interleave_reverse_kernel
(lines 251–269): Lines 256–257 usestd::optional<int>
NVCC 12.8 does not support
std::optional
in__device__
code (usecuda::std::optional
instead). Replace with plainint
and sentinel values (e.g.,-1
for absent) as the review suggests.Likely an incorrect or invalid review comment.
csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh (1)
911-1004
: Mask early-exit uses break; prefer continue for clarity and to avoid skipping subsequent tiles.Using break exits the per-thread loop entirely. If stride mapping changes, it could skip valid work. Safer to
continue
whenrowIdx_in_expert >= mask[expert_idx]
.- if (use_mask && rowIdx_in_expert >= mask[expert_idx]) { - break; - } + if (use_mask && rowIdx_in_expert >= mask[expert_idx]) { + continue; + }Also, rename alias to avoid template shadowing as above.
Likely an incorrect or invalid review comment.
tests/utils/test_fp4_quantize.py (2)
413-442
: Incorrect Technical Claim About PyTorch APIThe review comment incorrectly states that "Tensor.view takes sizes, not dtype." PyTorch's view() method does support dtype arguments for reinterpretation, as shown in official documentation with examples like
x.view(torch.int32)
,x.view(torch.cfloat)
, andx.view(torch.uint8)
. The pattern.view(torch.float8_e4m3fn)
is valid PyTorch syntax.The one actionable improvement in the review is valid: the error message in the assertion should be corrected. It currently says "Expected uint8" but should say "Expected float8_e4m3fn" to match the assertion's actual check.
Likely an incorrect or invalid review comment.
337-359
: Proposed fix has logical error: unswizzles only one scale, but not the other.The review correctly identifies that the test needs to handle scale buffer padding/swizzle differences for non-128-multiple shapes like (1, 120, 64). However, the proposed diff unswizzles
out_scale[i]
but compares it againstsingle_scale
, which is still swizzled (returned fromfp4_quantize
withis_sf_swizzled_layout=True
).The correct pattern already exists in the same file at lines 403-405 and 454-456: unswizzle both scales before comparing:
scale_ref = unswizzle_sf(single_scale.view(torch.float8_e4m3fn), m, n) scale_ans = unswizzle_sf(out_scale[i], m, n) torch.testing.assert_close(scale_ref, scale_ans, rtol=1e-5, atol=1e-5)The test should follow this established pattern rather than the one-sided unswizzle proposed in the review.
Likely an incorrect or invalid review comment.
/bot run |
[FAILED] Pipeline #36904035: 1/17 passed |
Signed-off-by: Shu Wang. <[email protected]>
6b55e4d
to
d972fbe
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/fp4_quantization.py (1)
211-223
: Fake fp4_quantize_sm100 returns wrong dtypes and ignores swizzled layout—fix required.Verification confirms the issue: the function returns
torch.int64
andtorch.int32
instead oftorch.uint8
, and completely ignores theis_sf_swizzled_layout
parameter. The scale factor size computation doesn't account for swizzled layout, which will cause mismatches with the real operation. The suggested fix should be applied as specified.
🧹 Nitpick comments (4)
csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp (2)
133-166
: Add missing input/output validations in fp4_batched_quantize.
- globalScale isn’t validated for device residency; add CHECK_CUDA.
- Enforce output buffer dtypes to match expectations (packed E2M1 as uint8 bytes; SF as uint8 backing even if kernel writes int32 words).
- Optionally assert output shapes to catch host-side allocation mistakes early.
void fp4_batched_quantize(Tensor self, Tensor globalScale, Tensor valueE2M1, Tensor scaleFP8SF, int64_t sfVecSize, bool sfUseUE8M0) { CHECK_CUDA(self); CHECK_CONTIGUOUS(self); + CHECK_CUDA(globalScale); + auto uint8_dtype = DLDataType{kDLUInt, 8, 1}; auto fp32_dtype = DLDataType{kDLFloat, 32, 1}; CHECK_INPUT_TYPE(globalScale, fp32_dtype); TVM_FFI_ICHECK_EQ(sfVecSize, 16) << "sfVecSize can only be 16"; @@ int64_t b = inputShape[0]; int64_t m = inputShape[1]; int64_t k = inputShape[2]; TVM_FFI_ICHECK_EQ(k % sfVecSize, 0); std::vector<int64_t> outputShape(inputShape.begin(), inputShape.end()); outputShape[rank - 1] = k / 2; + // Validate host-allocated outputs (dtype and minimal shape) + CHECK_INPUT_TYPE(valueE2M1, uint8_dtype); + CHECK_INPUT_TYPE(scaleFP8SF, uint8_dtype); + TVM_FFI_ICHECK_EQ(valueE2M1.ndim(), 3); + TVM_FFI_ICHECK_EQ(valueE2M1.shape()[0], b); + TVM_FFI_ICHECK_EQ(valueE2M1.shape()[1], m); + TVM_FFI_ICHECK_EQ(valueE2M1.shape()[2], k / 2);Also applies to: 158-163
231-241
: Range/robustness checks before kernel launch.
- Add n_experts > 0 to avoid division by zero in grid rounding.
- Optional: ensure m_topk, k fit into int (kernel takes int), and k >= CVT_ELTS_PER_THREAD to avoid block.x=0.
auto in_dtype = input->dtype; const cudaStream_t stream = get_stream(input->device); + TVM_FFI_ICHECK_GT(n_experts, 0) << "n_experts must be > 0"; + TVM_FFI_ICHECK_GT(m_topk, 0) << "m_topk must be > 0"; + TVM_FFI_ICHECK_GT(k, 0) << "k must be > 0";flashinfer/fp4_quantization.py (2)
374-441
: Validate mask/global_scale early; avoid ambiguous variable name.
- Add asserts for mask dtype (int32) and global_scale shape/dtype.
- Consider renaming l → num_groups to avoid E741.
def silu_and_mul_scaled_nvfp4_experts_quantize_sm100( input: torch.Tensor, mask: torch.Tensor, global_scale: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - device = input.device - l, m, k_by_2 = input.shape + device = input.device + num_groups, m, k_by_2 = input.shape + assert mask.dtype == torch.int32, f"mask must be int32, got {mask.dtype}" + assert global_scale is not None and global_scale.dtype == torch.float32, "global_scale must be float32" + assert global_scale.numel() == num_groups, "global_scale shape mismatch with groups" - k = k_by_2 // 2 + k = k_by_2 // 2 @@ - output = torch.empty(l, m, k // 2, device=device, dtype=torch.uint8) + output = torch.empty(num_groups, m, k // 2, device=device, dtype=torch.uint8) output_scales = torch.empty( - l, padded_m, padded_k_int32, device=device, dtype=torch.int32 + num_groups, padded_m, padded_k_int32, device=device, dtype=torch.int32 ) @@ - module.silu_and_mul_scaled_nvfp4_experts_quantize( - output.view(l * m, k // 2), - output_scales.view(l * padded_m, padded_k_int32), - input.view(l * m, k_by_2), + module.silu_and_mul_scaled_nvfp4_experts_quantize( + output.view(num_groups * m, k // 2), + output_scales.view(num_groups * padded_m, padded_k_int32), + input.view(num_groups * m, k_by_2), global_scale, mask, True, ) - output = output.permute(1, 2, 0) + output = output.permute(1, 2, 0)
469-557
: Do the same light validations for grouped path.
- Check input_global_scale dtype/length and mask dtype.
- Optional: rename l → num_groups.
def scaled_fp4_grouped_quant_sm100( input_tensor: torch.Tensor, input_global_scale: torch.Tensor, mask: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: - device = input_tensor.device - l, m, k = input_tensor.shape + device = input_tensor.device + num_groups, m, k = input_tensor.shape + assert mask.dtype == torch.int32, f"mask must be int32, got {mask.dtype}" + assert input_global_scale.dtype == torch.float32, f"input_global_scale must be float32, got {input_global_scale.dtype}" + assert input_global_scale.numel() == num_groups, "input_global_scale shape mismatch with groups" @@ - output = torch.empty(l, m, k // 2, device=device, dtype=torch.uint8) + output = torch.empty(num_groups, m, k // 2, device=device, dtype=torch.uint8) output_scales = torch.empty( - l, padded_m, padded_k_int32, device=device, dtype=torch.int32 + num_groups, padded_m, padded_k_int32, device=device, dtype=torch.int32 ) @@ - module.silu_and_mul_scaled_nvfp4_experts_quantize( - output.view(l * m, k // 2), - output_scales.view(l * padded_m, padded_k_int32), - input_tensor.view(l * m, k), + module.silu_and_mul_scaled_nvfp4_experts_quantize( + output.view(num_groups * m, k // 2), + output_scales.view(num_groups * padded_m, padded_k_int32), + input_tensor.view(num_groups * m, k), input_global_scale, mask, False, )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
csrc/nv_internal/cpp/kernels/quantization.cu
(5 hunks)csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp
(4 hunks)docs/api/activation.rst
(1 hunks)docs/api/fp4_quantization.rst
(1 hunks)flashinfer/fp4_quantization.py
(4 hunks)
✅ Files skipped from review due to trivial changes (1)
- docs/api/activation.rst
🚧 Files skipped from review as they are similar to previous changes (1)
- docs/api/fp4_quantization.rst
🧰 Additional context used
🧬 Code graph analysis (3)
csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp (3)
csrc/tvm_ffi_utils.h (1)
get_stream
(272-274)csrc/nv_internal/cpp/kernels/quantization.cu (4)
invokeSiluAndMulNVFP4Quantization
(298-329)invokeSiluAndMulNVFP4Quantization
(298-300)invokeSiluAndMulNVFP4Quantization
(346-349)invokeSiluAndMulNVFP4Quantization
(365-367)flashinfer/activation.py (1)
silu_and_mul_scaled_nvfp4_experts_quantize
(193-219)
flashinfer/fp4_quantization.py (4)
flashinfer/fp8_quantization.py (1)
_compute_swizzled_layout_sf_size
(15-18)flashinfer/utils.py (5)
register_custom_op
(266-275)register_custom_op
(285-304)register_fake_op
(277-281)register_fake_op
(306-311)get_compute_capability
(245-248)csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp (2)
silu_and_mul_scaled_nvfp4_experts_quantize
(188-245)silu_and_mul_scaled_nvfp4_experts_quantize
(188-190)flashinfer/activation.py (1)
silu_and_mul_scaled_nvfp4_experts_quantize
(193-219)
csrc/nv_internal/cpp/kernels/quantization.cu (2)
csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h (1)
layout
(29-47)include/flashinfer/trtllm/common.h (1)
device
(83-90)
🪛 Ruff (0.14.1)
flashinfer/fp4_quantization.py
361-361: Unused function argument: global_scale
(ARG001)
363-363: Unused function argument: sf_use_ue8m0
(ARG001)
413-413: Ambiguous variable name: l
(E741)
445-445: Unused function argument: mask
(ARG001)
446-446: Unused function argument: global_scale
(ARG001)
449-449: Ambiguous variable name: l
(E741)
498-498: Ambiguous variable name: l
(E741)
534-534: Unused function argument: input_global_scale
(ARG001)
535-535: Unused function argument: mask
(ARG001)
538-538: Ambiguous variable name: l
(E741)
🔇 Additional comments (4)
csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp (1)
247-250
: Export list looks good.Symbols match Python bindings and new API names.
flashinfer/fp4_quantization.py (2)
358-372
: Good: fake batched path matches real dtypes and swizzled SF sizing.
964-989
: Wrapper naming/dispatch looks consistent with new API.Good adapter to scaled_fp4_grouped_quant_sm100.
csrc/nv_internal/cpp/kernels/quantization.cu (1)
168-173
: invokeFP4Quantization interface changes look consistent.enable_pdl plumbed through; call sites updated accordingly.
Also applies to: 190-191, 221-222
Signed-off-by: Shu Wang. <[email protected]>
/bot run |
@wenscarl is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww |
@yzh119 could you trigger the CI? It seems the AI comments are just nit-picking. Not sure what exactly blocks the merge. |
📌 Description
This PR reverts #1774 and #1835 which have some issues with some shapes under cuda graph. The kernels ported in this PR comes from SGLANG. [NVIDA] [1/N] Nvfp4 Masked Gemm: Add quant op for the flashinfer grouped gemm and [NVIDIA] [2/N] Optimize silu_and_mul_scaled_fp4_grouped_quant perf by @kaixih .
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commit
by runningpip install pre-commit
(or used your preferred method).pre-commit install
.pre-commit run --all-files
and fixed any reported issues.🧪 Tests
unittest
, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
API Changes
Documentation
Tests