diff --git a/tests/distributed/test_eplb_fused_moe_layer_dep_nvfp4.py b/tests/distributed/test_eplb_fused_moe_layer_dep_nvfp4.py new file mode 100644 index 000000000000..951b692e1eda --- /dev/null +++ b/tests/distributed/test_eplb_fused_moe_layer_dep_nvfp4.py @@ -0,0 +1,276 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Test that the interaction between EPLB and FusedMoE Layer is okay for DP w/ NVFP4 + +from dataclasses import dataclass + +import pytest +import torch + +from tests.kernels.moe.utils import make_test_quant_config +from vllm.config import VllmConfig, set_current_vllm_config +from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace +from vllm.distributed.parallel_state import ( + ensure_model_parallel_initialized, + get_dp_group, +) +from vllm.forward_context import set_forward_context +from vllm.model_executor.layers.fused_moe.layer import FusedMoE +from vllm.model_executor.layers.quantization.modelopt import ( + ModelOptNvFp4Config, + ModelOptNvFp4FusedMoE, +) + +from .eplb_utils import distributed_run, set_env_vars_and_device + + +@dataclass +class TestConfig: + num_layers: int + num_experts: int + num_local_experts: int + num_topk: int + hidden_size: int + intermediate_size: int + num_tokens: int + + +def make_fused_moe_layer( + rank: int, + layer_idx: int, + test_config: TestConfig, +) -> FusedMoE: + quant_config = None + + device = torch.device(f"cuda:{rank}") + + quant_config = ModelOptNvFp4Config( + is_checkpoint_nvfp4_serialized=True, + kv_cache_quant_algo=None, + exclude_modules=[], + ) + + fml = FusedMoE( + num_experts=test_config.num_experts, + top_k=test_config.num_topk, + hidden_size=test_config.hidden_size, + intermediate_size=test_config.intermediate_size, + prefix=f"dummy_layer_{layer_idx}", + activation="silu", + is_act_and_mul=True, + params_dtype=torch.bfloat16, + quant_config=quant_config, + ) + + nvfp4_fused_moe = ModelOptNvFp4FusedMoE(quant_config, fml) + nvfp4_fused_moe.create_weights( + fml, + test_config.num_local_experts, + test_config.hidden_size, + test_config.intermediate_size, + params_dtype=torch.uint8, + global_num_experts=test_config.num_experts, + ) + + fml = fml.to(device) + w1_q, w2_q, quant_config = make_test_quant_config( + test_config.num_local_experts, + test_config.intermediate_size, + test_config.hidden_size, + in_dtype=torch.bfloat16, + quant_dtype="nvfp4", + block_shape=None, + per_act_token_quant=False, + ) + + fml.w13_weight.data = w1_q + fml.w2_weight.data = w2_q + + fml.w2_input_scale.data = torch.randn_like(fml.w2_input_scale.data) / 5 + fml.w13_input_scale.data = torch.randn_like(fml.w13_input_scale.data) / 5 + fml.w2_weight_scale_2.data = torch.randn_like(fml.w2_weight_scale_2.data) / 5 + fml.w13_weight_scale_2.data = torch.randn_like(fml.w13_weight_scale_2.data) / 5 + fml.w2_weight_scale.data = ( + torch.randn(fml.w2_weight_scale.data.shape, device=device) / 5 + ).to(fml.w2_weight_scale.data.dtype) + fml.w13_weight_scale.data = ( + torch.randn(fml.w13_weight_scale.data.shape, device=device) / 5 + ).to(fml.w13_weight_scale.data.dtype) + + nvfp4_fused_moe.process_weights_after_loading(fml) + + fml.maybe_init_modular_kernel() + + return fml + + +def _test_eplb_fml(env, world_size: int, test_config: TestConfig): + set_env_vars_and_device(env) + + vllm_config = VllmConfig() + vllm_config.parallel_config.data_parallel_size = world_size + vllm_config.parallel_config.enable_expert_parallel = True + + with set_current_vllm_config(vllm_config): + ensure_model_parallel_initialized( + tensor_model_parallel_size=1, pipeline_model_parallel_size=1 + ) + + ep_group = get_dp_group().cpu_group + ep_rank = torch.distributed.get_rank() + + device = torch.device(f"cuda:{ep_rank}") + + fml_layers = [ + make_fused_moe_layer(ep_rank, layer_idx, test_config).to(device) + for layer_idx in range(test_config.num_layers) + ] + rank_expert_weights = [fml.get_expert_weights() for fml in fml_layers] + + hidden_states = [] + router_logits = [] + for layer_idx in range(test_config.num_layers): + hidden_states.append( + torch.randn( + (test_config.num_tokens, test_config.hidden_size), + dtype=torch.bfloat16, + device=device, + ) + ) + router_logits.append( + torch.randn( + (test_config.num_tokens, test_config.num_experts), + dtype=torch.bfloat16, + device=device, + ) + ) + + out_before_shuffle = [] + with set_forward_context( + {}, + num_tokens=test_config.num_tokens, + num_tokens_across_dp=torch.tensor( + [test_config.num_tokens] * world_size, device="cpu", dtype=torch.int + ), + vllm_config=vllm_config, + ): + for lidx, fml in enumerate(fml_layers): + out_before_shuffle.append( + fml(hidden_states[lidx].clone(), router_logits[lidx].clone()) + ) + + indices = torch.zeros( + test_config.num_layers, test_config.num_experts, dtype=torch.long + ) + for lidx in range(test_config.num_layers): + indices[lidx] = torch.Tensor(range(test_config.num_experts)) + + shuffled_indices = torch.zeros_like(indices) + for lidx in range(test_config.num_layers): + shuffled_indices[lidx] = torch.randperm(test_config.num_experts) + + rearrange_expert_weights_inplace( + indices, + shuffled_indices, + rank_expert_weights, + ep_group, + is_profile=False, + ) + + num_global_experts = test_config.num_experts + + logical_to_physical_map_list = [] + for lidx, fml in enumerate(fml_layers): + physical_to_logical_map = shuffled_indices[lidx].to(device) + logical_to_physical_map = torch.empty( + (num_global_experts,), dtype=torch.int32, device=device + ) + logical_to_physical_map[physical_to_logical_map] = torch.arange( + 0, num_global_experts, dtype=torch.int32, device=device + ) + logical_to_physical_map_list.append( + logical_to_physical_map.reshape(num_global_experts, 1) + ) + + logical_to_physical_map = torch.stack(logical_to_physical_map_list) + + for lidx, fml in enumerate(fml_layers): + logical_replica_count = torch.ones( + (test_config.num_layers, num_global_experts), + dtype=torch.int32, + device=device, + ) + fml.enable_eplb = True + fml.set_eplb_state( + lidx, + torch.zeros( + (test_config.num_layers, num_global_experts), + dtype=torch.int32, + device=device, + ), + logical_to_physical_map, + logical_replica_count, + ) + + out_after_shuffle = [] + with set_forward_context( + {}, + num_tokens=test_config.num_tokens, + num_tokens_across_dp=torch.tensor( + [test_config.num_tokens] * world_size, device="cpu", dtype=torch.int + ), + vllm_config=vllm_config, + ): + for lidx, fml in enumerate(fml_layers): + out_after_shuffle.append( + fml(hidden_states[lidx].clone(), router_logits[lidx].clone()) + ) + + for lidx in range(test_config.num_layers): + torch.testing.assert_close( + out_before_shuffle[lidx], out_after_shuffle[lidx], atol=1e-1, rtol=1e-1 + ) + + +@pytest.mark.parametrize("world_size", [2, 4]) +@pytest.mark.parametrize("num_layers", [8]) +@pytest.mark.parametrize("num_experts", [32]) +@pytest.mark.parametrize("hidden_size", [256]) +@pytest.mark.parametrize("intermediate_size", [256]) +@pytest.mark.parametrize("num_tokens", [256]) +@pytest.mark.parametrize("backend", ["latency", "throughput"]) +def test_eplb_fml( + world_size: int, + num_layers: int, + num_experts: int, + hidden_size: int, + intermediate_size: int, + num_tokens: int, + backend: str, + monkeypatch, +): + monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1") + monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", backend) + + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need at least {world_size} GPUs to run the test") + + num_local_experts = num_experts // world_size + num_topk = 4 + + test_config = TestConfig( + num_layers=num_layers, + num_experts=num_experts, + num_local_experts=num_local_experts, + num_topk=num_topk, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_tokens=num_tokens, + ) + + distributed_run( + _test_eplb_fml, + world_size, + test_config, + ) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 034e97a713cd..d311fb1a7244 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -39,6 +39,7 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( build_flashinfer_fp4_cutlass_moe_prepare_finalize, flashinfer_trtllm_fp4_moe, + flashinfer_trtllm_fp4_routed_moe, prepare_static_weights_for_trtllm_fp4_moe, reorder_w1w3_to_w3w1, select_nvfp4_gemm_impl, @@ -1342,7 +1343,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: "Accuracy may be affected." ) - w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0] + w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0].contiguous() layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False) # Common processing for input scales and alphas @@ -1499,6 +1500,10 @@ def get_fused_moe_quant_config( a2_gscale=layer.w2_input_scale_quant, ) + @property + def supports_eplb(self) -> bool: + return True + def apply( self, layer: FusedMoE, @@ -1534,11 +1539,8 @@ def apply( if ( self.allow_flashinfer and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM + and not enable_eplb ): - if enable_eplb: - raise NotImplementedError( - "EPLB not supported for `ModelOptNvFp4FusedMoE` yet." - ) return flashinfer_trtllm_fp4_moe( layer=layer, x=x, @@ -1556,6 +1558,25 @@ def apply( router_logits=router_logits, ) + # EPLB path + if ( + self.allow_flashinfer + and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM + ): + # Pack top k ids and expert weights into a single int32 tensor, as + # required by TRT-LLM + packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to( + torch.bfloat16 + ).view(torch.int16) + + return flashinfer_trtllm_fp4_routed_moe( + layer=layer, + x=x, + topk_ids=packed_tensor, + top_k=top_k, + global_num_experts=global_num_experts, + ) + if self.use_marlin: return fused_marlin_moe( x, diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index eda40657b1e3..b39bc4020d73 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -331,3 +331,75 @@ def flashinfer_trtllm_fp4_moe( )[0] return out + + +def flashinfer_trtllm_fp4_routed_moe( + layer: torch.nn.Module, + x: torch.Tensor, + topk_ids: torch.Tensor, # Packed + top_k: int, + global_num_experts: int, +) -> torch.Tensor: + """ + Apply FlashInfer TensorRT-LLM FP4 MoE kernel. Uses packed + input top k expert indices and scores rather than computing + top k expert indices from scores. + + Args: + layer: The MoE layer with weights and scales + x: Input tensor + topk_ids: Ids of selected experts + top_k: Number of experts to select per token + global_num_experts: Total number of experts across all ranks + + Returns: + Output tensor from the MoE layer + """ + import flashinfer + + # Quantize input to FP4 + a1_gscale = layer.w13_input_scale_quant + (hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize( + x, + a1_gscale, + is_sf_swizzled_layout=False, + ) + + # Call TRT-LLM FP4 block-scale MoE kernel + out = flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe( + topk_ids=topk_ids, + routing_bias=None, + hidden_states=hidden_states_fp4, + hidden_states_scale=hidden_states_scale_linear_fp4.view( + torch.float8_e4m3fn + ).flatten(), + gemm1_weights=layer.gemm1_weights_fp4_shuffled.data, + gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view( + torch.float8_e4m3fn + ), + gemm1_bias=None, + gemm1_alpha=None, + gemm1_beta=None, + gemm1_clamp_limit=None, + gemm2_weights=layer.gemm2_weights_fp4_shuffled.data, + gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view( + torch.float8_e4m3fn + ), + gemm2_bias=None, + output1_scale_scalar=layer.g1_scale_c.data, + output1_scale_gate_scalar=layer.g1_alphas.data, + output2_scale_scalar=layer.g2_alphas.data, + num_experts=global_num_experts, + top_k=top_k, + n_group=0, + topk_group=0, + intermediate_size=layer.intermediate_size_per_partition, + local_expert_offset=layer.ep_rank * layer.local_num_experts, + local_num_experts=layer.local_num_experts, + routed_scaling_factor=None, + tile_tokens_dim=None, + routing_method_type=1, + do_finalize=True, + )[0] + + return out