Skip to content

Conversation

wenscarl
Copy link
Collaborator

@wenscarl wenscarl commented Oct 14, 2025

📌 Description

This PR reverts #1774 and #1835 which have some issues with some shapes under cuda graph. The kernels ported in this PR comes from SGLANG. [NVIDA] [1/N] Nvfp4 Masked Gemm: Add quant op for the flashinfer grouped gemm and [NVIDIA] [2/N] Optimize silu_and_mul_scaled_fp4_grouped_quant perf by @kaixih .

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added grouped FP4 quantization (scaled_fp4_grouped_quantize) and an NV-focused Silu+Mul expert quantization entry (silu_and_mul_scaled_nvfp4_experts_quantize).
  • API Changes

    • Replaced legacy batched APIs with new expert/grouped APIs; removed legacy mask parameter from FP4/MXFP8 quantization signatures and adjusted FP4 output layouts/types.
  • Documentation

    • Updated docs to list new functions and remove deprecated symbols.
  • Tests

    • Updated tests to validate new quantization paths, shapes, dtypes, and layouts.

@wenscarl wenscarl marked this pull request as ready for review October 14, 2025 03:42
@wenscarl wenscarl requested review from kaixih and yzh119 October 14, 2025 03:42
@yzh119
Copy link
Collaborator

yzh119 commented Oct 18, 2025

/bot run

@yzh119
Copy link
Collaborator

yzh119 commented Oct 18, 2025

/gemini review

@flashinfer-bot
Copy link
Collaborator

GitLab MR !84 has been created, and the CI pipeline #36839981 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #36839981: 0/17 passed

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request reworks the silu_and_mul nvfp4 quantization fusion by reverting previous changes and porting kernels from SGLANG. The changes are extensive, affecting CUDA kernels, C++ wrappers, Python bindings, and documentation. My review focuses on improving code quality, maintainability, and correctness. I've identified several areas for improvement, including code duplication in new Python functions that could be refactored, a new CUDA header file that appears to be unused and contains duplicated code, and minor issues in CUDA kernels such as commented-out code, unused parameters, and hardcoded values. Additionally, the documentation for FP4 quantization functions seems to be out of sync with the code, and a test case for the new functionality could be improved to better cover masking logic.

Copy link
Contributor

coderabbitai bot commented Oct 20, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

This PR removes mask (and some stream params) from FP4/MXFP8 kernel APIs, adds an NV-specific per-expert Silu-and-Mul FP4 quantization path with opaque void* signatures, and updates CUDA headers, C++ bindings, Python APIs, docs, and tests to scaled/grouped quantize variants.

Changes

Cohort / File(s) Summary
CUDA kernels & implementation
csrc/nv_internal/cpp/kernels/quantization.cu, csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh
Dropped mask from FP4/MXFP8 invocation signatures and trimmed stream params; removed device silu/NoSiluPolicy; added layout-aware SF offset helpers, multi-expert kernel cvt_fp16_to_fp4_expert, and NV launcher invokeSiluAndMulNVFP4Quantization.
CUDA public header
csrc/nv_internal/tensorrt_llm/kernels/quantization.h
Updated invokeFP4Quantization to remove mask; replaced typed invokeSiluAndMulFP4Quantization declaration with invokeSiluAndMulNVFP4Quantization using opaque void* parameters.
C++ bindings / THOP
csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp, csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.h
Removed mask from fp4_batched_quantize; added/renamed public entry silu_and_scaled_nvfp4_experts_quantize (exported as silu_and_mul_scaled_nvfp4_experts_quantize); adapted validation, dtypes, and kernel invocation to NV opaque path.
Python package exports & activation
flashinfer/__init__.py, flashinfer/activation.py
Replaced silu_and_mul_nvfp4_batched_quantize with silu_and_mul_scaled_nvfp4_experts_quantize; added/exported scaled_fp4_grouped_quantize; activation wrapper updated to call new backend dispatch.
High-level FP4 module
flashinfer/fp4_quantization.py
Added scaled_fp4_grouped_quant_sm100 and fake variant; renamed NV batched paths to scaled/expert variants; introduced grouped quantize wrappers and adjusted output layouts and scale reshaping.
Docs & Tests
docs/api/fp4_quantization.rst, docs/api/activation.rst, tests/utils/test_fp4_quantize.py
Docs updated to reflect removed/added public symbols; tests updated to use scaled_fp4_grouped_quantize and silu_and_mul_scaled_nvfp4_experts_quantize, removed mask usage and adjusted expected dtypes/shapes.

Sequence Diagram(s)

