Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
276 changes: 276 additions & 0 deletions tests/distributed/test_eplb_fused_moe_layer_dep_nvfp4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

# Test that the interaction between EPLB and FusedMoE Layer is okay for DP w/ NVFP4

from dataclasses import dataclass

import pytest
import torch

from tests.kernels.moe.utils import make_test_quant_config
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
from vllm.distributed.parallel_state import (
ensure_model_parallel_initialized,
get_dp_group,
)
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.quantization.modelopt import (
ModelOptNvFp4Config,
ModelOptNvFp4FusedMoE,
)

from .eplb_utils import distributed_run, set_env_vars_and_device


@dataclass
class TestConfig:
num_layers: int
num_experts: int
num_local_experts: int
num_topk: int
hidden_size: int
intermediate_size: int
num_tokens: int


def make_fused_moe_layer(
rank: int,
layer_idx: int,
test_config: TestConfig,
) -> FusedMoE:
quant_config = None

device = torch.device(f"cuda:{rank}")

quant_config = ModelOptNvFp4Config(
is_checkpoint_nvfp4_serialized=True,
kv_cache_quant_algo=None,
exclude_modules=[],
)

fml = FusedMoE(
num_experts=test_config.num_experts,
top_k=test_config.num_topk,
hidden_size=test_config.hidden_size,
intermediate_size=test_config.intermediate_size,
prefix=f"dummy_layer_{layer_idx}",
activation="silu",
is_act_and_mul=True,
params_dtype=torch.bfloat16,
quant_config=quant_config,
)

nvfp4_fused_moe = ModelOptNvFp4FusedMoE(quant_config, fml)
nvfp4_fused_moe.create_weights(
fml,
test_config.num_local_experts,
test_config.hidden_size,
test_config.intermediate_size,
params_dtype=torch.uint8,
global_num_experts=test_config.num_experts,
)

fml = fml.to(device)
w1_q, w2_q, quant_config = make_test_quant_config(
test_config.num_local_experts,
test_config.intermediate_size,
test_config.hidden_size,
in_dtype=torch.bfloat16,
quant_dtype="nvfp4",
block_shape=None,
per_act_token_quant=False,
)

fml.w13_weight.data = w1_q
fml.w2_weight.data = w2_q

fml.w2_input_scale.data = torch.randn_like(fml.w2_input_scale.data) / 5
fml.w13_input_scale.data = torch.randn_like(fml.w13_input_scale.data) / 5
fml.w2_weight_scale_2.data = torch.randn_like(fml.w2_weight_scale_2.data) / 5
fml.w13_weight_scale_2.data = torch.randn_like(fml.w13_weight_scale_2.data) / 5
fml.w2_weight_scale.data = (
torch.randn(fml.w2_weight_scale.data.shape, device=device) / 5
).to(fml.w2_weight_scale.data.dtype)
fml.w13_weight_scale.data = (
torch.randn(fml.w13_weight_scale.data.shape, device=device) / 5
).to(fml.w13_weight_scale.data.dtype)

nvfp4_fused_moe.process_weights_after_loading(fml)

fml.maybe_init_modular_kernel()

return fml


def _test_eplb_fml(env, world_size: int, test_config: TestConfig):
set_env_vars_and_device(env)

vllm_config = VllmConfig()
vllm_config.parallel_config.data_parallel_size = world_size
vllm_config.parallel_config.enable_expert_parallel = True

with set_current_vllm_config(vllm_config):
ensure_model_parallel_initialized(
tensor_model_parallel_size=1, pipeline_model_parallel_size=1
)

ep_group = get_dp_group().cpu_group
ep_rank = torch.distributed.get_rank()

device = torch.device(f"cuda:{ep_rank}")

fml_layers = [
make_fused_moe_layer(ep_rank, layer_idx, test_config).to(device)
for layer_idx in range(test_config.num_layers)
]
rank_expert_weights = [fml.get_expert_weights() for fml in fml_layers]

hidden_states = []
router_logits = []
for layer_idx in range(test_config.num_layers):
hidden_states.append(
torch.randn(
(test_config.num_tokens, test_config.hidden_size),
dtype=torch.bfloat16,
device=device,
)
)
router_logits.append(
torch.randn(
(test_config.num_tokens, test_config.num_experts),
dtype=torch.bfloat16,
device=device,
)
)

out_before_shuffle = []
with set_forward_context(
{},
num_tokens=test_config.num_tokens,
num_tokens_across_dp=torch.tensor(
[test_config.num_tokens] * world_size, device="cpu", dtype=torch.int
),
vllm_config=vllm_config,
):
for lidx, fml in enumerate(fml_layers):
out_before_shuffle.append(
fml(hidden_states[lidx].clone(), router_logits[lidx].clone())
)

indices = torch.zeros(
test_config.num_layers, test_config.num_experts, dtype=torch.long
)
for lidx in range(test_config.num_layers):
indices[lidx] = torch.Tensor(range(test_config.num_experts))

shuffled_indices = torch.zeros_like(indices)
for lidx in range(test_config.num_layers):
shuffled_indices[lidx] = torch.randperm(test_config.num_experts)

rearrange_expert_weights_inplace(
indices,
shuffled_indices,
rank_expert_weights,
ep_group,
is_profile=False,
)

num_global_experts = test_config.num_experts

