-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
[EPLB] Support EPLB w/ NVFP4 #29804
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[EPLB] Support EPLB w/ NVFP4 #29804
Conversation
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for Expert Parallel Load Balancing (EPLB) with NVFP4 quantization. The changes include a new test case for this functionality and modifications to ModelOptNvFp4FusedMoE to handle the EPLB path, along with a new kernel wrapper flashinfer_trtllm_fp4_routed_moe. The implementation is largely correct, but I've identified a critical issue where the routing method type is hardcoded in the new kernel wrapper. This would lead to incorrect behavior for MoE models that use different routing mechanisms. I have provided comments with suggestions to address this issue by dynamically determining the routing method.
| def flashinfer_trtllm_fp4_routed_moe( | ||
| layer: torch.nn.Module, | ||
| x: torch.Tensor, | ||
| topk_ids: torch.Tensor, # Packed | ||
| top_k: int, | ||
| global_num_experts: int, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Apply FlashInfer TensorRT-LLM FP4 MoE kernel. Uses packed | ||
| input top k expert indices and scores rather than computing | ||
| top k expert indices from scores. | ||
| Args: | ||
| layer: The MoE layer with weights and scales | ||
| x: Input tensor | ||
| topk_ids: Ids of selected experts | ||
| top_k: Number of experts to select per token | ||
| global_num_experts: Total number of experts across all ranks | ||
| Returns: | ||
| Output tensor from the MoE layer | ||
| """ | ||
| import flashinfer | ||
|
|
||
| # Quantize input to FP4 | ||
| a1_gscale = layer.w13_input_scale_quant | ||
| (hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize( | ||
| x, | ||
| a1_gscale, | ||
| is_sf_swizzled_layout=False, | ||
| ) | ||
|
|
||
| # Call TRT-LLM FP4 block-scale MoE kernel | ||
| out = flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe( | ||
| topk_ids=topk_ids, | ||
| routing_bias=None, | ||
| hidden_states=hidden_states_fp4, | ||
| hidden_states_scale=hidden_states_scale_linear_fp4.view( | ||
| torch.float8_e4m3fn | ||
| ).flatten(), | ||
| gemm1_weights=layer.gemm1_weights_fp4_shuffled.data, | ||
| gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view( | ||
| torch.float8_e4m3fn | ||
| ), | ||
| gemm1_bias=None, | ||
| gemm1_alpha=None, | ||
| gemm1_beta=None, | ||
| gemm1_clamp_limit=None, | ||
| gemm2_weights=layer.gemm2_weights_fp4_shuffled.data, | ||
| gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view( | ||
| torch.float8_e4m3fn | ||
| ), | ||
| gemm2_bias=None, | ||
| output1_scale_scalar=layer.g1_scale_c.data, | ||
| output1_scale_gate_scalar=layer.g1_alphas.data, | ||
| output2_scale_scalar=layer.g2_alphas.data, | ||
| num_experts=global_num_experts, | ||
| top_k=top_k, | ||
| n_group=0, | ||
| topk_group=0, | ||
| intermediate_size=layer.intermediate_size_per_partition, | ||
| local_expert_offset=layer.ep_rank * layer.local_num_experts, | ||
| local_num_experts=layer.local_num_experts, | ||
| routed_scaling_factor=None, | ||
| tile_tokens_dim=None, | ||
| routing_method_type=1, | ||
| do_finalize=True, | ||
| )[0] | ||
|
|
||
| return out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The routing_method_type is hardcoded to 1, which corresponds to Renormalize. This will cause incorrect behavior for models that use different routing mechanisms, such as DeepSeek or Llama4. The routing method should be determined dynamically based on the model's configuration, similar to how it's done in flashinfer_trtllm_fp4_moe.
To fix this, please update the function to accept custom_routing_function and use it to determine the correct routing_method_type.
def flashinfer_trtllm_fp4_routed_moe(
layer: torch.nn.Module,
x: torch.Tensor,
topk_ids: torch.Tensor, # Packed
top_k: int,
global_num_experts: int,
custom_routing_function: object | None,
) -> torch.Tensor:
"""
Apply FlashInfer TensorRT-LLM FP4 MoE kernel. Uses packed
input top k expert indices and scores rather than computing
top k expert indices from scores.
Args:
layer: The MoE layer with weights and scales
x: Input tensor
topk_ids: Ids of selected experts
top_k: Number of experts to select per token
global_num_experts: Total number of experts across all ranks
custom_routing_function: Custom routing function (e.g., for Llama4)
Returns:
Output tensor from the MoE layer
"""
import flashinfer
from vllm.model_executor.models.llama4 import Llama4MoE
use_llama4_routing = custom_routing_function is Llama4MoE.custom_routing_function
routing_method_type = layer.routing_method_type
if use_llama4_routing:
routing_method_type = flashinfer.RoutingMethodType.Llama4
# Quantize input to FP4
a1_gscale = layer.w13_input_scale_quant
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
x,
a1_gscale,
is_sf_swizzled_layout=False,
)
# Call TRT-LLM FP4 block-scale MoE kernel
out = flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe(
topk_ids=topk_ids,
routing_bias=None,
hidden_states=hidden_states_fp4,
hidden_states_scale=hidden_states_scale_linear_fp4.view(
torch.float8_e4m3fn
).flatten(),
gemm1_weights=layer.gemm1_weights_fp4_shuffled.data,
gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view(
torch.float8_e4m3fn
),
gemm1_bias=None,
gemm1_alpha=None,
gemm1_beta=None,
gemm1_clamp_limit=None,
gemm2_weights=layer.gemm2_weights_fp4_shuffled.data,
gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view(
torch.float8_e4m3fn
),
gemm2_bias=None,
output1_scale_scalar=layer.g1_scale_c.data,
output1_scale_gate_scalar=layer.g1_alphas.data,
output2_scale_scalar=layer.g2_alphas.data,
num_experts=global_num_experts,
top_k=top_k,
n_group=0,
topk_group=0,
intermediate_size=layer.intermediate_size_per_partition,
local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
routed_scaling_factor=None,
tile_tokens_dim=None,
routing_method_type=routing_method_type,
do_finalize=True,
)[0]
return out| return flashinfer_trtllm_fp4_routed_moe( | ||
| layer=layer, | ||
| x=x, | ||
| topk_ids=packed_tensor, | ||
| top_k=top_k, | ||
| global_num_experts=global_num_experts, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To fix the hardcoded routing_method_type in flashinfer_trtllm_fp4_routed_moe, you need to pass the custom_routing_function to it. This will allow the function to dynamically determine the correct routing method.
| return flashinfer_trtllm_fp4_routed_moe( | |
| layer=layer, | |
| x=x, | |
| topk_ids=packed_tensor, | |
| top_k=top_k, | |
| global_num_experts=global_num_experts, | |
| ) | |
| return flashinfer_trtllm_fp4_routed_moe( | |
| layer=layer, | |
| x=x, | |
| topk_ids=packed_tensor, | |
| top_k=top_k, | |
| global_num_experts=global_num_experts, | |
| custom_routing_function=custom_routing_function, | |
| ) |
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run You ask your reviewers to trigger select CI tests on top of Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
Signed-off-by: Andrew Briand <[email protected]>
cf61fec to
e642217
Compare
Purpose
Support EPLB in combination with NVFP4.
Test Plan
Added a test
test_eplb_fused_moe_layer_dep_nvfp4.pywhich ensures that NVFP4 backends correctly route tokens to physical experts based on their logical expert ids.Test Result
Tests pass on GB200.
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.