Skip to content

Commit e642217

Browse files
author
Andrew Briand
committed
Enable EPLB for NVFP4
Signed-off-by: Andrew Briand <[email protected]>
1 parent cabc77c commit e642217

File tree

3 files changed

+374
-5
lines changed

3 files changed

+374
-5
lines changed
Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
# Test that the interaction between EPLB and FusedMoE Layer is okay for DP w/ NVFP4
5+
6+
from dataclasses import dataclass
7+
8+
import pytest
9+
import torch
10+
11+
from tests.kernels.moe.utils import make_test_quant_config
12+
from vllm.config import VllmConfig, set_current_vllm_config
13+
from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
14+
from vllm.distributed.parallel_state import (
15+
ensure_model_parallel_initialized,
16+
get_dp_group,
17+
)
18+
from vllm.forward_context import set_forward_context
19+
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
20+
from vllm.model_executor.layers.quantization.modelopt import (
21+
ModelOptNvFp4Config,
22+
ModelOptNvFp4FusedMoE,
23+
)
24+
25+
from .eplb_utils import distributed_run, set_env_vars_and_device
26+
27+
28+
@dataclass
29+
class TestConfig:
30+
num_layers: int
31+
num_experts: int
32+
num_local_experts: int
33+
num_topk: int
34+
hidden_size: int
35+
intermediate_size: int
36+
num_tokens: int
37+
38+
39+
def make_fused_moe_layer(
40+
rank: int,
41+
layer_idx: int,
42+
test_config: TestConfig,
43+
) -> FusedMoE:
44+
quant_config = None
45+
46+
device = torch.device(f"cuda:{rank}")
47+
48+
quant_config = ModelOptNvFp4Config(
49+
is_checkpoint_nvfp4_serialized=True,
50+
kv_cache_quant_algo=None,
51+
exclude_modules=[],
52+
)
53+
54+
fml = FusedMoE(
55+
num_experts=test_config.num_experts,
56+
top_k=test_config.num_topk,
57+
hidden_size=test_config.hidden_size,
58+
intermediate_size=test_config.intermediate_size,
59+
prefix=f"dummy_layer_{layer_idx}",
60+
activation="silu",
61+
is_act_and_mul=True,
62+
params_dtype=torch.bfloat16,
63+
quant_config=quant_config,
64+
)
65+
66+
nvfp4_fused_moe = ModelOptNvFp4FusedMoE(quant_config, fml)
67+
nvfp4_fused_moe.create_weights(
68+
fml,
69+
test_config.num_local_experts,
70+
test_config.hidden_size,
71+
test_config.intermediate_size,
72+
params_dtype=torch.uint8,
73+
global_num_experts=test_config.num_experts,
74+
)
75+
76+
fml = fml.to(device)
77+
w1_q, w2_q, quant_config = make_test_quant_config(
78+
test_config.num_local_experts,
79+
test_config.intermediate_size,
80+
test_config.hidden_size,
81+
in_dtype=torch.bfloat16,
82+
quant_dtype="nvfp4",
83+
block_shape=None,
84+
per_act_token_quant=False,
85+
)
86+
87+
fml.w13_weight.data = w1_q
88+
fml.w2_weight.data = w2_q
89+
90+
fml.w2_input_scale.data = torch.randn_like(fml.w2_input_scale.data) / 5
91+
fml.w13_input_scale.data = torch.randn_like(fml.w13_input_scale.data) / 5
92+
fml.w2_weight_scale_2.data = torch.randn_like(fml.w2_weight_scale_2.data) / 5
93+
fml.w13_weight_scale_2.data = torch.randn_like(fml.w13_weight_scale_2.data) / 5
94+
fml.w2_weight_scale.data = (
95+
torch.randn(fml.w2_weight_scale.data.shape, device=device) / 5
96+
).to(fml.w2_weight_scale.data.dtype)
97+
fml.w13_weight_scale.data = (
98+
torch.randn(fml.w13_weight_scale.data.shape, device=device) / 5
99+
).to(fml.w13_weight_scale.data.dtype)
100+
101+
nvfp4_fused_moe.process_weights_after_loading(fml)
102+
103+
fml.maybe_init_modular_kernel()
104+
105+
return fml
106+
107+
108+
def _test_eplb_fml(env, world_size: int, test_config: TestConfig):
109+
set_env_vars_and_device(env)
110+
111+
vllm_config = VllmConfig()
112+
vllm_config.parallel_config.data_parallel_size = world_size
113+
vllm_config.parallel_config.enable_expert_parallel = True
114+
115+
with set_current_vllm_config(vllm_config):
116+
ensure_model_parallel_initialized(
117+
tensor_model_parallel_size=1, pipeline_model_parallel_size=1
118+
)
119+
120+
ep_group = get_dp_group().cpu_group
121+
ep_rank = torch.distributed.get_rank()
122+
123+
device = torch.device(f"cuda:{ep_rank}")
124+
125+
fml_layers = [
126+
make_fused_moe_layer(ep_rank, layer_idx, test_config).to(device)
127+
for layer_idx in range(test_config.num_layers)
128+
]
129+
rank_expert_weights = [fml.get_expert_weights() for fml in fml_layers]
130+
131+
hidden_states = []
132+
router_logits = []
133+
for layer_idx in range(test_config.num_layers):
134+
hidden_states.append(
135+
torch.randn(
136+
(test_config.num_tokens, test_config.hidden_size),
137+
dtype=torch.bfloat16,
138+
device=device,
139+
)
140+
)
141+
router_logits.append(
142+
torch.randn(
143+
(test_config.num_tokens, test_config.num_experts),
144+
dtype=torch.bfloat16,
145+
device=device,
146+
)
147+
)
148+
149+
out_before_shuffle = []
150+
with set_forward_context(
151+
{},
152+
num_tokens=test_config.num_tokens,
153+
num_tokens_across_dp=torch.tensor(
154+
[test_config.num_tokens] * world_size, device="cpu", dtype=torch.int
155+
),
156+
vllm_config=vllm_config,
157+
):
158+
for lidx, fml in enumerate(fml_layers):
159+
out_before_shuffle.append(
160+
fml(hidden_states[lidx].clone(), router_logits[lidx].clone())
161+
)
162+
163+
indices = torch.zeros(
164+
test_config.num_layers, test_config.num_experts, dtype=torch.long
165+
)
166+
for lidx in range(test_config.num_layers):
167+
indices[lidx] = torch.Tensor(range(test_config.num_experts))
168+
169+
shuffled_indices = torch.zeros_like(indices)
170+
for lidx in range(test_config.num_layers):
171+
shuffled_indices[lidx] = torch.randperm(test_config.num_experts)
172+
173+
rearrange_expert_weights_inplace(
174+
indices,
175+
shuffled_indices,
176+
rank_expert_weights,
177+
ep_group,
178+
is_profile=False,
179+
)
180+
181+
num_global_experts = test_config.num_experts
182+
183+
logical_to_physical_map_list = []
184+
for lidx, fml in enumerate(fml_layers):
185+
physical_to_logical_map = shuffled_indices[lidx].to(device)
186+
logical_to_physical_map = torch.empty(
187+
(num_global_experts,), dtype=torch.int32, device=device
188+
)
189+
logical_to_physical_map[physical_to_logical_map] = torch.arange(
190+
0, num_global_experts, dtype=torch.int32, device=device
191+
)
192+
logical_to_physical_map_list.append(
193+
logical_to_physical_map.reshape(num_global_experts, 1)
194+
)
195+
196+
logical_to_physical_map = torch.stack(logical_to_physical_map_list)
197+
198+
for lidx, fml in enumerate(fml_layers):
199+
logical_replica_count = torch.ones(
200+
(test_config.num_layers, num_global_experts),
201+
dtype=torch.int32,
202+
device=device,
203+
)
204+
fml.enable_eplb = True
205+
fml.set_eplb_state(
206+
lidx,
207+
torch.zeros(
208+
(test_config.num_layers, num_global_experts),
209+
dtype=torch.int32,
210+
device=device,
211+
),
212+
logical_to_physical_map,
213+
logical_replica_count,
214+
)
215+
216+
out_after_shuffle = []
217+
with set_forward_context(
218+
{},
219+
num_tokens=test_config.num_tokens,
220+
num_tokens_across_dp=torch.tensor(
221+
[test_config.num_tokens] * world_size, device="cpu", dtype=torch.int
222+
),
223+
vllm_config=vllm_config,
224+
):
225+
for lidx, fml in enumerate(fml_layers):
226+
out_after_shuffle.append(
227+
fml(hidden_states[lidx].clone(), router_logits[lidx].clone())
228+
)
229+
230+
for lidx in range(test_config.num_layers):
231+
torch.testing.assert_close(
232+
out_before_shuffle[lidx], out_after_shuffle[lidx], atol=1e-1, rtol=1e-1
233+
)
234+
235+
236+
@pytest.mark.parametrize("world_size", [2, 4])
237+
@pytest.mark.parametrize("num_layers", [8])
238+
@pytest.mark.parametrize("num_experts", [32])
239+
@pytest.mark.parametrize("hidden_size", [256])
240+
@pytest.mark.parametrize("intermediate_size", [256])
241+
@pytest.mark.parametrize("num_tokens", [256])
242+
@pytest.mark.parametrize("backend", ["latency", "throughput"])
243+
def test_eplb_fml(
244+
world_size: int,
245+
num_layers: int,
246+
num_experts: int,
247+
hidden_size: int,
248+
intermediate_size: int,
249+
num_tokens: int,
250+
backend: str,
251+
monkeypatch,
252+
):
253+
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1")
254+
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", backend)
255+
256+
if torch.cuda.device_count() < world_size:
257+
pytest.skip(f"Need at least {world_size} GPUs to run the test")
258+
259+
num_local_experts = num_experts // world_size
260+
num_topk = 4
261+
262+
test_config = TestConfig(
263+
num_layers=num_layers,
264+
num_experts=num_experts,
265+
num_local_experts=num_local_experts,
266+
num_topk=num_topk,
267+
hidden_size=hidden_size,
268+
intermediate_size=intermediate_size,
269+
num_tokens=num_tokens,
270+
)
271+
272+
distributed_run(
273+
_test_eplb_fml,
274+
world_size,
275+
test_config,
276+
)

