Skip to content

Conversation

@andrewbriand
Copy link

@andrewbriand andrewbriand commented Dec 1, 2025

Purpose

Support EPLB in combination with NVFP4.

Test Plan

Added a test test_eplb_fused_moe_layer_dep_nvfp4.py which ensures that NVFP4 backends correctly route tokens to physical experts based on their logical expert ids.

Test Result

Tests pass on GB200.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@chatgpt-codex-connector
Copy link

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for Expert Parallel Load Balancing (EPLB) with NVFP4 quantization. The changes include a new test case for this functionality and modifications to ModelOptNvFp4FusedMoE to handle the EPLB path, along with a new kernel wrapper flashinfer_trtllm_fp4_routed_moe. The implementation is largely correct, but I've identified a critical issue where the routing method type is hardcoded in the new kernel wrapper. This would lead to incorrect behavior for MoE models that use different routing mechanisms. I have provided comments with suggestions to address this issue by dynamically determining the routing method.

Comment on lines +336 to +405
def flashinfer_trtllm_fp4_routed_moe(
layer: torch.nn.Module,
x: torch.Tensor,
topk_ids: torch.Tensor, # Packed
top_k: int,
global_num_experts: int,
) -> torch.Tensor:
"""
Apply FlashInfer TensorRT-LLM FP4 MoE kernel. Uses packed
input top k expert indices and scores rather than computing
top k expert indices from scores.
Args:
layer: The MoE layer with weights and scales
x: Input tensor
topk_ids: Ids of selected experts
top_k: Number of experts to select per token
global_num_experts: Total number of experts across all ranks
Returns:
Output tensor from the MoE layer
"""
import flashinfer

# Quantize input to FP4
a1_gscale = layer.w13_input_scale_quant
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
x,
a1_gscale,
is_sf_swizzled_layout=False,
)

# Call TRT-LLM FP4 block-scale MoE kernel
out = flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe(
topk_ids=topk_ids,
routing_bias=None,
hidden_states=hidden_states_fp4,
hidden_states_scale=hidden_states_scale_linear_fp4.view(
torch.float8_e4m3fn
).flatten(),
gemm1_weights=layer.gemm1_weights_fp4_shuffled.data,
gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view(
torch.float8_e4m3fn
),
gemm1_bias=None,
gemm1_alpha=None,
gemm1_beta=None,
gemm1_clamp_limit=None,
gemm2_weights=layer.gemm2_weights_fp4_shuffled.data,
gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view(
torch.float8_e4m3fn
),
gemm2_bias=None,
output1_scale_scalar=layer.g1_scale_c.data,
output1_scale_gate_scalar=layer.g1_alphas.data,
output2_scale_scalar=layer.g2_alphas.data,
num_experts=global_num_experts,
top_k=top_k,
n_group=0,
topk_group=0,
intermediate_size=layer.intermediate_size_per_partition,
local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
routed_scaling_factor=None,
tile_tokens_dim=None,
routing_method_type=1,
do_finalize=True,
)[0]

return out
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The routing_method_type is hardcoded to 1, which corresponds to Renormalize. This will cause incorrect behavior for models that use different routing mechanisms, such as DeepSeek or Llama4. The routing method should be determined dynamically based on the model's configuration, similar to how it's done in flashinfer_trtllm_fp4_moe.

To fix this, please update the function to accept custom_routing_function and use it to determine the correct routing_method_type.

def flashinfer_trtllm_fp4_routed_moe(
    layer: torch.nn.Module,
    x: torch.Tensor,
    topk_ids: torch.Tensor,  # Packed
    top_k: int,
    global_num_experts: int,
    custom_routing_function: object | None,
) -> torch.Tensor:
    """
    Apply FlashInfer TensorRT-LLM FP4 MoE kernel. Uses packed
    input top k expert indices and scores rather than computing
    top k expert indices from scores.

    Args:
        layer: The MoE layer with weights and scales
        x: Input tensor
        topk_ids: Ids of selected experts
        top_k: Number of experts to select per token
        global_num_experts: Total number of experts across all ranks
        custom_routing_function: Custom routing function (e.g., for Llama4)

    Returns:
        Output tensor from the MoE layer
    """
    import flashinfer
    from vllm.model_executor.models.llama4 import Llama4MoE

    use_llama4_routing = custom_routing_function is Llama4MoE.custom_routing_function
    routing_method_type = layer.routing_method_type
    if use_llama4_routing:
        routing_method_type = flashinfer.RoutingMethodType.Llama4

    # Quantize input to FP4
    a1_gscale = layer.w13_input_scale_quant
    (hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
        x,
        a1_gscale,
        is_sf_swizzled_layout=False,
    )

    # Call TRT-LLM FP4 block-scale MoE kernel
    out = flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe(
        topk_ids=topk_ids,
        routing_bias=None,
        hidden_states=hidden_states_fp4,
        hidden_states_scale=hidden_states_scale_linear_fp4.view(
            torch.float8_e4m3fn
        ).flatten(),
        gemm1_weights=layer.gemm1_weights_fp4_shuffled.data,
        gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view(
            torch.float8_e4m3fn
        ),
        gemm1_bias=None,
        gemm1_alpha=None,
        gemm1_beta=None,
        gemm1_clamp_limit=None,
        gemm2_weights=layer.gemm2_weights_fp4_shuffled.data,
        gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view(
            torch.float8_e4m3fn
        ),
        gemm2_bias=None,
        output1_scale_scalar=layer.g1_scale_c.data,
        output1_scale_gate_scalar=layer.g1_alphas.data,
        output2_scale_scalar=layer.g2_alphas.data,
        num_experts=global_num_experts,
        top_k=top_k,
        n_group=0,
        topk_group=0,
        intermediate_size=layer.intermediate_size_per_partition,
        local_expert_offset=layer.ep_rank * layer.local_num_experts,
        local_num_experts=layer.local_num_experts,
        routed_scaling_factor=None,
        tile_tokens_dim=None,
        routing_method_type=routing_method_type,
        do_finalize=True,
    )[0]

    return out

Comment on lines +1515 to +1578
return flashinfer_trtllm_fp4_routed_moe(
layer=layer,
x=x,
topk_ids=packed_tensor,
top_k=top_k,
global_num_experts=global_num_experts,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

To fix the hardcoded routing_method_type in flashinfer_trtllm_fp4_routed_moe, you need to pass the custom_routing_function to it. This will allow the function to dynamically determine the correct routing method.

Suggested change
return flashinfer_trtllm_fp4_routed_moe(
layer=layer,
x=x,
topk_ids=packed_tensor,
top_k=top_k,
global_num_experts=global_num_experts,
)
return flashinfer_trtllm_fp4_routed_moe(
layer=layer,
x=x,
topk_ids=packed_tensor,
top_k=top_k,
global_num_experts=global_num_experts,
custom_routing_function=custom_routing_function,
)

@github-actions
Copy link

github-actions bot commented Dec 1, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

Signed-off-by: Andrew Briand <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

1 participant