Skip to content

Conversation

ChristinaZ
Copy link
Contributor

@ChristinaZ ChristinaZ commented Oct 1, 2025

📌 Description

Update the routing code to align with the implementation in TRTLLM and add support for KIMI K2 and Qwen

Also revised the unit test based on the config of kimi k2 (https://huggingface.co/moonshotai/Kimi-K2-Instruct/blob/main/config.json)

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

Release Notes

  • New Features

    • MoE operations now support optional routing parameters with automatic defaults for streamlined model configuration.
  • Refactor

    • Optimized expert kernel routing and buffer management for improved flexibility across multiple routing strategies.
    • Enhanced top-K result handling with unified buffer interfaces.

Copy link
Contributor

Summary of Changes

Hello @ChristinaZ, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly refactors the Mixture-of-Experts (MoE) routing logic within TRTLLMGEN to enhance its compatibility and performance for new models like Kimi K2 and Qwen. The changes introduce dynamic resource allocation for kernels, standardize data handling for top-K expert selection, and improve the robustness of numerical operations, thereby broadening the framework's support for diverse MoE architectures.

Highlights

  • New Model Support: The routing code has been updated to support Kimi K2 and Qwen models, aligning with the existing TRTLLM implementation.
  • Dynamic Expert Count Handling: The routing kernels now dynamically determine thread and shared memory configurations based on the number of experts (MaxNumExperts), replacing previously fixed constants. This allows for more flexible and efficient handling of varying MoE configurations.
  • Unified Top-K Data Pointers: The data structures and kernel parameters have been refactored to use standardized pointer names like mPtrTopKPacked, mPtrTopKWeights, and mPtrTopKIds for managing top-K expert scores and indices, improving consistency across different routing methods.
  • Enhanced Top-K Reduction Logic: The reduceTopK mechanism has been improved to handle a broader range of candidate numbers more efficiently, including a new overload that processes larger inputs in chunks.
  • Softmax Calculation Improvements: Softmax calculations within the kernels now explicitly use float for intermediate computations, ensuring better numerical stability when dealing with half or bfloat16 input types.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant updates to the MoE routing kernels to support new models like Kimi K2 and Qwen, and to align with TRTLLM implementations. The changes involve making the kernels more generic to handle variable numbers of experts, refactoring pointer names for consistency (e.g., mPtrExpertIdx to mPtrTopKPacked), and adding new execution paths for pre-computed top-K indices. The use of __launch_bounds__ and replacing cudaMemsetAsync with a dedicated kernel are good improvements.

My review focuses on a few areas for improvement:

  • A typo in a variable name that affects readability.
  • A confusing static_assert comment in the top-K reduction logic.
  • A potential bug in the new reduceTopK implementation related to an unresolved @todo and suspicious index initialization, which could lead to incorrect behavior.

Overall, the changes are well-structured and move towards a more flexible and robust implementation. Addressing the identified issues will further improve the code quality.

@yzh119
Copy link
Collaborator

yzh119 commented Oct 8, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !71 has been created, and the CI pipeline #36228669 is currently running. I'll report back once the pipeline job completes.

@sricketts sricketts mentioned this pull request Oct 8, 2025
@ChristinaZ ChristinaZ force-pushed the update_routing branch 2 times, most recently from 533e2de to 1a94ecb Compare October 15, 2025 14:35
Copy link
Contributor

coderabbitai bot commented Oct 18, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

This PR refactors MOE (Mixture of Experts) routing infrastructure to support variable expert counts and flexible routing configurations. Key changes include: making routing parameters (n_group, topk_group, routed_scaling_factor) optional in C++ launchers; parameterizing kernel launch bounds with KernelParams::MaxNumExperts; introducing new routing macros and kernel variants (LAUNCH_ROUTING_DEEPSEEK, LAUNCH_ROUTING_LLAMA4); renaming top-K buffers (mPtrExpertIdx → mPtrTopKPacked); and propagating corresponding Optional parameter types through Python MoE APIs.

Changes

Cohort / File(s) Summary
MOE kernel launcher entry points
csrc/trtllm_fused_moe_kernel_launcher.cu
Updated method signatures for trtllm_fp8_per_tensor_scale_moe_launcher, trtllm_fp8_block_scale_moe_launcher, trtllm_fp4_block_scale_moe_launcher and corresponding wrapper functions to accept Optional<int64_t> for n_group and topk_group, and Optional<double> for routed_scaling_factor; top_k remains required. Added validation branches for DeepSeekV3, Renormalize, and Llama4 routing methods with checks for Optional consistency.
DeepSeek routing kernel
csrc/trtllm_fused_moe_routing_deepseek.cu
Replaced fixed thread/warp/expert constants with parameterized expert-count model (NumKimiK2Experts: 384, NumDeepseekExperts: 256); replaced thread indexing with dynamic blockDim.x usage; updated kernel launch bounds to use KernelParams::MaxNumExperts; introduced getMaxNumExperts() helper and LAUNCH_ROUTING_DEEPSEEK macro; updated output paths to use TopKPacked/TopKIds.
Llama4 routing kernel
csrc/trtllm_fused_moe_routing_llama4.cu
Renamed MaxNumExperts to NumExpertsLimit; added getMaxNumExperts() helper and runImpl() entry point; extended input handling for multiple top-K formats (TopKPacked, TopKIds/TopKWeights, Scores); updated routing paths with conditional permutations based on top-K input type; generalized kernel computations to reference KernelParams::MaxNumExperts.
Renormalize routing kernel
csrc/trtllm_fused_moe_routing_renormalize.cu
Introduced routingIndicesBlockKernel for small token counts; parameterized launch bounds with KernelParams::MaxNumExperts; added getMaxNumExperts() helper and LAUNCH_ROUTING_RENORMALIZE macro; replaced fixed constants (MaxNumTopExperts: 8→10, added NumExpertsLimit: 512, BlockKernelMaxNumTokens: 4); updated validation to use KernelParams-based bounds.
MOE runner data structures
csrc/trtllm_fused_moe_runner.cu
Renamed routing output buffers across DeepSeek, Llama4, and Renormalize data structures: mPtrExpertIdxmPtrTopKPacked, mPtrExpertWeightsmPtrTopKWeights.
Python MoE API layer
flashinfer/fused_moe/core.py
Updated MoE function signatures (trtllm_fp8_per_tensor_scale_moe_op, trtllm_fp8_block_scale_moe_op, trtllm_fp4_block_scale_moe_op and fake variants) to accept Optional[int] for n_group/topk_group and Optional[float] for routed_scaling_factor. Updated permutation caching to use composite cache keys (e.g., ("w3_w1", shape)) instead of shape-based keys.
Routing kernel declarations
include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh
Updated launch bounds for routingIndicesHistogramKernel and routingIndicesOffsetsKernel from fixed NumThreadsHist to KernelParams::MaxNumExperts; introduced routingInitExpertCounts kernel; replaced expert-based indexing with TopK buffer paths (TopKIds/TopKWeights, TopKPacked).
Routing kernel header definitions
include/flashinfer/trtllm/fused_moe/RoutingKernel.h
Updated KernelParamsBase template to include int MaxNumExperts_ parameter; replaced mPtrExpertWeights with mPtrTopKWeights and mPtrTopKIds; replaced mPtrExpertIdx with mPtrTopKPacked; updated setBaseParams and setKernelParams wiring for TopK buffers across all specializations.
Top-K reduction utilities
include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh
Added constants MaxNumExpertsUnit (128) and MaxNumTopK (10); renamed reduceTopKreduceTopKFunc (small-N path); added new reduceTopK overload supporting large-N buffering while delegating small-N to reduceTopKFunc.
Device kernel launcher macros
include/flashinfer/trtllm/fused_moe/DevKernel.h
Removed LAUNCH_ROUTING macro; added LAUNCH_ROUTING_LLAMA4, LAUNCH_ROUTING_WITH_NUM_EXPERTS, and LAUNCH_ROUTING_DEEPSEEK_IMPL macros to support explicit numExperts parameterization; updated all kernel launch invocations to use new macro variants.
MOE test configuration
tests/moe/test_trtllm_gen_fused_moe.py
Changed cache_permute_indices fixture return type from Dict[torch.Size, torch.Tensor] to Dict[tuple, torch.Tensor]; extended intermediate_size values; reset routing configurations (n_groups, top_k_groups to 1); adjusted compatible_moe_impls orderings; increased top_k upper bound assertion from 8 to 10; added conditional skips for large configurations.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

This PR encompasses substantial refactoring across heterogeneous components: C++ kernel parameterization with new template parameters (MaxNumExperts), introduction of multiple routing macros with non-trivial logic flow, data structure buffer renaming propagating through multiple layers, and corresponding Python API signature changes. The changes require understanding the interconnected flow from Python bindings through C++ kernel launchers to device kernels, plus validation of the routing configuration logic branches for each method variant.

Possibly related issues

Poem

🐰 With experts now flexible, not bound by one size,
Through macros and templates, our routing device flies!
TopK buffers re-labeled, no more mPtrExplore,
Optional parameters grace every kernel's front door.
From Python to CUDA, the dance is complete—*
A MOE refactor making all architectures sweet! 🌟

Pre-merge checks and finishing touches

❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 27.27% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Title Check ❓ Inconclusive The title "Update the routing for TRTLLMGEN to support kimi k2 and qwen" is specific and refers to a real aspect of the changeset—adding support for new models (Kimi K2 and Qwen). However, the core technical changes documented in the summary are broader architectural refactoring efforts: introducing optional parameters for routing configuration (n_group, topk_group, routed_scaling_factor), parameterizing expert count handling across multiple routing kernels, and aligning with the TRTLLM implementation. The title captures the model-support outcome but does not convey the main architectural changes or the alignment with TRTLLM that the PR description identifies as primary objectives. Consider revising the title to reflect the primary architectural changes. A more descriptive title might be something like "Refactor routing to support optional parameters and expert count parameterization for TRTLLM alignment" or "Add optional routing parameters and expert count flexibility to support Kimi K2 and Qwen" to better communicate the main technical changes while still acknowledging the model support additions.
✅ Passed checks (1 passed)
Check name Status Explanation
Description Check ✅ Passed The pull request description follows the provided template structure, including all major sections: Description, Related Issues, Pre-commit Checks, Tests, and Reviewer Notes. The Description section explains the main objectives (updating routing code to align with TRTLLM and adding support for Kimi K2 and Qwen) and references a linked configuration file. The Pre-commit Checks items are appropriately marked as completed. While the Related Issues section is empty and the Tests section items are unchecked (despite the description mentioning revised unit tests), the description is substantially complete and provides sufficient context for reviewers to understand the intent and scope of changes.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (4)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)

165-168: Fix expert_count_histogram size; 2×256 is no longer sufficient (Kimi K2 = 384).

With num_experts > 256, the fixed 512-element buffer risks OOB writes in routing histograms.

Apply:

-  Tensor expert_count_histogram = alloc_tensor(
-      {2 * 256},
-      dl_int32,  // 256 is the max number of threads per block and max number of experts
-      routing_logits->device);
+  int64_t const size_of_expert_count_histogram = std::max(num_experts * 2, int64_t(256 * 2));
+  Tensor expert_count_histogram = alloc_tensor(
+      {size_of_expert_count_histogram},
+      dl_int32,  // sized by 2 * max(num_experts, 256)
+      routing_logits->device);
flashinfer/fused_moe/core.py (3)

1118-1145: Fake op return type mismatch (should return Tensor, not list).

The custom op returns a single tensor, but the fake op returns a 1‑element list; breaks meta/inference paths.