sequenceDiagram
  participant Py as Python caller
  participant PyAPI as flashinfer API
  participant Cpp as C++ binding
  participant NVInv as NV launcher
  participant Kernel as CUDA kernel
  participant GPU as Device memory

  Note left of Py: high-level call
  Py->>PyAPI: scaled_fp4_grouped_quantize(...) / silu_and_mul_scaled_nvfp4_experts_quantize(...)
  PyAPI->>Cpp: forward tensors & params
  Cpp->>NVInv: prepare opaque pointers, m_topk/k/n_experts, stream
  NVInv->>Kernel: launch per-expert kernel (cvt_fp16_to_fp4_expert / quantize)
  Kernel->>GPU: read input, compute SF, write SFout, write quantized out
  Kernel-->>NVInv: kernel completes
  NVInv-->>Cpp: return
  Cpp-->>PyAPI: return outputs
  PyAPI-->>Py: (output, output_scale)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Poem

🐰 I hopped through kernels, masks set free,

NV paths hum per-expert and neat,
Scales swizzled, grouped, in tidy rows,
Old signatures shrink as new ones grow,
A rabbit cheers — quantize, then eat.

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 37.50% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title Check ✅ Passed The PR title "silu_and_mul nvfp4 quanization fusion rework" clearly summarizes the main technical change: a reworking of the silu_and_mul NV FP4 quantization fusion. The title is specific enough that a teammate scanning commit history would understand the primary focus of the changes. However, there is a typo in the title: "quanization" should be "quantization". While the title captures the core change, it doesn't indicate that this PR is reverting previous problematic changes and importing kernels from SGLANG, though these are contextual details rather than alterations to the nature of the change itself.
Description Check ✅ Passed The PR description follows the required template structure with all major sections present. The Description section is well-filled, explaining what the PR does (reverts prior problematic PRs and ports kernels from SGLANG), why it's needed (addresses CUDA graph shape issues), and providing references to the source SGLANG PRs with the original contributor. The Related Issues section is empty but this is not critical since the related PRs (#1774 and #1835) are explicitly referenced in the description itself. The checklist sections are properly formatted with placeholder text, which is standard for newly created pull requests awaiting completion by the author.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d972fbe and 48259af.

