|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | + |
| 4 | +# Test that the interaction between EPLB and FusedMoE Layer is okay for DP w/ NVFP4 |
| 5 | + |
| 6 | +from dataclasses import dataclass |
| 7 | + |
| 8 | +import pytest |
| 9 | +import torch |
| 10 | + |
| 11 | +from tests.kernels.moe.utils import make_test_quant_config |
| 12 | +from vllm.config import VllmConfig, set_current_vllm_config |
| 13 | +from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace |
| 14 | +from vllm.distributed.parallel_state import ( |
| 15 | + ensure_model_parallel_initialized, |
| 16 | + get_dp_group, |
| 17 | +) |
| 18 | +from vllm.forward_context import set_forward_context |
| 19 | +from vllm.model_executor.layers.fused_moe.layer import FusedMoE |
| 20 | +from vllm.model_executor.layers.quantization.modelopt import ( |
| 21 | + ModelOptNvFp4Config, |
| 22 | + ModelOptNvFp4FusedMoE, |
| 23 | +) |
| 24 | + |
| 25 | +from .eplb_utils import distributed_run, set_env_vars_and_device |
| 26 | + |
| 27 | + |
| 28 | +@dataclass |
| 29 | +class TestConfig: |
| 30 | + num_layers: int |
| 31 | + num_experts: int |
| 32 | + num_local_experts: int |
| 33 | + num_topk: int |
| 34 | + hidden_size: int |
| 35 | + intermediate_size: int |
| 36 | + num_tokens: int |
| 37 | + |
| 38 | + |
| 39 | +def make_fused_moe_layer( |
| 40 | + rank: int, |
| 41 | + layer_idx: int, |
| 42 | + test_config: TestConfig, |
| 43 | +) -> FusedMoE: |
| 44 | + quant_config = None |
| 45 | + |
| 46 | + device = torch.device(f"cuda:{rank}") |
| 47 | + |
| 48 | + quant_config = ModelOptNvFp4Config( |
| 49 | + is_checkpoint_nvfp4_serialized=True, |
| 50 | + kv_cache_quant_algo=None, |
| 51 | + exclude_modules=[], |
| 52 | + ) |
| 53 | + |
| 54 | + fml = FusedMoE( |
| 55 | + num_experts=test_config.num_experts, |
| 56 | + top_k=test_config.num_topk, |
| 57 | + hidden_size=test_config.hidden_size, |
| 58 | + intermediate_size=test_config.intermediate_size, |
| 59 | + prefix=f"dummy_layer_{layer_idx}", |
| 60 | + activation="silu", |
| 61 | + is_act_and_mul=True, |
| 62 | + params_dtype=torch.bfloat16, |
| 63 | + quant_config=quant_config, |
| 64 | + ) |
| 65 | + |
| 66 | + nvfp4_fused_moe = ModelOptNvFp4FusedMoE(quant_config, fml) |
| 67 | + nvfp4_fused_moe.create_weights( |
| 68 | + fml, |
| 69 | + test_config.num_local_experts, |
| 70 | + test_config.hidden_size, |
| 71 | + test_config.intermediate_size, |
| 72 | + params_dtype=torch.uint8, |
| 73 | + global_num_experts=test_config.num_experts, |
| 74 | + ) |
| 75 | + |
| 76 | + fml = fml.to(device) |
| 77 | + w1_q, w2_q, quant_config = make_test_quant_config( |
| 78 | + test_config.num_local_experts, |
| 79 | + test_config.intermediate_size, |
| 80 | + test_config.hidden_size, |
| 81 | + in_dtype=torch.bfloat16, |
| 82 | + quant_dtype="nvfp4", |
| 83 | + block_shape=None, |
| 84 | + per_act_token_quant=False, |
| 85 | + ) |
| 86 | + |
| 87 | + fml.w13_weight.data = w1_q |
| 88 | + fml.w2_weight.data = w2_q |
| 89 | + |
| 90 | + fml.w2_input_scale.data = torch.randn_like(fml.w2_input_scale.data) / 5 |
| 91 | + fml.w13_input_scale.data = torch.randn_like(fml.w13_input_scale.data) / 5 |
| 92 | + fml.w2_weight_scale_2.data = torch.randn_like(fml.w2_weight_scale_2.data) / 5 |
| 93 | + fml.w13_weight_scale_2.data = torch.randn_like(fml.w13_weight_scale_2.data) / 5 |
| 94 | + fml.w2_weight_scale.data = ( |
| 95 | + torch.randn(fml.w2_weight_scale.data.shape, device=device) / 5 |
| 96 | + ).to(fml.w2_weight_scale.data.dtype) |
| 97 | + fml.w13_weight_scale.data = ( |
| 98 | + torch.randn(fml.w13_weight_scale.data.shape, device=device) / 5 |
| 99 | + ).to(fml.w13_weight_scale.data.dtype) |
| 100 | + |
| 101 | + nvfp4_fused_moe.process_weights_after_loading(fml) |
| 102 | + |
| 103 | + fml.maybe_init_modular_kernel() |
| 104 | + |
| 105 | + return fml |
| 106 | + |
| 107 | + |
| 108 | +def _test_eplb_fml(env, world_size: int, test_config: TestConfig): |
| 109 | + set_env_vars_and_device(env) |
| 110 | + |
| 111 | + vllm_config = VllmConfig() |
| 112 | + vllm_config.parallel_config.data_parallel_size = world_size |
| 113 | + vllm_config.parallel_config.enable_expert_parallel = True |
| 114 | + |
| 115 | + with set_current_vllm_config(vllm_config): |
| 116 | + ensure_model_parallel_initialized( |
| 117 | + tensor_model_parallel_size=1, pipeline_model_parallel_size=1 |
| 118 | + ) |
| 119 | + |
| 120 | + ep_group = get_dp_group().cpu_group |
| 121 | + ep_rank = torch.distributed.get_rank() |
| 122 | + |
| 123 | + device = torch.device(f"cuda:{ep_rank}") |
| 124 | + |
| 125 | + fml_layers = [ |
| 126 | + make_fused_moe_layer(ep_rank, layer_idx, test_config).to(device) |
| 127 | + for layer_idx in range(test_config.num_layers) |
| 128 | + ] |
| 129 | + rank_expert_weights = [fml.get_expert_weights() for fml in fml_layers] |
| 130 | + |
| 131 | + hidden_states = [] |
| 132 | + router_logits = [] |
| 133 | + for layer_idx in range(test_config.num_layers): |
| 134 | + hidden_states.append( |
| 135 | + torch.randn( |
| 136 | + (test_config.num_tokens, test_config.hidden_size), |
| 137 | + dtype=torch.bfloat16, |
| 138 | + device=device, |
| 139 | + ) |
| 140 | + ) |
| 141 | + router_logits.append( |
| 142 | + torch.randn( |
| 143 | + (test_config.num_tokens, test_config.num_experts), |
| 144 | + dtype=torch.bfloat16, |
| 145 | + device=device, |
| 146 | + ) |
| 147 | + ) |
| 148 | + |
| 149 | + out_before_shuffle = [] |
| 150 | + with set_forward_context( |
| 151 | + {}, |
| 152 | + num_tokens=test_config.num_tokens, |
| 153 | + num_tokens_across_dp=torch.tensor( |
| 154 | + [test_config.num_tokens] * world_size, device="cpu", dtype=torch.int |
| 155 | + ), |
| 156 | + vllm_config=vllm_config, |
| 157 | + ): |
| 158 | + for lidx, fml in enumerate(fml_layers): |
| 159 | + out_before_shuffle.append( |
| 160 | + fml(hidden_states[lidx].clone(), router_logits[lidx].clone()) |
| 161 | + ) |
| 162 | + |
| 163 | + indices = torch.zeros( |
| 164 | + test_config.num_layers, test_config.num_experts, dtype=torch.long |
| 165 | + ) |
| 166 | + for lidx in range(test_config.num_layers): |
| 167 | + indices[lidx] = torch.Tensor(range(test_config.num_experts)) |
| 168 | + |
| 169 | + shuffled_indices = torch.zeros_like(indices) |
| 170 | + for lidx in range(test_config.num_layers): |
| 171 | + shuffled_indices[lidx] = torch.randperm(test_config.num_experts) |
| 172 | + |
| 173 | + rearrange_expert_weights_inplace( |
| 174 | + indices, |
| 175 | + shuffled_indices, |
| 176 | + rank_expert_weights, |
| 177 | + ep_group, |
| 178 | + is_profile=False, |
| 179 | + ) |
| 180 | + |
| 181 | + num_global_experts = test_config.num_experts |
| 182 | + |
| 183 | + logical_to_physical_map_list = [] |
| 184 | + for lidx, fml in enumerate(fml_layers): |
| 185 | + physical_to_logical_map = shuffled_indices[lidx].to(device) |
| 186 | + logical_to_physical_map = torch.empty( |
| 187 | + (num_global_experts,), dtype=torch.int32, device=device |
| 188 | + ) |
| 189 | + logical_to_physical_map[physical_to_logical_map] = torch.arange( |
| 190 | + 0, num_global_experts, dtype=torch.int32, device=device |
| 191 | + ) |
| 192 | + logical_to_physical_map_list.append( |
| 193 | + logical_to_physical_map.reshape(num_global_experts, 1) |
| 194 | + ) |
| 195 | + |
| 196 | + logical_to_physical_map = torch.stack(logical_to_physical_map_list) |
| 197 | + |
| 198 | + for lidx, fml in enumerate(fml_layers): |
| 199 | + logical_replica_count = torch.ones( |
| 200 | + (test_config.num_layers, num_global_experts), |
| 201 | + dtype=torch.int32, |
| 202 | + device=device, |
| 203 | + ) |
| 204 | + fml.enable_eplb = True |
| 205 | + fml.set_eplb_state( |
| 206 | + lidx, |
| 207 | + torch.zeros( |
| 208 | + (test_config.num_layers, num_global_experts), |
| 209 | + dtype=torch.int32, |
| 210 | + device=device, |
| 211 | + ), |
| 212 | + logical_to_physical_map, |
| 213 | + logical_replica_count, |
| 214 | + ) |
| 215 | + |
| 216 | + out_after_shuffle = [] |
| 217 | + with set_forward_context( |
| 218 | + {}, |
| 219 | + num_tokens=test_config.num_tokens, |
| 220 | + num_tokens_across_dp=torch.tensor( |
| 221 | + [test_config.num_tokens] * world_size, device="cpu", dtype=torch.int |
| 222 | + ), |
| 223 | + vllm_config=vllm_config, |
| 224 | + ): |
| 225 | + for lidx, fml in enumerate(fml_layers): |
| 226 | + out_after_shuffle.append( |
| 227 | + fml(hidden_states[lidx].clone(), router_logits[lidx].clone()) |
| 228 | + ) |
| 229 | + |
| 230 | + for lidx in range(test_config.num_layers): |
| 231 | + torch.testing.assert_close( |
| 232 | + out_before_shuffle[lidx], out_after_shuffle[lidx], atol=1e-1, rtol=1e-1 |
| 233 | + ) |
| 234 | + |
| 235 | + |
| 236 | +@pytest.mark.parametrize("world_size", [2, 4]) |
| 237 | +@pytest.mark.parametrize("num_layers", [8]) |
| 238 | +@pytest.mark.parametrize("num_experts", [32]) |
| 239 | +@pytest.mark.parametrize("hidden_size", [256]) |
| 240 | +@pytest.mark.parametrize("intermediate_size", [256]) |
| 241 | +@pytest.mark.parametrize("num_tokens", [256]) |
| 242 | +@pytest.mark.parametrize("backend", ["latency", "throughput"]) |
| 243 | +def test_eplb_fml( |
| 244 | + world_size: int, |
| 245 | + num_layers: int, |
| 246 | + num_experts: int, |
| 247 | + hidden_size: int, |
| 248 | + intermediate_size: int, |
| 249 | + num_tokens: int, |
| 250 | + backend: str, |
| 251 | + monkeypatch, |
| 252 | +): |
| 253 | + monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1") |
| 254 | + monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", backend) |
| 255 | + |
| 256 | + if torch.cuda.device_count() < world_size: |
| 257 | + pytest.skip(f"Need at least {world_size} GPUs to run the test") |
| 258 | + |
| 259 | + num_local_experts = num_experts // world_size |
| 260 | + num_topk = 4 |
| 261 | + |
| 262 | + test_config = TestConfig( |
| 263 | + num_layers=num_layers, |
| 264 | + num_experts=num_experts, |
| 265 | + num_local_experts=num_local_experts, |
| 266 | + num_topk=num_topk, |
| 267 | + hidden_size=hidden_size, |
| 268 | + intermediate_size=intermediate_size, |
| 269 | + num_tokens=num_tokens, |
| 270 | + ) |
| 271 | + |
| 272 | + distributed_run( |
| 273 | + _test_eplb_fml, |
| 274 | + world_size, |
| 275 | + test_config, |
| 276 | + ) |
0 commit comments