Apply:

 @register_fake_op("flashinfer::trtllm_fp8_per_tensor_scale_moe")
 def _fake_trtllm_fp8_per_tensor_scale_moe(
@@
-        return [hidden_states.new_empty([seq_len, hidden_size], dtype=torch.bfloat16)]
+        return hidden_states.new_empty([seq_len, hidden_size], dtype=torch.bfloat16)

1205-1234: Same issue: fp8 block-scale fake op must return Tensor.

Align fake with real op.

 @register_fake_op("flashinfer::trtllm_fp8_block_scale_moe")
 def _fake_trtllm_fp8_block_scale_moe(
@@
-        return [hidden_states.new_empty([seq_len, hidden_size], dtype=torch.bfloat16)]
+        return hidden_states.new_empty([seq_len, hidden_size], dtype=torch.bfloat16)

225-235: Even-row assertion contradicts odd-row handling.

Asserting M is even prevents the odd‑M branch that follows from ever running.

-    assert M % 2 == 0, f"x.shape[0] must be even, not {M}"
+    # Support both even and odd M.
+    # (Odd M is handled via the (M + 1) // 2 split below.)

Also consider validating behavior with a quick unit test for odd M.

♻️ Duplicate comments (2)
include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh (1)

372-375: Address the TODO and fix buffer initialization.

As flagged in a previous review, the initialization topKBufferIdx[ii] = ii * WarpSize - 1 at line 374 is problematic:

When ii=0, this sets topKBufferIdx[0] = -1. This value is then used in RedType::makeCmpVal which calculates maxIdx - idx = 65535 - (-1) = 65536, overflowing the 0xFFFF mask used for the index part of the packed value. This could cause incorrect tie-breaking or other subtle bugs.

The @todo comment indicates this needs validation. Consider using RedType::maxIdx or another safe sentinel value for invalid indices.

Suggested fix:

     for (int ii = 0; ii < numResults; ++ii) {
       topKBufferValue[ii] = minValue;
-      topKBufferIdx[ii] = ii * WarpSize - 1;  //@todo: check if this is correct
+      topKBufferIdx[ii] = RedType::maxIdx;  // Use safe sentinel for invalid indices
     }
csrc/trtllm_fused_moe_routing_deepseek.cu (1)

192-199: Rename “intermidiate” to “intermediate”.**

Typo hurts readability; please rename both arrays accordingly.

🧹 Nitpick comments (8)
csrc/trtllm_fused_moe_routing_deepseek.cu (2)

191-206: Correct inter-topK scratch sizing; current formula overshoots.

NumInterTopKPerThread should be ceil(NumInterTopK / WarpSize). The current expression scales with NumExpertWarps again, inflating per-lane storage and work.

-        int constexpr NumInterTopKPerThread = (NumInterTopK * NumExpertWarps - 1) / WarpSize + 1;
+        int constexpr NumInterTopKPerThread = (NumInterTopK + WarpSize - 1) / WarpSize;

96-99: Typos in comment (“sigmoig”).

Nit: s/sigmoig/sigmoid/ in the comment.

flashinfer/fused_moe/core.py (2)

1130-1136: Silence Ruff ARG001 for unused optional params in fake ops.

Keep signature for registry, but explicitly mark unused.

 def _fake_trtllm_fp8_per_tensor_scale_moe(
@@
-    ):
+    ):
+        # Unused in fake; keep signature for registry.
+        del n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor
@@
 def _fake_trtllm_fp8_block_scale_moe(
@@
-    ):
+    ):
+        # Unused in fake; keep signature for registry.
+        del n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor

Alternatively add “# noqa: ARG001” to the function defs.

Also applies to: 1218-1223


113-122: Duplicate entry in trtllm_gen_dtype_has_scale.

MxE4m3 is listed twice; harmless but noisy.

-    if dtype in [
-        DtypeTrtllmGen.MxE4m3,
-        DtypeTrtllmGen.E2m1,
-        DtypeTrtllmGen.MxE2m1,
-        DtypeTrtllmGen.MxE4m3,
-    ]:
+    if dtype in [DtypeTrtllmGen.MxE4m3, DtypeTrtllmGen.E2m1, DtypeTrtllmGen.MxE2m1]:
         return True
csrc/trtllm_fused_moe_routing_renormalize.cu (4)

25-30: Top‑K and expert limits: good, but add explicit guards.

MaxNumTopExperts=10 and NumExpertsLimit=512 look fine. Add compile‑time checks tying assumptions together.

 static constexpr int MaxNumTopExperts = 10;
 static constexpr int NumExpertsLimit = 512;
+static_assert(MaxNumTopExperts <= std::numeric_limits<int8_t>::max(),
+              "TopK index stored in int8_t; must fit.");
+static_assert(NumExpertsLimit % WarpSize == 0,
+              "Max experts must be multiple of warp size for VecSize.");

75-217: Block kernel: int8_t scratch indices rely on small token/top‑k; make it explicit.

smemKIdx/smemOffset use int8_t; safe with BlockKernelMaxNumTokens=4 and MaxNumTopExperts=10, but brittle if thresholds grow.

  • Add comments and static_asserts on bounds (expert counts per token per expert ≤ 127).
  • Consider uint16_t if future configs may exceed 127.
-  __shared__ int8_t __attribute((aligned(128))) smemOffset[totalExpertCounts];
-  __shared__ int8_t __attribute((aligned(128))) smemKIdx[totalExpertCounts];
+  __shared__ int8_t __attribute((aligned(128))) smemOffset[totalExpertCounts];   // offsetWithinExpert ∈ [0, BlockKernelMaxNumTokens)
+  __shared__ int8_t __attribute((aligned(128))) smemKIdx[totalExpertCounts];     // kIdx ∈ [0, MaxNumTopExperts)
+  static_assert(BlockKernelMaxNumTokens < 128 && MaxNumTopExperts < 128, "int8_t bounds");

370-381: Typo in macro name (RENORNALIZE).

Nit, but spreads quickly in call sites.

-#define LAUNCH_ROUTING_RENORNALIZE(data, coopLaunch, kernel, numBlocks, numThreads, smemSize,  \
+#define LAUNCH_ROUTING_RENORMALIZE(data, coopLaunch, kernel, numBlocks, numThreads, smemSize,  \
                                    stream, extraFlag1)                                         \