vllm/model_executor/layers/quantization/modelopt.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
4040
build_flashinfer_fp4_cutlass_moe_prepare_finalize,
4141
flashinfer_trtllm_fp4_moe,
42+
flashinfer_trtllm_fp4_routed_moe,
4243
prepare_static_weights_for_trtllm_fp4_moe,
4344
reorder_w1w3_to_w3w1,
4445
select_nvfp4_gemm_impl,
@@ -1342,7 +1343,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
13421343
"Accuracy may be affected."
13431344
)
13441345

1345-
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
1346+
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0].contiguous()
13461347
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
13471348

13481349
# Common processing for input scales and alphas
@@ -1499,6 +1500,10 @@ def get_fused_moe_quant_config(
14991500
a2_gscale=layer.w2_input_scale_quant,
15001501
)
15011502

1503+
@property
1504+
def supports_eplb(self) -> bool:
1505+
return True
1506+
15021507
def apply(
15031508
self,
15041509
layer: FusedMoE,
@@ -1534,11 +1539,8 @@ def apply(
15341539
if (
15351540
self.allow_flashinfer
15361541
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
1542+
and not enable_eplb
15371543
):
1538-
if enable_eplb:
1539-
raise NotImplementedError(
1540-
"EPLB not supported for `ModelOptNvFp4FusedMoE` yet."
1541-
)
15421544
return flashinfer_trtllm_fp4_moe(
15431545
layer=layer,
15441546
x=x,
@@ -1556,6 +1558,25 @@ def apply(
15561558
router_logits=router_logits,
15571559
)
15581560

1561+
# EPLB path
1562+
if (
1563+
self.allow_flashinfer
1564+
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
1565+
):
1566+
# Pack top k ids and expert weights into a single int32 tensor, as
1567+
# required by TRT-LLM
1568+
packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
1569+
torch.bfloat16
1570+
).view(torch.int16)
1571+
1572+
return flashinfer_trtllm_fp4_routed_moe(
1573+
layer=layer,
1574+
x=x,
1575+
topk_ids=packed_tensor,
1576+
top_k=top_k,
1577+
global_num_experts=global_num_experts,
1578+
)
1579+
15591580
if self.use_marlin:
15601581
return fused_marlin_moe(
15611582
x,

0 commit comments

Comments
 (0)