logical_to_physical_map_list = []
for lidx, fml in enumerate(fml_layers):
physical_to_logical_map = shuffled_indices[lidx].to(device)
logical_to_physical_map = torch.empty(
(num_global_experts,), dtype=torch.int32, device=device
)
logical_to_physical_map[physical_to_logical_map] = torch.arange(
0, num_global_experts, dtype=torch.int32, device=device
)
logical_to_physical_map_list.append(
logical_to_physical_map.reshape(num_global_experts, 1)
)

logical_to_physical_map = torch.stack(logical_to_physical_map_list)

for lidx, fml in enumerate(fml_layers):
logical_replica_count = torch.ones(
(test_config.num_layers, num_global_experts),
dtype=torch.int32,
device=device,
)
fml.enable_eplb = True
fml.set_eplb_state(
lidx,
torch.zeros(
(test_config.num_layers, num_global_experts),
dtype=torch.int32,
device=device,
),
logical_to_physical_map,
logical_replica_count,
)

out_after_shuffle = []
with set_forward_context(
{},
num_tokens=test_config.num_tokens,
num_tokens_across_dp=torch.tensor(
[test_config.num_tokens] * world_size, device="cpu", dtype=torch.int
),
vllm_config=vllm_config,
):
for lidx, fml in enumerate(fml_layers):
out_after_shuffle.append(
fml(hidden_states[lidx].clone(), router_logits[lidx].clone())
)

for lidx in range(test_config.num_layers):
torch.testing.assert_close(
out_before_shuffle[lidx], out_after_shuffle[lidx], atol=1e-1, rtol=1e-1
)


@pytest.mark.parametrize("world_size", [2, 4])
@pytest.mark.parametrize("num_layers", [8])
@pytest.mark.parametrize("num_experts", [32])
@pytest.mark.parametrize("hidden_size", [256])
@pytest.mark.parametrize("intermediate_size", [256])
@pytest.mark.parametrize("num_tokens", [256])
@pytest.mark.parametrize("backend", ["latency", "throughput"])
def test_eplb_fml(
world_size: int,
num_layers: int,
num_experts: int,
hidden_size: int,
intermediate_size: int,
num_tokens: int,
backend: str,
monkeypatch,
):
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1")
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", backend)

if torch.cuda.device_count() < world_size:
pytest.skip(f"Need at least {world_size} GPUs to run the test")

num_local_experts = num_experts // world_size
num_topk = 4

test_config = TestConfig(
num_layers=num_layers,
num_experts=num_experts,
num_local_experts=num_local_experts,
num_topk=num_topk,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_tokens=num_tokens,
)

distributed_run(
_test_eplb_fml,
world_size,
test_config,
)
31 changes: 26 additions & 5 deletions vllm/model_executor/layers/quantization/modelopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
build_flashinfer_fp4_cutlass_moe_prepare_finalize,
flashinfer_trtllm_fp4_moe,
flashinfer_trtllm_fp4_routed_moe,
prepare_static_weights_for_trtllm_fp4_moe,
reorder_w1w3_to_w3w1,
select_nvfp4_gemm_impl,
Expand Down Expand Up @@ -1342,7 +1343,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"Accuracy may be affected."
)

w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0].contiguous()
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)

# Common processing for input scales and alphas
Expand Down Expand Up @@ -1499,6 +1500,10 @@ def get_fused_moe_quant_config(
a2_gscale=layer.w2_input_scale_quant,
)

@property
def supports_eplb(self) -> bool:
return True

def apply(
self,
layer: FusedMoE,
Expand Down Expand Up @@ -1534,11 +1539,8 @@ def apply(
if (
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
and not enable_eplb
):
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `ModelOptNvFp4FusedMoE` yet."
)
return flashinfer_trtllm_fp4_moe(
layer=layer,
x=x,
Expand All @@ -1556,6 +1558,25 @@ def apply(
router_logits=router_logits,
)

# EPLB path
if (
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
):
# Pack top k ids and expert weights into a single int32 tensor, as
# required by TRT-LLM
packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
Copy link
Contributor

@IwakuraRein IwakuraRein Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe hide this packing operation in the flashinfer_trtllm_fp4_routed_moe. I.e., let flashinfer_trtllm_fp4_routed_moe take topk_ids and topk_weights directly, making its interface closer to Marlin’s.

Additionally, the packing will be removed in the flashinfer api in the near future so we can just pass topk_ids and topk_weights to flashinfer.

torch.bfloat16
).view(torch.int16)

return flashinfer_trtllm_fp4_routed_moe(
layer=layer,
x=x,
topk_ids=packed_tensor,
top_k=top_k,
global_num_experts=global_num_experts,
)
Comment on lines +1572 to +1578
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

To fix the hardcoded routing_method_type in flashinfer_trtllm_fp4_routed_moe, you need to pass the custom_routing_function to it. This will allow the function to dynamically determine the correct routing method.

Suggested change
return flashinfer_trtllm_fp4_routed_moe(
layer=layer,
x=x,
topk_ids=packed_tensor,
top_k=top_k,
global_num_experts=global_num_experts,
)
return flashinfer_trtllm_fp4_routed_moe(
layer=layer,
x=x,
topk_ids=packed_tensor,
top_k=top_k,
global_num_experts=global_num_experts,
custom_routing_function=custom_routing_function,
)


if self.use_marlin:
return fused_marlin_moe(
x,
Expand Down
Loading