@@
-    LAUNCH_ROUTING_RENORNALIZE(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \
+    LAUNCH_ROUTING_RENORMALIZE(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \

…and update invocations below.


382-464: run(): validation and dispatch look solid; minor clarity nits.

  • Good: enforce inputs for large token cases; dynamic numThreads via getMaxNumExperts.
  • Suggest rename “useSingleBlock” threshold comment to reference BlockKernelMaxNumTokens constant.
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9f25eee and ed83138.

📒 Files selected for processing (13)
  • csrc/trtllm_batched_gemm_runner.cu (0 hunks)
  • csrc/trtllm_fused_moe_kernel_launcher.cu (13 hunks)
  • csrc/trtllm_fused_moe_routing_deepseek.cu (11 hunks)
  • csrc/trtllm_fused_moe_routing_llama4.cu (8 hunks)
  • csrc/trtllm_fused_moe_routing_renormalize.cu (8 hunks)
  • csrc/trtllm_fused_moe_runner.cu (3 hunks)
  • flashinfer/fused_moe/core.py (7 hunks)
  • include/flashinfer/trtllm/fused_moe/DevKernel.h (2 hunks)
  • include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh (12 hunks)
  • include/flashinfer/trtllm/fused_moe/RoutingKernel.h (9 hunks)
  • include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh (3 hunks)
  • tests/conftest.py (1 hunks)
  • tests/moe/test_trtllm_gen_fused_moe.py (5 hunks)
💤 Files with no reviewable changes (1)
  • csrc/trtllm_batched_gemm_runner.cu
🧰 Additional context used
🧬 Code graph analysis (8)
include/flashinfer/trtllm/fused_moe/DevKernel.h (1)
csrc/trtllm_fmha_kernel_launcher.cu (3)
  • namespace trtllm_cubin_loader { (303-305)
  • Context (34-312)
  • trtllm_paged_attention_launcher (75-165)
csrc/trtllm_fused_moe_routing_deepseek.cu (1)
csrc/trtllm_fused_moe_routing_llama4.cu (8)
  • void (67-280)
  • void (354-356)
  • void (363-423)
  • __launch_bounds__ (363-363)
  • getMaxNumExperts (426-433)
  • getMaxNumExperts (426-426)
  • routingIndicesClusterKernel (285-352)
  • routingIndicesClusterKernel (354-354)
csrc/trtllm_fused_moe_routing_llama4.cu (2)
csrc/trtllm_fused_moe_routing_renormalize.cu (6)
  • routingTopKExperts (32-37)
  • void (76-217)
  • void (288-291)
  • void (297-353)
  • getMaxNumExperts (357-366)
  • getMaxNumExperts (357-357)
csrc/trtllm_fused_moe_routing_deepseek.cu (5)
  • void (34-252)
  • void (276-278)
  • void (459-461)
  • getMaxNumExperts (464-475)
  • getMaxNumExperts (464-464)
include/flashinfer/trtllm/fused_moe/RoutingKernel.h (1)
include/flashinfer/trtllm/fused_moe/DevKernel.h (4)
  • setKernelParams (219-235)
  • setKernelParams (273-284)
  • setKernelParams (335-350)
  • setKernelParams (415-434)
tests/moe/test_trtllm_gen_fused_moe.py (3)
flashinfer/autotuner.py (1)
  • autotune (251-262)
flashinfer/fused_moe/core.py (1)
  • RoutingMethodType (59-73)
include/flashinfer/trtllm/fused_moe/runner.h (1)
  • RoutingMethodType (37-135)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)
include/flashinfer/trtllm/fused_moe/runner.h (7)
  • top_k (269-269)
  • n_group (270-270)
  • topk_group (272-272)
  • intermediate_size (274-274)
  • local_expert_offset (275-275)
  • local_num_experts (276-276)
  • num_experts (262-262)
flashinfer/fused_moe/core.py (1)
include/flashinfer/trtllm/fused_moe/runner.h (5)
  • n_group (270-270)
  • topk_group (272-272)
  • intermediate_size (274-274)
  • local_expert_offset (275-275)
  • local_num_experts (276-276)
csrc/trtllm_fused_moe_routing_renormalize.cu (2)
csrc/trtllm_fused_moe_routing_deepseek.cu (7)
  • void (34-252)
  • void (276-278)
  • void (459-461)
  • getMaxNumExperts (464-475)
  • getMaxNumExperts (464-464)
  • routingIndicesClusterKernel (260-274)
  • routingIndicesClusterKernel (276-276)
csrc/trtllm_fused_moe_routing_llama4.cu (9)
  • void (67-280)
  • void (354-356)
  • void (363-423)
  • routingTopKExperts (40-44)
  • getMaxNumExperts (426-433)
  • getMaxNumExperts (426-426)
  • routingIndicesClusterKernel (285-352)
  • routingIndicesClusterKernel (354-354)
  • routingIndicesHistogramScoresKernel (364-364)
🪛 Ruff (0.14.0)
flashinfer/fused_moe/core.py

1130-1130: Unused function argument: n_group

(ARG001)


1131-1131: Unused function argument: topk_group

(ARG001)


1132-1132: Unused function argument: intermediate_size

(ARG001)


1133-1133: Unused function argument: local_expert_offset

(ARG001)


1134-1134: Unused function argument: local_num_experts

(ARG001)


1135-1135: Unused function argument: routed_scaling_factor

(ARG001)


1218-1218: Unused function argument: n_group

(ARG001)


1219-1219: Unused function argument: topk_group

(ARG001)


1220-1220: Unused function argument: intermediate_size

(ARG001)


1221-1221: Unused function argument: local_expert_offset

(ARG001)


1222-1222: Unused function argument: local_num_experts

(ARG001)


1223-1223: Unused function argument: routed_scaling_factor

(ARG001)

🔇 Additional comments (17)
tests/moe/test_trtllm_gen_fused_moe.py (4)

108-108: LGTM: Autotune disabled during warmup.

Disabling autotuning during CUDA graph warmup is correct, as autotuning can interfere with graph capture.


1860-1861: Verify kimi_k2 configuration parameters.

The n_groups and top_k_groups parameters have been changed from [12, 4] to [1, 1], which removes hierarchical expert grouping. According to the PR description, this is based on the Kimi K2 config from Hugging Face.

Please verify that n_groups=1 and top_k_groups=1 accurately reflect the Kimi K2 Instruct config. If the actual model uses hierarchical grouping, these values may need adjustment.


1908-1933: Test coverage expanded for Renormalize variants.

The configurations for Renorm and RenormalizeNaive have been updated to test larger configurations (512 experts, top_k=10) and the skip marks have been commented out to enable these tests. This aligns with the PR objective to support models like Qwen with larger routing configurations.


2091-2091: Verify top_k upper bound after removing assertion.

The assert top_k <= 8 has been commented out, presumably to support configurations with top_k=10. Please confirm:

  1. What is the new upper bound for top_k (if any)?
  2. Are there any kernel or hardware constraints that limit top_k values?
  3. Should this assertion be replaced with a higher limit or removed entirely?

Based on the test configurations using top_k=10, this change appears intentional to support larger top-k values.

include/flashinfer/trtllm/fused_moe/DevKernel.h (3)

34-34: LGTM: Logger include added.

The logger header is needed to support FLASHINFER_WARN calls in the new routing macros.


116-126: LGTM: Llama4-specific routing macro added.

The LAUNCH_ROUTING_LLAMA4 macro correctly implements llama4-specific routing with a fixed per-expert block size of 128. The hardcoded constant is well-documented in the comments.


128-171: LGTM: Parameterized expert count macros added.

The new LAUNCH_ROUTING_WITH_NUM_EXPERTS macros provide flexible expert count configuration with optional input type forcing. The implementation correctly handles both Fp32 and Bfloat16 expert weight types with appropriate flag combinations.

csrc/trtllm_fused_moe_runner.cu (1)

68-68: LGTM: Consistent pointer naming refactor.

The pointer renames from expert-centric (mPtrExpertIdx, mPtrExpertWeights) to top-k-centric (mPtrTopKPacked, mPtrTopKWeights) have been applied consistently across all three routing methods (DeepSeekV3, Llama4, and Renormalize). This improves semantic clarity and aligns with the broader top-k buffer architecture changes described in the PR.

Also applies to: 73-73, 105-105, 110-110, 147-147, 152-152

include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh (3)

34-35: LGTM: System limits defined as constants.

The new constants MaxNumExpertsUnit (128) and MaxNumTopK (10) establish clear system-wide limits. The MaxNumTopK=10 value aligns with the test configuration changes that allow top_k=10.


317-324: Refactored top-k reduction with clearer constraints.

The function has been renamed to reduceTopKFunc to distinguish it as an internal implementation, and the static assertion has been tightened from N <= 16 to N < 5. This clarifies that this function variant handles only small candidate counts (N ≤ 4), with larger counts delegated to the buffering path in the new public reduceTopK overload.

Note: The past review comment about the confusing static assert message is now partially addressed by the tighter constraint, though the comment text itself could still be updated.


376-400: Buffering algorithm is sound pending initialization fix.

The multi-pass buffering strategy correctly reduces large candidate counts (N > 4) by processing in chunks of 4 and consolidating results. The logic properly handles thread distribution and edge cases, though the buffer initialization issue at line 374 needs to be resolved.

include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh (3)

102-129: Softmax in float: good numeric stability tradeoff.

Computing in float and reducing via cg::reduce is correct and safer for bf16/half inputs.


472-701: Histogram/offset kernels parameterized by KernelParams::MaxNumExperts: LGTM.

Launch bounds and shared buffers consistently use MaxNumExperts; PDL sync points are placed correctly for the two-step path.

Please ensure LAUNCH_* macros always pass blockDim == KernelParams::MaxNumExperts for these kernels.


705-728: Init kernel addition is appropriate.

routingInitExpertCounts isolates zeroing; matches usage in run paths.

csrc/trtllm_fused_moe_routing_llama4.cu (1)

310-352: Llama4 routing refactor looks coherent.

  • MaxNumExperts limited to 128 with compile-time guards.
  • Supports TopKIds/TopKPacked/Scores inputs consistently across warp/cluster/histogram paths.

Confirm tests cover both scores-input and packed/ids-input paths for num_tokens near WarpKernelMaxNumTokens boundaries.

Also applies to: 362-423, 425-433, 476-525

include/flashinfer/trtllm/fused_moe/RoutingKernel.h (1)

59-75: Top-K bufferization (Ids/Weights/Packed) and MaxNumExperts param: sensible API evolution.

DataBase and KernelParams* now expose Top-K buffers and parameterized bounds; casting is consistent.

Ensure callers always supply mPtrTopKWeights when providing mPtrTopKIds (guards exist in runImpl), and that OutputT matches the allocated packed score type.

Also applies to: 103-153, 174-213, 229-247, 267-293

csrc/trtllm_fused_moe_routing_renormalize.cu (1)

357-366: getMaxNumExperts: OK; unreachable 0 guarded above.

Function returns 0 for unsupported values; run() already checks ≤ NumExpertsLimit. Keep as is; no action.

@pytest.hookimpl(tryfirst=True)
def pytest_runtest_call(item):
# skip OOM error and missing JIT cache errors
# Wrap the test call so we don't invoke item.runtest() ourselves; yield lets pytest run it.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Misleading comment.

The comment states "yield lets pytest run it," but the code directly calls item.runtest() on line 144 without using yield. The comment should be updated to accurately reflect the implementation.

Apply this diff to correct the comment:

-    # Wrap the test call so we don't invoke item.runtest() ourselves; yield lets pytest run it.
+    # Execute the test and handle exceptions for OOM and missing JIT cache.
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Wrap the test call so we don't invoke item.runtest() ourselves; yield lets pytest run it.
# Execute the test and handle exceptions for OOM and missing JIT cache.
🤖 Prompt for AI Agents
In tests/conftest.py around line 142, the existing comment says "Wrap the test
call so we don't invoke item.runtest() ourselves; yield lets pytest run it." but
the implementation actually calls item.runtest() directly; update the comment to
accurately describe the behavior (e.g., state that the fixture wraps the test
call and invokes item.runtest() directly rather than yielding to pytest),
keeping it concise and reflecting the real control flow.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)

145-146: Add bounds checking for routed_scaling_factor to match validation patterns used elsewhere.

The code lacks validation for routed_scaling_factor despite similar optional parameters (n_group, topk_group) receiving extensive bounds checks (lines 79–94). Consider adding validation to ensure the scaling factor is positive and within reasonable bounds, consistent with the codebase's defensive programming pattern and the guidance that scaling factors should avoid zero and negative values.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ed83138 and 8e72f08.

📒 Files selected for processing (1)
  • csrc/trtllm_fused_moe_kernel_launcher.cu (14 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)
include/flashinfer/trtllm/fused_moe/runner.h (7)
  • top_k (269-269)
  • n_group (270-270)
  • topk_group (272-272)
  • intermediate_size (274-274)
  • local_expert_offset (275-275)
  • local_num_experts (276-276)
  • num_experts (262-262)
🔇 Additional comments (9)
csrc/trtllm_fused_moe_kernel_launcher.cu (9)

45-47: LGTM! Function signature updated to support optional routing parameters.

The changes to accept Optional<int64_t> for n_group and topk_group, and Optional<double> for routed_scaling_factor align with the PR objective to support multiple routing configurations for Kimi K2 and Qwen models.


79-94: LGTM! Comprehensive validation for grouped routing.

The validation logic properly ensures:

  • Grouped routing is only used with DeepSeekV3
  • All required parameters are present when using groups
  • Divisibility and bound constraints are satisfied
  • Sufficient experts exist in selected groups for the requested top_k

164-166: LGTM! Histogram sizing updated to handle variable expert counts.

The sizing calculation ensures the histogram buffer is large enough for both the number of experts and the maximum thread block size (256), which is appropriate for the routing kernel.


303-306: LGTM! Wrapper signature updated consistently.

The wrapper function signature properly reflects the launcher's updated parameters and forwards them correctly.


379-384: Renormalize methods supported here (contrast with per-tensor launcher).

This block-scale launcher correctly validates top_k bounds for Renormalize and RenormalizeNaive methods, while the per-tensor scale launcher at lines 95-100 rejects these methods entirely. This confirms the inconsistency flagged earlier.


731-758: LGTM! Validation logic properly structured.

The validation logic correctly handles optional parameters using value_or(0) and provides appropriate checks for different routing method configurations. The main concern is the commented-out checks, which have been flagged separately.


806-810: LGTM! Default values consistent across launchers.

The default values for optional parameters are consistent with the other launcher functions in this file.


834-836: LGTM! Histogram sizing consistent with other launchers.

The sizing calculation matches the approach used in the per-tensor and block-scale launchers.


95-100: Inconsistent Renormalize routing support between per-tensor and block-scale launchers.

This per-tensor scale launcher throws NotImplementedError for Renormalize and RenormalizeNaive routing methods (lines 95-100), while the block-scale launcher validates these methods with top_k <= 10 && top_k > 0 constraints instead of rejecting them. This inconsistency needs clarification:

  • Determine if per-tensor intentionally excludes Renormalize support or if it should match block-scale behavior
  • Update validation logic in one or both launchers to align

@jiahanc
Copy link
Collaborator

jiahanc commented Oct 18, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !71 has been updated with latest changes, and the CI pipeline #36861014 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh (1)

152-160: Guard against actualK > K to prevent out-of-bounds writes.

Loop indexes out[kk] and outIdx[kk] assume kk < K. Clamp actualK to K.

-  for (int kk = 0; kk < actualK; ++kk)  //@todo: check if actualK is correct
+  int cappedK = actualK < K ? actualK : K;  // cap to template bound
+  for (int kk = 0; kk < cappedK; ++kk)  //@todo: check if actualK is correct
♻️ Duplicate comments (2)
include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh (2)

170-171: Fix misleading static_assert message (copy-paste).

Condition enforces N < 5 (i.e., up to 4), but the message says “<= 128.” Update the message for clarity.

-  static_assert(N < 5, "Only support candidates number less than or equal to 128");
+  static_assert(N < 5, "Only support up to 4 candidates per thread in this function.");

218-221: Invalid sentinel index initialization (-1) can corrupt tie-breaking.

Setting topKBufferIdx[ii] = ii*WarpSize - 1 yields -1 for ii=0. In makeCmpVal this becomes (maxIdx - idx) = 65536, which wraps to 0x0000 and collides with a valid index in the lower 16 bits. Initialize to a safe value.

-    for (int ii = 0; ii < numResults; ++ii) {
-      topKBufferValue[ii] = minValue;
-      topKBufferIdx[ii] = ii * WarpSize - 1;  //@todo: check if this is correct
-    }
+    using RedT = TopKRedType<Type>;
+    for (int ii = 0; ii < numResults; ++ii) {
+      topKBufferValue[ii] = minValue;
+      // Use a valid sentinel that won't overflow the 0xFFFF packed index space
+      topKBufferIdx[ii] = RedT::maxIdx;
+    }
🧹 Nitpick comments (4)
flashinfer/fused_moe/core.py (1)

1122-1149: Silence Ruff ARG001 in fake op: prefix unused Optional params.

The fake implementation ignores several params; prefix with “_” to satisfy lint without behavior change.

-    def _fake_trtllm_fp8_per_tensor_scale_moe(
+    def _fake_trtllm_fp8_per_tensor_scale_moe(
         routing_logits: torch.Tensor,
-        routing_bias: Optional[torch.Tensor],
+        routing_bias: Optional[torch.Tensor],
         hidden_states: torch.Tensor,
         gemm1_weights: torch.Tensor,
         output1_scales_scalar: torch.Tensor,
         output1_scales_gate_scalar: torch.Tensor,
         gemm2_weights: torch.Tensor,
         output2_scales_scalar: torch.Tensor,
         num_experts: int,
         top_k: int,
-        n_group: Optional[int],
-        topk_group: Optional[int],
+        _n_group: Optional[int],
+        _topk_group: Optional[int],
         intermediate_size: int,
-        local_expert_offset: int,
-        local_num_experts: int,
-        routed_scaling_factor: Optional[float],
-        use_routing_scales_on_input: bool,
-        tile_tokens_dim: int = 8,
-        routing_method_type: int = 0,
-        enable_pdl: Optional[bool] = None,
+        _local_expert_offset: int,
+        _local_num_experts: int,
+        _routed_scaling_factor: Optional[float],
+        _use_routing_scales_on_input: bool,
+        _tile_tokens_dim: int = 8,
+        _routing_method_type: int = 0,
+        _enable_pdl: Optional[bool] = None,
     ):

Apply analogous renames in _fake_trtllm_fp8_block_scale_moe for: n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor, tile_tokens_dim, routing_method_type, enable_pdl.

tests/moe/test_trtllm_gen_fused_moe.py (2)

1835-1837: Align cache key with parameters impacting permutation.

Tests document keys as (weight_type, shape). Given core depends on epilogue_tile_m and num_elts_per_sf, consider reflecting that here to avoid hidden coupling.


2065-2072: Nit: typo in comment.

“epxerts” → “experts”.

-    # Skip large intermediate size and hidden size for configurations with small epxerts
+    # Skip large intermediate size and hidden size for configurations with small experts
csrc/trtllm_fused_moe_kernel_launcher.cu (1)

363-376: Inconsistent top_k bound for grouped routing across launchers.

fp8_block_scale enforces top_k <= 8 with groups; fp4_block_scale allows top_k <= 10. Is this intentional? If yes, add comments documenting per-kernel limits; if not, harmonize the checks.

Also applies to: 731-747

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 8e72f08 and 9fb400f.

📒 Files selected for processing (5)
  • csrc/trtllm_fused_moe_kernel_launcher.cu (14 hunks)
  • csrc/trtllm_fused_moe_routing_renormalize.cu (8 hunks)
  • flashinfer/fused_moe/core.py (11 hunks)
  • include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh (3 hunks)
  • tests/moe/test_trtllm_gen_fused_moe.py (6 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
csrc/trtllm_fused_moe_routing_renormalize.cu (2)
csrc/trtllm_fused_moe_routing_deepseek.cu (7)
  • void (34-252)
  • void (276-278)
  • void (459-461)
  • getMaxNumExperts (464-475)
  • getMaxNumExperts (464-464)
  • routingIndicesClusterKernel (260-274)
  • routingIndicesClusterKernel (276-276)
csrc/trtllm_fused_moe_routing_llama4.cu (9)
  • void (67-280)
  • void (354-356)
  • void (363-423)
  • routingTopKExperts (40-44)
  • getMaxNumExperts (426-433)
  • getMaxNumExperts (426-426)
  • routingIndicesClusterKernel (285-352)
  • routingIndicesClusterKernel (354-354)
  • routingIndicesHistogramScoresKernel (364-364)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)
include/flashinfer/trtllm/fused_moe/runner.h (7)
  • top_k (269-269)
  • n_group (270-270)
  • topk_group (272-272)
  • intermediate_size (274-274)
  • local_expert_offset (275-275)
  • local_num_experts (276-276)
  • num_experts (262-262)
tests/moe/test_trtllm_gen_fused_moe.py (2)
flashinfer/fused_moe/core.py (1)
  • RoutingMethodType (59-73)
include/flashinfer/trtllm/fused_moe/runner.h (4)
  • RoutingMethodType (37-135)
  • intermediate_size (274-274)
  • hidden_size (264-264)
  • top_k (269-269)
flashinfer/fused_moe/core.py (1)
include/flashinfer/trtllm/fused_moe/runner.h (5)
  • n_group (270-270)
  • topk_group (272-272)
  • intermediate_size (274-274)
  • local_expert_offset (275-275)
  • local_num_experts (276-276)
🪛 Ruff (0.14.0)
flashinfer/fused_moe/core.py

1134-1134: Unused function argument: n_group

(ARG001)


1135-1135: Unused function argument: topk_group

(ARG001)


1136-1136: Unused function argument: intermediate_size

(ARG001)


1137-1137: Unused function argument: local_expert_offset

(ARG001)


1138-1138: Unused function argument: local_num_experts

(ARG001)


1139-1139: Unused function argument: routed_scaling_factor

(ARG001)


1222-1222: Unused function argument: n_group

(ARG001)


1223-1223: Unused function argument: topk_group

(ARG001)


1224-1224: Unused function argument: intermediate_size

(ARG001)


1225-1225: Unused function argument: local_expert_offset

(ARG001)


1226-1226: Unused function argument: local_num_experts

(ARG001)


1227-1227: Unused function argument: routed_scaling_factor

(ARG001)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh (1)

226-231: Local-expert check uses log2 value as bitmask instead of converting to actual mask—affects 5 locations, not 2.

The review correctly identifies the bug but underestimated scope. Using (localExpertIdx & params.mLocalExpertsStrideLog2) == 0 incorrectly uses a log2 value as a bitmask. Example: if mLocalExpertsStrideLog2=3 (stride=8), the code checks idx & 3 (divides by 4) instead of idx & 7 (divides by 8).

Apply the suggested fix to all 5 occurrences:

+    int32_t strideMask = (1 << params.mLocalExpertsStrideLog2) - 1;
     auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent &&
-                         (localExpertIdx & params.mLocalExpertsStrideLog2) == 0;
+                         ((localExpertIdx & strideMask) == 0);

Locations: lines 228, 364, 434, 582, and 661.

csrc/trtllm_fused_moe_routing_llama4.cu (1)

264-267: Fix the local-expert mask bitwise AND operations across 6 locations.

The bug is confirmed: using a log2 value directly in bitwise AND instead of converting to a mask. This occurs in:

  1. csrc/trtllm_fused_moe_routing_llama4.cu line 266
  2. include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh lines 228, 364, 434, 582, 661

Apply the fix from the review comment to all 6 locations: create int32_t strideMask = (1 << params.mLocalExpertsStrideLog2) - 1; and use ((localExpertIdx & strideMask) == 0) instead of (localExpertIdx & params.mLocalExpertsStrideLog2) == 0.

♻️ Duplicate comments (5)
include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh (3)

218-221: Sentinel -1 overflows index packing; use a safe sentinel.

topKBufferIdx[ii] = ii * WarpSize - 1 can produce -1 → 65536 after (maxIdx - idx), overflowing 16-bit field and skewing tie-breaks. Initialize with RedType::maxIdx.

-    for (int ii = 0; ii < numResults; ++ii) {
-      topKBufferValue[ii] = minValue;
-      topKBufferIdx[ii] = ii * WarpSize - 1;
-    }
+    for (int ii = 0; ii < numResults; ++ii) {
+      topKBufferValue[ii] = minValue;
+      topKBufferIdx[ii] = TopKRedType<Type>::maxIdx;
+    }

228-233: OOB reads when N is not a multiple of 4.

Tail chunk reads past N. Guard and fill with neutral values.

-      for (int i = 0; i < 4; ++i) {
-        inValue[i] = value[start + i];
-        inIdx[i] = idx[start + i];
-      }
+      int rem = N - start;
+      for (int i = 0; i < 4; ++i) {
+        if (i < rem) {
+          inValue[i] = value[start + i];
+          inIdx[i] = idx[start + i];
+        } else {
+          inValue[i] = minValue;
+          inIdx[i] = TopKRedType<Type>::maxIdx;
+        }
+      }

170-171: Update misleading static_assert message.

N < 5 enforces up to 4 candidates per thread; fix message to match.

-  static_assert(N < 5, "Only support candidates number less than or equal to 128");
+  static_assert(N < 5, "Only support up to 4 candidates per thread in this function.");
flashinfer/fused_moe/core.py (1)

174-194: Cache key still incomplete (previously flagged).

The cache key uses only (prefix, shape) but epilogue_tile_m and num_elts_per_sf also affect permutation indices. This was identified in past reviews but remains unaddressed.

-    cache_key = ("w3_w1", dst_w3_w1_weight.shape)
+    cache_key = ("w3_w1", dst_w3_w1_weight.shape, epilogue_tile_m, num_elts_per_sf)

Apply the same fix to get_w2_permute_indices_with_cache:

-    cache_key = ("w2", dst_w2_weight.shape)
+    cache_key = ("w2", dst_w2_weight.shape, epilogue_tile_m, num_elts_per_sf)
csrc/trtllm_fused_moe_routing_deepseek.cu (1)

192-201: Fix variable name typos.

The variables intermidiateScore and intermidiateExpert should be renamed to intermediateScore and intermediateExpert.

-        float intermidiateScore[NumInterTopKPerThread];
-        int32_t intermidiateExpert[NumInterTopKPerThread];
+        float intermediateScore[NumInterTopKPerThread];
+        int32_t intermediateExpert[NumInterTopKPerThread];
         for (int i = laneIdx; i < NumInterTopKPerThread * WarpSize; i += WarpSize) {
           int ii = i / WarpSize;
           if (i < NumInterTopK) {
-            intermidiateScore[ii] = smemInterTopScores[i];
-            intermidiateExpert[ii] = smemInterTopExperts[i];
+            intermediateScore[ii] = smemInterTopScores[i];
+            intermediateExpert[ii] = smemInterTopExperts[i];
           } else {
-            intermidiateScore[ii] = invalidScoreFloat;
-            intermidiateExpert[ii] = KernelParams::MaxNumExperts - 1;
+            intermediateScore[ii] = invalidScoreFloat;
+            intermediateExpert[ii] = KernelParams::MaxNumExperts - 1;
           }
         }
-        topk::reduceTopK(warp, topScores, topExperts, intermidiateScore, intermidiateExpert,
+        topk::reduceTopK(warp, topScores, topExperts, intermediateScore, intermediateExpert,
                          /* minValue */ invalidScoreFloat, params.mTopK);
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9fb400f and 0f33230.

📒 Files selected for processing (11)
  • csrc/trtllm_fused_moe_kernel_launcher.cu (14 hunks)
  • csrc/trtllm_fused_moe_routing_deepseek.cu (11 hunks)
  • csrc/trtllm_fused_moe_routing_llama4.cu (8 hunks)
  • csrc/trtllm_fused_moe_routing_renormalize.cu (8 hunks)
  • csrc/trtllm_fused_moe_runner.cu (3 hunks)
  • flashinfer/fused_moe/core.py (11 hunks)
  • include/flashinfer/trtllm/fused_moe/DevKernel.h (2 hunks)
  • include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh (12 hunks)
  • include/flashinfer/trtllm/fused_moe/RoutingKernel.h (9 hunks)
  • include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh (3 hunks)
  • tests/moe/test_trtllm_gen_fused_moe.py (6 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • csrc/trtllm_fused_moe_runner.cu
🧰 Additional context used
🧬 Code graph analysis (7)
csrc/trtllm_fused_moe_routing_deepseek.cu (1)
csrc/trtllm_fused_moe_routing_renormalize.cu (4)
  • getMaxNumExperts (357-366)
  • getMaxNumExperts (357-357)
  • routingIndicesClusterKernel (222-286)
  • routingIndicesClusterKernel (289-289)
tests/moe/test_trtllm_gen_fused_moe.py (2)
flashinfer/fused_moe/core.py (1)
  • RoutingMethodType (59-73)
include/flashinfer/trtllm/fused_moe/runner.h (1)
  • RoutingMethodType (37-135)
csrc/trtllm_fused_moe_routing_renormalize.cu (2)
csrc/trtllm_fused_moe_routing_deepseek.cu (7)
  • void (34-252)
  • void (276-278)
  • void (459-461)
  • getMaxNumExperts (464-475)
  • getMaxNumExperts (464-464)
  • routingIndicesClusterKernel (260-274)
  • routingIndicesClusterKernel (276-276)
csrc/trtllm_fused_moe_routing_llama4.cu (9)
  • void (67-280)
  • void (354-356)
  • void (363-423)
  • routingTopKExperts (40-44)
  • getMaxNumExperts (426-433)
  • getMaxNumExperts (426-426)
  • routingIndicesClusterKernel (285-352)
  • routingIndicesClusterKernel (354-354)
  • routingIndicesHistogramScoresKernel (364-364)
csrc/trtllm_fused_moe_routing_llama4.cu (1)
csrc/trtllm_fused_moe_routing_renormalize.cu (9)
  • routingTopKExperts (32-37)
  • void (76-217)
  • void (288-291)
  • void (297-353)
  • getMaxNumExperts (357-366)
  • getMaxNumExperts (357-357)
  • routingIndicesClusterKernel (222-286)
  • routingIndicesClusterKernel (289-289)
  • routingIndicesHistogramScoresKernel (298-298)
include/flashinfer/trtllm/fused_moe/RoutingKernel.h (1)
include/flashinfer/trtllm/fused_moe/DevKernel.h (4)
  • setKernelParams (219-235)
  • setKernelParams (273-284)
  • setKernelParams (335-350)
  • setKernelParams (415-434)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)
include/flashinfer/trtllm/fused_moe/runner.h (7)
  • top_k (269-269)
  • n_group (270-270)
  • topk_group (272-272)
  • intermediate_size (274-274)
  • local_expert_offset (275-275)
  • local_num_experts (276-276)
  • num_experts (262-262)
flashinfer/fused_moe/core.py (1)
include/flashinfer/trtllm/fused_moe/runner.h (5)
  • n_group (270-270)
  • topk_group (272-272)
  • intermediate_size (274-274)
  • local_expert_offset (275-275)
  • local_num_experts (276-276)
🪛 Ruff (0.14.0)
flashinfer/fused_moe/core.py

1134-1134: Unused function argument: n_group

(ARG001)


1135-1135: Unused function argument: topk_group

(ARG001)


1136-1136: Unused function argument: intermediate_size

(ARG001)


1137-1137: Unused function argument: local_expert_offset

(ARG001)


1138-1138: Unused function argument: local_num_experts

(ARG001)


1139-1139: Unused function argument: routed_scaling_factor

(ARG001)


1222-1222: Unused function argument: n_group

(ARG001)


1223-1223: Unused function argument: topk_group

(ARG001)


1224-1224: Unused function argument: intermediate_size

(ARG001)


1225-1225: Unused function argument: local_expert_offset

(ARG001)


1226-1226: Unused function argument: local_num_experts

(ARG001)


1227-1227: Unused function argument: routed_scaling_factor

(ARG001)

🔇 Additional comments (14)
include/flashinfer/trtllm/fused_moe/DevKernel.h (1)

116-126: Confirm cooperative launch for cluster kernels.

LAUNCH_ROUTING_LLAMA4 exposes coopLaunch; cluster kernels using cluster_dims generally require cooperative launch attributes. Ensure callers pass coopLaunch=true for cluster paths; otherwise cluster barriers can deadlock.

Would you like me to scan callers and patch the relevant launches?

tests/moe/test_trtllm_gen_fused_moe.py (3)

2066-2071: Minor typo.

“epxerts” -> “experts”.

-    # Skip large intermediate size and hidden size for configurations with small epxerts
+    # Skip large intermediate size and hidden size for configurations with few experts

Likely an incorrect or invalid review comment.


1842-1843: Review comment assertion unverifiable — MaxNumTopK exists only once in codebase.

Verification found only one MaxNumTopK = 10 definition in RoutingKernelTopK.cuh (line 35). No separate kernel implementations with independent MaxNumTopK settings exist for the routing methods (Renormalize, DeepSeekV3, TopK, Llama4). The constraint is enforced globally via test assertion assert top_k <= 10 at line 2120, which applies uniformly to all routing methods.

The review's request to verify that each kernel type is "compiled with MaxNumTopK >= 10" assumes separate kernel implementations with independent compilation parameters. This architecture does not appear to exist based on the codebase structure. Recommend manual verification of whether runtime kernel configurations or external kernel libraries bypass this single definition.


1835-1837: Verified: cache key implementation is correct. Both helper functions properly construct tuple keys.

The verification confirms that _maybe_get_cached_w3_w1_permute_indices and get_w2_permute_indices_with_cache both correctly construct tuple cache keys:

  • _maybe_get_cached_w3_w1_permute_indices uses cache_key = ("w3_w1", dst_w3_w1_weight.shape) (core.py:175)
  • get_w2_permute_indices_with_cache uses cache_key = ("w2", dst_w2_weight.shape) (core.py:204)

Both functions consistently store and retrieve cached values using these tuple keys with the Dict[tuple, torch.Tensor] dictionary. The tuple structure ensures uniqueness by weight identifier and shape, preventing collisions or misses.

csrc/trtllm_fused_moe_routing_llama4.cu (1)

503-508: PDL score path vs init path: OK, but ensure TopKPacked is allocated.

HistogramScoresKernel writes mPtrTopKPacked; non-score path resets counts. Verify callers always allocate mPtrTopKPacked when mPtrScores is used, else later kernels read garbage.

Also applies to: 515-524

csrc/trtllm_fused_moe_routing_deepseek.cu (2)

26-31: LGTM: Constants updated for Kimi K2 and expanded routing.

The new constants correctly support 384 experts for Kimi K2 and increased group limits (MaxNumTopGroups=4, MaxNumGroups=8) for flexible routing configurations.


464-475: Well-structured expert-count routing function.

The getMaxNumExperts function cleanly maps runtime expert counts to compile-time kernel specializations with proper error handling.

include/flashinfer/trtllm/fused_moe/RoutingKernel.h (1)

103-152: LGTM: Clean parameterization with MaxNumExperts.

The template parameter MaxNumExperts_ is correctly threaded through KernelParamsBase and propagated to all derived KernelParams specializations, enabling compile-time kernel bounds. The addition of mPtrTopKWeights and mPtrTopKIds properly exposes the Top-K data flow.

csrc/trtllm_fused_moe_routing_renormalize.cu (2)

25-29: Constants expanded for Qwen and small-batch optimization.

MaxNumTopExperts=10 supports higher top-k routing (e.g., Qwen), and BlockKernelMaxNumTokens=4 enables a fast single-block path for small batches.


405-425: Test coverage for single-block kernel path is incomplete.

The test file ./tests/moe/test_trtllm_gen_fused_moe.py parametrizes num_tokens with values [1, 8, 1024] (line 1840). This tests the single-block kernel path with 1 token, but does not explicitly test the boundary case of exactly 4 tokens with high expert counts, which the review comment specifically raises.

Verification results:

  • BlockKernelMaxNumTokens = 4 is correctly defined (line 29) and used consistently throughout
  • Shared memory allocation is safe: 4 * MaxNumExperts ≤ 4 * 128 = 512 int8_t elements ≈ 1KB
  • Kernel loop logic is correct: j < BlockKernelMaxNumTokens iterates over indices 0–3
  • Single-block path is exercised by the 1-token test case

However, the edge case at the upper boundary (exactly 4 tokens) with high expert counts is not explicitly covered in the test parametrization.

csrc/trtllm_fused_moe_kernel_launcher.cu (3)

79-104: LGTM: Routing method validation properly expanded.

The validation logic correctly distinguishes DeepSeekV3 (grouped), Renormalize/RenormalizeNaive, and Llama4 routing methods with appropriate constraints.


164-167: Histogram sizing accommodates optional grouping.

The histogram buffer size correctly uses max(num_experts*2, 256*2) to handle both grouped and non-grouped routing configurations.


738-739: No changes needed—top_k limits correctly reflect kernel capabilities.

The difference is intentional and correct. FP8 launchers enforce top_k <= 8 because they use the DeepSeek kernel (MaxNumTopExperts=8), while FP4 launchers allow top_k <= 10 because they use the Renormalize kernel (MaxNumTopExperts=10). The validation limits are properly aligned with each kernel's design constraints.

flashinfer/fused_moe/core.py (1)

1069-1090: LGTM: Optional routing parameters properly propagated.

The signature updates consistently use Optional[int] for n_group/topk_group and Optional[float] for routed_scaling_factor, enabling flexible routing configurations across all MoE operation variants.

Also applies to: 1154-1177, 1243-1276

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
tests/moe/test_trtllm_gen_fused_moe.py (1)

2059-2071: Test skip conditions improve CI performance.

The added skip conditions pragmatically reduce test execution time by avoiding expensive combinations:

  • Large expert counts (≥512) with large intermediate sizes (>512)
  • Small expert counts (<512) with large intermediate/hidden sizes

While this improves test speed, ensure that at least some tests still exercise these larger configurations to catch potential issues.

flashinfer/fused_moe/core.py (1)

1180-1180: Minor: Remove unnecessary blank line.

This blank line appears to be unintentionally added and can be removed for consistency.

Apply this diff:

         enable_pdl = device_support_pdl(hidden_states.device)
-
         # Call the C++ function for block scale MoE
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0f33230 and d779e40.

📒 Files selected for processing (2)
  • flashinfer/fused_moe/core.py (11 hunks)
  • tests/moe/test_trtllm_gen_fused_moe.py (6 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/moe/test_trtllm_gen_fused_moe.py (2)
flashinfer/fused_moe/core.py (1)
  • RoutingMethodType (59-73)
include/flashinfer/trtllm/fused_moe/runner.h (4)
  • RoutingMethodType (37-135)
  • intermediate_size (274-274)
  • hidden_size (264-264)
  • top_k (269-269)
flashinfer/fused_moe/core.py (1)
include/flashinfer/trtllm/fused_moe/runner.h (5)
  • n_group (270-270)
  • topk_group (272-272)
  • intermediate_size (274-274)
  • local_expert_offset (275-275)
  • local_num_experts (276-276)
🪛 Ruff (0.14.0)
flashinfer/fused_moe/core.py

1134-1134: Unused function argument: n_group

(ARG001)


1135-1135: Unused function argument: topk_group

(ARG001)


1136-1136: Unused function argument: intermediate_size

(ARG001)


1137-1137: Unused function argument: local_expert_offset

(ARG001)


1138-1138: Unused function argument: local_num_experts

(ARG001)


1139-1139: Unused function argument: routed_scaling_factor

(ARG001)


1222-1222: Unused function argument: n_group

(ARG001)


1223-1223: Unused function argument: topk_group

(ARG001)


1224-1224: Unused function argument: intermediate_size

(ARG001)


1225-1225: Unused function argument: local_expert_offset

(ARG001)


1226-1226: Unused function argument: local_num_experts

(ARG001)


1227-1227: Unused function argument: routed_scaling_factor

(ARG001)

🔇 Additional comments (9)
tests/moe/test_trtllm_gen_fused_moe.py (5)

1835-1836: LGTM! Cache key structure properly documented.

The cache key type change from Dict[torch.Size, torch.Tensor] to Dict[tuple, torch.Tensor] aligns with the fix in flashinfer/fused_moe/core.py where the cache key now includes epilogue_tile_m and num_elts_per_sf. The comment clearly documents the new key structure.


1842-1842: Test coverage expanded for smaller intermediate sizes.

Adding 512 to the intermediate_size parameter list provides better test coverage for configurations with smaller intermediate dimensions, which is relevant for the Qwen and Kimi K2 models mentioned in the PR.


1924-1937: New Qwen3_next configuration added successfully.

This new test case covers the Qwen model with:

  • 512 experts
  • top_k=10
  • Renormalize routing method

This aligns with the PR objective to support Qwen models. The configuration looks correct for a large-scale MoE setup.


2120-2120: top_k bound correctly increased for Qwen3 support.

The assertion change from top_k <= 8 to top_k <= 10 is necessary to accommodate the new Qwen3_next configuration which uses top_k=10. This aligns with the expanded routing support in the PR.


1856-1872: No changes needed — configuration already matches upstream.

The Kimi K2 Instruct model configuration has n_groups and top_k_groups both set to 1, which exactly matches the values in the test code. The configuration is already correct.

flashinfer/fused_moe/core.py (4)

174-177: Critical fix: Cache key now includes all permutation parameters.

This change addresses a previously identified bug where the cache key only used (weight_type, shape), which could return incorrect permutation indices when the same shape was used with different epilogue_tile_m or num_elts_per_sf values. Including these parameters in the cache key ensures correctness.

Based on learnings from past review comments.


190-193: Cache storage and retrieval correctly updated.

The cache storage at line 190 and retrieval at line 193 now use the composite cache_key that includes all relevant parameters. This ensures that cached permutation indices are only reused when all parameters match.


203-218: W2 permutation cache consistently updated.

The get_w2_permute_indices_with_cache function now uses the same comprehensive cache key structure as the W3/W1 permutation function. This ensures consistent caching behavior across both weight matrices.


1080-1082: Routing parameters correctly made optional.

Changing n_group, topk_group, and routed_scaling_factor to Optional types is appropriate because not all routing methods require these parameters:

  • RoutingMethodType.Renormalize and RoutingMethodType.RenormalizeNaive don't use group-based routing
  • RoutingMethodType.TopK doesn't use scaling factors

This change improves API flexibility while maintaining backward compatibility.

Also applies to: 1085-1085

@jiahanc
Copy link
Collaborator

jiahanc commented Oct 19, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !71 has been updated with latest changes, and the CI pipeline #36866363 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tests/moe/test_trtllm_gen_fused_moe.py (1)

1856-1872: KIMI K2 routing configuration values are incorrect and must be updated.

The test case currently uses n_groups=1 and top_k_groups=1, but the actual KIMI K2 Instruct model configuration specifies n_groups=8 and top_k_groups=4. Update these values to match the published model configuration:

"n_groups": 8,
"top_k_groups": 4,
🧹 Nitpick comments (1)
tests/moe/test_trtllm_gen_fused_moe.py (1)

2059-2063: Skip condition appears unreachable with current test parameters.

The skip condition checks for num_experts >= 512, but the maximum num_experts in the test parameters is 384 (kimi_k2 configuration). This means the skip will never be triggered by the current test matrix.

If this is intended for future-proofing, consider either:

  1. Lowering the threshold to a value that will actually trigger (e.g., >= 384 or >= 256)
  2. Adding a comment explaining this is for future test configurations
  3. Removing the condition if it's not needed

Apply this diff to make the condition effective for current tests:

-    if routing_config["num_experts"] >= 512 and intermediate_size > 512:
+    if routing_config["num_experts"] >= 256 and intermediate_size > 1024:
         pytest.skip(
-            f"Skipping for testing speed: intermediate_size={intermediate_size} with {routing_config['num_experts']} experts"
+            f"Skipping for testing speed: large config with intermediate_size={intermediate_size} and {routing_config['num_experts']} experts"
         )

Or add a clarifying comment:

+    # Future-proofing: skip very large configurations that may be added later
     if routing_config["num_experts"] >= 512 and intermediate_size > 512:
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d779e40 and 1b8587c.

📒 Files selected for processing (1)
  • tests/moe/test_trtllm_gen_fused_moe.py (6 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/moe/test_trtllm_gen_fused_moe.py (2)
flashinfer/fused_moe/core.py (1)
  • RoutingMethodType (59-73)
include/flashinfer/trtllm/fused_moe/runner.h (1)
  • RoutingMethodType (37-135)
🔇 Additional comments (2)
tests/moe/test_trtllm_gen_fused_moe.py (2)

1835-1836: LGTM! Good documentation of cache key structure.

The comment clearly documents the tuple structure used as the cache key, which improves code maintainability.


1917-1917: LGTM! Implementation list reordering is benign.

The reordering of compatible_moe_impls doesn't affect functionality, only the test execution order.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (3)
include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh (2)

169-169: Fix the misleading static_assert message.

The condition N < 5 enforces N ≤ 4, but the message incorrectly references "128". This creates confusion about the actual constraint.

Apply this diff:

-  static_assert(N < 5, "Only support candidates number less than or equal to 128");
+  static_assert(N < 5, "Only support up to 4 candidates per thread in this function");

220-223: Reconsider the buffer index initialization.

When ii=0, topKBufferIdx[0] = -1, which can cause arithmetic overflow in TopKRedType::makeCmpVal when computing maxIdx - idx = 65535 - (-1) = 65536, exceeding the 0xFFFF mask. While minValue ensures these entries won't be selected, the overflow is still technically incorrect.

Consider initializing with a safer sentinel:

     for (int ii = 0; ii < numResults; ++ii) {
       topKBufferValue[ii] = minValue;
-      topKBufferIdx[ii] = ii * WarpSize - 1;
+      topKBufferIdx[ii] = RedType::maxIdx;  // or another safe sentinel
     }
csrc/trtllm_fused_moe_routing_deepseek.cu (1)

521-523: Clarify the confusing error message.

The check data.mNumExperts >= MaxNumTopExperts ensures there are enough experts, but the message "expects %d to be at most #experts %d" reads incorrectly. The message should reflect that experts must be at least MaxNumTopExperts.

Apply this diff:

   FLASHINFER_CHECK(data.mNumExperts >= MaxNumTopExperts,
-                   "Routing kernel expects %d to be at most #experts %d", MaxNumTopExperts,
+                   "Routing kernel expects #experts >= %d, got %d", MaxNumTopExperts,
                    data.mNumExperts);
🧹 Nitpick comments (3)
csrc/trtllm_fused_moe_routing_deepseek.cu (1)

478-494: Consider failing fast on unsupported expert counts.

The macro logs an error but continues execution when numExperts exceeds NumKimiK2Experts, potentially leaving kernels unlaunched and leading to incorrect results downstream.

Consider adding a check that terminates:

   } else {
     TLLM_LOG_ERROR("Unsupported numExperts");
+    FLASHINFER_CHECK(false, "Unsupported numExperts: %d exceeds maximum %d", 
+                     data.mNumExperts, NumKimiK2Experts);
   }
include/flashinfer/trtllm/fused_moe/runner.h (2)

76-78: Typo in serialization string (“InvalidRountingMethod”)

Minor, but this shows up in logs and can confuse debugging.

-      return "InvalidRountingMethod";  // TODO throw error
+      return "InvalidRoutingMethod";  // TODO throw error

268-273: Clarify routing-method-specific top_k limits; reference the central constant instead of hard-coded "10"

The comment oversimplifies: top_k support varies by routing method. DeepSeek/grouped kernels support only top_k<=8, while renormalize methods support top_k<=10. Avoid hard-coding "10" in a public header; instead reference the routing configuration or note the method-specific constraints.

Suggested update:

-  // Now support top_k<=10
+  // Maximum top_k depends on routing method: renormalize/TopK supports up to 10
+  // (see MaxNumTopK in RoutingKernelTopK.cuh), while DeepSeek/grouped limits to 8.

Validation confirmed: limits are correctly enforced per method in kernel_launcher.cu (lines 86–89, 370–373, 383–384, 738–741, 752–754). Tests cover top_k=10 for Qwen/renormalize (test_trtllm_gen_fused_moe.py:2112).

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1b8587c and 721371d.

📒 Files selected for processing (4)
  • csrc/trtllm_fused_moe_routing_deepseek.cu (11 hunks)
  • include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh (12 hunks)
  • include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh (4 hunks)
  • include/flashinfer/trtllm/fused_moe/runner.h (1 hunks)
🔇 Additional comments (11)
include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh (3)

34-35: LGTM!

The new compile-time constants provide clear bounds for expert routing and top-K selection.


56-56: LGTM!

Direct assignment is cleaner and equivalent for this POD type.


204-207: Good defensive programming.

The static_assert at line 206-207 ensures N % 4 == 0 (or N <= 4), which prevents out-of-bounds reads in the loop at lines 230-233. This addresses the previous concern about reading past the array bounds when the tail chunk is smaller than 4.

csrc/trtllm_fused_moe_routing_deepseek.cu (3)

26-31: LGTM!

Constants updated appropriately to support the new Kimi K2 and Qwen model architectures.


77-78: LGTM!

Using blockDim.x makes the indexing flexible and correctly handles different thread counts.


579-592: LGTM!

The launch logic now correctly matches runtime thread counts to compile-time specializations, addressing the previous blockDim-specialization concern.

include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh (5)

102-129: LGTM!

Computing softmax in float precision before casting back to DataType improves numerical stability for half/bfloat16 inputs.


211-217: LGTM!

Both code paths now properly initialize scoreIdx, eliminating the previous uninitialized variable concern.


230-232: LGTM!

The conditional weight writing correctly handles the case where weights need to be extracted from packed representation.


387-469: LGTM!

The kernel is now properly parameterized with KernelParams::MaxNumExperts, and the data loading logic correctly handles both TopKIds and TopKPacked paths.


706-729: LGTM!

The new initialization kernel properly sets up the expert counts buffer using a grid-stride loop and consistent PDL synchronization.

Comment on lines +464 to +475
int constexpr getMaxNumExperts(int32_t numExperts) {
if (numExperts <= topk::MaxNumExpertsUnit) {
return topk::MaxNumExpertsUnit;
} else if (numExperts <= NumDeepseekExperts) {
return NumDeepseekExperts;
} else if (numExperts <= NumKimiK2Experts) {
return NumKimiK2Experts;
} else {
TLLM_LOG_ERROR("Unsupported numExperts");
return 0;
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Returning 0 on error is unsafe.

When numExperts exceeds NumKimiK2Experts, the function logs an error but returns 0. Callers will use this value for kernel launches and array sizing, leading to undefined behavior or silent failures.

Consider asserting or throwing instead:

   } else {
-    TLLM_LOG_ERROR("Unsupported numExperts");
-    return 0;
+    FLASHINFER_CHECK(false, "Unsupported numExperts: %d exceeds maximum %d", 
+                     numExperts, NumKimiK2Experts);
+    return NumKimiK2Experts;  // unreachable after CHECK, but silences warnings
   }
 }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
int constexpr getMaxNumExperts(int32_t numExperts) {
if (numExperts <= topk::MaxNumExpertsUnit) {
return topk::MaxNumExpertsUnit;
} else if (numExperts <= NumDeepseekExperts) {
return NumDeepseekExperts;
} else if (numExperts <= NumKimiK2Experts) {
return NumKimiK2Experts;
} else {
TLLM_LOG_ERROR("Unsupported numExperts");
return 0;
}
}
int constexpr getMaxNumExperts(int32_t numExperts) {
if (numExperts <= topk::MaxNumExpertsUnit) {
return topk::MaxNumExpertsUnit;
} else if (numExperts <= NumDeepseekExperts) {
return NumDeepseekExperts;
} else if (numExperts <= NumKimiK2Experts) {
return NumKimiK2Experts;
} else {
FLASHINFER_CHECK(false, "Unsupported numExperts: %d exceeds maximum %d",
numExperts, NumKimiK2Experts);
return NumKimiK2Experts; // unreachable after CHECK, but silences warnings
}
}

ChristinaZ and others added 2 commits October 20, 2025 09:38
Remove the cudaDeviceSync;Add support numexpert16; Clean up unit test
Signed-off-by: Christina Zhang <[email protected]>
Signed-off-by: jiahanc <[email protected]>
Signed-off-by: jiahanc <[email protected]>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (10)
include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh (2)

218-233: Fix OOB reads and invalid sentinel index in large-N TopK buffering.

  • Reads past end when N % 4 != 0.
  • Sentinel idx = -1 overflows 0xFFFF mask; breaks tie-breaking.

Apply:

-    for (int ii = 0; ii < numResults; ++ii) {
-      topKBufferValue[ii] = minValue;
-      topKBufferIdx[ii] = ii * WarpSize - 1;
-    }
+    using RedT = TopKRedType<Type>;
+    for (int ii = 0; ii < numResults; ++ii) {
+      topKBufferValue[ii] = minValue;
+      topKBufferIdx[ii] = RedT::maxIdx;  // safe sentinel, preserves ordering
+    }

@@
-      for (int i = 0; i < 4; ++i) {
-        inValue[i] = value[start + i];
-        inIdx[i] = idx[start + i];
-      }
+      int rem = N - start;
+      for (int i = 0; i < 4; ++i) {
+        if (i < rem) {
+          inValue[i] = value[start + i];
+          inIdx[i] = idx[start + i];
+        } else {
+          inValue[i] = minValue;
+          inIdx[i] = RedT::maxIdx;
+        }
+      }

Also applies to: 219-221


170-171: Static-assert message is misleading.

Condition enforces N <= 4, but message says “<= 128”.

-  static_assert(N < 5, "Only support candidates number less than or equal to 128");
+  static_assert(N < 5, "Only support up to 4 candidates per thread in this function.");
csrc/trtllm_fused_moe_routing_llama4.cu (1)

482-487: Enable cooperative launch for cluster kernel.

Cluster kernel uses cluster_dims/cluster barriers; launch cooperatively to guarantee cluster scheduling.

-    LAUNCH_ROUTING_LLAMA4(data,
-                          /*coopLaunch=*/false, routingIndicesClusterKernel, NumBlocksPerCluster,
+    LAUNCH_ROUTING_LLAMA4(data,
+                          /*coopLaunch=*/true, routingIndicesClusterKernel, NumBlocksPerCluster,
                           NumThreads,
                           /*smemSize=*/0,  // No dynamic smem
                           stream);
csrc/trtllm_fused_moe_routing_deepseek.cu (4)

191-206: Fix inter-topK scratch sizing and typos.

  • NumInterTopKPerThread mistakenly multiplies by NumExpertWarps; it should cover NumInterTopK only.
  • Rename misspelled intermidiateScore/intermidiateExpert.
-        int constexpr NumInterTopKPerThread = (NumInterTopK * NumExpertWarps - 1) / WarpSize + 1;
-        float intermidiateScore[NumInterTopKPerThread];
-        int32_t intermidiateExpert[NumInterTopKPerThread];
+        int constexpr NumInterTopKPerThread = (NumInterTopK - 1) / WarpSize + 1;
+        float intermediateScore[NumInterTopKPerThread];
+        int32_t intermediateExpert[NumInterTopKPerThread];
@@
-            intermidiateScore[ii] = smemInterTopScores[i];
-            intermidiateExpert[ii] = smemInterTopExperts[i];
+            intermediateScore[ii] = smemInterTopScores[i];
+            intermediateExpert[ii] = smemInterTopExperts[i];
@@
-            intermidiateScore[ii] = invalidScoreFloat;
-            intermidiateExpert[ii] = KernelParams::MaxNumExperts - 1;
+            intermediateScore[ii] = invalidScoreFloat;
+            intermediateExpert[ii] = KernelParams::MaxNumExperts - 1;
@@
-        topk::reduceTopK(warp, topScores, topExperts, intermidiateScore, intermidiateExpert,
+        topk::reduceTopK(warp, topScores, topExperts, intermediateScore, intermediateExpert,
                          /* minValue */ invalidScoreFloat, params.mTopK);

518-523: Clarify the error message to match the check.

The check enforces numExperts ≥ MaxNumTopExperts; message says “at most”.

-  FLASHINFER_CHECK(data.mNumExperts >= MaxNumTopExperts,
-                   "Routing kernel expects %d to be at most #experts %d", MaxNumTopExperts,
-                   data.mNumExperts);
+  FLASHINFER_CHECK(data.mNumExperts >= MaxNumTopExperts,
+                   "Routing kernel expects #experts >= %d, got %d",
+                   MaxNumTopExperts, data.mNumExperts);

464-475: Avoid returning 0 from getMaxNumExperts.

Returning 0 can cascade into invalid grid sizes. Fail fast.

   } else {
-    TLLM_LOG_ERROR("Unsupported numExperts");
-    return 0;
+    FLASHINFER_CHECK(false, "Unsupported numExperts: %d (max %d)",
+                     numExperts, NumKimiK2Experts);
+    return NumKimiK2Experts; // unreachable; silences warnings
   }

574-589: Only run routingMainKernel when scores are provided; match blockDim to specialization.

Launching routingMainKernel with TopKPacked-only inputs is a no-op; also choose blockDim via getMaxNumExperts.

-  if (data.mPtrTopKIds == nullptr) {
-    int const numThreadsMain =
-        data.mNumExperts < NumDeepseekExperts ? NumDeepseekExperts : NumKimiK2Experts;
+  if (data.mPtrScores != nullptr) {
+    int const numThreadsMain = getMaxNumExperts(data.mNumExperts);
     LAUNCH_ROUTING_DEEPSEEK(data,
                             /*coopLaunch=*/false, routingMainKernel, numBlocks, numThreadsMain,
                             /*smemSize=*/0,  // No dynamic smem
                             stream, data.mNumExpertGroups > 1);
   } else {
     // Reset the global histograms.
include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh (2)

227-229: Use a proper stride mask instead of comparing with log2 value.

The current check uses mLocalExpertsStrideLog2 as a mask; compute mask = (1 << log2) - 1 and test lower bits.

-    auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent &&
-                         (localExpertIdx & params.mLocalExpertsStrideLog2) == 0;
+    int32_t strideMask = (1 << params.mLocalExpertsStrideLog2) - 1;
+    bool isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent &&
+                         ((localExpertIdx & strideMask) == 0);

Apply similarly at the other referenced locations.

Also applies to: 364-365, 581-583, 659-662


419-431: Initialize idx when using TopKPacked; write weights conditionally.

If mPtrTopKIds == nullptr and mPtrTopKWeights == nullptr, idx remains uninitialized.

-  auto loopBody = [&](int expandedIdx) {
-    PackedScoreIdx<OutputT> scoreIdx;
-    int idx;
-    if (params.mPtrTopKIds != nullptr) {
-      idx = params.mPtrTopKIds[expandedIdx];
-    } else {
-      // If params.mPtrTopKIds != nullptr, we don't need to store the weights
-      if (params.mPtrTopKWeights != nullptr) {
-        scoreIdx = params.mPtrTopKPacked[expandedIdx];
-        idx = scoreIdx.idx;
-        params.mPtrTopKWeights[expandedIdx] = static_cast<OutputT>(scoreIdx.score);
-      }
-    }
+  auto loopBody = [&](int expandedIdx) {
+    int idx;
+    if (params.mPtrTopKIds != nullptr) {
+      idx = params.mPtrTopKIds[expandedIdx];
+    } else {
+      auto scoreIdx = params.mPtrTopKPacked[expandedIdx];
+      idx = scoreIdx.idx;
+      if (params.mPtrTopKWeights != nullptr) {
+        params.mPtrTopKWeights[expandedIdx] = static_cast<OutputT>(scoreIdx.score);
+      }
+    }
csrc/trtllm_fused_moe_routing_renormalize.cu (1)

115-146: Handle TopKPacked in block kernel (small-token path).

When only mPtrTopKPacked is provided, smemKIdx remains unset, producing invalid outputs.

   if (params.mPtrTopKIds != nullptr) {
     ...
   } else if (params.mPtrScores != nullptr) {
     ...
+  } else { // params.mPtrTopKPacked != nullptr
+    if (validToken && laneIdx < params.mTopK) {
+      TypePacked packed = params.mPtrTopKPacked[warpIdx * params.mTopK + laneIdx];
+      int offset = warpIdx * MaxNumExperts + static_cast<int>(packed.idx);
+      smemKIdx[offset] = static_cast<int8_t>(laneIdx);
+      if (params.mPtrTopKWeights != nullptr) {
+        params.mPtrTopKWeights[warpIdx * params.mTopK + laneIdx] = OutputT{packed.score};
+      }
+    }
   }
🧹 Nitpick comments (5)
include/flashinfer/trtllm/fused_moe/DevKernel.h (1)

121-124: Unify BF16 type: prefer cutlass::bfloat16_t over __nv_bfloat16.

Other launch macros use cutlass types; mixing __nv_bfloat16 can add include/friction. Use cutlass::bfloat16_t for consistency.

-               LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, 128 /* Always 128 for llama4*/), kernel,
+               LAUNCH_ESC(cutlass::bfloat16_t, cutlass::bfloat16_t, 128 /* Always 128 for llama4*/), kernel,
-    LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, true),
+    LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(cutlass::bfloat16_t, cutlass::bfloat16_t, numExperts, true),
                kernel, numBlocks, numThreads, smemSize, stream);
...
-    LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, false),
+    LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(cutlass::bfloat16_t, cutlass::bfloat16_t, numExperts, false),
                kernel, numBlocks, numThreads, smemSize, stream);

Also applies to: 191-196

flashinfer/fused_moe/core.py (1)

1134-1147: Silence unused-arg warnings in fake ops.

Consume parameters locally to keep signatures stable while satisfying linters.

 def _fake_trtllm_fp8_per_tensor_scale_moe(
@@
-    ):
+    ):
+        # consume unused to appease linters while keeping signature stable
+        _ = (n_group, topk_group, intermediate_size, local_expert_offset,
+             local_num_experts, routed_scaling_factor, tile_tokens_dim,
+             routing_method_type, enable_pdl)
         seq_len = hidden_states.shape[0]
         hidden_size = hidden_states.shape[1]
@@
 def _fake_trtllm_fp8_block_scale_moe(
@@
-    ):
+    ):
+        # consume unused to appease linters while keeping signature stable
+        _ = (n_group, topk_group, intermediate_size, local_expert_offset,
+             local_num_experts, routed_scaling_factor, tile_tokens_dim,
+             routing_method_type, use_shuffled_weight, weight_layout, enable_pdl)
         seq_len = hidden_states.shape[0]
         hidden_size = hidden_states.shape[1]
@@
 def _fake_trtllm_fp4_block_scale_moe(
@@
-    ):
+    ):
+        # consume unused to appease linters while keeping signature stable
+        _ = (routing_logits, topk_ids, expert_weights, routing_bias,
+             hidden_states_scale, gemm1_bias, gemm1_alpha, gemm1_beta,
+             gemm1_clamp_limit, output1_scale_scalar, output1_scale_gate_scalar,
+             output2_scale_scalar, n_group, topk_group, intermediate_size,
+             local_expert_offset, routed_scaling_factor, tile_tokens_dim,
+             routing_method_type, do_finalize, enable_pdl, gated_act_type,
+             output, tune_max_num_tokens)
         seq_len = hidden_states.shape[0]
         hidden_size = hidden_states.shape[1]

Also applies to: 1222-1233, 1450-1463

include/flashinfer/trtllm/fused_moe/RoutingKernel.h (1)

103-153: Document/enforce TopK buffer invariants in setBaseParams.

Guard that when mPtrTopKIds is set, mPtrTopKWeights must also be set (aligns with run-time checks elsewhere). Add a brief comment or debug assert.

   void setBaseParams(DataType const& data) {
@@
     mPtrTopKWeights = static_cast<OutputT*>(data.mPtrTopKWeights);
     mPtrTopKIds = static_cast<int32_t*>(data.mPtrTopKIds);
+    // Invariant: if IDs are provided, weights must also be provided.
+    // assert((mPtrTopKIds == nullptr) || (mPtrTopKWeights != nullptr));
csrc/trtllm_fused_moe_kernel_launcher.cu (1)

420-429: Pointer cast for routing_logits.

args.routing_logits is assigned as float* even when dtype is bfloat16 for non‑DeepSeek paths. Prefer void* or branch on dtype to avoid UB.

-  args.routing_logits = static_cast<float*>(routing_logits->data);
+  args.routing_logits = routing_logits->data; // keep as void*, let callee interpret via dtype
csrc/trtllm_fused_moe_routing_renormalize.cu (1)

370-381: Macro name typo (RENORNALIZE).

Consider renaming to LAUNCH_ROUTING_RENORMALIZE for clarity and grepability.

-#define LAUNCH_ROUTING_RENORNALIZE(...
+#define LAUNCH_ROUTING_RENORMALIZE(...

Also update all call sites in this file.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 721371d and 63233a3.

📒 Files selected for processing (11)
  • csrc/trtllm_fused_moe_kernel_launcher.cu (14 hunks)
  • csrc/trtllm_fused_moe_routing_deepseek.cu (11 hunks)
  • csrc/trtllm_fused_moe_routing_llama4.cu (8 hunks)
  • csrc/trtllm_fused_moe_routing_renormalize.cu (8 hunks)
  • csrc/trtllm_fused_moe_runner.cu (3 hunks)
  • flashinfer/fused_moe/core.py (11 hunks)
  • include/flashinfer/trtllm/fused_moe/DevKernel.h (2 hunks)
  • include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh (12 hunks)
  • include/flashinfer/trtllm/fused_moe/RoutingKernel.h (9 hunks)
  • include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh (3 hunks)
  • tests/moe/test_trtllm_gen_fused_moe.py (6 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • csrc/trtllm_fused_moe_runner.cu
🧰 Additional context used
🧬 Code graph analysis (7)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)
include/flashinfer/trtllm/fused_moe/runner.h (7)
  • top_k (270-270)
  • n_group (271-271)
  • topk_group (273-273)
  • intermediate_size (275-275)
  • local_expert_offset (276-276)
  • local_num_experts (277-277)
  • num_experts (263-263)
csrc/trtllm_fused_moe_routing_deepseek.cu (2)
csrc/trtllm_fused_moe_routing_renormalize.cu (9)
  • void (76-217)
  • void (288-291)
  • void (297-353)
  • __launch_bounds__ (76-76)
  • __launch_bounds__ (297-297)
  • getMaxNumExperts (357-366)
  • getMaxNumExperts (357-357)
  • routingIndicesClusterKernel (222-286)
  • routingIndicesClusterKernel (289-289)
csrc/trtllm_fused_moe_routing_llama4.cu (8)
  • void (67-280)
  • void (354-356)
  • void (363-423)
  • __launch_bounds__ (363-363)
  • getMaxNumExperts (426-433)
  • getMaxNumExperts (426-426)
  • routingIndicesClusterKernel (285-352)
  • routingIndicesClusterKernel (354-354)
tests/moe/test_trtllm_gen_fused_moe.py (2)
flashinfer/fused_moe/core.py (1)
  • RoutingMethodType (59-73)
include/flashinfer/trtllm/fused_moe/runner.h (1)
  • RoutingMethodType (37-136)
csrc/trtllm_fused_moe_routing_renormalize.cu (2)
csrc/trtllm_fused_moe_routing_deepseek.cu (7)
  • void (34-252)
  • void (276-278)
  • void (459-461)
  • getMaxNumExperts (464-475)
  • getMaxNumExperts (464-464)
  • routingIndicesClusterKernel (260-274)
  • routingIndicesClusterKernel (276-276)
csrc/trtllm_fused_moe_routing_llama4.cu (9)
  • void (67-280)
  • void (354-356)
  • void (363-423)
  • routingTopKExperts (40-44)
  • getMaxNumExperts (426-433)
  • getMaxNumExperts (426-426)
  • routingIndicesClusterKernel (285-352)
  • routingIndicesClusterKernel (354-354)
  • routingIndicesHistogramScoresKernel (364-364)
csrc/trtllm_fused_moe_routing_llama4.cu (1)
csrc/trtllm_fused_moe_routing_renormalize.cu (9)
  • routingTopKExperts (32-37)
  • void (76-217)
  • void (288-291)
  • void (297-353)
  • getMaxNumExperts (357-366)
  • getMaxNumExperts (357-357)
  • routingIndicesClusterKernel (222-286)
  • routingIndicesClusterKernel (289-289)
  • routingIndicesHistogramScoresKernel (298-298)
include/flashinfer/trtllm/fused_moe/RoutingKernel.h (1)
include/flashinfer/trtllm/fused_moe/DevKernel.h (4)
  • setKernelParams (246-262)
  • setKernelParams (300-311)
  • setKernelParams (362-377)
  • setKernelParams (442-461)
flashinfer/fused_moe/core.py (1)
include/flashinfer/trtllm/fused_moe/runner.h (5)
  • n_group (271-271)
  • topk_group (273-273)
  • intermediate_size (275-275)
  • local_expert_offset (276-276)
  • local_num_experts (277-277)
🪛 Ruff (0.14.1)
flashinfer/fused_moe/core.py

1134-1134: Unused function argument: n_group

(ARG001)


1135-1135: Unused function argument: topk_group

(ARG001)


1136-1136: Unused function argument: intermediate_size

(ARG001)


1137-1137: Unused function argument: local_expert_offset

(ARG001)


1138-1138: Unused function argument: local_num_experts

(ARG001)


1139-1139: Unused function argument: routed_scaling_factor

(ARG001)


1222-1222: Unused function argument: n_group

(ARG001)


1223-1223: Unused function argument: topk_group

(ARG001)


1224-1224: Unused function argument: intermediate_size

(ARG001)


1225-1225: Unused function argument: local_expert_offset

(ARG001)


1226-1226: Unused function argument: local_num_experts

(ARG001)


1227-1227: Unused function argument: routed_scaling_factor

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (4)
flashinfer/fused_moe/core.py (2)

174-194: Cache-key fix looks good.

Including epilogue_tile_m and num_elts_per_sf prevents cache collisions for permutations.


1070-1119: I need to search more specifically for the pybind11 bindings. The initial search didn't find the PYBIND11_MODULE declarations. Let me expand the search:<!-- [scratchpad_update]
[key_finding_from_web_search]
pybind11 DOES support std::optional for binding Python None values

  • pybind11 automatically handles conversion of None to std::optional (from docs)
  • C++17's std::optional is the preferred approach for optional parameters (SO answer)
  • pybind11/stl.h provides type caster for std::optional

However, the critical constraint:

  • pybind11 docs state: "To pass optional argument of these copied types consider using std::optional"
  • This refers to built-in types (int, float) being copied, not custom/opaque types

[action_next]

  • Still need to find the actual pybind11 bindings for trtllm_fp8_per_tensor_scale_moe
  • Search specifically for the binding source files with different patterns
    [scratchpad_update] -->

Now let me search for the actual pybind11 bindings in the repository:<!-- [scratchpad_update]
[critical_findings]
From FlashInfer docs (0.4.0 and 0.3.0):

  • trtllm_fp4_block_scale_moe sig shows: n_group: int | None, topk_group: int | None, routed_scaling_factor: float | None
  • trtllm_fp8_block_scale_moe sig shows: n_group: int (NOT optional), topk_group: int (NOT optional), routed_scaling_factor: float (NOT optional)

This is INCONSISTENT across different MOE functions!

  • Some functions accept None (int | None, float | None)
  • Others don't (plain int, plain float)

[key_insight]
The Python typing in the docs indicates which functions DO support Optional vs which DON'T.
The review concern is VALID for some functions but NOT for others.

[action_next]

  • Check the actual implementation in core.py to see which lines call which functions
  • Verify the specific signature of trtllm_fp8_per_tensor_scale_moe (the one at line 1070-1119)
    [scratchpad_update] -->

Good! The documentation shows parameter types. Let me search for the specific signature of trtllm_fp8_per_tensor_scale_moe:<!-- [scratchpad_final]
[CRITICAL FINDING]
From FlashInfer official docs:

  • trtllm_fp4_block_scale_moe: n_group: int | None, topk_group: int | None, routed_scaling_factor: float | None ✓ ACCEPTS OPTIONAL
  • trtllm_fp8_block_scale_moe: n_group: int, topk_group: int, routed_scaling_factor: float ✗ NOT OPTIONAL

Document search could not locate trtllm_fp8_per_tensor_scale_moe signature explicitly, but based on pattern:

  • fp8_per_tensor_scale_moe likely follows fp8_block_scale_moe signature (non-optional based on consistency)
  • fp4_block_scale_moe accepts optionals

[CONCLUSION]
The review comment is PARTIALLY VALID:

  • For fp4_block_scale_moe: pybind11 DOES accept std::optional (per pybind11 docs confirming std::optional support)
  • For fp8_block_scale_moe & fp8_per_tensor_scale_moe: pybind11 may REJECT None if C++ signature doesn't use std::optional

Code at lines 1070-1119 calls trtllm_fp8_per_tensor_scale_moe, which based on pattern likely DOESN'T accept optional parameters.

BUT: Need to verify actual C++ binding signature and whether it uses std::optional. The codebase search showed std::optional EXISTS in the codebase, but that doesn't prove these specific functions use it.

[RISK ASSESSMENT]

  • HIGH RISK: If trtllm_fp8_per_tensor_scale_moe bindings don't use std::optional and code passes None, runtime TypeError will occur
  • Recommendation: Developer should verify by either checking pybind binding definitions OR by converting None to sentinels
    [scratchpad_final] -->

Based on my verification, I found important information about pybind11 handling of optional parameters and pybind11 recommends using std::optional to pass optional arguments of copied types. The FlashInfer documentation shows inconsistency across MOE functions: trtllm_fp4_block_scale_moe accepts n_group: int | None, topk_group: int | None, routed_scaling_factor: float | None, while trtllm_fp8_block_scale_moe requires n_group: int, topk_group: int, routed_scaling_factor: float (non-optional).

The code at lines 1070-1119 calls trtllm_fp8_per_tensor_scale_moe, which appears to follow the FP8 pattern of requiring non-optional parameters. Without access to the specific pybind11 binding definitions for this function, I cannot definitively confirm whether it will reject None values at runtime. The concern in the review comment is valid but requires verification of the actual C++ binding signatures.

Verify pybind11 bindings accept Optional parameters, or convert None to sentinels before C++ calls.

The risk is legitimate: if pybind bindings don't declare std::optional<T>, passing None will cause runtime TypeError. Convert None values to sentinel integers (e.g., -1) or verify the C++ bindings support std::optional before merging.

tests/moe/test_trtllm_gen_fused_moe.py (1)

2110-2113: Test guard aligns with kernel limit.

Asserting top_k <= 10 here matches the updated routing kernels.

include/flashinfer/trtllm/fused_moe/DevKernel.h (1)

49-49: Review comment is incorrect and should be dismissed.

The LAUNCH_ROUTING macros are not removed—they remain actively defined and used throughout the codebase:

  • LAUNCH_ROUTING_LLAMA4, LAUNCH_ROUTING_DEEPSEEK_IMPL, and LAUNCH_ROUTING_WITH_NUM_EXPERTS are called 18+ times across csrc/trtllm_fused_moe_routing_*.cu files
  • The new LAUNCH_ESC macro (line 49) is a generic escape/passthrough wrapper, not a replacement for LAUNCH_ROUTING_*
  • The original verification script's regex pattern \bLAUNCH_ROUTING\s*\( was too narrow and did not match the actual macro variants being used (e.g., LAUNCH_ROUTING_LLAMA4)

No migration or verification is needed.

Likely an incorrect or invalid review comment.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
flashinfer/fused_moe/core.py (2)

174-194: Cache key is missing critical parameters—potential regression.

A previous review correctly identified that the cache key must include epilogue_tile_m and num_elts_per_sf (marked as addressed in commit d779e40). However, the current code at line 175 only uses ("w3_w1", dst_w3_w1_weight.shape). Since the permute indices computation on lines 179-188 depends on both parameters, omitting them can return incorrect cached permutations when the same weight shape is used with different tile/scaling parameters.

Apply this fix:

-    # Create a unique cache key (weight_type, weight_shape)
-    cache_key = ("w3_w1", dst_w3_w1_weight.shape)
+    # Create a unique cache key (weight_type, weight_shape, epilogue_tile_m, num_elts_per_sf)
+    cache_key = ("w3_w1", dst_w3_w1_weight.shape, epilogue_tile_m, num_elts_per_sf)

197-219: Same cache key issue for w2 permute indices.

The w2 permute indices cache at line 204 has the same problem: the cache key ("w2", dst_w2_weight.shape) omits epilogue_tile_m and num_elts_per_sf, which affect the permute indices computation on lines 206-215.

Apply this fix:

-    # Create a unique cache key (weight_type, weight_shape)
-    cache_key = ("w2", dst_w2_weight.shape)
+    # Create a unique cache key (weight_type, weight_shape, epilogue_tile_m, num_elts_per_sf)
+    cache_key = ("w2", dst_w2_weight.shape, epilogue_tile_m, num_elts_per_sf)
🧹 Nitpick comments (2)
tests/moe/test_trtllm_gen_fused_moe.py (2)

1835-1836: Update comment to reflect actual cache key structure.

The comment states the cache key is (weight_type, shape), but the actual implementation uses a 4-tuple: (weight_type, shape, epilogue_tile_m, num_elts_per_sf). Update the comment to accurately document all key components.

-    # The cache key is now a tuple of (weight_type, shape)
+    # The cache key is now a tuple of (weight_type, shape, epilogue_tile_m, num_elts_per_sf)
+    # where num_elts_per_sf is optional (can be None)

1948-1948: Consider documenting why FP8PerTensorMoe is excluded from RenormNaive.

The RenormNaive routing configuration excludes FP8PerTensorMoe from compatible_moe_impls, while the similar Renorm configuration includes it. If this is an intentional limitation (e.g., RenormalizeNaive routing has specific implementation constraints), consider adding a comment to explain the exclusion for future maintainability.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 63233a3 and 4b38452.

📒 Files selected for processing (2)
  • flashinfer/fused_moe/core.py (11 hunks)
  • tests/moe/test_trtllm_gen_fused_moe.py (6 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/moe/test_trtllm_gen_fused_moe.py (2)
flashinfer/fused_moe/core.py (1)
  • RoutingMethodType (59-73)
include/flashinfer/trtllm/fused_moe/runner.h (1)
  • RoutingMethodType (37-136)
flashinfer/fused_moe/core.py (1)
include/flashinfer/trtllm/fused_moe/runner.h (5)
  • n_group (271-271)
  • topk_group (273-273)
  • intermediate_size (275-275)
  • local_expert_offset (276-276)
  • local_num_experts (277-277)
🪛 Ruff (0.14.1)
flashinfer/fused_moe/core.py

1134-1134: Unused function argument: n_group

(ARG001)


1135-1135: Unused function argument: topk_group

(ARG001)


1136-1136: Unused function argument: intermediate_size

(ARG001)


1137-1137: Unused function argument: local_expert_offset

(ARG001)


1138-1138: Unused function argument: local_num_experts

(ARG001)


1139-1139: Unused function argument: routed_scaling_factor

(ARG001)


1222-1222: Unused function argument: n_group

(ARG001)


1223-1223: Unused function argument: topk_group

(ARG001)


1224-1224: Unused function argument: intermediate_size

(ARG001)


1225-1225: Unused function argument: local_expert_offset

(ARG001)


1226-1226: Unused function argument: local_num_experts

(ARG001)


1227-1227: Unused function argument: routed_scaling_factor

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (8)
tests/moe/test_trtllm_gen_fused_moe.py (5)

1842-1842: LGTM - Improved test coverage.

Adding intermediate_size=512 extends test coverage appropriately, and the new skip condition at lines 2059-2064 ensures this doesn't cause excessive test times for large expert counts.


1924-1937: Clarify why FP8PerTensorMoe is excluded from Qwen3_next.

The Qwen3_next configuration excludes FP8PerTensorMoe from compatible_moe_impls, while the similar "Renorm" routing configuration (line 1917) includes it. Both use the same RoutingMethodType.Renormalize routing method. Is this exclusion intentional, or should FP8PerTensorMoe be added to the compatible implementations list?


2059-2064: LGTM - Reasonable performance optimization.

The skip condition appropriately limits test execution time for configurations with many experts (≥512) and large intermediate sizes (>512) without affecting the current test matrix (max 384 experts in kimi_k2 config).


2112-2112: LGTM - Correctly reflects updated kernel capability.

The assertion change from top_k <= 8 to top_k <= 10 properly aligns with the kernel updates that increased MaxNumTopExperts to 10, enabling support for the new Qwen3_next configuration.


1856-1872: Now I need to verify whether the routed_scaling discrepancy (2.5 in test vs 2.827 in official config) is intentional:

Verify if routed_scaling value of 2.5 in test config should be 2.827 per official KIMI K2 spec.

Verification confirms most of the KIMI K2 configuration is correct: 384 experts with 8 experts selected per token, and the official config has n_group: 1 and topk_group: 1 matching your test. However, the HuggingFace config shows routed_scaling_factor: 2.827, while your test uses routed_scaling: 2.5. This 14.9% difference may be intentional (e.g., for TRTLLM implementation specifics) or an oversight. Please confirm whether this difference is acceptable or if the test should use 2.827 to align with the official model specification.

flashinfer/fused_moe/core.py (3)

1134-1139: Static analysis warnings expected for fake implementations.

Ruff flags unused arguments in _fake_trtllm_fp8_per_tensor_scale_moe (lines 1134-1139) and _fake_trtllm_fp8_block_scale_moe (lines 1222-1227). These are shape inference stubs for torch.compile, so unused arguments are expected and not issues.

Also applies to: 1222-1227


1476-1546: Consistent Optional parameter propagation across public API.

The Optional parameter changes for n_group, topk_group, and routed_scaling_factor are consistently applied across all public MoE variants (FP8 per-tensor, FP8 block-scale, and FP4 block-scale). The docstrings correctly document that these parameters can be None for certain routing methods, aligning with the PR objective to support different routing configurations (KIMI K2, Qwen).

Also applies to: 1549-1623, 1626-1755


1069-1120: No changes needed. The C++ layer properly handles Optional parameters.

The C++ function signatures in trtllm_fused_moe_kernel_launcher.cu (lines 305-313) correctly declare Optional<int64_t> and Optional<double> types. The implementation validates these using .has_value() checks (e.g., line 81) and applies sensible defaults (0 for n_group/topk_group, 1.0 for routed_scaling_factor). The TVM FFI binding layer properly marshals Python None values to C++ Optional<> types, so Python can pass None directly without manual conversion.

@jiahanc
Copy link
Collaborator

jiahanc commented Oct 20, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !71 has been updated with latest changes, and the CI pipeline #36935152 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #36935152: 1/17 passed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants