-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
[EPLB] Support EPLB w/ NVFP4 #29804
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
andrewbriand
wants to merge
1
commit into
vllm-project:main
Choose a base branch
from
andrewbriand:abriand_eplb_nvfp4_2
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+374
−5
Open
[EPLB] Support EPLB w/ NVFP4 #29804
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
276 changes: 276 additions & 0 deletions
276
tests/distributed/test_eplb_fused_moe_layer_dep_nvfp4.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,276 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| # Test that the interaction between EPLB and FusedMoE Layer is okay for DP w/ NVFP4 | ||
|
|
||
| from dataclasses import dataclass | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| from tests.kernels.moe.utils import make_test_quant_config | ||
| from vllm.config import VllmConfig, set_current_vllm_config | ||
| from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace | ||
| from vllm.distributed.parallel_state import ( | ||
| ensure_model_parallel_initialized, | ||
| get_dp_group, | ||
| ) | ||
| from vllm.forward_context import set_forward_context | ||
| from vllm.model_executor.layers.fused_moe.layer import FusedMoE | ||
| from vllm.model_executor.layers.quantization.modelopt import ( | ||
| ModelOptNvFp4Config, | ||
| ModelOptNvFp4FusedMoE, | ||
| ) | ||
|
|
||
| from .eplb_utils import distributed_run, set_env_vars_and_device | ||
|
|
||
|
|
||
| @dataclass | ||
| class TestConfig: | ||
| num_layers: int | ||
| num_experts: int | ||
| num_local_experts: int | ||
| num_topk: int | ||
| hidden_size: int | ||
| intermediate_size: int | ||
| num_tokens: int | ||
|
|
||
|
|
||
| def make_fused_moe_layer( | ||
| rank: int, | ||
| layer_idx: int, | ||
| test_config: TestConfig, | ||
| ) -> FusedMoE: | ||
| quant_config = None | ||
|
|
||
| device = torch.device(f"cuda:{rank}") | ||
|
|
||
| quant_config = ModelOptNvFp4Config( | ||
| is_checkpoint_nvfp4_serialized=True, | ||
| kv_cache_quant_algo=None, | ||
| exclude_modules=[], | ||
| ) | ||
|
|
||
| fml = FusedMoE( | ||
| num_experts=test_config.num_experts, | ||
| top_k=test_config.num_topk, | ||
| hidden_size=test_config.hidden_size, | ||
| intermediate_size=test_config.intermediate_size, | ||
| prefix=f"dummy_layer_{layer_idx}", | ||
| activation="silu", | ||
| is_act_and_mul=True, | ||
| params_dtype=torch.bfloat16, | ||
| quant_config=quant_config, | ||
| ) | ||
|
|
||
| nvfp4_fused_moe = ModelOptNvFp4FusedMoE(quant_config, fml) | ||
| nvfp4_fused_moe.create_weights( | ||
| fml, | ||
| test_config.num_local_experts, | ||
| test_config.hidden_size, | ||
| test_config.intermediate_size, | ||
| params_dtype=torch.uint8, | ||
| global_num_experts=test_config.num_experts, | ||
| ) | ||
|
|
||
| fml = fml.to(device) | ||
| w1_q, w2_q, quant_config = make_test_quant_config( | ||
| test_config.num_local_experts, | ||
| test_config.intermediate_size, | ||
| test_config.hidden_size, | ||
| in_dtype=torch.bfloat16, | ||
| quant_dtype="nvfp4", | ||
| block_shape=None, | ||
| per_act_token_quant=False, | ||
| ) | ||
|
|
||
| fml.w13_weight.data = w1_q | ||
| fml.w2_weight.data = w2_q | ||
|
|
||
| fml.w2_input_scale.data = torch.randn_like(fml.w2_input_scale.data) / 5 | ||
| fml.w13_input_scale.data = torch.randn_like(fml.w13_input_scale.data) / 5 | ||
| fml.w2_weight_scale_2.data = torch.randn_like(fml.w2_weight_scale_2.data) / 5 | ||
| fml.w13_weight_scale_2.data = torch.randn_like(fml.w13_weight_scale_2.data) / 5 | ||
| fml.w2_weight_scale.data = ( | ||
| torch.randn(fml.w2_weight_scale.data.shape, device=device) / 5 | ||
| ).to(fml.w2_weight_scale.data.dtype) | ||
| fml.w13_weight_scale.data = ( | ||
| torch.randn(fml.w13_weight_scale.data.shape, device=device) / 5 | ||
| ).to(fml.w13_weight_scale.data.dtype) | ||
|
|
||
| nvfp4_fused_moe.process_weights_after_loading(fml) | ||
|
|
||
| fml.maybe_init_modular_kernel() | ||
|
|
||
| return fml | ||
|
|
||
|
|
||
| def _test_eplb_fml(env, world_size: int, test_config: TestConfig): | ||
| set_env_vars_and_device(env) | ||
|
|
||
| vllm_config = VllmConfig() | ||
| vllm_config.parallel_config.data_parallel_size = world_size | ||
| vllm_config.parallel_config.enable_expert_parallel = True | ||
|
|
||
| with set_current_vllm_config(vllm_config): | ||
| ensure_model_parallel_initialized( | ||
| tensor_model_parallel_size=1, pipeline_model_parallel_size=1 | ||
| ) | ||
|
|
||
| ep_group = get_dp_group().cpu_group | ||
| ep_rank = torch.distributed.get_rank() | ||
|
|
||
| device = torch.device(f"cuda:{ep_rank}") | ||
|
|
||
| fml_layers = [ | ||
| make_fused_moe_layer(ep_rank, layer_idx, test_config).to(device) | ||
| for layer_idx in range(test_config.num_layers) | ||
| ] | ||
| rank_expert_weights = [fml.get_expert_weights() for fml in fml_layers] | ||
|
|
||
| hidden_states = [] | ||
| router_logits = [] | ||
| for layer_idx in range(test_config.num_layers): | ||
| hidden_states.append( | ||
| torch.randn( | ||
| (test_config.num_tokens, test_config.hidden_size), | ||
| dtype=torch.bfloat16, | ||
| device=device, | ||
| ) | ||
| ) | ||
| router_logits.append( | ||
| torch.randn( | ||
| (test_config.num_tokens, test_config.num_experts), | ||
| dtype=torch.bfloat16, | ||
| device=device, | ||
| ) | ||
| ) | ||
|
|
||
| out_before_shuffle = [] | ||
| with set_forward_context( | ||
| {}, | ||
| num_tokens=test_config.num_tokens, | ||
| num_tokens_across_dp=torch.tensor( | ||
| [test_config.num_tokens] * world_size, device="cpu", dtype=torch.int | ||
| ), | ||
| vllm_config=vllm_config, | ||
| ): | ||
| for lidx, fml in enumerate(fml_layers): | ||
| out_before_shuffle.append( | ||
| fml(hidden_states[lidx].clone(), router_logits[lidx].clone()) | ||
| ) | ||
|
|
||
| indices = torch.zeros( | ||
| test_config.num_layers, test_config.num_experts, dtype=torch.long | ||
| ) | ||
| for lidx in range(test_config.num_layers): | ||
| indices[lidx] = torch.Tensor(range(test_config.num_experts)) | ||
|
|
||
| shuffled_indices = torch.zeros_like(indices) | ||
| for lidx in range(test_config.num_layers): | ||
| shuffled_indices[lidx] = torch.randperm(test_config.num_experts) | ||
|
|
||
| rearrange_expert_weights_inplace( | ||
| indices, | ||
| shuffled_indices, | ||
| rank_expert_weights, | ||
| ep_group, | ||
| is_profile=False, | ||
| ) | ||
|
|
||
| num_global_experts = test_config.num_experts | ||
|
|
||
| logical_to_physical_map_list = [] | ||
| for lidx, fml in enumerate(fml_layers): | ||
| physical_to_logical_map = shuffled_indices[lidx].to(device) | ||
| logical_to_physical_map = torch.empty( | ||
| (num_global_experts,), dtype=torch.int32, device=device | ||
| ) | ||
| logical_to_physical_map[physical_to_logical_map] = torch.arange( | ||
| 0, num_global_experts, dtype=torch.int32, device=device | ||
| ) | ||
| logical_to_physical_map_list.append( | ||
| logical_to_physical_map.reshape(num_global_experts, 1) | ||
| ) | ||
|
|
||
| logical_to_physical_map = torch.stack(logical_to_physical_map_list) | ||
|
|
||
| for lidx, fml in enumerate(fml_layers): | ||
| logical_replica_count = torch.ones( | ||
| (test_config.num_layers, num_global_experts), | ||
| dtype=torch.int32, | ||
| device=device, | ||
| ) | ||
| fml.enable_eplb = True | ||
| fml.set_eplb_state( | ||
| lidx, | ||
| torch.zeros( | ||
| (test_config.num_layers, num_global_experts), | ||
| dtype=torch.int32, | ||
| device=device, | ||
| ), | ||
| logical_to_physical_map, | ||
| logical_replica_count, | ||
| ) | ||
|
|
||
| out_after_shuffle = [] | ||
| with set_forward_context( | ||
| {}, | ||
| num_tokens=test_config.num_tokens, | ||
| num_tokens_across_dp=torch.tensor( | ||
| [test_config.num_tokens] * world_size, device="cpu", dtype=torch.int | ||
| ), | ||
| vllm_config=vllm_config, | ||
| ): | ||
| for lidx, fml in enumerate(fml_layers): | ||
| out_after_shuffle.append( | ||
| fml(hidden_states[lidx].clone(), router_logits[lidx].clone()) | ||
| ) | ||
|
|
||
| for lidx in range(test_config.num_layers): | ||
| torch.testing.assert_close( | ||
| out_before_shuffle[lidx], out_after_shuffle[lidx], atol=1e-1, rtol=1e-1 | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("world_size", [2, 4]) | ||
| @pytest.mark.parametrize("num_layers", [8]) | ||
| @pytest.mark.parametrize("num_experts", [32]) | ||
| @pytest.mark.parametrize("hidden_size", [256]) | ||
| @pytest.mark.parametrize("intermediate_size", [256]) | ||
| @pytest.mark.parametrize("num_tokens", [256]) | ||
| @pytest.mark.parametrize("backend", ["latency", "throughput"]) | ||
| def test_eplb_fml( | ||
| world_size: int, | ||
| num_layers: int, | ||
| num_experts: int, | ||
| hidden_size: int, | ||
| intermediate_size: int, | ||
| num_tokens: int, | ||
| backend: str, | ||
| monkeypatch, | ||
| ): | ||
| monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1") | ||
| monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", backend) | ||
|
|
||
| if torch.cuda.device_count() < world_size: | ||
| pytest.skip(f"Need at least {world_size} GPUs to run the test") | ||
|
|
||
| num_local_experts = num_experts // world_size | ||
| num_topk = 4 | ||
|
|
||
| test_config = TestConfig( | ||
| num_layers=num_layers, | ||
| num_experts=num_experts, | ||
| num_local_experts=num_local_experts, | ||
| num_topk=num_topk, | ||
| hidden_size=hidden_size, | ||
| intermediate_size=intermediate_size, | ||
| num_tokens=num_tokens, | ||
| ) | ||
|
|
||
| distributed_run( | ||
| _test_eplb_fml, | ||
| world_size, | ||
| test_config, | ||
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -39,6 +39,7 @@ | |||||||||||||||||||||||||||||||
| from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( | ||||||||||||||||||||||||||||||||
| build_flashinfer_fp4_cutlass_moe_prepare_finalize, | ||||||||||||||||||||||||||||||||
| flashinfer_trtllm_fp4_moe, | ||||||||||||||||||||||||||||||||
| flashinfer_trtllm_fp4_routed_moe, | ||||||||||||||||||||||||||||||||
| prepare_static_weights_for_trtllm_fp4_moe, | ||||||||||||||||||||||||||||||||
| reorder_w1w3_to_w3w1, | ||||||||||||||||||||||||||||||||
| select_nvfp4_gemm_impl, | ||||||||||||||||||||||||||||||||
|
|
@@ -1342,7 +1343,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | |||||||||||||||||||||||||||||||
| "Accuracy may be affected." | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0] | ||||||||||||||||||||||||||||||||
| w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0].contiguous() | ||||||||||||||||||||||||||||||||
| layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # Common processing for input scales and alphas | ||||||||||||||||||||||||||||||||
|
|
@@ -1499,6 +1500,10 @@ def get_fused_moe_quant_config( | |||||||||||||||||||||||||||||||
| a2_gscale=layer.w2_input_scale_quant, | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| @property | ||||||||||||||||||||||||||||||||
| def supports_eplb(self) -> bool: | ||||||||||||||||||||||||||||||||
| return True | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def apply( | ||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||
| layer: FusedMoE, | ||||||||||||||||||||||||||||||||
|
|
@@ -1534,11 +1539,8 @@ def apply( | |||||||||||||||||||||||||||||||
| if ( | ||||||||||||||||||||||||||||||||
| self.allow_flashinfer | ||||||||||||||||||||||||||||||||
| and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM | ||||||||||||||||||||||||||||||||
| and not enable_eplb | ||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||
| if enable_eplb: | ||||||||||||||||||||||||||||||||
| raise NotImplementedError( | ||||||||||||||||||||||||||||||||
| "EPLB not supported for `ModelOptNvFp4FusedMoE` yet." | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
| return flashinfer_trtllm_fp4_moe( | ||||||||||||||||||||||||||||||||
| layer=layer, | ||||||||||||||||||||||||||||||||
| x=x, | ||||||||||||||||||||||||||||||||
|
|
@@ -1556,6 +1558,25 @@ def apply( | |||||||||||||||||||||||||||||||
| router_logits=router_logits, | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # EPLB path | ||||||||||||||||||||||||||||||||
| if ( | ||||||||||||||||||||||||||||||||
| self.allow_flashinfer | ||||||||||||||||||||||||||||||||
| and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM | ||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||
| # Pack top k ids and expert weights into a single int32 tensor, as | ||||||||||||||||||||||||||||||||
| # required by TRT-LLM | ||||||||||||||||||||||||||||||||
| packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to( | ||||||||||||||||||||||||||||||||
| torch.bfloat16 | ||||||||||||||||||||||||||||||||
| ).view(torch.int16) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| return flashinfer_trtllm_fp4_routed_moe( | ||||||||||||||||||||||||||||||||
| layer=layer, | ||||||||||||||||||||||||||||||||
| x=x, | ||||||||||||||||||||||||||||||||
| topk_ids=packed_tensor, | ||||||||||||||||||||||||||||||||
| top_k=top_k, | ||||||||||||||||||||||||||||||||
| global_num_experts=global_num_experts, | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
|
Comment on lines
+1572
to
+1578
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To fix the hardcoded
Suggested change
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| if self.use_marlin: | ||||||||||||||||||||||||||||||||
| return fused_marlin_moe( | ||||||||||||||||||||||||||||||||
| x, | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe hide this packing operation in the
flashinfer_trtllm_fp4_routed_moe. I.e., letflashinfer_trtllm_fp4_routed_moetaketopk_idsandtopk_weightsdirectly, making its interface closer to Marlin’s.Additionally, the packing will be removed in the flashinfer api in the near future so we can just pass
topk_idsandtopk_weightsto flashinfer.