📒 Files selected for processing (2)
  • csrc/nv_internal/cpp/kernels/quantization.cu (5 hunks)
  • csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp (4 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp (3)
csrc/tvm_ffi_utils.h (1)
  • get_stream (272-274)
csrc/nv_internal/cpp/kernels/quantization.cu (4)
  • invokeSiluAndMulNVFP4Quantization (298-332)
  • invokeSiluAndMulNVFP4Quantization (298-300)
  • invokeSiluAndMulNVFP4Quantization (349-352)
  • invokeSiluAndMulNVFP4Quantization (368-370)
flashinfer/activation.py (1)
  • silu_and_mul_scaled_nvfp4_experts_quantize (193-219)
csrc/nv_internal/cpp/kernels/quantization.cu (1)
csrc/trtllm_gemm_runner.cu (8)
  • m (111-126)
  • m (111-111)
  • m (128-179)
  • m (128-130)
  • m (181-236)
  • m (181-181)
  • m (238-250)
  • m (238-238)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp (2)

133-140: Enforce unsupported combo explicitly.

The batched path enforces sfVecSize==16 but still accepts sfUseUE8M0. Add a guard to reject UE8M0 here to avoid confusing, unsupported inputs.
Apply:

   CHECK_INPUT_TYPE(globalScale, fp32_dtype);
   TVM_FFI_ICHECK_EQ(sfVecSize, 16) << "sfVecSize can only be 16";
+  TVM_FFI_ICHECK(!sfUseUE8M0) << "UE8M0 (sfVecSize=32) is not supported in batched path.";

158-163: Critical: stream passed into enable_pdl; kernels run on default stream.

get_stream(self->device) is currently bound to enable_pdl (bool), and stream defaults to 0. This flips enable_pdl and launches on default stream, breaking CUDA Graph/PDL semantics and likely causing the CI failures.

Fix by passing both enable_pdl and stream (batched path can keep PDL disabled):

-  tensorrt_llm::kernels::invokeFP4Quantization<T, SF_VEC_SIZE>(                                  \
-      b, m, k, reinterpret_cast<T*>(self->data), static_cast<float*>(globalScale->data),         \
-      reinterpret_cast<int64_t*>(valueE2M1->data), reinterpret_cast<int32_t*>(scaleFP8SF->data), \
-      sfUseUE8M0, layout, mMultiProcessorCount, get_stream(self->device));
+  tensorrt_llm::kernels::invokeFP4Quantization<T, SF_VEC_SIZE>(                                  \
+      b, m, k, reinterpret_cast<T*>(self->data), static_cast<float*>(globalScale->data),         \
+      reinterpret_cast<int64_t*>(valueE2M1->data), reinterpret_cast<int32_t*>(scaleFP8SF->data), \
+      sfUseUE8M0, layout, mMultiProcessorCount, /*enable_pdl=*/false, get_stream(self->device));

I can scan the repo and patch other sites similarly.

flashinfer/fp4_quantization.py (1)

371-438: Fix dtype reinterpretation for scale factors in silu_and_mul_scaled_nvfp4_experts_quantize_sm100 and its fake op.

Lines 433 and 460 incorrectly view int32 tensors directly to torch.float8_e4m3fn. Replace both occurrences with proper byte-level reinterpretation:

-        output_scales = output_scales.view(torch.float8_e4m3fn).view(
-            l, padded_m // 128, padded_k // 4, 32, 4, 4
-        )
+        output_scales = output_scales.view(torch.uint8).view(
+            l, padded_m // 128, padded_k // 4, 32, 4, 4
+        ).view(torch.float8_e4m3fn)

Apply at:

  • Line 433 in silu_and_mul_scaled_nvfp4_experts_quantize_sm100
  • Line 460 in _fake_silu_and_mul_scaled_nvfp4_experts_quantize_sm100
♻️ Duplicate comments (1)
tests/utils/test_fp4_quantize.py (1)

362-396: Good: randomized mask exercises masking logic.

Using randint(1, m+1) ensures some rows are masked; this improves coverage.

🧹 Nitpick comments (9)
csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.h (1)

36-38: API clarity: function name vs boolean flag.

The name includes “silu_and_mul”, yet the API exposes use_silu_and_mul. Consider either removing the flag (always apply silu*mul) or renaming to reflect the optionality to avoid confusion at call sites.

csrc/nv_internal/tensorrt_llm/kernels/quantization.h (1)

65-69: Guard against size narrowing (m_topk/k/n_experts).

The NV path takes int; typical shapes come as int64_t. Add explicit range checks (<= INT_MAX) at FFI boundaries to prevent UB on large tensors.

flashinfer/activation.py (1)

193-218: Fix docstring and argument docs to match actual behavior (expects 2K input, mask before scale).

Current doc says input is [B, M, K] and lists a_global_sf before mask; code consumes [B, M, 2K] when fusing SiLU and passes (a, mask, a_global_sf). Update docs to avoid misuse.

 def silu_and_mul_scaled_nvfp4_experts_quantize(
     a,
     mask,
     a_global_sf,
 ):
-    """
-    Silu and multiply and quantize batched input tensor to NVFP4 format with mask.
+    """
+    Fused SiLU-and-Mul, then NVFP4 quantize with per-expert mask.
     Parameters:
-        a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16.
-        a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32.
-        mask (torch.Tensor): Mask tensor to apply before quantization.
-        sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
+        a (torch.Tensor): Input tensor of shape [B, M, 2*K] (SiLU-and-Mul fused) with dtype fp16/bf16.
+        mask (torch.Tensor): Per-expert valid rows, shape [B], dtype int32.
+        a_global_sf (torch.Tensor): Per-expert global scale, shape [B], dtype float32.
     Returns:
         Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
             - Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2
             - Scale factors tensor with shape determined by layout and sf_vec_size
     """
flashinfer/fp4_quantization.py (2)

439-465: Fake op returns int32 scales; should mirror real path’s float8 reshaped layout.

Keep dtype/shape parity to prevent downstream assumptions from breaking.

-        output_scales = torch.empty(
+        output_scales = torch.empty(
             l, padded_m, padded_k_int32, device=device, dtype=torch.int32
         )
-
-        output_scales = output_scales.view(torch.float8_e4m3fn).view(
-            l, padded_m // 128, padded_k // 4, 32, 4, 4
-        )
-        output_scales = output_scales.permute(3, 4, 1, 5, 2, 0)
-        return (output, output_scales)
+        scales_f8 = output_scales.view(torch.uint8).view(
+            l, padded_m // 128, padded_k // 4, 32, 4, 4
+        ).view(torch.float8_e4m3fn)
+        scales_f8 = scales_f8.permute(3, 4, 1, 5, 2, 0)
+        return (output, scales_f8)

466-527: Scaled grouped quantization reuses silu kernel with use_silu_and_mul=False; OK. Two tweaks:

  • Name locals descriptively (replace ambiguous l) to improve readability.
  • Apply the same dtype reinterpret for scales as above.
-        device = input_tensor.device
-        l, m, k = input_tensor.shape
+        device = input_tensor.device
+        num_groups, m, k = input_tensor.shape
@@
-        output = torch.empty(l, m, k // 2, device=device, dtype=torch.uint8)
+        output = torch.empty(num_groups, m, k // 2, device=device, dtype=torch.uint8)
@@
-        module.silu_and_mul_scaled_nvfp4_experts_quantize(
-            output.view(l * m, k // 2),
-            output_scales.view(l * padded_m, padded_k_int32),
-            input_tensor.view(l * m, k),
+        module.silu_and_mul_scaled_nvfp4_experts_quantize(
+            output.view(num_groups * m, k // 2),
+            output_scales.view(num_groups * padded_m, padded_k_int32),
+            input_tensor.view(num_groups * m, k),
             input_global_scale,
             mask,
             False,
         )
@@
-        output = output.permute(1, 2, 0)
+        output = output.permute(1, 2, 0)
@@
-        output_scales = output_scales.view(torch.float8_e4m3fn).view(
-            l, padded_m // 128, padded_k // 4, 32, 4, 4
-        )
-        output_scales = output_scales.permute(3, 4, 1, 5, 2, 0)
+        output_scales = output_scales.view(torch.uint8).view(
+            num_groups, padded_m // 128, padded_k // 4, 32, 4, 4
+        ).view(torch.float8_e4m3fn)
+        output_scales = output_scales.permute(3, 4, 1, 5, 2, 0)
csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh (3)

741-838: Kernel template compiles, but a few correctness and maintainability nits.

  • Add restrict to in/out pointers to help codegen.
  • Rename local alias to avoid shadowing the template (PackedVecT vs PackedVec).
  • Comment “numColThreadsForSf” derivation and enforce it’s multiple of warp size when using shuffles.
-template <BlockScaleQuantizationType quantization_type, class Type, int SF_VEC_SIZE, bool UE8M0_SF>
+template <BlockScaleQuantizationType quantization_type, class Type, int SF_VEC_SIZE, bool UE8M0_SF>
 __global__ void
@@
-  using PackedVec = PackedVec<Type>;
+  using PackedVecT = ::tensorrt_llm::kernels::PackedVec<Type>;
@@
-          PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
+          PackedVecT in_vec = reinterpret_cast<PackedVecT const*>(in)[inOffset];

Verify adding __restrict__ doesn’t regress occupancy; check SASS for reduced ld/st hazards.


840-884: cvt_quant_to_fp4_get_sf_out_offset: constant uses CVT_FP4_SF_VEC_SIZE; good. Minor: document padding math.

Add a brief comment explaining why numKTiles rounds by (SF_VEC_SIZE*4).


886-910: silu_and_mul: OK. Consider making silu constexpr device and using fast_math for expf.

Micro-optimization; optional.

csrc/nv_internal/cpp/kernels/quantization.cu (1)

331-350: Explicit instantiations look correct. Consider adding BF16/FP8 guards to match ENABLE_ macros.*

Keeps ODR and compile time in check.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between bea5949 and 856f918.

📒 Files selected for processing (10)
  • csrc/nv_internal/cpp/kernels/quantization.cu (5 hunks)
  • csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh (3 hunks)
  • csrc/nv_internal/tensorrt_llm/kernels/quantization.h (1 hunks)
  • csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp (4 hunks)
  • csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.h (1 hunks)
  • docs/api/fp4_quantization.rst (1 hunks)
  • flashinfer/__init__.py (2 hunks)
  • flashinfer/activation.py (1 hunks)
  • flashinfer/fp4_quantization.py (4 hunks)
  • tests/utils/test_fp4_quantize.py (5 hunks)
🧰 Additional context used
🧬 Code graph analysis (8)
flashinfer/activation.py (3)
csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp (2)
  • silu_and_mul_scaled_nvfp4_experts_quantize (188-244)
  • silu_and_mul_scaled_nvfp4_experts_quantize (188-190)
flashinfer/utils.py (1)
  • get_compute_capability (245-248)
flashinfer/fp4_quantization.py (1)
  • silu_and_mul_scaled_nvfp4_experts_quantize_sm100 (375-437)
csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.h (1)
csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp (4)
  • fp4_batched_quantize (133-186)
  • fp4_batched_quantize (133-134)
  • silu_and_mul_scaled_nvfp4_experts_quantize (188-244)
  • silu_and_mul_scaled_nvfp4_experts_quantize (188-190)
flashinfer/fp4_quantization.py (3)
flashinfer/utils.py (5)
  • register_custom_op (266-275)
  • register_custom_op (285-304)
  • register_fake_op (277-281)
  • register_fake_op (306-311)
  • get_compute_capability (245-248)
csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp (2)
  • silu_and_mul_scaled_nvfp4_experts_quantize (188-244)
  • silu_and_mul_scaled_nvfp4_experts_quantize (188-190)
flashinfer/activation.py (1)
  • silu_and_mul_scaled_nvfp4_experts_quantize (193-219)
flashinfer/__init__.py (3)
csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp (2)
  • silu_and_mul_scaled_nvfp4_experts_quantize (188-244)
  • silu_and_mul_scaled_nvfp4_experts_quantize (188-190)
flashinfer/activation.py (1)
  • silu_and_mul_scaled_nvfp4_experts_quantize (193-219)
flashinfer/fp4_quantization.py (1)
  • scaled_fp4_grouped_quantize (961-986)
csrc/nv_internal/tensorrt_llm/kernels/quantization.h (1)
csrc/nv_internal/cpp/kernels/quantization.cu (6)
  • void (225-249)
  • void (251-269)
  • invokeSiluAndMulNVFP4Quantization (298-329)
  • invokeSiluAndMulNVFP4Quantization (298-300)
  • invokeSiluAndMulNVFP4Quantization (346-349)
  • invokeSiluAndMulNVFP4Quantization (365-367)
csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp (4)
csrc/tvm_ffi_utils.h (1)
  • get_stream (272-274)
csrc/nv_internal/cpp/kernels/quantization.cu (4)
  • invokeSiluAndMulNVFP4Quantization (298-329)
  • invokeSiluAndMulNVFP4Quantization (298-300)
  • invokeSiluAndMulNVFP4Quantization (346-349)
  • invokeSiluAndMulNVFP4Quantization (365-367)
flashinfer/fp4_quantization.py (1)
  • fp4_quantize (624-686)
flashinfer/activation.py (1)
  • silu_and_mul_scaled_nvfp4_experts_quantize (193-219)
tests/utils/test_fp4_quantize.py (4)
flashinfer/fp4_quantization.py (3)
  • scaled_fp4_grouped_quantize (961-986)
  • nvfp4_batched_quantize (932-958)
  • fp4_quantize (624-686)
csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp (4)
  • silu_and_mul_scaled_nvfp4_experts_quantize (188-244)
  • silu_and_mul_scaled_nvfp4_experts_quantize (188-190)
  • fp4_quantize (35-122)
  • fp4_quantize (35-37)
flashinfer/activation.py (2)
  • silu_and_mul_scaled_nvfp4_experts_quantize (193-219)
  • silu_and_mul (69-110)
flashinfer/utils.py (1)
  • is_sm100a_supported (483-485)
csrc/nv_internal/cpp/kernels/quantization.cu (2)
csrc/trtllm_gemm_runner.cu (8)
  • m (111-126)
  • m (111-111)
  • m (128-179)
  • m (128-130)
  • m (181-236)
  • m (181-181)
  • m (238-250)
  • m (238-238)
include/flashinfer/trtllm/common.h (1)
  • device (83-90)
🪛 Ruff (0.14.0)
flashinfer/fp4_quantization.py

361-361: Unused function argument: global_scale

(ARG001)


363-363: Unused function argument: sf_use_ue8m0

(ARG001)


410-410: Ambiguous variable name: l

(E741)


442-442: Unused function argument: mask

(ARG001)


443-443: Unused function argument: global_scale

(ARG001)


446-446: Ambiguous variable name: l

(E741)


495-495: Ambiguous variable name: l

(E741)


531-531: Unused function argument: input_global_scale

(ARG001)


532-532: Unused function argument: mask

(ARG001)


535-535: Ambiguous variable name: l

(E741)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (13)
csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.h (1)

33-34: Signature update looks consistent with cpp.

Header matches the new fp4_batched_quantize(Tensor, ...) definition and mask removal. No issues spotted here.

csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp (1)

75-75: Correct param ordering (enable_pdl, stream).

fp4_quantize launches with enable_pdl followed by stream; this is correct post‑signature change.

flashinfer/__init__.py (2)

28-29: Export rename aligns with backend.

silu_and_mul_scaled_nvfp4_experts_quantize re-export looks correct and consistent with activation.py.


74-75: Good to expose grouped quantize publicly.

scaled_fp4_grouped_quantize export is consistent with fp4_quantization.py.

flashinfer/activation.py (1)

68-110: silu_and_mul API looks fine; PDL defaulting and shape checks are consistent.

No action needed.

flashinfer/fp4_quantization.py (2)

961-986: Public wrapper: doc and arg order match grouped path; good. Minor: keep wrapper names consistent across activation and quant modules.

Consider aligning function names and docstrings with activation.py to reduce API confusion.

Ensure init.py re-exports this symbol and old alias (if any) to avoid downstream breakage.


623-686: fp4_quantize path looks good; watch for is_column_major reshape when is_sf_swizzled_layout=True.

No blocking issues here.

csrc/nv_internal/cpp/kernels/quantization.cu (3)

105-106: SFOuput type cast: consistent with kernel signature. LGTM.

Casts to uint32_t* match cvt_quant_get_sf_out_offset expectation.


168-223: invokeFP4Quantization: good use of cudaLaunchKernelEx + PDL.

No issues.


271-295: Correct lines 237-238 and 256-257: std::optional in device kernels will fail with NVCC 12.8.

The review correctly identifies a critical issue, but targets the wrong lines. The host wrapper functions at lines 271-295 are fine. The actual problem is in the __global__ kernels they call:

  • block_scale_interleave_kernel (lines 225–249): Lines 237–238 use std::optional<int>
  • block_scale_interleave_reverse_kernel (lines 251–269): Lines 256–257 use std::optional<int>

NVCC 12.8 does not support std::optional in __device__ code (use cuda::std::optional instead). Replace with plain int and sentinel values (e.g., -1 for absent) as the review suggests.

Likely an incorrect or invalid review comment.

csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh (1)

911-1004: Mask early-exit uses break; prefer continue for clarity and to avoid skipping subsequent tiles.

Using break exits the per-thread loop entirely. If stride mapping changes, it could skip valid work. Safer to continue when rowIdx_in_expert >= mask[expert_idx].

-    if (use_mask && rowIdx_in_expert >= mask[expert_idx]) {
-      break;
-    }
+    if (use_mask && rowIdx_in_expert >= mask[expert_idx]) {
+      continue;
+    }

Also, rename alias to avoid template shadowing as above.

Likely an incorrect or invalid review comment.

tests/utils/test_fp4_quantize.py (2)

413-442: Incorrect Technical Claim About PyTorch API

The review comment incorrectly states that "Tensor.view takes sizes, not dtype." PyTorch's view() method does support dtype arguments for reinterpretation, as shown in official documentation with examples like x.view(torch.int32), x.view(torch.cfloat), and x.view(torch.uint8). The pattern .view(torch.float8_e4m3fn) is valid PyTorch syntax.

The one actionable improvement in the review is valid: the error message in the assertion should be corrected. It currently says "Expected uint8" but should say "Expected float8_e4m3fn" to match the assertion's actual check.

Likely an incorrect or invalid review comment.


337-359: Proposed fix has logical error: unswizzles only one scale, but not the other.

The review correctly identifies that the test needs to handle scale buffer padding/swizzle differences for non-128-multiple shapes like (1, 120, 64). However, the proposed diff unswizzles out_scale[i] but compares it against single_scale, which is still swizzled (returned from fp4_quantize with is_sf_swizzled_layout=True).

The correct pattern already exists in the same file at lines 403-405 and 454-456: unswizzle both scales before comparing:

scale_ref = unswizzle_sf(single_scale.view(torch.float8_e4m3fn), m, n)
scale_ans = unswizzle_sf(out_scale[i], m, n)
torch.testing.assert_close(scale_ref, scale_ans, rtol=1e-5, atol=1e-5)

The test should follow this established pattern rather than the one-sided unswizzle proposed in the review.

Likely an incorrect or invalid review comment.

@yzh119
Copy link
Collaborator

yzh119 commented Oct 20, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !84 has been updated with latest changes, and the CI pipeline #36904035 is currently running. I'll report back once the pipeline job completes.

@yzh119 yzh119 enabled auto-merge (squash) October 20, 2025 07:12
@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #36904035: 1/17 passed

Signed-off-by: Shu Wang. <[email protected]>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/fp4_quantization.py (1)

211-223: Fake fp4_quantize_sm100 returns wrong dtypes and ignores swizzled layout—fix required.

Verification confirms the issue: the function returns torch.int64 and torch.int32 instead of torch.uint8, and completely ignores the is_sf_swizzled_layout parameter. The scale factor size computation doesn't account for swizzled layout, which will cause mismatches with the real operation. The suggested fix should be applied as specified.

🧹 Nitpick comments (4)
csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp (2)

133-166: Add missing input/output validations in fp4_batched_quantize.

  • globalScale isn’t validated for device residency; add CHECK_CUDA.
  • Enforce output buffer dtypes to match expectations (packed E2M1 as uint8 bytes; SF as uint8 backing even if kernel writes int32 words).
  • Optionally assert output shapes to catch host-side allocation mistakes early.
 void fp4_batched_quantize(Tensor self, Tensor globalScale, Tensor valueE2M1, Tensor scaleFP8SF,
                           int64_t sfVecSize, bool sfUseUE8M0) {
   CHECK_CUDA(self);
   CHECK_CONTIGUOUS(self);
+  CHECK_CUDA(globalScale);
+  auto uint8_dtype = DLDataType{kDLUInt, 8, 1};
   auto fp32_dtype = DLDataType{kDLFloat, 32, 1};
   CHECK_INPUT_TYPE(globalScale, fp32_dtype);
   TVM_FFI_ICHECK_EQ(sfVecSize, 16) << "sfVecSize can only be 16";
@@
   int64_t b = inputShape[0];
   int64_t m = inputShape[1];
   int64_t k = inputShape[2];
 
   TVM_FFI_ICHECK_EQ(k % sfVecSize, 0);
 
   std::vector<int64_t> outputShape(inputShape.begin(), inputShape.end());
   outputShape[rank - 1] = k / 2;
+  // Validate host-allocated outputs (dtype and minimal shape)
+  CHECK_INPUT_TYPE(valueE2M1, uint8_dtype);
+  CHECK_INPUT_TYPE(scaleFP8SF, uint8_dtype);
+  TVM_FFI_ICHECK_EQ(valueE2M1.ndim(), 3);
+  TVM_FFI_ICHECK_EQ(valueE2M1.shape()[0], b);
+  TVM_FFI_ICHECK_EQ(valueE2M1.shape()[1], m);
+  TVM_FFI_ICHECK_EQ(valueE2M1.shape()[2], k / 2);

Also applies to: 158-163


231-241: Range/robustness checks before kernel launch.

  • Add n_experts > 0 to avoid division by zero in grid rounding.
  • Optional: ensure m_topk, k fit into int (kernel takes int), and k >= CVT_ELTS_PER_THREAD to avoid block.x=0.
   auto in_dtype = input->dtype;
   const cudaStream_t stream = get_stream(input->device);
+  TVM_FFI_ICHECK_GT(n_experts, 0) << "n_experts must be > 0";
+  TVM_FFI_ICHECK_GT(m_topk, 0) << "m_topk must be > 0";
+  TVM_FFI_ICHECK_GT(k, 0) << "k must be > 0";
flashinfer/fp4_quantization.py (2)

374-441: Validate mask/global_scale early; avoid ambiguous variable name.

  • Add asserts for mask dtype (int32) and global_scale shape/dtype.
  • Consider renaming l → num_groups to avoid E741.
 def silu_and_mul_scaled_nvfp4_experts_quantize_sm100(
     input: torch.Tensor,
     mask: torch.Tensor,
     global_scale: Optional[torch.Tensor] = None,
 ) -> Tuple[torch.Tensor, torch.Tensor]:
-    device = input.device
-    l, m, k_by_2 = input.shape
+    device = input.device
+    num_groups, m, k_by_2 = input.shape
+    assert mask.dtype == torch.int32, f"mask must be int32, got {mask.dtype}"
+    assert global_scale is not None and global_scale.dtype == torch.float32, "global_scale must be float32"
+    assert global_scale.numel() == num_groups, "global_scale shape mismatch with groups"
-    k = k_by_2 // 2
+    k = k_by_2 // 2
@@
-    output = torch.empty(l, m, k // 2, device=device, dtype=torch.uint8)
+    output = torch.empty(num_groups, m, k // 2, device=device, dtype=torch.uint8)
     output_scales = torch.empty(
-            l, padded_m, padded_k_int32, device=device, dtype=torch.int32
+            num_groups, padded_m, padded_k_int32, device=device, dtype=torch.int32
     )
@@
-        module.silu_and_mul_scaled_nvfp4_experts_quantize(
-            output.view(l * m, k // 2),
-            output_scales.view(l * padded_m, padded_k_int32),
-            input.view(l * m, k_by_2),
+        module.silu_and_mul_scaled_nvfp4_experts_quantize(
+            output.view(num_groups * m, k // 2),
+            output_scales.view(num_groups * padded_m, padded_k_int32),
+            input.view(num_groups * m, k_by_2),
             global_scale,
             mask,
             True,
         )
-        output = output.permute(1, 2, 0)
+        output = output.permute(1, 2, 0)

469-557: Do the same light validations for grouped path.

  • Check input_global_scale dtype/length and mask dtype.
  • Optional: rename l → num_groups.
 def scaled_fp4_grouped_quant_sm100(
     input_tensor: torch.Tensor,
     input_global_scale: torch.Tensor,
     mask: torch.Tensor,
 ) -> Tuple[torch.Tensor, torch.Tensor]:
-    device = input_tensor.device
-    l, m, k = input_tensor.shape
+    device = input_tensor.device
+    num_groups, m, k = input_tensor.shape
+    assert mask.dtype == torch.int32, f"mask must be int32, got {mask.dtype}"
+    assert input_global_scale.dtype == torch.float32, f"input_global_scale must be float32, got {input_global_scale.dtype}"
+    assert input_global_scale.numel() == num_groups, "input_global_scale shape mismatch with groups"
@@
-    output = torch.empty(l, m, k // 2, device=device, dtype=torch.uint8)
+    output = torch.empty(num_groups, m, k // 2, device=device, dtype=torch.uint8)
     output_scales = torch.empty(
-            l, padded_m, padded_k_int32, device=device, dtype=torch.int32
+            num_groups, padded_m, padded_k_int32, device=device, dtype=torch.int32
     )
@@
-        module.silu_and_mul_scaled_nvfp4_experts_quantize(
-            output.view(l * m, k // 2),
-            output_scales.view(l * padded_m, padded_k_int32),
-            input_tensor.view(l * m, k),
+        module.silu_and_mul_scaled_nvfp4_experts_quantize(
+            output.view(num_groups * m, k // 2),
+            output_scales.view(num_groups * padded_m, padded_k_int32),
+            input_tensor.view(num_groups * m, k),
             input_global_scale,
             mask,
             False,
         )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 856f918 and d972fbe.

📒 Files selected for processing (5)
  • csrc/nv_internal/cpp/kernels/quantization.cu (5 hunks)
  • csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp (4 hunks)
  • docs/api/activation.rst (1 hunks)
  • docs/api/fp4_quantization.rst (1 hunks)
  • flashinfer/fp4_quantization.py (4 hunks)
✅ Files skipped from review due to trivial changes (1)
  • docs/api/activation.rst
🚧 Files skipped from review as they are similar to previous changes (1)
  • docs/api/fp4_quantization.rst
🧰 Additional context used
🧬 Code graph analysis (3)
csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp (3)
csrc/tvm_ffi_utils.h (1)
  • get_stream (272-274)
csrc/nv_internal/cpp/kernels/quantization.cu (4)
  • invokeSiluAndMulNVFP4Quantization (298-329)
  • invokeSiluAndMulNVFP4Quantization (298-300)
  • invokeSiluAndMulNVFP4Quantization (346-349)
  • invokeSiluAndMulNVFP4Quantization (365-367)
flashinfer/activation.py (1)
  • silu_and_mul_scaled_nvfp4_experts_quantize (193-219)
flashinfer/fp4_quantization.py (4)
flashinfer/fp8_quantization.py (1)
  • _compute_swizzled_layout_sf_size (15-18)
flashinfer/utils.py (5)
  • register_custom_op (266-275)
  • register_custom_op (285-304)
  • register_fake_op (277-281)
  • register_fake_op (306-311)
  • get_compute_capability (245-248)
csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp (2)
  • silu_and_mul_scaled_nvfp4_experts_quantize (188-245)
  • silu_and_mul_scaled_nvfp4_experts_quantize (188-190)
flashinfer/activation.py (1)
  • silu_and_mul_scaled_nvfp4_experts_quantize (193-219)
csrc/nv_internal/cpp/kernels/quantization.cu (2)
csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h (1)
  • layout (29-47)
include/flashinfer/trtllm/common.h (1)
  • device (83-90)
🪛 Ruff (0.14.1)
flashinfer/fp4_quantization.py

361-361: Unused function argument: global_scale

(ARG001)


363-363: Unused function argument: sf_use_ue8m0

(ARG001)


413-413: Ambiguous variable name: l

(E741)


445-445: Unused function argument: mask

(ARG001)


446-446: Unused function argument: global_scale

(ARG001)


449-449: Ambiguous variable name: l

(E741)


498-498: Ambiguous variable name: l

(E741)


534-534: Unused function argument: input_global_scale

(ARG001)


535-535: Unused function argument: mask

(ARG001)


538-538: Ambiguous variable name: l

(E741)

🔇 Additional comments (4)
csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp (1)

247-250: Export list looks good.

Symbols match Python bindings and new API names.

flashinfer/fp4_quantization.py (2)

358-372: Good: fake batched path matches real dtypes and swizzled SF sizing.


964-989: Wrapper naming/dispatch looks consistent with new API.

Good adapter to scaled_fp4_grouped_quant_sm100.

csrc/nv_internal/cpp/kernels/quantization.cu (1)

168-173: invokeFP4Quantization interface changes look consistent.

enable_pdl plumbed through; call sites updated accordingly.

Also applies to: 190-191, 221-222

Signed-off-by: Shu Wang. <[email protected]>
@wenscarl
Copy link
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Collaborator

@wenscarl is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww

@wenscarl
Copy link
Collaborator Author

wenscarl commented Oct 22, 2025

@yzh119 could you trigger the CI? It seems the AI comments are just nit-picking. Not sure what exactly blocks the merge.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants