-
Notifications
You must be signed in to change notification settings - Fork 539
Update the routing for TRTLLMGEN to support kimi k2 and qwen #1831
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @ChristinaZ, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly refactors the Mixture-of-Experts (MoE) routing logic within TRTLLMGEN to enhance its compatibility and performance for new models like Kimi K2 and Qwen. The changes introduce dynamic resource allocation for kernels, standardize data handling for top-K expert selection, and improve the robustness of numerical operations, thereby broadening the framework's support for diverse MoE architectures. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces significant updates to the MoE routing kernels to support new models like Kimi K2 and Qwen, and to align with TRTLLM implementations. The changes involve making the kernels more generic to handle variable numbers of experts, refactoring pointer names for consistency (e.g., mPtrExpertIdx
to mPtrTopKPacked
), and adding new execution paths for pre-computed top-K indices. The use of __launch_bounds__
and replacing cudaMemsetAsync
with a dedicated kernel are good improvements.
My review focuses on a few areas for improvement:
- A typo in a variable name that affects readability.
- A confusing
static_assert
comment in the top-K reduction logic. - A potential bug in the new
reduceTopK
implementation related to an unresolved@todo
and suspicious index initialization, which could lead to incorrect behavior.
Overall, the changes are well-structured and move towards a more flexible and robust implementation. Addressing the identified issues will further improve the code quality.
/bot run |
e0efad8
to
80cdec5
Compare
533e2de
to
1a94ecb
Compare
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughThis PR refactors MOE (Mixture of Experts) routing infrastructure to support variable expert counts and flexible routing configurations. Key changes include: making routing parameters (n_group, topk_group, routed_scaling_factor) optional in C++ launchers; parameterizing kernel launch bounds with KernelParams::MaxNumExperts; introducing new routing macros and kernel variants (LAUNCH_ROUTING_DEEPSEEK, LAUNCH_ROUTING_LLAMA4); renaming top-K buffers (mPtrExpertIdx → mPtrTopKPacked); and propagating corresponding Optional parameter types through Python MoE APIs. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes This PR encompasses substantial refactoring across heterogeneous components: C++ kernel parameterization with new template parameters (MaxNumExperts), introduction of multiple routing macros with non-trivial logic flow, data structure buffer renaming propagating through multiple layers, and corresponding Python API signature changes. The changes require understanding the interconnected flow from Python bindings through C++ kernel launchers to device kernels, plus validation of the routing configuration logic branches for each method variant. Possibly related issues
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)
165-168
: Fix expert_count_histogram size; 2×256 is no longer sufficient (Kimi K2 = 384).With num_experts > 256, the fixed 512-element buffer risks OOB writes in routing histograms.
Apply:
- Tensor expert_count_histogram = alloc_tensor( - {2 * 256}, - dl_int32, // 256 is the max number of threads per block and max number of experts - routing_logits->device); + int64_t const size_of_expert_count_histogram = std::max(num_experts * 2, int64_t(256 * 2)); + Tensor expert_count_histogram = alloc_tensor( + {size_of_expert_count_histogram}, + dl_int32, // sized by 2 * max(num_experts, 256) + routing_logits->device);flashinfer/fused_moe/core.py (3)
1118-1145
: Fake op return type mismatch (should return Tensor, not list).The custom op returns a single tensor, but the fake op returns a 1‑element list; breaks meta/inference paths.
Apply:
@register_fake_op("flashinfer::trtllm_fp8_per_tensor_scale_moe") def _fake_trtllm_fp8_per_tensor_scale_moe( @@ - return [hidden_states.new_empty([seq_len, hidden_size], dtype=torch.bfloat16)] + return hidden_states.new_empty([seq_len, hidden_size], dtype=torch.bfloat16)
1205-1234
: Same issue: fp8 block-scale fake op must return Tensor.Align fake with real op.
@register_fake_op("flashinfer::trtllm_fp8_block_scale_moe") def _fake_trtllm_fp8_block_scale_moe( @@ - return [hidden_states.new_empty([seq_len, hidden_size], dtype=torch.bfloat16)] + return hidden_states.new_empty([seq_len, hidden_size], dtype=torch.bfloat16)
225-235
: Even-row assertion contradicts odd-row handling.Asserting M is even prevents the odd‑M branch that follows from ever running.
- assert M % 2 == 0, f"x.shape[0] must be even, not {M}" + # Support both even and odd M. + # (Odd M is handled via the (M + 1) // 2 split below.)Also consider validating behavior with a quick unit test for odd M.
♻️ Duplicate comments (2)
include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh (1)
372-375
: Address the TODO and fix buffer initialization.As flagged in a previous review, the initialization
topKBufferIdx[ii] = ii * WarpSize - 1
at line 374 is problematic:When
ii=0
, this setstopKBufferIdx[0] = -1
. This value is then used inRedType::makeCmpVal
which calculatesmaxIdx - idx = 65535 - (-1) = 65536
, overflowing the0xFFFF
mask used for the index part of the packed value. This could cause incorrect tie-breaking or other subtle bugs.The
@todo
comment indicates this needs validation. Consider usingRedType::maxIdx
or another safe sentinel value for invalid indices.Suggested fix:
for (int ii = 0; ii < numResults; ++ii) { topKBufferValue[ii] = minValue; - topKBufferIdx[ii] = ii * WarpSize - 1; //@todo: check if this is correct + topKBufferIdx[ii] = RedType::maxIdx; // Use safe sentinel for invalid indices }csrc/trtllm_fused_moe_routing_deepseek.cu (1)
192-199
: Rename “intermidiate” to “intermediate”.**Typo hurts readability; please rename both arrays accordingly.
🧹 Nitpick comments (8)
csrc/trtllm_fused_moe_routing_deepseek.cu (2)
191-206
: Correct inter-topK scratch sizing; current formula overshoots.NumInterTopKPerThread should be ceil(NumInterTopK / WarpSize). The current expression scales with NumExpertWarps again, inflating per-lane storage and work.
- int constexpr NumInterTopKPerThread = (NumInterTopK * NumExpertWarps - 1) / WarpSize + 1; + int constexpr NumInterTopKPerThread = (NumInterTopK + WarpSize - 1) / WarpSize;
96-99
: Typos in comment (“sigmoig”).Nit: s/sigmoig/sigmoid/ in the comment.
flashinfer/fused_moe/core.py (2)
1130-1136
: Silence Ruff ARG001 for unused optional params in fake ops.Keep signature for registry, but explicitly mark unused.
def _fake_trtllm_fp8_per_tensor_scale_moe( @@ - ): + ): + # Unused in fake; keep signature for registry. + del n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor @@ def _fake_trtllm_fp8_block_scale_moe( @@ - ): + ): + # Unused in fake; keep signature for registry. + del n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factorAlternatively add “# noqa: ARG001” to the function defs.
Also applies to: 1218-1223
113-122
: Duplicate entry in trtllm_gen_dtype_has_scale.MxE4m3 is listed twice; harmless but noisy.
- if dtype in [ - DtypeTrtllmGen.MxE4m3, - DtypeTrtllmGen.E2m1, - DtypeTrtllmGen.MxE2m1, - DtypeTrtllmGen.MxE4m3, - ]: + if dtype in [DtypeTrtllmGen.MxE4m3, DtypeTrtllmGen.E2m1, DtypeTrtllmGen.MxE2m1]: return Truecsrc/trtllm_fused_moe_routing_renormalize.cu (4)
25-30
: Top‑K and expert limits: good, but add explicit guards.MaxNumTopExperts=10 and NumExpertsLimit=512 look fine. Add compile‑time checks tying assumptions together.
static constexpr int MaxNumTopExperts = 10; static constexpr int NumExpertsLimit = 512; +static_assert(MaxNumTopExperts <= std::numeric_limits<int8_t>::max(), + "TopK index stored in int8_t; must fit."); +static_assert(NumExpertsLimit % WarpSize == 0, + "Max experts must be multiple of warp size for VecSize.");
75-217
: Block kernel: int8_t scratch indices rely on small token/top‑k; make it explicit.smemKIdx/smemOffset use int8_t; safe with BlockKernelMaxNumTokens=4 and MaxNumTopExperts=10, but brittle if thresholds grow.
- Add comments and static_asserts on bounds (expert counts per token per expert ≤ 127).
- Consider uint16_t if future configs may exceed 127.
- __shared__ int8_t __attribute((aligned(128))) smemOffset[totalExpertCounts]; - __shared__ int8_t __attribute((aligned(128))) smemKIdx[totalExpertCounts]; + __shared__ int8_t __attribute((aligned(128))) smemOffset[totalExpertCounts]; // offsetWithinExpert ∈ [0, BlockKernelMaxNumTokens) + __shared__ int8_t __attribute((aligned(128))) smemKIdx[totalExpertCounts]; // kIdx ∈ [0, MaxNumTopExperts) + static_assert(BlockKernelMaxNumTokens < 128 && MaxNumTopExperts < 128, "int8_t bounds");
370-381
: Typo in macro name (RENORNALIZE).Nit, but spreads quickly in call sites.
-#define LAUNCH_ROUTING_RENORNALIZE(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ +#define LAUNCH_ROUTING_RENORMALIZE(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ stream, extraFlag1) \ @@ - LAUNCH_ROUTING_RENORNALIZE(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ + LAUNCH_ROUTING_RENORMALIZE(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \…and update invocations below.
382-464
: run(): validation and dispatch look solid; minor clarity nits.
- Good: enforce inputs for large token cases; dynamic numThreads via getMaxNumExperts.
- Suggest rename “useSingleBlock” threshold comment to reference BlockKernelMaxNumTokens constant.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (13)
csrc/trtllm_batched_gemm_runner.cu
(0 hunks)csrc/trtllm_fused_moe_kernel_launcher.cu
(13 hunks)csrc/trtllm_fused_moe_routing_deepseek.cu
(11 hunks)csrc/trtllm_fused_moe_routing_llama4.cu
(8 hunks)csrc/trtllm_fused_moe_routing_renormalize.cu
(8 hunks)csrc/trtllm_fused_moe_runner.cu
(3 hunks)flashinfer/fused_moe/core.py
(7 hunks)include/flashinfer/trtllm/fused_moe/DevKernel.h
(2 hunks)include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh
(12 hunks)include/flashinfer/trtllm/fused_moe/RoutingKernel.h
(9 hunks)include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh
(3 hunks)tests/conftest.py
(1 hunks)tests/moe/test_trtllm_gen_fused_moe.py
(5 hunks)
💤 Files with no reviewable changes (1)
- csrc/trtllm_batched_gemm_runner.cu
🧰 Additional context used
🧬 Code graph analysis (8)
include/flashinfer/trtllm/fused_moe/DevKernel.h (1)
csrc/trtllm_fmha_kernel_launcher.cu (3)
namespace trtllm_cubin_loader {
(303-305)Context
(34-312)trtllm_paged_attention_launcher
(75-165)
csrc/trtllm_fused_moe_routing_deepseek.cu (1)
csrc/trtllm_fused_moe_routing_llama4.cu (8)
void
(67-280)void
(354-356)void
(363-423)__launch_bounds__
(363-363)getMaxNumExperts
(426-433)getMaxNumExperts
(426-426)routingIndicesClusterKernel
(285-352)routingIndicesClusterKernel
(354-354)
csrc/trtllm_fused_moe_routing_llama4.cu (2)
csrc/trtllm_fused_moe_routing_renormalize.cu (6)
routingTopKExperts
(32-37)void
(76-217)void
(288-291)void
(297-353)getMaxNumExperts
(357-366)getMaxNumExperts
(357-357)csrc/trtllm_fused_moe_routing_deepseek.cu (5)
void
(34-252)void
(276-278)void
(459-461)getMaxNumExperts
(464-475)getMaxNumExperts
(464-464)
include/flashinfer/trtllm/fused_moe/RoutingKernel.h (1)
include/flashinfer/trtllm/fused_moe/DevKernel.h (4)
setKernelParams
(219-235)setKernelParams
(273-284)setKernelParams
(335-350)setKernelParams
(415-434)
tests/moe/test_trtllm_gen_fused_moe.py (3)
flashinfer/autotuner.py (1)
autotune
(251-262)flashinfer/fused_moe/core.py (1)
RoutingMethodType
(59-73)include/flashinfer/trtllm/fused_moe/runner.h (1)
RoutingMethodType
(37-135)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)
include/flashinfer/trtllm/fused_moe/runner.h (7)
top_k
(269-269)n_group
(270-270)topk_group
(272-272)intermediate_size
(274-274)local_expert_offset
(275-275)local_num_experts
(276-276)num_experts
(262-262)
flashinfer/fused_moe/core.py (1)
include/flashinfer/trtllm/fused_moe/runner.h (5)
n_group
(270-270)topk_group
(272-272)intermediate_size
(274-274)local_expert_offset
(275-275)local_num_experts
(276-276)
csrc/trtllm_fused_moe_routing_renormalize.cu (2)
csrc/trtllm_fused_moe_routing_deepseek.cu (7)
void
(34-252)void
(276-278)void
(459-461)getMaxNumExperts
(464-475)getMaxNumExperts
(464-464)routingIndicesClusterKernel
(260-274)routingIndicesClusterKernel
(276-276)csrc/trtllm_fused_moe_routing_llama4.cu (9)
void
(67-280)void
(354-356)void
(363-423)routingTopKExperts
(40-44)getMaxNumExperts
(426-433)getMaxNumExperts
(426-426)routingIndicesClusterKernel
(285-352)routingIndicesClusterKernel
(354-354)routingIndicesHistogramScoresKernel
(364-364)
🪛 Ruff (0.14.0)
flashinfer/fused_moe/core.py
1130-1130: Unused function argument: n_group
(ARG001)
1131-1131: Unused function argument: topk_group
(ARG001)
1132-1132: Unused function argument: intermediate_size
(ARG001)
1133-1133: Unused function argument: local_expert_offset
(ARG001)
1134-1134: Unused function argument: local_num_experts
(ARG001)
1135-1135: Unused function argument: routed_scaling_factor
(ARG001)
1218-1218: Unused function argument: n_group
(ARG001)
1219-1219: Unused function argument: topk_group
(ARG001)
1220-1220: Unused function argument: intermediate_size
(ARG001)
1221-1221: Unused function argument: local_expert_offset
(ARG001)
1222-1222: Unused function argument: local_num_experts
(ARG001)
1223-1223: Unused function argument: routed_scaling_factor
(ARG001)
🔇 Additional comments (17)
tests/moe/test_trtllm_gen_fused_moe.py (4)
108-108
: LGTM: Autotune disabled during warmup.Disabling autotuning during CUDA graph warmup is correct, as autotuning can interfere with graph capture.
1860-1861
: Verify kimi_k2 configuration parameters.The
n_groups
andtop_k_groups
parameters have been changed from[12, 4]
to[1, 1]
, which removes hierarchical expert grouping. According to the PR description, this is based on the Kimi K2 config from Hugging Face.Please verify that
n_groups=1
andtop_k_groups=1
accurately reflect the Kimi K2 Instruct config. If the actual model uses hierarchical grouping, these values may need adjustment.
1908-1933
: Test coverage expanded for Renormalize variants.The configurations for
Renorm
andRenormalizeNaive
have been updated to test larger configurations (512 experts, top_k=10) and the skip marks have been commented out to enable these tests. This aligns with the PR objective to support models like Qwen with larger routing configurations.
2091-2091
: Verify top_k upper bound after removing assertion.The
assert top_k <= 8
has been commented out, presumably to support configurations withtop_k=10
. Please confirm:
- What is the new upper bound for
top_k
(if any)?- Are there any kernel or hardware constraints that limit
top_k
values?- Should this assertion be replaced with a higher limit or removed entirely?
Based on the test configurations using
top_k=10
, this change appears intentional to support larger top-k values.include/flashinfer/trtllm/fused_moe/DevKernel.h (3)
34-34
: LGTM: Logger include added.The logger header is needed to support
FLASHINFER_WARN
calls in the new routing macros.
116-126
: LGTM: Llama4-specific routing macro added.The
LAUNCH_ROUTING_LLAMA4
macro correctly implements llama4-specific routing with a fixed per-expert block size of 128. The hardcoded constant is well-documented in the comments.
128-171
: LGTM: Parameterized expert count macros added.The new
LAUNCH_ROUTING_WITH_NUM_EXPERTS
macros provide flexible expert count configuration with optional input type forcing. The implementation correctly handles both Fp32 and Bfloat16 expert weight types with appropriate flag combinations.csrc/trtllm_fused_moe_runner.cu (1)
68-68
: LGTM: Consistent pointer naming refactor.The pointer renames from expert-centric (
mPtrExpertIdx
,mPtrExpertWeights
) to top-k-centric (mPtrTopKPacked
,mPtrTopKWeights
) have been applied consistently across all three routing methods (DeepSeekV3, Llama4, and Renormalize). This improves semantic clarity and aligns with the broader top-k buffer architecture changes described in the PR.Also applies to: 73-73, 105-105, 110-110, 147-147, 152-152
include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh (3)
34-35
: LGTM: System limits defined as constants.The new constants
MaxNumExpertsUnit
(128) andMaxNumTopK
(10) establish clear system-wide limits. TheMaxNumTopK=10
value aligns with the test configuration changes that allowtop_k=10
.
317-324
: Refactored top-k reduction with clearer constraints.The function has been renamed to
reduceTopKFunc
to distinguish it as an internal implementation, and the static assertion has been tightened fromN <= 16
toN < 5
. This clarifies that this function variant handles only small candidate counts (N ≤ 4), with larger counts delegated to the buffering path in the new publicreduceTopK
overload.Note: The past review comment about the confusing static assert message is now partially addressed by the tighter constraint, though the comment text itself could still be updated.
376-400
: Buffering algorithm is sound pending initialization fix.The multi-pass buffering strategy correctly reduces large candidate counts (N > 4) by processing in chunks of 4 and consolidating results. The logic properly handles thread distribution and edge cases, though the buffer initialization issue at line 374 needs to be resolved.
include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh (3)
102-129
: Softmax in float: good numeric stability tradeoff.Computing in float and reducing via cg::reduce is correct and safer for bf16/half inputs.
472-701
: Histogram/offset kernels parameterized by KernelParams::MaxNumExperts: LGTM.Launch bounds and shared buffers consistently use MaxNumExperts; PDL sync points are placed correctly for the two-step path.
Please ensure LAUNCH_* macros always pass blockDim == KernelParams::MaxNumExperts for these kernels.
705-728
: Init kernel addition is appropriate.routingInitExpertCounts isolates zeroing; matches usage in run paths.
csrc/trtllm_fused_moe_routing_llama4.cu (1)
310-352
: Llama4 routing refactor looks coherent.
- MaxNumExperts limited to 128 with compile-time guards.
- Supports TopKIds/TopKPacked/Scores inputs consistently across warp/cluster/histogram paths.
Confirm tests cover both scores-input and packed/ids-input paths for num_tokens near WarpKernelMaxNumTokens boundaries.
Also applies to: 362-423, 425-433, 476-525
include/flashinfer/trtllm/fused_moe/RoutingKernel.h (1)
59-75
: Top-K bufferization (Ids/Weights/Packed) and MaxNumExperts param: sensible API evolution.DataBase and KernelParams* now expose Top-K buffers and parameterized bounds; casting is consistent.
Ensure callers always supply mPtrTopKWeights when providing mPtrTopKIds (guards exist in runImpl), and that OutputT matches the allocated packed score type.
Also applies to: 103-153, 174-213, 229-247, 267-293
csrc/trtllm_fused_moe_routing_renormalize.cu (1)
357-366
: getMaxNumExperts: OK; unreachable 0 guarded above.Function returns 0 for unsupported values; run() already checks ≤ NumExpertsLimit. Keep as is; no action.
tests/conftest.py
Outdated
@pytest.hookimpl(tryfirst=True) | ||
def pytest_runtest_call(item): | ||
# skip OOM error and missing JIT cache errors | ||
# Wrap the test call so we don't invoke item.runtest() ourselves; yield lets pytest run it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Misleading comment.
The comment states "yield lets pytest run it," but the code directly calls item.runtest()
on line 144 without using yield
. The comment should be updated to accurately reflect the implementation.
Apply this diff to correct the comment:
- # Wrap the test call so we don't invoke item.runtest() ourselves; yield lets pytest run it.
+ # Execute the test and handle exceptions for OOM and missing JIT cache.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
# Wrap the test call so we don't invoke item.runtest() ourselves; yield lets pytest run it. | |
# Execute the test and handle exceptions for OOM and missing JIT cache. |
🤖 Prompt for AI Agents
In tests/conftest.py around line 142, the existing comment says "Wrap the test
call so we don't invoke item.runtest() ourselves; yield lets pytest run it." but
the implementation actually calls item.runtest() directly; update the comment to
accurately describe the behavior (e.g., state that the fixture wraps the test
call and invokes item.runtest() directly rather than yielding to pytest),
keeping it concise and reflecting the real control flow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)
145-146
: Add bounds checking for routed_scaling_factor to match validation patterns used elsewhere.The code lacks validation for
routed_scaling_factor
despite similar optional parameters (n_group
,topk_group
) receiving extensive bounds checks (lines 79–94). Consider adding validation to ensure the scaling factor is positive and within reasonable bounds, consistent with the codebase's defensive programming pattern and the guidance that scaling factors should avoid zero and negative values.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
csrc/trtllm_fused_moe_kernel_launcher.cu
(14 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)
include/flashinfer/trtllm/fused_moe/runner.h (7)
top_k
(269-269)n_group
(270-270)topk_group
(272-272)intermediate_size
(274-274)local_expert_offset
(275-275)local_num_experts
(276-276)num_experts
(262-262)
🔇 Additional comments (9)
csrc/trtllm_fused_moe_kernel_launcher.cu (9)
45-47
: LGTM! Function signature updated to support optional routing parameters.The changes to accept
Optional<int64_t>
forn_group
andtopk_group
, andOptional<double>
forrouted_scaling_factor
align with the PR objective to support multiple routing configurations for Kimi K2 and Qwen models.
79-94
: LGTM! Comprehensive validation for grouped routing.The validation logic properly ensures:
- Grouped routing is only used with DeepSeekV3
- All required parameters are present when using groups
- Divisibility and bound constraints are satisfied
- Sufficient experts exist in selected groups for the requested top_k
164-166
: LGTM! Histogram sizing updated to handle variable expert counts.The sizing calculation ensures the histogram buffer is large enough for both the number of experts and the maximum thread block size (256), which is appropriate for the routing kernel.
303-306
: LGTM! Wrapper signature updated consistently.The wrapper function signature properly reflects the launcher's updated parameters and forwards them correctly.
379-384
: Renormalize methods supported here (contrast with per-tensor launcher).This block-scale launcher correctly validates top_k bounds for Renormalize and RenormalizeNaive methods, while the per-tensor scale launcher at lines 95-100 rejects these methods entirely. This confirms the inconsistency flagged earlier.
731-758
: LGTM! Validation logic properly structured.The validation logic correctly handles optional parameters using
value_or(0)
and provides appropriate checks for different routing method configurations. The main concern is the commented-out checks, which have been flagged separately.
806-810
: LGTM! Default values consistent across launchers.The default values for optional parameters are consistent with the other launcher functions in this file.
834-836
: LGTM! Histogram sizing consistent with other launchers.The sizing calculation matches the approach used in the per-tensor and block-scale launchers.
95-100
: Inconsistent Renormalize routing support between per-tensor and block-scale launchers.This per-tensor scale launcher throws
NotImplementedError
forRenormalize
andRenormalizeNaive
routing methods (lines 95-100), while the block-scale launcher validates these methods withtop_k <= 10 && top_k > 0
constraints instead of rejecting them. This inconsistency needs clarification:
- Determine if per-tensor intentionally excludes Renormalize support or if it should match block-scale behavior
- Update validation logic in one or both launchers to align
f49ee40
to
0f33230
Compare
/bot run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh (1)
152-160
: Guard against actualK > K to prevent out-of-bounds writes.Loop indexes out[kk] and outIdx[kk] assume kk < K. Clamp actualK to K.
- for (int kk = 0; kk < actualK; ++kk) //@todo: check if actualK is correct + int cappedK = actualK < K ? actualK : K; // cap to template bound + for (int kk = 0; kk < cappedK; ++kk) //@todo: check if actualK is correct
♻️ Duplicate comments (2)
include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh (2)
170-171
: Fix misleading static_assert message (copy-paste).Condition enforces N < 5 (i.e., up to 4), but the message says “<= 128.” Update the message for clarity.
- static_assert(N < 5, "Only support candidates number less than or equal to 128"); + static_assert(N < 5, "Only support up to 4 candidates per thread in this function.");
218-221
: Invalid sentinel index initialization (-1) can corrupt tie-breaking.Setting topKBufferIdx[ii] = ii*WarpSize - 1 yields -1 for ii=0. In makeCmpVal this becomes (maxIdx - idx) = 65536, which wraps to 0x0000 and collides with a valid index in the lower 16 bits. Initialize to a safe value.
- for (int ii = 0; ii < numResults; ++ii) { - topKBufferValue[ii] = minValue; - topKBufferIdx[ii] = ii * WarpSize - 1; //@todo: check if this is correct - } + using RedT = TopKRedType<Type>; + for (int ii = 0; ii < numResults; ++ii) { + topKBufferValue[ii] = minValue; + // Use a valid sentinel that won't overflow the 0xFFFF packed index space + topKBufferIdx[ii] = RedT::maxIdx; + }
🧹 Nitpick comments (4)
flashinfer/fused_moe/core.py (1)
1122-1149
: Silence Ruff ARG001 in fake op: prefix unused Optional params.The fake implementation ignores several params; prefix with “_” to satisfy lint without behavior change.
- def _fake_trtllm_fp8_per_tensor_scale_moe( + def _fake_trtllm_fp8_per_tensor_scale_moe( routing_logits: torch.Tensor, - routing_bias: Optional[torch.Tensor], + routing_bias: Optional[torch.Tensor], hidden_states: torch.Tensor, gemm1_weights: torch.Tensor, output1_scales_scalar: torch.Tensor, output1_scales_gate_scalar: torch.Tensor, gemm2_weights: torch.Tensor, output2_scales_scalar: torch.Tensor, num_experts: int, top_k: int, - n_group: Optional[int], - topk_group: Optional[int], + _n_group: Optional[int], + _topk_group: Optional[int], intermediate_size: int, - local_expert_offset: int, - local_num_experts: int, - routed_scaling_factor: Optional[float], - use_routing_scales_on_input: bool, - tile_tokens_dim: int = 8, - routing_method_type: int = 0, - enable_pdl: Optional[bool] = None, + _local_expert_offset: int, + _local_num_experts: int, + _routed_scaling_factor: Optional[float], + _use_routing_scales_on_input: bool, + _tile_tokens_dim: int = 8, + _routing_method_type: int = 0, + _enable_pdl: Optional[bool] = None, ):Apply analogous renames in _fake_trtllm_fp8_block_scale_moe for: n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor, tile_tokens_dim, routing_method_type, enable_pdl.
tests/moe/test_trtllm_gen_fused_moe.py (2)
1835-1837
: Align cache key with parameters impacting permutation.Tests document keys as (weight_type, shape). Given core depends on epilogue_tile_m and num_elts_per_sf, consider reflecting that here to avoid hidden coupling.
2065-2072
: Nit: typo in comment.“epxerts” → “experts”.
- # Skip large intermediate size and hidden size for configurations with small epxerts + # Skip large intermediate size and hidden size for configurations with small expertscsrc/trtllm_fused_moe_kernel_launcher.cu (1)
363-376
: Inconsistent top_k bound for grouped routing across launchers.fp8_block_scale enforces top_k <= 8 with groups; fp4_block_scale allows top_k <= 10. Is this intentional? If yes, add comments documenting per-kernel limits; if not, harmonize the checks.
Also applies to: 731-747
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
csrc/trtllm_fused_moe_kernel_launcher.cu
(14 hunks)csrc/trtllm_fused_moe_routing_renormalize.cu
(8 hunks)flashinfer/fused_moe/core.py
(11 hunks)include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh
(3 hunks)tests/moe/test_trtllm_gen_fused_moe.py
(6 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
csrc/trtllm_fused_moe_routing_renormalize.cu (2)
csrc/trtllm_fused_moe_routing_deepseek.cu (7)
void
(34-252)void
(276-278)void
(459-461)getMaxNumExperts
(464-475)getMaxNumExperts
(464-464)routingIndicesClusterKernel
(260-274)routingIndicesClusterKernel
(276-276)csrc/trtllm_fused_moe_routing_llama4.cu (9)
void
(67-280)void
(354-356)void
(363-423)routingTopKExperts
(40-44)getMaxNumExperts
(426-433)getMaxNumExperts
(426-426)routingIndicesClusterKernel
(285-352)routingIndicesClusterKernel
(354-354)routingIndicesHistogramScoresKernel
(364-364)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)
include/flashinfer/trtllm/fused_moe/runner.h (7)
top_k
(269-269)n_group
(270-270)topk_group
(272-272)intermediate_size
(274-274)local_expert_offset
(275-275)local_num_experts
(276-276)num_experts
(262-262)
tests/moe/test_trtllm_gen_fused_moe.py (2)
flashinfer/fused_moe/core.py (1)
RoutingMethodType
(59-73)include/flashinfer/trtllm/fused_moe/runner.h (4)
RoutingMethodType
(37-135)intermediate_size
(274-274)hidden_size
(264-264)top_k
(269-269)
flashinfer/fused_moe/core.py (1)
include/flashinfer/trtllm/fused_moe/runner.h (5)
n_group
(270-270)topk_group
(272-272)intermediate_size
(274-274)local_expert_offset
(275-275)local_num_experts
(276-276)
🪛 Ruff (0.14.0)
flashinfer/fused_moe/core.py
1134-1134: Unused function argument: n_group
(ARG001)
1135-1135: Unused function argument: topk_group
(ARG001)
1136-1136: Unused function argument: intermediate_size
(ARG001)
1137-1137: Unused function argument: local_expert_offset
(ARG001)
1138-1138: Unused function argument: local_num_experts
(ARG001)
1139-1139: Unused function argument: routed_scaling_factor
(ARG001)
1222-1222: Unused function argument: n_group
(ARG001)
1223-1223: Unused function argument: topk_group
(ARG001)
1224-1224: Unused function argument: intermediate_size
(ARG001)
1225-1225: Unused function argument: local_expert_offset
(ARG001)
1226-1226: Unused function argument: local_num_experts
(ARG001)
1227-1227: Unused function argument: routed_scaling_factor
(ARG001)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh (1)
226-231
: Local-expert check uses log2 value as bitmask instead of converting to actual mask—affects 5 locations, not 2.The review correctly identifies the bug but underestimated scope. Using
(localExpertIdx & params.mLocalExpertsStrideLog2) == 0
incorrectly uses a log2 value as a bitmask. Example: ifmLocalExpertsStrideLog2=3
(stride=8), the code checksidx & 3
(divides by 4) instead ofidx & 7
(divides by 8).Apply the suggested fix to all 5 occurrences:
+ int32_t strideMask = (1 << params.mLocalExpertsStrideLog2) - 1; auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent && - (localExpertIdx & params.mLocalExpertsStrideLog2) == 0; + ((localExpertIdx & strideMask) == 0);Locations: lines 228, 364, 434, 582, and 661.
csrc/trtllm_fused_moe_routing_llama4.cu (1)
264-267
: Fix the local-expert mask bitwise AND operations across 6 locations.The bug is confirmed: using a log2 value directly in bitwise AND instead of converting to a mask. This occurs in:
csrc/trtllm_fused_moe_routing_llama4.cu
line 266include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh
lines 228, 364, 434, 582, 661Apply the fix from the review comment to all 6 locations: create
int32_t strideMask = (1 << params.mLocalExpertsStrideLog2) - 1;
and use((localExpertIdx & strideMask) == 0)
instead of(localExpertIdx & params.mLocalExpertsStrideLog2) == 0
.
♻️ Duplicate comments (5)
include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh (3)
218-221
: Sentinel -1 overflows index packing; use a safe sentinel.topKBufferIdx[ii] = ii * WarpSize - 1 can produce -1 → 65536 after (maxIdx - idx), overflowing 16-bit field and skewing tie-breaks. Initialize with RedType::maxIdx.
- for (int ii = 0; ii < numResults; ++ii) { - topKBufferValue[ii] = minValue; - topKBufferIdx[ii] = ii * WarpSize - 1; - } + for (int ii = 0; ii < numResults; ++ii) { + topKBufferValue[ii] = minValue; + topKBufferIdx[ii] = TopKRedType<Type>::maxIdx; + }
228-233
: OOB reads when N is not a multiple of 4.Tail chunk reads past N. Guard and fill with neutral values.
- for (int i = 0; i < 4; ++i) { - inValue[i] = value[start + i]; - inIdx[i] = idx[start + i]; - } + int rem = N - start; + for (int i = 0; i < 4; ++i) { + if (i < rem) { + inValue[i] = value[start + i]; + inIdx[i] = idx[start + i]; + } else { + inValue[i] = minValue; + inIdx[i] = TopKRedType<Type>::maxIdx; + } + }
170-171
: Update misleading static_assert message.N < 5 enforces up to 4 candidates per thread; fix message to match.
- static_assert(N < 5, "Only support candidates number less than or equal to 128"); + static_assert(N < 5, "Only support up to 4 candidates per thread in this function.");flashinfer/fused_moe/core.py (1)
174-194
: Cache key still incomplete (previously flagged).The cache key uses only
(prefix, shape)
butepilogue_tile_m
andnum_elts_per_sf
also affect permutation indices. This was identified in past reviews but remains unaddressed.- cache_key = ("w3_w1", dst_w3_w1_weight.shape) + cache_key = ("w3_w1", dst_w3_w1_weight.shape, epilogue_tile_m, num_elts_per_sf)Apply the same fix to
get_w2_permute_indices_with_cache
:- cache_key = ("w2", dst_w2_weight.shape) + cache_key = ("w2", dst_w2_weight.shape, epilogue_tile_m, num_elts_per_sf)csrc/trtllm_fused_moe_routing_deepseek.cu (1)
192-201
: Fix variable name typos.The variables
intermidiateScore
andintermidiateExpert
should be renamed tointermediateScore
andintermediateExpert
.- float intermidiateScore[NumInterTopKPerThread]; - int32_t intermidiateExpert[NumInterTopKPerThread]; + float intermediateScore[NumInterTopKPerThread]; + int32_t intermediateExpert[NumInterTopKPerThread]; for (int i = laneIdx; i < NumInterTopKPerThread * WarpSize; i += WarpSize) { int ii = i / WarpSize; if (i < NumInterTopK) { - intermidiateScore[ii] = smemInterTopScores[i]; - intermidiateExpert[ii] = smemInterTopExperts[i]; + intermediateScore[ii] = smemInterTopScores[i]; + intermediateExpert[ii] = smemInterTopExperts[i]; } else { - intermidiateScore[ii] = invalidScoreFloat; - intermidiateExpert[ii] = KernelParams::MaxNumExperts - 1; + intermediateScore[ii] = invalidScoreFloat; + intermediateExpert[ii] = KernelParams::MaxNumExperts - 1; } } - topk::reduceTopK(warp, topScores, topExperts, intermidiateScore, intermidiateExpert, + topk::reduceTopK(warp, topScores, topExperts, intermediateScore, intermediateExpert, /* minValue */ invalidScoreFloat, params.mTopK);
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (11)
csrc/trtllm_fused_moe_kernel_launcher.cu
(14 hunks)csrc/trtllm_fused_moe_routing_deepseek.cu
(11 hunks)csrc/trtllm_fused_moe_routing_llama4.cu
(8 hunks)csrc/trtllm_fused_moe_routing_renormalize.cu
(8 hunks)csrc/trtllm_fused_moe_runner.cu
(3 hunks)flashinfer/fused_moe/core.py
(11 hunks)include/flashinfer/trtllm/fused_moe/DevKernel.h
(2 hunks)include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh
(12 hunks)include/flashinfer/trtllm/fused_moe/RoutingKernel.h
(9 hunks)include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh
(3 hunks)tests/moe/test_trtllm_gen_fused_moe.py
(6 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- csrc/trtllm_fused_moe_runner.cu
🧰 Additional context used
🧬 Code graph analysis (7)
csrc/trtllm_fused_moe_routing_deepseek.cu (1)
csrc/trtllm_fused_moe_routing_renormalize.cu (4)
getMaxNumExperts
(357-366)getMaxNumExperts
(357-357)routingIndicesClusterKernel
(222-286)routingIndicesClusterKernel
(289-289)
tests/moe/test_trtllm_gen_fused_moe.py (2)
flashinfer/fused_moe/core.py (1)
RoutingMethodType
(59-73)include/flashinfer/trtllm/fused_moe/runner.h (1)
RoutingMethodType
(37-135)
csrc/trtllm_fused_moe_routing_renormalize.cu (2)
csrc/trtllm_fused_moe_routing_deepseek.cu (7)
void
(34-252)void
(276-278)void
(459-461)getMaxNumExperts
(464-475)getMaxNumExperts
(464-464)routingIndicesClusterKernel
(260-274)routingIndicesClusterKernel
(276-276)csrc/trtllm_fused_moe_routing_llama4.cu (9)
void
(67-280)void
(354-356)void
(363-423)routingTopKExperts
(40-44)getMaxNumExperts
(426-433)getMaxNumExperts
(426-426)routingIndicesClusterKernel
(285-352)routingIndicesClusterKernel
(354-354)routingIndicesHistogramScoresKernel
(364-364)
csrc/trtllm_fused_moe_routing_llama4.cu (1)
csrc/trtllm_fused_moe_routing_renormalize.cu (9)
routingTopKExperts
(32-37)void
(76-217)void
(288-291)void
(297-353)getMaxNumExperts
(357-366)getMaxNumExperts
(357-357)routingIndicesClusterKernel
(222-286)routingIndicesClusterKernel
(289-289)routingIndicesHistogramScoresKernel
(298-298)
include/flashinfer/trtllm/fused_moe/RoutingKernel.h (1)
include/flashinfer/trtllm/fused_moe/DevKernel.h (4)
setKernelParams
(219-235)setKernelParams
(273-284)setKernelParams
(335-350)setKernelParams
(415-434)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)
include/flashinfer/trtllm/fused_moe/runner.h (7)
top_k
(269-269)n_group
(270-270)topk_group
(272-272)intermediate_size
(274-274)local_expert_offset
(275-275)local_num_experts
(276-276)num_experts
(262-262)
flashinfer/fused_moe/core.py (1)
include/flashinfer/trtllm/fused_moe/runner.h (5)
n_group
(270-270)topk_group
(272-272)intermediate_size
(274-274)local_expert_offset
(275-275)local_num_experts
(276-276)
🪛 Ruff (0.14.0)
flashinfer/fused_moe/core.py
1134-1134: Unused function argument: n_group
(ARG001)
1135-1135: Unused function argument: topk_group
(ARG001)
1136-1136: Unused function argument: intermediate_size
(ARG001)
1137-1137: Unused function argument: local_expert_offset
(ARG001)
1138-1138: Unused function argument: local_num_experts
(ARG001)
1139-1139: Unused function argument: routed_scaling_factor
(ARG001)
1222-1222: Unused function argument: n_group
(ARG001)
1223-1223: Unused function argument: topk_group
(ARG001)
1224-1224: Unused function argument: intermediate_size
(ARG001)
1225-1225: Unused function argument: local_expert_offset
(ARG001)
1226-1226: Unused function argument: local_num_experts
(ARG001)
1227-1227: Unused function argument: routed_scaling_factor
(ARG001)
🔇 Additional comments (14)
include/flashinfer/trtllm/fused_moe/DevKernel.h (1)
116-126
: Confirm cooperative launch for cluster kernels.LAUNCH_ROUTING_LLAMA4 exposes coopLaunch; cluster kernels using cluster_dims generally require cooperative launch attributes. Ensure callers pass coopLaunch=true for cluster paths; otherwise cluster barriers can deadlock.
Would you like me to scan callers and patch the relevant launches?
tests/moe/test_trtllm_gen_fused_moe.py (3)
2066-2071
: Minor typo.“epxerts” -> “experts”.
- # Skip large intermediate size and hidden size for configurations with small epxerts + # Skip large intermediate size and hidden size for configurations with few expertsLikely an incorrect or invalid review comment.
1842-1843
: Review comment assertion unverifiable — MaxNumTopK exists only once in codebase.Verification found only one
MaxNumTopK = 10
definition inRoutingKernelTopK.cuh
(line 35). No separate kernel implementations with independentMaxNumTopK
settings exist for the routing methods (Renormalize, DeepSeekV3, TopK, Llama4). The constraint is enforced globally via test assertionassert top_k <= 10
at line 2120, which applies uniformly to all routing methods.The review's request to verify that each kernel type is "compiled with MaxNumTopK >= 10" assumes separate kernel implementations with independent compilation parameters. This architecture does not appear to exist based on the codebase structure. Recommend manual verification of whether runtime kernel configurations or external kernel libraries bypass this single definition.
1835-1837
: Verified: cache key implementation is correct. Both helper functions properly construct tuple keys.The verification confirms that
_maybe_get_cached_w3_w1_permute_indices
andget_w2_permute_indices_with_cache
both correctly construct tuple cache keys:
_maybe_get_cached_w3_w1_permute_indices
usescache_key = ("w3_w1", dst_w3_w1_weight.shape)
(core.py:175)get_w2_permute_indices_with_cache
usescache_key = ("w2", dst_w2_weight.shape)
(core.py:204)Both functions consistently store and retrieve cached values using these tuple keys with the
Dict[tuple, torch.Tensor]
dictionary. The tuple structure ensures uniqueness by weight identifier and shape, preventing collisions or misses.csrc/trtllm_fused_moe_routing_llama4.cu (1)
503-508
: PDL score path vs init path: OK, but ensure TopKPacked is allocated.HistogramScoresKernel writes mPtrTopKPacked; non-score path resets counts. Verify callers always allocate mPtrTopKPacked when mPtrScores is used, else later kernels read garbage.
Also applies to: 515-524
csrc/trtllm_fused_moe_routing_deepseek.cu (2)
26-31
: LGTM: Constants updated for Kimi K2 and expanded routing.The new constants correctly support 384 experts for Kimi K2 and increased group limits (MaxNumTopGroups=4, MaxNumGroups=8) for flexible routing configurations.
464-475
: Well-structured expert-count routing function.The
getMaxNumExperts
function cleanly maps runtime expert counts to compile-time kernel specializations with proper error handling.include/flashinfer/trtllm/fused_moe/RoutingKernel.h (1)
103-152
: LGTM: Clean parameterization with MaxNumExperts.The template parameter
MaxNumExperts_
is correctly threaded throughKernelParamsBase
and propagated to all derivedKernelParams
specializations, enabling compile-time kernel bounds. The addition ofmPtrTopKWeights
andmPtrTopKIds
properly exposes the Top-K data flow.csrc/trtllm_fused_moe_routing_renormalize.cu (2)
25-29
: Constants expanded for Qwen and small-batch optimization.
MaxNumTopExperts=10
supports higher top-k routing (e.g., Qwen), andBlockKernelMaxNumTokens=4
enables a fast single-block path for small batches.
405-425
: Test coverage for single-block kernel path is incomplete.The test file
./tests/moe/test_trtllm_gen_fused_moe.py
parametrizesnum_tokens
with values[1, 8, 1024]
(line 1840). This tests the single-block kernel path with 1 token, but does not explicitly test the boundary case of exactly 4 tokens with high expert counts, which the review comment specifically raises.Verification results:
BlockKernelMaxNumTokens = 4
is correctly defined (line 29) and used consistently throughout- Shared memory allocation is safe:
4 * MaxNumExperts ≤ 4 * 128 = 512 int8_t
elements ≈ 1KB- Kernel loop logic is correct:
j < BlockKernelMaxNumTokens
iterates over indices 0–3- Single-block path is exercised by the 1-token test case
However, the edge case at the upper boundary (exactly 4 tokens) with high expert counts is not explicitly covered in the test parametrization.
csrc/trtllm_fused_moe_kernel_launcher.cu (3)
79-104
: LGTM: Routing method validation properly expanded.The validation logic correctly distinguishes DeepSeekV3 (grouped), Renormalize/RenormalizeNaive, and Llama4 routing methods with appropriate constraints.
164-167
: Histogram sizing accommodates optional grouping.The histogram buffer size correctly uses
max(num_experts*2, 256*2)
to handle both grouped and non-grouped routing configurations.
738-739
: No changes needed—top_k limits correctly reflect kernel capabilities.The difference is intentional and correct. FP8 launchers enforce
top_k <= 8
because they use the DeepSeek kernel (MaxNumTopExperts=8), while FP4 launchers allowtop_k <= 10
because they use the Renormalize kernel (MaxNumTopExperts=10). The validation limits are properly aligned with each kernel's design constraints.flashinfer/fused_moe/core.py (1)
1069-1090
: LGTM: Optional routing parameters properly propagated.The signature updates consistently use
Optional[int]
forn_group
/topk_group
andOptional[float]
forrouted_scaling_factor
, enabling flexible routing configurations across all MoE operation variants.Also applies to: 1154-1177, 1243-1276
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
tests/moe/test_trtllm_gen_fused_moe.py (1)
2059-2071
: Test skip conditions improve CI performance.The added skip conditions pragmatically reduce test execution time by avoiding expensive combinations:
- Large expert counts (≥512) with large intermediate sizes (>512)
- Small expert counts (<512) with large intermediate/hidden sizes
While this improves test speed, ensure that at least some tests still exercise these larger configurations to catch potential issues.
flashinfer/fused_moe/core.py (1)
1180-1180
: Minor: Remove unnecessary blank line.This blank line appears to be unintentionally added and can be removed for consistency.
Apply this diff:
enable_pdl = device_support_pdl(hidden_states.device) - # Call the C++ function for block scale MoE
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
flashinfer/fused_moe/core.py
(11 hunks)tests/moe/test_trtllm_gen_fused_moe.py
(6 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/moe/test_trtllm_gen_fused_moe.py (2)
flashinfer/fused_moe/core.py (1)
RoutingMethodType
(59-73)include/flashinfer/trtllm/fused_moe/runner.h (4)
RoutingMethodType
(37-135)intermediate_size
(274-274)hidden_size
(264-264)top_k
(269-269)
flashinfer/fused_moe/core.py (1)
include/flashinfer/trtllm/fused_moe/runner.h (5)
n_group
(270-270)topk_group
(272-272)intermediate_size
(274-274)local_expert_offset
(275-275)local_num_experts
(276-276)
🪛 Ruff (0.14.0)
flashinfer/fused_moe/core.py
1134-1134: Unused function argument: n_group
(ARG001)
1135-1135: Unused function argument: topk_group
(ARG001)
1136-1136: Unused function argument: intermediate_size
(ARG001)
1137-1137: Unused function argument: local_expert_offset
(ARG001)
1138-1138: Unused function argument: local_num_experts
(ARG001)
1139-1139: Unused function argument: routed_scaling_factor
(ARG001)
1222-1222: Unused function argument: n_group
(ARG001)
1223-1223: Unused function argument: topk_group
(ARG001)
1224-1224: Unused function argument: intermediate_size
(ARG001)
1225-1225: Unused function argument: local_expert_offset
(ARG001)
1226-1226: Unused function argument: local_num_experts
(ARG001)
1227-1227: Unused function argument: routed_scaling_factor
(ARG001)
🔇 Additional comments (9)
tests/moe/test_trtllm_gen_fused_moe.py (5)
1835-1836
: LGTM! Cache key structure properly documented.The cache key type change from
Dict[torch.Size, torch.Tensor]
toDict[tuple, torch.Tensor]
aligns with the fix inflashinfer/fused_moe/core.py
where the cache key now includesepilogue_tile_m
andnum_elts_per_sf
. The comment clearly documents the new key structure.
1842-1842
: Test coverage expanded for smaller intermediate sizes.Adding 512 to the intermediate_size parameter list provides better test coverage for configurations with smaller intermediate dimensions, which is relevant for the Qwen and Kimi K2 models mentioned in the PR.
1924-1937
: New Qwen3_next configuration added successfully.This new test case covers the Qwen model with:
- 512 experts
- top_k=10
- Renormalize routing method
This aligns with the PR objective to support Qwen models. The configuration looks correct for a large-scale MoE setup.
2120-2120
: top_k bound correctly increased for Qwen3 support.The assertion change from
top_k <= 8
totop_k <= 10
is necessary to accommodate the new Qwen3_next configuration which uses top_k=10. This aligns with the expanded routing support in the PR.
1856-1872
: No changes needed — configuration already matches upstream.The Kimi K2 Instruct model configuration has n_groups and top_k_groups both set to 1, which exactly matches the values in the test code. The configuration is already correct.
flashinfer/fused_moe/core.py (4)
174-177
: Critical fix: Cache key now includes all permutation parameters.This change addresses a previously identified bug where the cache key only used
(weight_type, shape)
, which could return incorrect permutation indices when the same shape was used with differentepilogue_tile_m
ornum_elts_per_sf
values. Including these parameters in the cache key ensures correctness.Based on learnings from past review comments.
190-193
: Cache storage and retrieval correctly updated.The cache storage at line 190 and retrieval at line 193 now use the composite
cache_key
that includes all relevant parameters. This ensures that cached permutation indices are only reused when all parameters match.
203-218
: W2 permutation cache consistently updated.The
get_w2_permute_indices_with_cache
function now uses the same comprehensive cache key structure as the W3/W1 permutation function. This ensures consistent caching behavior across both weight matrices.
1080-1082
: Routing parameters correctly made optional.Changing
n_group
,topk_group
, androuted_scaling_factor
toOptional
types is appropriate because not all routing methods require these parameters:
RoutingMethodType.Renormalize
andRoutingMethodType.RenormalizeNaive
don't use group-based routingRoutingMethodType.TopK
doesn't use scaling factorsThis change improves API flexibility while maintaining backward compatibility.
Also applies to: 1085-1085
/bot run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tests/moe/test_trtllm_gen_fused_moe.py (1)
1856-1872
: KIMI K2 routing configuration values are incorrect and must be updated.The test case currently uses
n_groups=1
andtop_k_groups=1
, but the actual KIMI K2 Instruct model configuration specifies n_groups=8 and top_k_groups=4. Update these values to match the published model configuration:"n_groups": 8, "top_k_groups": 4,
🧹 Nitpick comments (1)
tests/moe/test_trtllm_gen_fused_moe.py (1)
2059-2063
: Skip condition appears unreachable with current test parameters.The skip condition checks for
num_experts >= 512
, but the maximumnum_experts
in the test parameters is 384 (kimi_k2 configuration). This means the skip will never be triggered by the current test matrix.If this is intended for future-proofing, consider either:
- Lowering the threshold to a value that will actually trigger (e.g.,
>= 384
or>= 256
)- Adding a comment explaining this is for future test configurations
- Removing the condition if it's not needed
Apply this diff to make the condition effective for current tests:
- if routing_config["num_experts"] >= 512 and intermediate_size > 512: + if routing_config["num_experts"] >= 256 and intermediate_size > 1024: pytest.skip( - f"Skipping for testing speed: intermediate_size={intermediate_size} with {routing_config['num_experts']} experts" + f"Skipping for testing speed: large config with intermediate_size={intermediate_size} and {routing_config['num_experts']} experts" )Or add a clarifying comment:
+ # Future-proofing: skip very large configurations that may be added later if routing_config["num_experts"] >= 512 and intermediate_size > 512:
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tests/moe/test_trtllm_gen_fused_moe.py
(6 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/moe/test_trtllm_gen_fused_moe.py (2)
flashinfer/fused_moe/core.py (1)
RoutingMethodType
(59-73)include/flashinfer/trtllm/fused_moe/runner.h (1)
RoutingMethodType
(37-135)
🔇 Additional comments (2)
tests/moe/test_trtllm_gen_fused_moe.py (2)
1835-1836
: LGTM! Good documentation of cache key structure.The comment clearly documents the tuple structure used as the cache key, which improves code maintainability.
1917-1917
: LGTM! Implementation list reordering is benign.The reordering of
compatible_moe_impls
doesn't affect functionality, only the test execution order.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (3)
include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh (2)
169-169
: Fix the misleading static_assert message.The condition
N < 5
enforces N ≤ 4, but the message incorrectly references "128". This creates confusion about the actual constraint.Apply this diff:
- static_assert(N < 5, "Only support candidates number less than or equal to 128"); + static_assert(N < 5, "Only support up to 4 candidates per thread in this function");
220-223
: Reconsider the buffer index initialization.When
ii=0
,topKBufferIdx[0] = -1
, which can cause arithmetic overflow inTopKRedType::makeCmpVal
when computingmaxIdx - idx = 65535 - (-1) = 65536
, exceeding the0xFFFF
mask. WhileminValue
ensures these entries won't be selected, the overflow is still technically incorrect.Consider initializing with a safer sentinel:
for (int ii = 0; ii < numResults; ++ii) { topKBufferValue[ii] = minValue; - topKBufferIdx[ii] = ii * WarpSize - 1; + topKBufferIdx[ii] = RedType::maxIdx; // or another safe sentinel }csrc/trtllm_fused_moe_routing_deepseek.cu (1)
521-523
: Clarify the confusing error message.The check
data.mNumExperts >= MaxNumTopExperts
ensures there are enough experts, but the message "expects %d to be at most #experts %d" reads incorrectly. The message should reflect that experts must be at leastMaxNumTopExperts
.Apply this diff:
FLASHINFER_CHECK(data.mNumExperts >= MaxNumTopExperts, - "Routing kernel expects %d to be at most #experts %d", MaxNumTopExperts, + "Routing kernel expects #experts >= %d, got %d", MaxNumTopExperts, data.mNumExperts);
🧹 Nitpick comments (3)
csrc/trtllm_fused_moe_routing_deepseek.cu (1)
478-494
: Consider failing fast on unsupported expert counts.The macro logs an error but continues execution when
numExperts
exceedsNumKimiK2Experts
, potentially leaving kernels unlaunched and leading to incorrect results downstream.Consider adding a check that terminates:
} else { TLLM_LOG_ERROR("Unsupported numExperts"); + FLASHINFER_CHECK(false, "Unsupported numExperts: %d exceeds maximum %d", + data.mNumExperts, NumKimiK2Experts); }include/flashinfer/trtllm/fused_moe/runner.h (2)
76-78
: Typo in serialization string (“InvalidRountingMethod”)Minor, but this shows up in logs and can confuse debugging.
- return "InvalidRountingMethod"; // TODO throw error + return "InvalidRoutingMethod"; // TODO throw error
268-273
: Clarify routing-method-specific top_k limits; reference the central constant instead of hard-coded "10"The comment oversimplifies: top_k support varies by routing method. DeepSeek/grouped kernels support only
top_k<=8
, while renormalize methods supporttop_k<=10
. Avoid hard-coding "10" in a public header; instead reference the routing configuration or note the method-specific constraints.Suggested update:
- // Now support top_k<=10 + // Maximum top_k depends on routing method: renormalize/TopK supports up to 10 + // (see MaxNumTopK in RoutingKernelTopK.cuh), while DeepSeek/grouped limits to 8.Validation confirmed: limits are correctly enforced per method in kernel_launcher.cu (lines 86–89, 370–373, 383–384, 738–741, 752–754). Tests cover top_k=10 for Qwen/renormalize (test_trtllm_gen_fused_moe.py:2112).
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
csrc/trtllm_fused_moe_routing_deepseek.cu
(11 hunks)include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh
(12 hunks)include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh
(4 hunks)include/flashinfer/trtllm/fused_moe/runner.h
(1 hunks)
🔇 Additional comments (11)
include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh (3)
34-35
: LGTM!The new compile-time constants provide clear bounds for expert routing and top-K selection.
56-56
: LGTM!Direct assignment is cleaner and equivalent for this POD type.
204-207
: Good defensive programming.The
static_assert
at line 206-207 ensuresN % 4 == 0
(orN <= 4
), which prevents out-of-bounds reads in the loop at lines 230-233. This addresses the previous concern about reading past the array bounds when the tail chunk is smaller than 4.csrc/trtllm_fused_moe_routing_deepseek.cu (3)
26-31
: LGTM!Constants updated appropriately to support the new Kimi K2 and Qwen model architectures.
77-78
: LGTM!Using
blockDim.x
makes the indexing flexible and correctly handles different thread counts.
579-592
: LGTM!The launch logic now correctly matches runtime thread counts to compile-time specializations, addressing the previous blockDim-specialization concern.
include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh (5)
102-129
: LGTM!Computing softmax in
float
precision before casting back toDataType
improves numerical stability for half/bfloat16 inputs.
211-217
: LGTM!Both code paths now properly initialize
scoreIdx
, eliminating the previous uninitialized variable concern.
230-232
: LGTM!The conditional weight writing correctly handles the case where weights need to be extracted from packed representation.
387-469
: LGTM!The kernel is now properly parameterized with
KernelParams::MaxNumExperts
, and the data loading logic correctly handles bothTopKIds
andTopKPacked
paths.
706-729
: LGTM!The new initialization kernel properly sets up the expert counts buffer using a grid-stride loop and consistent PDL synchronization.
int constexpr getMaxNumExperts(int32_t numExperts) { | ||
if (numExperts <= topk::MaxNumExpertsUnit) { | ||
return topk::MaxNumExpertsUnit; | ||
} else if (numExperts <= NumDeepseekExperts) { | ||
return NumDeepseekExperts; | ||
} else if (numExperts <= NumKimiK2Experts) { | ||
return NumKimiK2Experts; | ||
} else { | ||
TLLM_LOG_ERROR("Unsupported numExperts"); | ||
return 0; | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Returning 0 on error is unsafe.
When numExperts
exceeds NumKimiK2Experts
, the function logs an error but returns 0. Callers will use this value for kernel launches and array sizing, leading to undefined behavior or silent failures.
Consider asserting or throwing instead:
} else {
- TLLM_LOG_ERROR("Unsupported numExperts");
- return 0;
+ FLASHINFER_CHECK(false, "Unsupported numExperts: %d exceeds maximum %d",
+ numExperts, NumKimiK2Experts);
+ return NumKimiK2Experts; // unreachable after CHECK, but silences warnings
}
}
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
int constexpr getMaxNumExperts(int32_t numExperts) { | |
if (numExperts <= topk::MaxNumExpertsUnit) { | |
return topk::MaxNumExpertsUnit; | |
} else if (numExperts <= NumDeepseekExperts) { | |
return NumDeepseekExperts; | |
} else if (numExperts <= NumKimiK2Experts) { | |
return NumKimiK2Experts; | |
} else { | |
TLLM_LOG_ERROR("Unsupported numExperts"); | |
return 0; | |
} | |
} | |
int constexpr getMaxNumExperts(int32_t numExperts) { | |
if (numExperts <= topk::MaxNumExpertsUnit) { | |
return topk::MaxNumExpertsUnit; | |
} else if (numExperts <= NumDeepseekExperts) { | |
return NumDeepseekExperts; | |
} else if (numExperts <= NumKimiK2Experts) { | |
return NumKimiK2Experts; | |
} else { | |
FLASHINFER_CHECK(false, "Unsupported numExperts: %d exceeds maximum %d", | |
numExperts, NumKimiK2Experts); | |
return NumKimiK2Experts; // unreachable after CHECK, but silences warnings | |
} | |
} |
Remove the cudaDeviceSync;Add support numexpert16; Clean up unit test Signed-off-by: Christina Zhang <[email protected]> Signed-off-by: jiahanc <[email protected]>
Signed-off-by: jiahanc <[email protected]>
Signed-off-by: jiahanc <[email protected]>
Signed-off-by: jiahanc <[email protected]>
Signed-off-by: jiahanc <[email protected]>
Signed-off-by: jiahanc <[email protected]>
Signed-off-by: jiahanc <[email protected]>
Signed-off-by: jiahanc <[email protected]>
Signed-off-by: jiahanc <[email protected]>
721371d
to
63233a3
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (10)
include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh (2)
218-233
: Fix OOB reads and invalid sentinel index in large-N TopK buffering.
- Reads past end when N % 4 != 0.
- Sentinel idx = -1 overflows 0xFFFF mask; breaks tie-breaking.
Apply:
- for (int ii = 0; ii < numResults; ++ii) { - topKBufferValue[ii] = minValue; - topKBufferIdx[ii] = ii * WarpSize - 1; - } + using RedT = TopKRedType<Type>; + for (int ii = 0; ii < numResults; ++ii) { + topKBufferValue[ii] = minValue; + topKBufferIdx[ii] = RedT::maxIdx; // safe sentinel, preserves ordering + } @@ - for (int i = 0; i < 4; ++i) { - inValue[i] = value[start + i]; - inIdx[i] = idx[start + i]; - } + int rem = N - start; + for (int i = 0; i < 4; ++i) { + if (i < rem) { + inValue[i] = value[start + i]; + inIdx[i] = idx[start + i]; + } else { + inValue[i] = minValue; + inIdx[i] = RedT::maxIdx; + } + }Also applies to: 219-221
170-171
: Static-assert message is misleading.Condition enforces N <= 4, but message says “<= 128”.
- static_assert(N < 5, "Only support candidates number less than or equal to 128"); + static_assert(N < 5, "Only support up to 4 candidates per thread in this function.");csrc/trtllm_fused_moe_routing_llama4.cu (1)
482-487
: Enable cooperative launch for cluster kernel.Cluster kernel uses cluster_dims/cluster barriers; launch cooperatively to guarantee cluster scheduling.
- LAUNCH_ROUTING_LLAMA4(data, - /*coopLaunch=*/false, routingIndicesClusterKernel, NumBlocksPerCluster, + LAUNCH_ROUTING_LLAMA4(data, + /*coopLaunch=*/true, routingIndicesClusterKernel, NumBlocksPerCluster, NumThreads, /*smemSize=*/0, // No dynamic smem stream);csrc/trtllm_fused_moe_routing_deepseek.cu (4)
191-206
: Fix inter-topK scratch sizing and typos.
- NumInterTopKPerThread mistakenly multiplies by NumExpertWarps; it should cover NumInterTopK only.
- Rename misspelled intermidiateScore/intermidiateExpert.
- int constexpr NumInterTopKPerThread = (NumInterTopK * NumExpertWarps - 1) / WarpSize + 1; - float intermidiateScore[NumInterTopKPerThread]; - int32_t intermidiateExpert[NumInterTopKPerThread]; + int constexpr NumInterTopKPerThread = (NumInterTopK - 1) / WarpSize + 1; + float intermediateScore[NumInterTopKPerThread]; + int32_t intermediateExpert[NumInterTopKPerThread]; @@ - intermidiateScore[ii] = smemInterTopScores[i]; - intermidiateExpert[ii] = smemInterTopExperts[i]; + intermediateScore[ii] = smemInterTopScores[i]; + intermediateExpert[ii] = smemInterTopExperts[i]; @@ - intermidiateScore[ii] = invalidScoreFloat; - intermidiateExpert[ii] = KernelParams::MaxNumExperts - 1; + intermediateScore[ii] = invalidScoreFloat; + intermediateExpert[ii] = KernelParams::MaxNumExperts - 1; @@ - topk::reduceTopK(warp, topScores, topExperts, intermidiateScore, intermidiateExpert, + topk::reduceTopK(warp, topScores, topExperts, intermediateScore, intermediateExpert, /* minValue */ invalidScoreFloat, params.mTopK);
518-523
: Clarify the error message to match the check.The check enforces numExperts ≥ MaxNumTopExperts; message says “at most”.
- FLASHINFER_CHECK(data.mNumExperts >= MaxNumTopExperts, - "Routing kernel expects %d to be at most #experts %d", MaxNumTopExperts, - data.mNumExperts); + FLASHINFER_CHECK(data.mNumExperts >= MaxNumTopExperts, + "Routing kernel expects #experts >= %d, got %d", + MaxNumTopExperts, data.mNumExperts);
464-475
: Avoid returning 0 from getMaxNumExperts.Returning 0 can cascade into invalid grid sizes. Fail fast.
} else { - TLLM_LOG_ERROR("Unsupported numExperts"); - return 0; + FLASHINFER_CHECK(false, "Unsupported numExperts: %d (max %d)", + numExperts, NumKimiK2Experts); + return NumKimiK2Experts; // unreachable; silences warnings }
574-589
: Only run routingMainKernel when scores are provided; match blockDim to specialization.Launching routingMainKernel with TopKPacked-only inputs is a no-op; also choose blockDim via getMaxNumExperts.
- if (data.mPtrTopKIds == nullptr) { - int const numThreadsMain = - data.mNumExperts < NumDeepseekExperts ? NumDeepseekExperts : NumKimiK2Experts; + if (data.mPtrScores != nullptr) { + int const numThreadsMain = getMaxNumExperts(data.mNumExperts); LAUNCH_ROUTING_DEEPSEEK(data, /*coopLaunch=*/false, routingMainKernel, numBlocks, numThreadsMain, /*smemSize=*/0, // No dynamic smem stream, data.mNumExpertGroups > 1); } else { // Reset the global histograms.include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh (2)
227-229
: Use a proper stride mask instead of comparing with log2 value.The current check uses mLocalExpertsStrideLog2 as a mask; compute mask = (1 << log2) - 1 and test lower bits.
- auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent && - (localExpertIdx & params.mLocalExpertsStrideLog2) == 0; + int32_t strideMask = (1 << params.mLocalExpertsStrideLog2) - 1; + bool isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent && + ((localExpertIdx & strideMask) == 0);Apply similarly at the other referenced locations.
Also applies to: 364-365, 581-583, 659-662
419-431
: Initialize idx when using TopKPacked; write weights conditionally.If mPtrTopKIds == nullptr and mPtrTopKWeights == nullptr, idx remains uninitialized.
- auto loopBody = [&](int expandedIdx) { - PackedScoreIdx<OutputT> scoreIdx; - int idx; - if (params.mPtrTopKIds != nullptr) { - idx = params.mPtrTopKIds[expandedIdx]; - } else { - // If params.mPtrTopKIds != nullptr, we don't need to store the weights - if (params.mPtrTopKWeights != nullptr) { - scoreIdx = params.mPtrTopKPacked[expandedIdx]; - idx = scoreIdx.idx; - params.mPtrTopKWeights[expandedIdx] = static_cast<OutputT>(scoreIdx.score); - } - } + auto loopBody = [&](int expandedIdx) { + int idx; + if (params.mPtrTopKIds != nullptr) { + idx = params.mPtrTopKIds[expandedIdx]; + } else { + auto scoreIdx = params.mPtrTopKPacked[expandedIdx]; + idx = scoreIdx.idx; + if (params.mPtrTopKWeights != nullptr) { + params.mPtrTopKWeights[expandedIdx] = static_cast<OutputT>(scoreIdx.score); + } + }csrc/trtllm_fused_moe_routing_renormalize.cu (1)
115-146
: Handle TopKPacked in block kernel (small-token path).When only mPtrTopKPacked is provided, smemKIdx remains unset, producing invalid outputs.
if (params.mPtrTopKIds != nullptr) { ... } else if (params.mPtrScores != nullptr) { ... + } else { // params.mPtrTopKPacked != nullptr + if (validToken && laneIdx < params.mTopK) { + TypePacked packed = params.mPtrTopKPacked[warpIdx * params.mTopK + laneIdx]; + int offset = warpIdx * MaxNumExperts + static_cast<int>(packed.idx); + smemKIdx[offset] = static_cast<int8_t>(laneIdx); + if (params.mPtrTopKWeights != nullptr) { + params.mPtrTopKWeights[warpIdx * params.mTopK + laneIdx] = OutputT{packed.score}; + } + } }
🧹 Nitpick comments (5)
include/flashinfer/trtllm/fused_moe/DevKernel.h (1)
121-124
: Unify BF16 type: prefer cutlass::bfloat16_t over __nv_bfloat16.Other launch macros use cutlass types; mixing __nv_bfloat16 can add include/friction. Use cutlass::bfloat16_t for consistency.
- LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, 128 /* Always 128 for llama4*/), kernel, + LAUNCH_ESC(cutlass::bfloat16_t, cutlass::bfloat16_t, 128 /* Always 128 for llama4*/), kernel,- LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, true), + LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(cutlass::bfloat16_t, cutlass::bfloat16_t, numExperts, true), kernel, numBlocks, numThreads, smemSize, stream); ... - LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, false), + LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(cutlass::bfloat16_t, cutlass::bfloat16_t, numExperts, false), kernel, numBlocks, numThreads, smemSize, stream);Also applies to: 191-196
flashinfer/fused_moe/core.py (1)
1134-1147
: Silence unused-arg warnings in fake ops.Consume parameters locally to keep signatures stable while satisfying linters.
def _fake_trtllm_fp8_per_tensor_scale_moe( @@ - ): + ): + # consume unused to appease linters while keeping signature stable + _ = (n_group, topk_group, intermediate_size, local_expert_offset, + local_num_experts, routed_scaling_factor, tile_tokens_dim, + routing_method_type, enable_pdl) seq_len = hidden_states.shape[0] hidden_size = hidden_states.shape[1] @@ def _fake_trtllm_fp8_block_scale_moe( @@ - ): + ): + # consume unused to appease linters while keeping signature stable + _ = (n_group, topk_group, intermediate_size, local_expert_offset, + local_num_experts, routed_scaling_factor, tile_tokens_dim, + routing_method_type, use_shuffled_weight, weight_layout, enable_pdl) seq_len = hidden_states.shape[0] hidden_size = hidden_states.shape[1] @@ def _fake_trtllm_fp4_block_scale_moe( @@ - ): + ): + # consume unused to appease linters while keeping signature stable + _ = (routing_logits, topk_ids, expert_weights, routing_bias, + hidden_states_scale, gemm1_bias, gemm1_alpha, gemm1_beta, + gemm1_clamp_limit, output1_scale_scalar, output1_scale_gate_scalar, + output2_scale_scalar, n_group, topk_group, intermediate_size, + local_expert_offset, routed_scaling_factor, tile_tokens_dim, + routing_method_type, do_finalize, enable_pdl, gated_act_type, + output, tune_max_num_tokens) seq_len = hidden_states.shape[0] hidden_size = hidden_states.shape[1]Also applies to: 1222-1233, 1450-1463
include/flashinfer/trtllm/fused_moe/RoutingKernel.h (1)
103-153
: Document/enforce TopK buffer invariants in setBaseParams.Guard that when mPtrTopKIds is set, mPtrTopKWeights must also be set (aligns with run-time checks elsewhere). Add a brief comment or debug assert.
void setBaseParams(DataType const& data) { @@ mPtrTopKWeights = static_cast<OutputT*>(data.mPtrTopKWeights); mPtrTopKIds = static_cast<int32_t*>(data.mPtrTopKIds); + // Invariant: if IDs are provided, weights must also be provided. + // assert((mPtrTopKIds == nullptr) || (mPtrTopKWeights != nullptr));csrc/trtllm_fused_moe_kernel_launcher.cu (1)
420-429
: Pointer cast for routing_logits.args.routing_logits is assigned as float* even when dtype is bfloat16 for non‑DeepSeek paths. Prefer void* or branch on dtype to avoid UB.
- args.routing_logits = static_cast<float*>(routing_logits->data); + args.routing_logits = routing_logits->data; // keep as void*, let callee interpret via dtypecsrc/trtllm_fused_moe_routing_renormalize.cu (1)
370-381
: Macro name typo (RENORNALIZE).Consider renaming to LAUNCH_ROUTING_RENORMALIZE for clarity and grepability.
-#define LAUNCH_ROUTING_RENORNALIZE(... +#define LAUNCH_ROUTING_RENORMALIZE(...Also update all call sites in this file.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (11)
csrc/trtllm_fused_moe_kernel_launcher.cu
(14 hunks)csrc/trtllm_fused_moe_routing_deepseek.cu
(11 hunks)csrc/trtllm_fused_moe_routing_llama4.cu
(8 hunks)csrc/trtllm_fused_moe_routing_renormalize.cu
(8 hunks)csrc/trtllm_fused_moe_runner.cu
(3 hunks)flashinfer/fused_moe/core.py
(11 hunks)include/flashinfer/trtllm/fused_moe/DevKernel.h
(2 hunks)include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh
(12 hunks)include/flashinfer/trtllm/fused_moe/RoutingKernel.h
(9 hunks)include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh
(3 hunks)tests/moe/test_trtllm_gen_fused_moe.py
(6 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- csrc/trtllm_fused_moe_runner.cu
🧰 Additional context used
🧬 Code graph analysis (7)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)
include/flashinfer/trtllm/fused_moe/runner.h (7)
top_k
(270-270)n_group
(271-271)topk_group
(273-273)intermediate_size
(275-275)local_expert_offset
(276-276)local_num_experts
(277-277)num_experts
(263-263)
csrc/trtllm_fused_moe_routing_deepseek.cu (2)
csrc/trtllm_fused_moe_routing_renormalize.cu (9)
void
(76-217)void
(288-291)void
(297-353)__launch_bounds__
(76-76)__launch_bounds__
(297-297)getMaxNumExperts
(357-366)getMaxNumExperts
(357-357)routingIndicesClusterKernel
(222-286)routingIndicesClusterKernel
(289-289)csrc/trtllm_fused_moe_routing_llama4.cu (8)
void
(67-280)void
(354-356)void
(363-423)__launch_bounds__
(363-363)getMaxNumExperts
(426-433)getMaxNumExperts
(426-426)routingIndicesClusterKernel
(285-352)routingIndicesClusterKernel
(354-354)
tests/moe/test_trtllm_gen_fused_moe.py (2)
flashinfer/fused_moe/core.py (1)
RoutingMethodType
(59-73)include/flashinfer/trtllm/fused_moe/runner.h (1)
RoutingMethodType
(37-136)
csrc/trtllm_fused_moe_routing_renormalize.cu (2)
csrc/trtllm_fused_moe_routing_deepseek.cu (7)
void
(34-252)void
(276-278)void
(459-461)getMaxNumExperts
(464-475)getMaxNumExperts
(464-464)routingIndicesClusterKernel
(260-274)routingIndicesClusterKernel
(276-276)csrc/trtllm_fused_moe_routing_llama4.cu (9)
void
(67-280)void
(354-356)void
(363-423)routingTopKExperts
(40-44)getMaxNumExperts
(426-433)getMaxNumExperts
(426-426)routingIndicesClusterKernel
(285-352)routingIndicesClusterKernel
(354-354)routingIndicesHistogramScoresKernel
(364-364)
csrc/trtllm_fused_moe_routing_llama4.cu (1)
csrc/trtllm_fused_moe_routing_renormalize.cu (9)
routingTopKExperts
(32-37)void
(76-217)void
(288-291)void
(297-353)getMaxNumExperts
(357-366)getMaxNumExperts
(357-357)routingIndicesClusterKernel
(222-286)routingIndicesClusterKernel
(289-289)routingIndicesHistogramScoresKernel
(298-298)
include/flashinfer/trtllm/fused_moe/RoutingKernel.h (1)
include/flashinfer/trtllm/fused_moe/DevKernel.h (4)
setKernelParams
(246-262)setKernelParams
(300-311)setKernelParams
(362-377)setKernelParams
(442-461)
flashinfer/fused_moe/core.py (1)
include/flashinfer/trtllm/fused_moe/runner.h (5)
n_group
(271-271)topk_group
(273-273)intermediate_size
(275-275)local_expert_offset
(276-276)local_num_experts
(277-277)
🪛 Ruff (0.14.1)
flashinfer/fused_moe/core.py
1134-1134: Unused function argument: n_group
(ARG001)
1135-1135: Unused function argument: topk_group
(ARG001)
1136-1136: Unused function argument: intermediate_size
(ARG001)
1137-1137: Unused function argument: local_expert_offset
(ARG001)
1138-1138: Unused function argument: local_num_experts
(ARG001)
1139-1139: Unused function argument: routed_scaling_factor
(ARG001)
1222-1222: Unused function argument: n_group
(ARG001)
1223-1223: Unused function argument: topk_group
(ARG001)
1224-1224: Unused function argument: intermediate_size
(ARG001)
1225-1225: Unused function argument: local_expert_offset
(ARG001)
1226-1226: Unused function argument: local_num_experts
(ARG001)
1227-1227: Unused function argument: routed_scaling_factor
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (4)
flashinfer/fused_moe/core.py (2)
174-194
: Cache-key fix looks good.Including epilogue_tile_m and num_elts_per_sf prevents cache collisions for permutations.
1070-1119
: I need to search more specifically for the pybind11 bindings. The initial search didn't find the PYBIND11_MODULE declarations. Let me expand the search:<!-- [scratchpad_update]
[key_finding_from_web_search]
pybind11 DOES support std::optional for binding Python None values
- pybind11 automatically handles conversion of None to std::optional (from docs)
- C++17's std::optional is the preferred approach for optional parameters (SO answer)
- pybind11/stl.h provides type caster for std::optional
However, the critical constraint:
- pybind11 docs state: "To pass optional argument of these copied types consider using std::optional"
- This refers to built-in types (int, float) being copied, not custom/opaque types
[action_next]
- Still need to find the actual pybind11 bindings for trtllm_fp8_per_tensor_scale_moe
- Search specifically for the binding source files with different patterns
[scratchpad_update] -->Now let me search for the actual pybind11 bindings in the repository:<!-- [scratchpad_update]
[critical_findings]
From FlashInfer docs (0.4.0 and 0.3.0):
- trtllm_fp4_block_scale_moe sig shows: n_group: int | None, topk_group: int | None, routed_scaling_factor: float | None
- trtllm_fp8_block_scale_moe sig shows: n_group: int (NOT optional), topk_group: int (NOT optional), routed_scaling_factor: float (NOT optional)
This is INCONSISTENT across different MOE functions!
- Some functions accept None (int | None, float | None)
- Others don't (plain int, plain float)
[key_insight]
The Python typing in the docs indicates which functions DO support Optional vs which DON'T.
The review concern is VALID for some functions but NOT for others.[action_next]
- Check the actual implementation in core.py to see which lines call which functions
- Verify the specific signature of trtllm_fp8_per_tensor_scale_moe (the one at line 1070-1119)
[scratchpad_update] -->Good! The documentation shows parameter types. Let me search for the specific signature of
trtllm_fp8_per_tensor_scale_moe
:<!-- [scratchpad_final]
[CRITICAL FINDING]
From FlashInfer official docs:
- trtllm_fp4_block_scale_moe: n_group: int | None, topk_group: int | None, routed_scaling_factor: float | None ✓ ACCEPTS OPTIONAL
- trtllm_fp8_block_scale_moe: n_group: int, topk_group: int, routed_scaling_factor: float ✗ NOT OPTIONAL
Document search could not locate trtllm_fp8_per_tensor_scale_moe signature explicitly, but based on pattern:
- fp8_per_tensor_scale_moe likely follows fp8_block_scale_moe signature (non-optional based on consistency)
- fp4_block_scale_moe accepts optionals
[CONCLUSION]
The review comment is PARTIALLY VALID:
- For fp4_block_scale_moe: pybind11 DOES accept std::optional (per pybind11 docs confirming std::optional support)
- For fp8_block_scale_moe & fp8_per_tensor_scale_moe: pybind11 may REJECT None if C++ signature doesn't use std::optional
Code at lines 1070-1119 calls trtllm_fp8_per_tensor_scale_moe, which based on pattern likely DOESN'T accept optional parameters.
BUT: Need to verify actual C++ binding signature and whether it uses std::optional. The codebase search showed std::optional EXISTS in the codebase, but that doesn't prove these specific functions use it.
[RISK ASSESSMENT]
- HIGH RISK: If trtllm_fp8_per_tensor_scale_moe bindings don't use std::optional and code passes None, runtime TypeError will occur
- Recommendation: Developer should verify by either checking pybind binding definitions OR by converting None to sentinels
[scratchpad_final] -->Based on my verification, I found important information about pybind11 handling of optional parameters and pybind11 recommends using std::optional to pass optional arguments of copied types. The FlashInfer documentation shows inconsistency across MOE functions: trtllm_fp4_block_scale_moe accepts n_group: int | None, topk_group: int | None, routed_scaling_factor: float | None, while trtllm_fp8_block_scale_moe requires n_group: int, topk_group: int, routed_scaling_factor: float (non-optional).
The code at lines 1070-1119 calls
trtllm_fp8_per_tensor_scale_moe
, which appears to follow the FP8 pattern of requiring non-optional parameters. Without access to the specific pybind11 binding definitions for this function, I cannot definitively confirm whether it will reject None values at runtime. The concern in the review comment is valid but requires verification of the actual C++ binding signatures.Verify pybind11 bindings accept Optional parameters, or convert None to sentinels before C++ calls.
The risk is legitimate: if pybind bindings don't declare
std::optional<T>
, passing None will cause runtime TypeError. Convert None values to sentinel integers (e.g., -1) or verify the C++ bindings support std::optional before merging.tests/moe/test_trtllm_gen_fused_moe.py (1)
2110-2113
: Test guard aligns with kernel limit.Asserting top_k <= 10 here matches the updated routing kernels.
include/flashinfer/trtllm/fused_moe/DevKernel.h (1)
49-49
: Review comment is incorrect and should be dismissed.The LAUNCH_ROUTING macros are not removed—they remain actively defined and used throughout the codebase:
- LAUNCH_ROUTING_LLAMA4, LAUNCH_ROUTING_DEEPSEEK_IMPL, and LAUNCH_ROUTING_WITH_NUM_EXPERTS are called 18+ times across csrc/trtllm_fused_moe_routing_*.cu files
- The new LAUNCH_ESC macro (line 49) is a generic escape/passthrough wrapper, not a replacement for LAUNCH_ROUTING_*
- The original verification script's regex pattern
\bLAUNCH_ROUTING\s*\(
was too narrow and did not match the actual macro variants being used (e.g., LAUNCH_ROUTING_LLAMA4)No migration or verification is needed.
Likely an incorrect or invalid review comment.
Signed-off-by: jiahanc <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
flashinfer/fused_moe/core.py (2)
174-194
: Cache key is missing critical parameters—potential regression.A previous review correctly identified that the cache key must include
epilogue_tile_m
andnum_elts_per_sf
(marked as addressed in commit d779e40). However, the current code at line 175 only uses("w3_w1", dst_w3_w1_weight.shape)
. Since the permute indices computation on lines 179-188 depends on both parameters, omitting them can return incorrect cached permutations when the same weight shape is used with different tile/scaling parameters.Apply this fix:
- # Create a unique cache key (weight_type, weight_shape) - cache_key = ("w3_w1", dst_w3_w1_weight.shape) + # Create a unique cache key (weight_type, weight_shape, epilogue_tile_m, num_elts_per_sf) + cache_key = ("w3_w1", dst_w3_w1_weight.shape, epilogue_tile_m, num_elts_per_sf)
197-219
: Same cache key issue for w2 permute indices.The w2 permute indices cache at line 204 has the same problem: the cache key
("w2", dst_w2_weight.shape)
omitsepilogue_tile_m
andnum_elts_per_sf
, which affect the permute indices computation on lines 206-215.Apply this fix:
- # Create a unique cache key (weight_type, weight_shape) - cache_key = ("w2", dst_w2_weight.shape) + # Create a unique cache key (weight_type, weight_shape, epilogue_tile_m, num_elts_per_sf) + cache_key = ("w2", dst_w2_weight.shape, epilogue_tile_m, num_elts_per_sf)
🧹 Nitpick comments (2)
tests/moe/test_trtllm_gen_fused_moe.py (2)
1835-1836
: Update comment to reflect actual cache key structure.The comment states the cache key is
(weight_type, shape)
, but the actual implementation uses a 4-tuple:(weight_type, shape, epilogue_tile_m, num_elts_per_sf)
. Update the comment to accurately document all key components.- # The cache key is now a tuple of (weight_type, shape) + # The cache key is now a tuple of (weight_type, shape, epilogue_tile_m, num_elts_per_sf) + # where num_elts_per_sf is optional (can be None)
1948-1948
: Consider documenting why FP8PerTensorMoe is excluded from RenormNaive.The
RenormNaive
routing configuration excludesFP8PerTensorMoe
fromcompatible_moe_impls
, while the similarRenorm
configuration includes it. If this is an intentional limitation (e.g., RenormalizeNaive routing has specific implementation constraints), consider adding a comment to explain the exclusion for future maintainability.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
flashinfer/fused_moe/core.py
(11 hunks)tests/moe/test_trtllm_gen_fused_moe.py
(6 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/moe/test_trtllm_gen_fused_moe.py (2)
flashinfer/fused_moe/core.py (1)
RoutingMethodType
(59-73)include/flashinfer/trtllm/fused_moe/runner.h (1)
RoutingMethodType
(37-136)
flashinfer/fused_moe/core.py (1)
include/flashinfer/trtllm/fused_moe/runner.h (5)
n_group
(271-271)topk_group
(273-273)intermediate_size
(275-275)local_expert_offset
(276-276)local_num_experts
(277-277)
🪛 Ruff (0.14.1)
flashinfer/fused_moe/core.py
1134-1134: Unused function argument: n_group
(ARG001)
1135-1135: Unused function argument: topk_group
(ARG001)
1136-1136: Unused function argument: intermediate_size
(ARG001)
1137-1137: Unused function argument: local_expert_offset
(ARG001)
1138-1138: Unused function argument: local_num_experts
(ARG001)
1139-1139: Unused function argument: routed_scaling_factor
(ARG001)
1222-1222: Unused function argument: n_group
(ARG001)
1223-1223: Unused function argument: topk_group
(ARG001)
1224-1224: Unused function argument: intermediate_size
(ARG001)
1225-1225: Unused function argument: local_expert_offset
(ARG001)
1226-1226: Unused function argument: local_num_experts
(ARG001)
1227-1227: Unused function argument: routed_scaling_factor
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (8)
tests/moe/test_trtllm_gen_fused_moe.py (5)
1842-1842
: LGTM - Improved test coverage.Adding
intermediate_size=512
extends test coverage appropriately, and the new skip condition at lines 2059-2064 ensures this doesn't cause excessive test times for large expert counts.
1924-1937
: Clarify why FP8PerTensorMoe is excluded from Qwen3_next.The Qwen3_next configuration excludes
FP8PerTensorMoe
fromcompatible_moe_impls
, while the similar "Renorm" routing configuration (line 1917) includes it. Both use the sameRoutingMethodType.Renormalize
routing method. Is this exclusion intentional, or shouldFP8PerTensorMoe
be added to the compatible implementations list?
2059-2064
: LGTM - Reasonable performance optimization.The skip condition appropriately limits test execution time for configurations with many experts (≥512) and large intermediate sizes (>512) without affecting the current test matrix (max 384 experts in kimi_k2 config).
2112-2112
: LGTM - Correctly reflects updated kernel capability.The assertion change from
top_k <= 8
totop_k <= 10
properly aligns with the kernel updates that increasedMaxNumTopExperts
to 10, enabling support for the new Qwen3_next configuration.
1856-1872
: Now I need to verify whether therouted_scaling
discrepancy (2.5 in test vs 2.827 in official config) is intentional:Verify if
routed_scaling
value of 2.5 in test config should be 2.827 per official KIMI K2 spec.Verification confirms most of the KIMI K2 configuration is correct: 384 experts with 8 experts selected per token, and the official config has
n_group: 1
andtopk_group: 1
matching your test. However, the HuggingFace config showsrouted_scaling_factor: 2.827
, while your test usesrouted_scaling: 2.5
. This 14.9% difference may be intentional (e.g., for TRTLLM implementation specifics) or an oversight. Please confirm whether this difference is acceptable or if the test should use 2.827 to align with the official model specification.flashinfer/fused_moe/core.py (3)
1134-1139
: Static analysis warnings expected for fake implementations.Ruff flags unused arguments in
_fake_trtllm_fp8_per_tensor_scale_moe
(lines 1134-1139) and_fake_trtllm_fp8_block_scale_moe
(lines 1222-1227). These are shape inference stubs for torch.compile, so unused arguments are expected and not issues.Also applies to: 1222-1227
1476-1546
: Consistent Optional parameter propagation across public API.The Optional parameter changes for
n_group
,topk_group
, androuted_scaling_factor
are consistently applied across all public MoE variants (FP8 per-tensor, FP8 block-scale, and FP4 block-scale). The docstrings correctly document that these parameters can be None for certain routing methods, aligning with the PR objective to support different routing configurations (KIMI K2, Qwen).Also applies to: 1549-1623, 1626-1755
1069-1120
: No changes needed. The C++ layer properly handles Optional parameters.The C++ function signatures in
trtllm_fused_moe_kernel_launcher.cu
(lines 305-313) correctly declareOptional<int64_t>
andOptional<double>
types. The implementation validates these using.has_value()
checks (e.g., line 81) and applies sensible defaults (0 forn_group
/topk_group
, 1.0 forrouted_scaling_factor
). The TVM FFI binding layer properly marshals PythonNone
values to C++Optional<>
types, so Python can passNone
directly without manual conversion.
/bot run |
[FAILED] Pipeline #36935152: 1/17 passed |
📌 Description
Update the routing code to align with the implementation in TRTLLM and add support for KIMI K2 and Qwen
Also revised the unit test based on the config of kimi k2 (https://huggingface.co/moonshotai/Kimi-K2-Instruct/blob/main/config.json)
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commit
by runningpip install pre-commit
(or used your preferred method).pre-commit install
.pre-commit run --all-files
and fixed any reported issues.🧪 Tests
unittest
, etc.).Reviewer Notes
Summary by CodeRabbit
Release Notes
New Features
Refactor