From 8098bff321d247c0879806242f3f404bb0bf64f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=98=99=E5=A2=A8?= Date: Tue, 16 Jun 2026 20:35:01 +0800 Subject: [PATCH 1/3] feat(moe): add bf16 DeepEP-normal MoE path via DeepGEMM grouped GEMM Add an opt-in expert path for unquantized (bf16) MoE under DeepEP normal mode that uses DeepGEMM grouped GEMM instead of the Triton fused_moe_kernel. - DeepGemmBf16HybridExecutor: runtime-dispatches between a masked 3D layout (small token count / decode) and a contiguous flat layout (large token count / prefill) for better memory utilization. - ep_scatter_bf16 / ep_scatter_v2_bf16: bf16 variants of the existing fp8 scatter kernels (flat -> contiguous, flat -> 3D masked). - CudaNoQuantDpNormalDeepGemmStrategy: opt-in only, selected via --moe_strategy no_quant_dp_normal_deepgemm, gated on bf16 + has_deep_gemm + SM>=9 + no CUDA graph. It is NOT part of "auto" selection, so the default MoE path on existing CUDA deployments is unchanged. deepgemm_wrapper.py changes are backward-compatible and do not affect existing fp8/bf16 callers: - has_deep_gemm() re-checks until the first successful import (then caches True) instead of caching the first result; for normal processes where deep_gemm is importable at import time it returns True on the first call exactly as before. Needed for spawned subprocesses whose sys.path is set up after module import. - Symbol resolution is deferred from import-time to first use (_ensure_initialized); functionally identical, only lazy. - bf16 grouped-GEMM legacy fallback names corrected to the real deep_gemm symbols (gemm_bf16_bf16_bf16_nt*). resolve_symbol() tries the standard name first, so existing resolution is unchanged; this only makes the previously dormant bf16 path resolvable. Co-Authored-By: Claude Opus 4.8 --- .../kernels/cuda/deepgemm_wrapper.py | 59 ++- .../modules/factory/fused_moe/__init__.py | 2 + .../deepgemm_bf16_hybrid_executor.py | 354 ++++++++++++++++++ .../fused_moe/impl/cuda/strategy/__init__.py | 2 + .../fused_moe/impl/cuda/strategy/no_quant.py | 41 ++ .../triton_kernels/moe/ep_kernels.py | 216 +++++++++++ .../models_py/triton_kernels/moe/test/BUILD | 10 + .../moe/test/test_ep_scatter_bf16.py | 326 ++++++++++++++++ rtp_llm/server/server_args/moe_group_args.py | 1 + 9 files changed, 1001 insertions(+), 10 deletions(-) create mode 100644 rtp_llm/models_py/modules/factory/fused_moe/impl/cuda/executors/deepgemm_bf16_hybrid_executor.py create mode 100644 rtp_llm/models_py/triton_kernels/moe/test/test_ep_scatter_bf16.py diff --git a/rtp_llm/models_py/kernels/cuda/deepgemm_wrapper.py b/rtp_llm/models_py/kernels/cuda/deepgemm_wrapper.py index 401656de2b..1f7b859275 100644 --- a/rtp_llm/models_py/kernels/cuda/deepgemm_wrapper.py +++ b/rtp_llm/models_py/kernels/cuda/deepgemm_wrapper.py @@ -1,4 +1,5 @@ import functools +import importlib.util from contextlib import contextmanager from typing import Any, Callable, Generator, List, NoReturn, Optional, Tuple @@ -6,7 +7,7 @@ import triton import triton.language as tl -from rtp_llm.utils.module_util import has_module, resolve_symbol +from rtp_llm.utils.module_util import resolve_symbol __all__ = [ "fp8_gemm_nt", @@ -30,13 +31,18 @@ "m_grouped_bf16_gemm_nt_masked": "m_grouped_bf16_gemm_nt_masked", } +# Legacy/fallback symbol names. resolve_symbol() tries _new_map first and only +# falls back here, so existing fp8 resolution is unchanged. The bf16 entries are +# the real deep_gemm symbols (gemm_bf16_bf16_bf16_nt*); they only make the bf16 +# grouped-GEMM path (previously dormant) resolvable and do not affect any caller +# that already resolved via _new_map. _deep_gemm_impl_old_map = { "fp8_gemm_nt": "fp8_gemm_nt", "m_grouped_fp8_gemm_nt_contiguous": "m_grouped_fp8_gemm_nt_contiguous", "m_grouped_fp8_gemm_nt_masked": "fp8_m_grouped_gemm_nt_masked", - "bf16_gemm_nt": "bf16_gemm_nt", - "m_grouped_bf16_gemm_nt_contiguous": "m_grouped_bf16_gemm_nt_contiguous", - "m_grouped_bf16_gemm_nt_masked": "m_grouped_bf16_gemm_nt_masked", + "bf16_gemm_nt": "gemm_bf16_bf16_bf16_nt", + "m_grouped_bf16_gemm_nt_contiguous": "m_grouped_gemm_bf16_bf16_bf16_nt_contiguous", + "m_grouped_bf16_gemm_nt_masked": "m_grouped_gemm_bf16_bf16_bf16_nt_masked", } @@ -48,10 +54,23 @@ _m_grouped_bf16_gemm_nt_masked_impl: Callable[..., Any] | None = None -@functools.cache +_deep_gemm_available: bool | None = None + + def has_deep_gemm() -> bool: - """Whether the optional `deep_gemm` package is available.""" - return has_module("deep_gemm") + """Whether the optional `deep_gemm` package is available. + + Re-checks until first successful detection, then caches True. + This handles late sys.path setup in spawned subprocesses where + deep_gemm may not be importable at module-load time. + """ + global _deep_gemm_available + if _deep_gemm_available is True: + return True + available = importlib.util.find_spec("deep_gemm") is not None + if available: + _deep_gemm_available = True + return available @functools.cache @@ -62,6 +81,7 @@ def is_deep_gemm_e8m0_used() -> bool: @contextmanager def configure_deep_gemm_num_sms(num_sms: int) -> Generator[None, None, None]: """Configure the number of sms for deep gemm.""" + _ensure_initialized() if not has_deep_gemm(): raise RuntimeError( "DeepGEMM is not available. Please install the `deep_gemm` package to enable DeepGEMM kernels." @@ -135,7 +155,22 @@ def _lazy_init_deep_gemm_once(): ) -_lazy_init_deep_gemm_once() +_symbols_initialized = False + + +def _ensure_initialized(): + """Resolve deep_gemm symbols on first actual use (not at import time). + + Retries until deep_gemm becomes available on sys.path, which handles + spawned subprocesses where path setup happens after module import. + """ + global _symbols_initialized + if _symbols_initialized: + return + if not has_deep_gemm(): + return + _lazy_init_deep_gemm_once() + _symbols_initialized = True @triton.jit @@ -387,6 +422,7 @@ def fp8_gemm_nt( None """ global _fp8_gemm_nt_impl + _ensure_initialized() if _fp8_gemm_nt_impl is None: return _missing_deep_gemm() _fp8_gemm_nt_impl( @@ -424,6 +460,7 @@ def m_grouped_fp8_gemm_nt_contiguous( """ global _m_grouped_fp8_gemm_nt_contiguous_impl + _ensure_initialized() if _m_grouped_fp8_gemm_nt_contiguous_impl is None: return _missing_deep_gemm() _m_grouped_fp8_gemm_nt_contiguous_impl( @@ -487,6 +524,7 @@ def m_grouped_fp8_gemm_nt_masked( Defaults to None, which will be set to False if E8M0 scale is used, otherwise True. """ global _m_grouped_fp8_gemm_nt_masked_impl + _ensure_initialized() if _m_grouped_fp8_gemm_nt_masked_impl is None: return _missing_deep_gemm() @@ -527,6 +565,7 @@ def bf16_gemm_nt( compiled_dims (str, optional): Compiled dimensions. Defaults to "nk". """ global _bf16_gemm_nt_impl + _ensure_initialized() if _bf16_gemm_nt_impl is None: return _missing_deep_gemm() _bf16_gemm_nt_impl(a, b, output, c, compiled_dims) @@ -550,6 +589,7 @@ def m_grouped_bf16_gemm_nt_contiguous( compiled_dims (str, optional): Compiled dimensions. Defaults to "nk". """ global _m_grouped_bf16_gemm_nt_contiguous_impl + _ensure_initialized() if _m_grouped_bf16_gemm_nt_contiguous_impl is None: return _missing_deep_gemm() _m_grouped_bf16_gemm_nt_contiguous_impl( @@ -557,7 +597,6 @@ def m_grouped_bf16_gemm_nt_contiguous( b, output, m_indices, - compiled_dims, ) @@ -580,6 +619,7 @@ def m_grouped_bf16_gemm_nt_masked( compiled_dims (str, optional): Compiled dimensions. Defaults to "nk". """ global _m_grouped_bf16_gemm_nt_masked_impl + _ensure_initialized() if _m_grouped_bf16_gemm_nt_masked_impl is None: return _missing_deep_gemm() _m_grouped_bf16_gemm_nt_masked_impl( @@ -588,5 +628,4 @@ def m_grouped_bf16_gemm_nt_masked( output, masked_m, expected_m, - compiled_dims, ) diff --git a/rtp_llm/models_py/modules/factory/fused_moe/__init__.py b/rtp_llm/models_py/modules/factory/fused_moe/__init__.py index a5f0ff05fe..67ad48b200 100644 --- a/rtp_llm/models_py/modules/factory/fused_moe/__init__.py +++ b/rtp_llm/models_py/modules/factory/fused_moe/__init__.py @@ -77,6 +77,7 @@ CudaFp8PerTensorEpNormalStrategy, CudaFp8PerTensorNoDPStrategy, CudaNoQuantCppStrategy, + CudaNoQuantDpNormalDeepGemmStrategy, CudaNoQuantDpNormalStrategy, CudaNoQuantEpLowLatencyStrategy, CudaW4a8Int4PerChannelEpLowLatencyStrategy, @@ -95,6 +96,7 @@ registry.register(CudaFp8PerBlockNoDPStrategy()) registry.register(CudaFp8PerTensorNoDPStrategy()) registry.register(CudaNoQuantEpLowLatencyStrategy()) + registry.register(CudaNoQuantDpNormalDeepGemmStrategy()) registry.register(CudaNoQuantDpNormalStrategy()) registry.register(CudaNoQuantCppStrategy()) registry.register(BatchedTritonStrategy()) diff --git a/rtp_llm/models_py/modules/factory/fused_moe/impl/cuda/executors/deepgemm_bf16_hybrid_executor.py b/rtp_llm/models_py/modules/factory/fused_moe/impl/cuda/executors/deepgemm_bf16_hybrid_executor.py new file mode 100644 index 0000000000..eb602e5890 --- /dev/null +++ b/rtp_llm/models_py/modules/factory/fused_moe/impl/cuda/executors/deepgemm_bf16_hybrid_executor.py @@ -0,0 +1,354 @@ +"""DeepGemm BF16 Hybrid Executor for DeepEP Normal mode. + +Supports two compute paths selected at runtime by token count: + +- Masked (token_num <= masked_max_token_num, typical for decode): + ep_scatter_v2_bf16 → deepgemm bf16 masked (gate+up) → silu_mul_masked_bf16 + → deepgemm bf16 masked (down) → ep_gather + +- Contiguous (token_num > masked_max_token_num, typical for prefill): + ep_scatter_bf16 → deepgemm bf16 contiguous (gate+up) → silu_and_mul + → deepgemm bf16 contiguous (down) → ep_gather + +The masked path uses a 3D [E, alignment, K] layout where alignment = align(token_num, 128), +which is memory-efficient only when token_num is small (decode). For large token counts +(prefill), the alignment blows up to token_num, wasting E × token_num × K memory and +launching GEMM tiles over mostly-empty rows. + +The contiguous path uses a flat [Σ align(ei, 128), K] layout where each expert contributes +only its actual (padded) tokens, giving ~E× better memory utilization for prefill. +""" + +import logging +from typing import Any, Dict, Optional + +import torch + +from rtp_llm.models_py.kernels.cuda.deepgemm_wrapper import ( + configure_deep_gemm_num_sms, + m_grouped_bf16_gemm_nt_contiguous, + m_grouped_bf16_gemm_nt_masked, +) +from rtp_llm.models_py.modules.factory.fused_moe.defs.config_adapter import ( + MoEConfigAdapter, +) +from rtp_llm.models_py.modules.factory.fused_moe.defs.fused_moe import ( + CombineForwardPayload, + ExpertForwardPayload, + FusedMoeExpertExecutor, +) +from rtp_llm.models_py.modules.factory.fused_moe.defs.quant_config import ( + FusedMoEQuantConfig, +) +from rtp_llm.models_py.modules.factory.fused_moe.defs.type import ExecutorType +from rtp_llm.models_py.triton_kernels.common.activation import ( + silu_and_mul, + silu_mul_masked_bf16_no_post_quant_fwd, +) +from rtp_llm.models_py.triton_kernels.moe.ep_kernels import ( + ep_gather, + ep_scatter_bf16, + ep_scatter_v2_bf16, +) +from rtp_llm.models_py.utils.arch import get_num_device_sms, get_sm +from rtp_llm.models_py.utils.math import align, ceil_div +from rtp_llm.models_py.utils.memory import dispose_tensor +from rtp_llm.utils.model_weight import W + + +logger = logging.getLogger(__name__) + + +class DeepGemmBf16HybridExecutor(FusedMoeExpertExecutor): + """Executor for DeepEP Normal bf16 mode using deepgemm grouped GEMM. + + Dispatches between two paths at runtime based on token count: + - Masked (token_num <= masked_max_token_num): 3D layout, efficient for decode. + - Contiguous (token_num > masked_max_token_num): flat layout, efficient for prefill. + """ + + EXPERT_ALIGNMENT = 128 + DEEPGEMM_BLOCK_SHAPE: list[int] = [128, 128] + + @classmethod + def executor_type(cls) -> ExecutorType: + # Returns DEEPGEMM_MASKED as the nominal type; consumers use this for + # logging/registration only — actual dispatch (masked vs contiguous) is + # done at runtime inside execute() based on token count. + return ExecutorType.DEEPGEMM_MASKED + + @classmethod + def check_conditions(cls, checker: Any, config: MoEConfigAdapter) -> None: + from rtp_llm.models_py.kernels.cuda.deepgemm_wrapper import has_deep_gemm + from rtp_llm.models_py.modules.factory.fused_moe.utils.config_resolver import ( + MoeConfigResolver, + ) + + resolver = MoeConfigResolver() + quant_method = resolver.get_quant_method(config) + checker.check(quant_method is None) + checker.check(resolver.is_bf16(config)) + checker.check(has_deep_gemm()) + checker.check(get_sm()[0] >= 9) + checker.check(not config.enable_cuda_graph) + + def __init__( + self, + config: MoEConfigAdapter, + quant_config: FusedMoEQuantConfig, + weights: Dict[str, torch.Tensor], + ): + super().__init__(config, quant_config, weights) + + self.ep_size = config.ep_size + self.ep_rank = config.ep_rank + self.num_experts = config.expert_num + + assert self.num_experts % self.ep_size == 0 + self.num_experts_per_partition = self.num_experts // self.ep_size + self.start_expert_id = self.ep_rank * self.num_experts_per_partition + self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1 + + self.top_k = config.moe_k + self.activation = config.activation_type + self.masked_max_token_num = config.masked_max_token_num + + # Weight initialization (bf16, no quantization) + self.w1 = weights[W.moe_w1] + self.w2 = weights[W.moe_w2] + + self.num_local_experts, self.intermediate_size, self.hidden_size = self.w1.size() + assert self.intermediate_size % 2 == 0 + assert self.w2.size(0) == self.num_local_experts + assert self.w2.size(1) == self.hidden_size + assert self.w2.size(2) == self.intermediate_size // 2 + + self.num_gemm_sms = get_num_device_sms() + + def _to_local_expert_ids(self, topk_idx: torch.Tensor) -> torch.Tensor: + """Convert global expert IDs to partition-local IDs (0-based), -1 for out-of-partition.""" + local = topk_idx - self.start_expert_id + return torch.where( + (local >= 0) & (local < self.num_experts_per_partition), + local, + torch.tensor(-1, device=local.device, dtype=local.dtype), + ) + + def execute( + self, + payload: ExpertForwardPayload, + activation: str, + expert_map: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + extra_expert_args: Optional[dict[str, Any]], + ) -> CombineForwardPayload: + assert payload.expert_x is not None, "hidden_states is not initialized" + assert payload.expert_topk_ids is not None, "expert_topk_ids is not initialized" + assert payload.expert_topk_weights is not None, "expert_topk_weights is not initialized" + assert payload.expert_tokens_meta is not None, "expert_tokens_meta is not initialized" + assert payload.expert_tokens_meta.expert_num_tokens is not None + # Router weight is always applied at the gather stage (ep_gather). DeepEP Normal + # callers must pass apply_router_weight_on_input=False. + assert not apply_router_weight_on_input, ( + "DeepGemmBf16HybridExecutor applies router weight at gather; " + "apply_router_weight_on_input=True is not supported." + ) + + token_num = payload.expert_x.shape[0] + if token_num <= self.masked_max_token_num: + return self.execute_masked( + payload, activation, expert_map, a2_scale, apply_router_weight_on_input, extra_expert_args + ) + else: + return self.execute_contiguous( + payload, activation, expert_map, a2_scale, apply_router_weight_on_input, extra_expert_args + ) + + def execute_masked( + self, + payload: ExpertForwardPayload, + activation: str, + expert_map: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + extra_expert_args: Optional[dict[str, Any]], + ) -> CombineForwardPayload: + with configure_deep_gemm_num_sms(self.num_gemm_sms): + hidden_states = payload.expert_x + topk_idx = self._to_local_expert_ids(payload.expert_topk_ids) + topk_weights = payload.expert_topk_weights + num_recv_tokens_per_expert = payload.expert_tokens_meta.expert_num_tokens + + token_num = hidden_states.shape[0] + num_experts = num_recv_tokens_per_expert.shape[0] + max_token_num = token_num * self.top_k + token_num_mean_per_expert = ceil_div(max_token_num, num_experts) + alignment = align(token_num, self.EXPERT_ALIGNMENT) + expected_m = min(alignment, token_num_mean_per_expert) + + device = hidden_states.device + hidden_states_shape = hidden_states.shape + + # Step 1: Scatter flat [M, K] → 3D [E, alignment, K] + input_tensor = torch.empty( + (self.num_experts_per_partition, alignment, self.hidden_size), + device=device, + dtype=torch.bfloat16, + ) + output_index = torch.empty_like(topk_idx) + expert_start_loc = torch.empty_like(num_recv_tokens_per_expert) + + ep_scatter_v2_bf16( + hidden_states, + topk_idx, + alignment, + expert_start_loc, + input_tensor.view(self.num_experts_per_partition * alignment, self.hidden_size), + output_index, + ) + dispose_tensor(hidden_states) + + # Step 2: Gate and Up GEMM (deepgemm bf16 masked) + upgate_output = torch.empty( + (self.num_experts_per_partition, alignment, self.intermediate_size), + device=device, + dtype=torch.bfloat16, + ) + m_grouped_bf16_gemm_nt_masked( + input_tensor, + self.w1, + upgate_output, + num_recv_tokens_per_expert, + expected_m, + ) + dispose_tensor(input_tensor) + + # Step 3: SiLU Activation (masked bf16) + down_input = torch.empty( + (self.num_experts_per_partition, alignment, self.intermediate_size // 2), + device=device, + dtype=torch.bfloat16, + ) + silu_mul_masked_bf16_no_post_quant_fwd( + input=upgate_output, + output=down_input, + masked_m=num_recv_tokens_per_expert, + expected_m=expected_m, + group_size=self.DEEPGEMM_BLOCK_SHAPE[0], + ) + dispose_tensor(upgate_output) + + # Step 4: Down GEMM (deepgemm bf16 masked) + down_output = torch.empty( + (self.num_experts_per_partition, alignment, self.hidden_size), + device=device, + dtype=torch.bfloat16, + ) + m_grouped_bf16_gemm_nt_masked( + down_input, + self.w2, + down_output, + num_recv_tokens_per_expert, + expected_m, + ) + dispose_tensor(down_input) + + # Step 5: Gather 3D → flat, with router weight multiplication + gather_out = torch.empty(hidden_states_shape, device=device, dtype=torch.bfloat16) + ep_gather( + down_output.view(self.num_experts_per_partition * alignment, self.hidden_size), + topk_idx, + topk_weights, + output_index, + gather_out, + ) + dispose_tensor(down_output) + + return CombineForwardPayload(fused_expert_output=gather_out) + + def execute_contiguous( + self, + payload: ExpertForwardPayload, + activation: str, + expert_map: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + extra_expert_args: Optional[dict[str, Any]], + ) -> CombineForwardPayload: + """Large-token-count path: flat [all_tokens, K] layout, efficient for prefill. + + Each expert's tokens are packed contiguously (padded to EXPERT_ALIGNMENT), + so GEMM operates on dense data without wasting tiles over empty rows. + """ + hidden_states = payload.expert_x + topk_idx = self._to_local_expert_ids(payload.expert_topk_ids) + topk_weights = payload.expert_topk_weights + + # Get per-expert token counts as a Python list (needed for alignment arithmetic) + if payload.expert_tokens_meta.expert_num_tokens_cpu is not None: + tokens_per_expert_list = payload.expert_tokens_meta.expert_num_tokens_cpu + else: + tokens_per_expert_list = payload.expert_tokens_meta.expert_num_tokens.cpu().tolist() + if isinstance(tokens_per_expert_list, torch.Tensor): + tokens_per_expert_list = tokens_per_expert_list.tolist() + + # Align each expert's token count to EXPERT_ALIGNMENT (128) + aligned_tokens = [align(x, self.EXPERT_ALIGNMENT) for x in tokens_per_expert_list] + all_tokens = sum(aligned_tokens) + + device = hidden_states.device + hidden_states_shape = hidden_states.shape + + if all_tokens <= 0: + return CombineForwardPayload( + fused_expert_output=torch.zeros(hidden_states_shape, device=device, dtype=torch.bfloat16) + ) + + num_recv_tokens_per_expert_gpu = torch.tensor( + aligned_tokens, dtype=torch.int32, pin_memory=True, device="cpu" + ).cuda(non_blocking=True) + expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu) + m_indices = torch.empty(all_tokens, device=device, dtype=torch.int32) + output_index = torch.empty_like(topk_idx) + + # Step 1: Scatter flat [M, K] → expert-sorted flat [all_tokens, K] + input_tensor = torch.empty((all_tokens, self.hidden_size), device=device, dtype=torch.bfloat16) + ep_scatter_bf16( + hidden_states, + topk_idx, + num_recv_tokens_per_expert_gpu, + expert_start_loc, + input_tensor, + m_indices, + output_index, + ) + # ep_scatter_bf16 fills m_indices for occupied slots and leaves padding slots at 0 + # (from torch.empty initialization). clamp_ guards against any stale values in + # unoccupied trailing slots that deepgemm uses as the expert-index array. + m_indices.clamp_(min=0, max=self.num_experts_per_partition - 1) + dispose_tensor(hidden_states) + + # Step 2: Gate and Up GEMM (deepgemm bf16 contiguous) + gateup_output = torch.empty((all_tokens, self.intermediate_size), device=device, dtype=torch.bfloat16) + with configure_deep_gemm_num_sms(self.num_gemm_sms): + m_grouped_bf16_gemm_nt_contiguous(input_tensor, self.w1, gateup_output, m_indices) + dispose_tensor(input_tensor) + + # Step 3: SiLU activation (flat, no mask needed) + down_input = torch.empty((all_tokens, self.intermediate_size // 2), device=device, dtype=torch.bfloat16) + silu_and_mul(down_input, gateup_output) + dispose_tensor(gateup_output) + + # Step 4: Down GEMM (deepgemm bf16 contiguous) + down_output = torch.empty((all_tokens, self.hidden_size), device=device, dtype=torch.bfloat16) + with configure_deep_gemm_num_sms(self.num_gemm_sms): + m_grouped_bf16_gemm_nt_contiguous(down_input, self.w2, down_output, m_indices) + dispose_tensor(down_input) + + # Step 5: Gather flat → [M, K], apply router weights + gather_out = torch.empty(hidden_states_shape, device=device, dtype=torch.bfloat16) + ep_gather(down_output, topk_idx, topk_weights, output_index, gather_out) + dispose_tensor(down_output) + + return CombineForwardPayload(fused_expert_output=gather_out) diff --git a/rtp_llm/models_py/modules/factory/fused_moe/impl/cuda/strategy/__init__.py b/rtp_llm/models_py/modules/factory/fused_moe/impl/cuda/strategy/__init__.py index f0985001ca..32a6bc8816 100644 --- a/rtp_llm/models_py/modules/factory/fused_moe/impl/cuda/strategy/__init__.py +++ b/rtp_llm/models_py/modules/factory/fused_moe/impl/cuda/strategy/__init__.py @@ -20,6 +20,7 @@ ) from .no_quant import ( CudaNoQuantCppStrategy, + CudaNoQuantDpNormalDeepGemmStrategy, CudaNoQuantDpNormalStrategy, CudaNoQuantEpLowLatencyStrategy, ) @@ -33,6 +34,7 @@ "CudaNoQuantEpLowLatencyStrategy", "CudaNoQuantCppStrategy", "CudaNoQuantDpNormalStrategy", + "CudaNoQuantDpNormalDeepGemmStrategy", # FP8 PerBlock "CudaFp8PerBlockNoDPMaskedStrategy", "CudaFp8PerBlockNoDPStrategy", diff --git a/rtp_llm/models_py/modules/factory/fused_moe/impl/cuda/strategy/no_quant.py b/rtp_llm/models_py/modules/factory/fused_moe/impl/cuda/strategy/no_quant.py index 25ba75ed1c..81965b7782 100644 --- a/rtp_llm/models_py/modules/factory/fused_moe/impl/cuda/strategy/no_quant.py +++ b/rtp_llm/models_py/modules/factory/fused_moe/impl/cuda/strategy/no_quant.py @@ -96,3 +96,44 @@ def get_attributes(self) -> StrategyAttributes: executor_class=TritonFusedMoeExecutor, quant_config=quant_config, ) + + +class CudaNoQuantDpNormalDeepGemmStrategy(MoeStrategy): + """CUDA DeepEP Normal mode without quantization, using deepgemm bf16 grouped GEMM. + + Instead of fused_moe_kernel, this strategy uses the hybrid bf16 deepgemm path + (masked for decode, contiguous for prefill): + ep_scatter(_v2)_bf16 → deepgemm bf16 grouped GEMM → ep_gather + """ + + @classmethod + def check_conditions(cls, checker: Any, config: MoEConfigAdapter) -> None: + from rtp_llm.models_py.kernels.cuda.deepgemm_wrapper import has_deep_gemm + from rtp_llm.models_py.utils.arch import get_sm + + resolver = MoeConfigResolver() + quant_method = resolver.get_quant_method(config) + checker.check(quant_method is None) + # Opt-in only: must be explicitly requested via --moe_strategy. Not part of + # "auto" selection because the bf16 deepgemm path is not yet benchmarked on + # CUDA — keeping it out of auto avoids changing the default CUDA MoE path. + checker.check(config.moe_strategy == "no_quant_dp_normal_deepgemm") + checker.check(has_deep_gemm()) + checker.check(get_sm()[0] >= 9) + # executor dispatches masked/contiguous at runtime — incompatible with CUDA Graph replay + checker.check(not config.enable_cuda_graph) + + def get_attributes(self) -> StrategyAttributes: + from rtp_llm.models_py.modules.factory.fused_moe.impl.cuda.executors.deepgemm_bf16_hybrid_executor import ( + DeepGemmBf16HybridExecutor, + ) + from rtp_llm.models_py.modules.factory.fused_moe.impl.cuda.routers.deepep_normal_router import ( + DeepepNormalRouterNoQuant, + ) + + quant_config = FusedMoEQuantConfig(quant_dtype=None) + return StrategyAttributes( + router_class=DeepepNormalRouterNoQuant, + executor_class=DeepGemmBf16HybridExecutor, + quant_config=quant_config, + ) diff --git a/rtp_llm/models_py/triton_kernels/moe/ep_kernels.py b/rtp_llm/models_py/triton_kernels/moe/ep_kernels.py index 39532cb5ef..93441ae37c 100644 --- a/rtp_llm/models_py/triton_kernels/moe/ep_kernels.py +++ b/rtp_llm/models_py/triton_kernels/moe/ep_kernels.py @@ -181,6 +181,118 @@ def ep_scatter( return +@triton.jit +def _fwd_kernel_ep_scatter_2_bf16( + total_token_num, + expert_start_loc, + recv_x, + recv_x_stride0, + recv_x_stride1, + recv_topk, + recv_topk_stride0, + recv_topk_stride1, + output_tensor, + output_tensor_stride0, + output_tensor_stride1, + output_index, + output_index_stride0, + output_index_stride1, + topk_num: tl.constexpr, + num_experts: tl.constexpr, + HIDDEN_SIZE: tl.constexpr, + HIDDEN_SIZE_PAD: tl.constexpr, +): + """BF16 contiguous scatter kernel (no FP8 scale). + + Mirrors _fwd_kernel_ep_scatter_2 but for BF16 input without scale tensors. + Scatters tokens into expert-sorted flat layout for m_grouped_bf16_gemm_nt_contiguous. + """ + start_token_id = tl.program_id(0) + grid_num = tl.num_programs(0) + offset_in = tl.arange(0, HIDDEN_SIZE_PAD) + mask = offset_in < HIDDEN_SIZE + for token_id_int32 in range(start_token_id, total_token_num, grid_num): + token_id = token_id_int32.to(tl.int64) + to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask) + for topk_idx_int32 in tl.range(0, topk_num, 1, num_stages=4): + topk_index = topk_idx_int32.to(tl.int64) + expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + topk_index) + if expert_id >= 0 and expert_id < num_experts: + dest_token_index_int32 = tl.atomic_add(expert_start_loc + expert_id, 1) + dest_token_index = dest_token_index_int32.to(tl.int64) + tl.store( + output_index + token_id * output_index_stride0 + topk_index, + dest_token_index_int32, + ) + output_tensor_ptr = output_tensor + dest_token_index * output_tensor_stride0 + tl.store(output_tensor_ptr + offset_in, to_copy, mask=mask) + + +@torch.no_grad() +def ep_scatter_bf16( + recv_x: torch.Tensor, + recv_topk: torch.Tensor, + num_recv_tokens_per_expert: torch.Tensor, + expert_start_loc: torch.Tensor, + output_tensor: torch.Tensor, + m_indices: torch.Tensor, + output_index: torch.Tensor, +): + """BF16 contiguous scatter: flat [M, K] → expert-sorted flat [all_tokens, K]. + + Unlike ep_scatter_v2_bf16 (which produces a 3D masked layout for small token counts), + this produces a contiguous flat layout suited for large token counts (prefill). + Use with m_grouped_bf16_gemm_nt_contiguous. + + Args: + recv_x: Input tokens [M, K] in bf16. + recv_topk: Local expert assignments [M, topk] (local IDs, -1 = invalid). + num_recv_tokens_per_expert: Aligned token counts per local expert [E], GPU tensor. + expert_start_loc: Pre-allocated per-expert start locations [E], GPU tensor. + output_tensor: Output flat tensor [all_tokens, K] in bf16. + m_indices: Output expert index per slot [all_tokens]. + output_index: Output index mapping [M, topk]. + """ + BLOCK_E = 128 + num_warps = 8 + num_experts = num_recv_tokens_per_expert.shape[0] + hidden_size = recv_x.shape[1] + + assert m_indices.shape[0] % BLOCK_E == 0 + _fwd_kernel_ep_scatter_1[(num_experts,)]( + num_recv_tokens_per_expert, + expert_start_loc, + m_indices, + num_experts=num_experts, + num_warps=num_warps, + BLOCK_E=BLOCK_E, + BLOCK_EXPERT_NUM=triton.next_power_of_2(num_experts), + ) + grid = min(recv_topk.shape[0], 1024 * 8) + _fwd_kernel_ep_scatter_2_bf16[(grid,)]( + recv_topk.shape[0], + expert_start_loc, + recv_x, + recv_x.stride(0), + recv_x.stride(1), + recv_topk, + recv_topk.stride(0), + recv_topk.stride(1), + output_tensor, + output_tensor.stride(0), + output_tensor.stride(1), + output_index, + output_index.stride(0), + output_index.stride(1), + topk_num=recv_topk.shape[1], + num_experts=num_experts, + num_warps=num_warps, + HIDDEN_SIZE=hidden_size, + HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size), + ) + return + + @triton.jit def _fwd_kernel_ep_scatter_1_v2( alignment, @@ -330,6 +442,110 @@ def ep_scatter_v2( return +@triton.jit +def _fwd_kernel_ep_scatter_2_v2_bf16( + total_token_num, + expert_start_loc, + recv_x, + recv_x_stride0, + recv_x_stride1, + recv_topk, + recv_topk_stride0, + recv_topk_stride1, + output_tensor, + output_tensor_stride0, + output_tensor_stride1, + output_index, + output_index_stride0, + output_index_stride1, + topk_num: tl.constexpr, + num_experts: tl.constexpr, + alignment: tl.constexpr, + HIDDEN_SIZE: tl.constexpr, + HIDDEN_SIZE_PAD: tl.constexpr, +): + start_token_id = tl.program_id(0) + grid_num = tl.num_programs(0) + offset_in = tl.arange(0, HIDDEN_SIZE_PAD) + mask = offset_in < HIDDEN_SIZE + for token_id_int32 in range(start_token_id, total_token_num, grid_num): + token_id = token_id_int32.to(tl.int64) + to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask) + for topk_idx_int32 in tl.range(0, topk_num, 1, num_stages=4): + topk_index = topk_idx_int32.to(tl.int64) + expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + topk_index) + if expert_id >= 0 and expert_id < num_experts: + dest_token_index_int32 = tl.atomic_add(expert_start_loc + expert_id, 1) + dest_token_index = dest_token_index_int32.to(tl.int64) + tl.store( + output_index + token_id * output_index_stride0 + topk_index, + dest_token_index_int32, + ) + output_tensor_ptr = ( + output_tensor + dest_token_index * output_tensor_stride0 + ) + tl.store(output_tensor_ptr + offset_in, to_copy, mask=mask) + + +@torch.no_grad() +def ep_scatter_v2_bf16( + recv_x: torch.Tensor, + recv_topk: torch.Tensor, + alignment: int, + expert_start_loc: torch.Tensor, + output_tensor: torch.Tensor, + output_index: torch.Tensor, +): + """BF16 version of ep_scatter_v2 without FP8 scale handling. + + Scatters tokens from flat layout to 3D [E, alignment, K] layout + based on expert assignments, for use with deepgemm bf16 masked GEMM. + + Args: + recv_x: Input tokens [M, K] in bf16. + recv_topk: Expert assignments [M, topk]. + alignment: Token alignment per expert. + expert_start_loc: Per-expert start locations [E]. + output_tensor: Output 3D tensor [E * alignment, K] in bf16. + output_index: Output index mapping [M, topk]. + """ + num_warps = 8 + num_experts = expert_start_loc.shape[0] + hidden_size = recv_x.shape[1] + + _fwd_kernel_ep_scatter_1_v2[(1,)]( + alignment, + expert_start_loc, + num_experts=num_experts, + num_warps=num_warps, + BLOCK_EXPERT_NUM=triton.next_power_of_2(num_experts), + ) + grid = min(recv_topk.shape[0], 1024 * 8) + _fwd_kernel_ep_scatter_2_v2_bf16[(grid,)]( + recv_topk.shape[0], + expert_start_loc, + recv_x, + recv_x.stride(0), + recv_x.stride(1), + recv_topk, + recv_topk.stride(0), + recv_topk.stride(1), + output_tensor, + output_tensor.stride(0), + output_tensor.stride(1), + output_index, + output_index.stride(0), + output_index.stride(1), + topk_num=recv_topk.shape[1], + num_experts=num_experts, + alignment=alignment, + num_warps=num_warps, + HIDDEN_SIZE=hidden_size, + HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size), + ) + return + + @triton.jit def _fwd_kernel_ep_gather( total_token_num, diff --git a/rtp_llm/models_py/triton_kernels/moe/test/BUILD b/rtp_llm/models_py/triton_kernels/moe/test/BUILD index be0bb8fd2d..105667cf9a 100644 --- a/rtp_llm/models_py/triton_kernels/moe/test/BUILD +++ b/rtp_llm/models_py/triton_kernels/moe/test/BUILD @@ -22,3 +22,13 @@ py_test( visibility = ["//visibility:public"], exec_properties = {"gpu": "H20"}, ) + +py_test( + name = "test_ep_scatter_bf16", + srcs = ["test_ep_scatter_bf16.py"], + deps = [ + "//rtp_llm/models_py/triton_kernels:moe", + ] + [":triton", ":torch"], + visibility = ["//visibility:public"], + exec_properties = {"gpu": "H20"}, +) diff --git a/rtp_llm/models_py/triton_kernels/moe/test/test_ep_scatter_bf16.py b/rtp_llm/models_py/triton_kernels/moe/test/test_ep_scatter_bf16.py new file mode 100644 index 0000000000..7e660b852e --- /dev/null +++ b/rtp_llm/models_py/triton_kernels/moe/test/test_ep_scatter_bf16.py @@ -0,0 +1,326 @@ +"""Unit tests for bf16 EP scatter kernels. + +ep_scatter_v2_bf16: flat [M, K] → 3D [E, alignment, K], for masked deepgemm (decode). +ep_scatter_bf16: flat [M, K] → flat expert-sorted [all_tokens, K], for contiguous deepgemm (prefill). + +Run with: + python -m pytest github-opensource/rtp_llm/models_py/triton_kernels/moe/test/test_ep_scatter_bf16.py -v +""" + +import unittest + +import torch + +from rtp_llm.models_py.triton_kernels.moe.ep_kernels import ( + ep_scatter_v2_bf16, + ep_scatter_bf16, + ep_gather, +) +from rtp_llm.models_py.utils.math import align + + +class TestEpScatterV2Bf16(unittest.TestCase): + """Tests for ep_scatter_v2_bf16 kernel.""" + + def setUp(self) -> None: + if not torch.cuda.is_available(): + self.skipTest("CUDA required for triton kernel test") + self.device = torch.device("cuda") + + def _reference_scatter( + self, + recv_x: torch.Tensor, + recv_topk: torch.Tensor, + num_experts: int, + alignment: int, + ): + """CPU reference: scatter tokens to [E, alignment, K] based on topk assignments.""" + token_num, hidden_size = recv_x.shape + topk_num = recv_topk.shape[1] + output = torch.zeros( + (num_experts, alignment, hidden_size), + dtype=recv_x.dtype, + device=recv_x.device, + ) + output_index = torch.full_like(recv_topk, -1) + expert_counts = torch.zeros(num_experts, dtype=torch.int32, device=recv_x.device) + + for token_idx in range(token_num): + for topk_idx in range(topk_num): + expert_id = recv_topk[token_idx, topk_idx].item() + if 0 <= expert_id < num_experts: + slot = expert_counts[expert_id].item() + output[expert_id, slot, :] = recv_x[token_idx, :] + output_index[token_idx, topk_idx] = expert_id * alignment + slot + expert_counts[expert_id] += 1 + + return output, output_index + + def test_basic_scatter(self) -> None: + """Basic scatter: each token assigned to one expert.""" + num_experts = 4 + hidden_size = 256 + token_num = 8 + topk = 1 + alignment = align(token_num, 128) + + recv_x = torch.randn(token_num, hidden_size, dtype=torch.bfloat16, device=self.device) + recv_topk = torch.tensor( + [[0], [1], [2], [3], [0], [1], [2], [3]], + dtype=torch.int32, + device=self.device, + ) + + output_tensor = torch.zeros( + num_experts * alignment, hidden_size, dtype=torch.bfloat16, device=self.device + ) + output_index = torch.empty_like(recv_topk) + expert_start_loc = torch.empty(num_experts, dtype=torch.int32, device=self.device) + + ep_scatter_v2_bf16(recv_x, recv_topk, alignment, expert_start_loc, output_tensor, output_index) + + ref_output, ref_index = self._reference_scatter(recv_x, recv_topk, num_experts, alignment) + ref_output_flat = ref_output.view(num_experts * alignment, hidden_size) + + self.assertTrue( + torch.equal(output_tensor, ref_output_flat), + f"scatter output mismatch", + ) + self.assertTrue( + torch.equal(output_index, ref_index), + f"index mismatch:\n got={output_index}\n want={ref_index}", + ) + + def test_topk_scatter(self) -> None: + """Scatter with topk=2: each token placed in two expert slots.""" + num_experts = 4 + hidden_size = 128 + token_num = 4 + topk = 2 + alignment = align(token_num, 128) + + recv_x = torch.randn(token_num, hidden_size, dtype=torch.bfloat16, device=self.device) + recv_topk = torch.tensor( + [[0, 1], [2, 3], [0, 2], [1, 3]], + dtype=torch.int32, + device=self.device, + ) + + output_tensor = torch.zeros( + num_experts * alignment, hidden_size, dtype=torch.bfloat16, device=self.device + ) + output_index = torch.empty_like(recv_topk) + expert_start_loc = torch.empty(num_experts, dtype=torch.int32, device=self.device) + + ep_scatter_v2_bf16(recv_x, recv_topk, alignment, expert_start_loc, output_tensor, output_index) + + ref_output, ref_index = self._reference_scatter(recv_x, recv_topk, num_experts, alignment) + ref_output_flat = ref_output.view(num_experts * alignment, hidden_size) + + self.assertTrue( + torch.equal(output_tensor, ref_output_flat), + "scatter output mismatch for topk=2", + ) + self.assertTrue( + torch.equal(output_index, ref_index), + f"index mismatch for topk=2", + ) + + def test_negative_expert_ids_skipped(self) -> None: + """Expert id = -1 should be skipped (not scattered).""" + num_experts = 4 + hidden_size = 128 + token_num = 4 + topk = 2 + alignment = align(token_num, 128) + + recv_x = torch.randn(token_num, hidden_size, dtype=torch.bfloat16, device=self.device) + recv_topk = torch.tensor( + [[0, -1], [-1, 2], [1, 3], [-1, -1]], + dtype=torch.int32, + device=self.device, + ) + + output_tensor = torch.zeros( + num_experts * alignment, hidden_size, dtype=torch.bfloat16, device=self.device + ) + output_index = torch.empty_like(recv_topk) + expert_start_loc = torch.empty(num_experts, dtype=torch.int32, device=self.device) + + ep_scatter_v2_bf16(recv_x, recv_topk, alignment, expert_start_loc, output_tensor, output_index) + + # Expert 0 should have token 0, expert 1 should have token 2, + # expert 2 should have token 1, expert 3 should have token 2 + output_3d = output_tensor.view(num_experts, alignment, hidden_size) + self.assertTrue(torch.equal(output_3d[0, 0], recv_x[0])) + self.assertTrue(torch.equal(output_3d[2, 0], recv_x[1])) + self.assertTrue(torch.equal(output_3d[1, 0], recv_x[2])) + self.assertTrue(torch.equal(output_3d[3, 0], recv_x[2])) + + def test_scatter_gather_roundtrip(self) -> None: + """Scatter then gather should recover original (weighted) values.""" + num_experts = 4 + hidden_size = 256 + token_num = 16 + topk = 2 + alignment = align(token_num, 128) + + torch.manual_seed(42) + recv_x = torch.randn(token_num, hidden_size, dtype=torch.bfloat16, device=self.device) + + # Random expert assignments + recv_topk = torch.randint(0, num_experts, (token_num, topk), dtype=torch.int32, device=self.device) + topk_weights = torch.ones(token_num, topk, dtype=torch.float32, device=self.device) / topk + + output_tensor = torch.zeros( + num_experts * alignment, hidden_size, dtype=torch.bfloat16, device=self.device + ) + output_index = torch.empty_like(recv_topk) + expert_start_loc = torch.empty(num_experts, dtype=torch.int32, device=self.device) + + ep_scatter_v2_bf16(recv_x, recv_topk, alignment, expert_start_loc, output_tensor, output_index) + + # Gather back (identity transform: no GEMM, just scatter → gather) + gather_out = torch.zeros_like(recv_x) + ep_gather(output_tensor, recv_topk, topk_weights, output_index, gather_out) + + # Each token was scattered topk times with weight 1/topk, so gather sum = original + self.assertTrue( + torch.allclose(gather_out.float(), recv_x.float(), atol=1e-2), + f"roundtrip mismatch: max diff = {(gather_out.float() - recv_x.float()).abs().max().item()}", + ) + + def test_large_random(self) -> None: + """Stress test with larger dimensions.""" + num_experts = 8 + hidden_size = 2048 + token_num = 256 + topk = 4 + alignment = align(token_num, 128) + + torch.manual_seed(123) + recv_x = torch.randn(token_num, hidden_size, dtype=torch.bfloat16, device=self.device) + recv_topk = torch.randint(0, num_experts, (token_num, topk), dtype=torch.int32, device=self.device) + + output_tensor = torch.zeros( + num_experts * alignment, hidden_size, dtype=torch.bfloat16, device=self.device + ) + output_index = torch.empty_like(recv_topk) + expert_start_loc = torch.empty(num_experts, dtype=torch.int32, device=self.device) + + ep_scatter_v2_bf16(recv_x, recv_topk, alignment, expert_start_loc, output_tensor, output_index) + + ref_output, ref_index = self._reference_scatter(recv_x, recv_topk, num_experts, alignment) + ref_output_flat = ref_output.view(num_experts * alignment, hidden_size) + + self.assertTrue( + torch.equal(output_tensor, ref_output_flat), + f"large random scatter mismatch", + ) + + +class TestEpScatterBf16Contiguous(unittest.TestCase): + """Tests for ep_scatter_bf16 — flat [M, K] → expert-sorted flat [all_tokens, K].""" + + def setUp(self) -> None: + if not torch.cuda.is_available(): + self.skipTest("CUDA required") + self.device = torch.device("cuda") + + def _run_scatter(self, recv_x, recv_topk, num_experts, aligned_tokens_per_expert): + """Run ep_scatter_bf16 and return (output_tensor, m_indices, output_index).""" + all_tokens = sum(aligned_tokens_per_expert) + hidden_size = recv_x.shape[1] + num_recv_tokens_per_expert_gpu = torch.tensor( + aligned_tokens_per_expert, dtype=torch.int32, device=self.device + ) + expert_start_loc = torch.empty(num_experts, dtype=torch.int32, device=self.device) + output_tensor = torch.zeros(all_tokens, hidden_size, dtype=torch.bfloat16, device=self.device) + m_indices = torch.zeros(all_tokens, dtype=torch.int32, device=self.device) + output_index = torch.empty_like(recv_topk) + ep_scatter_bf16( + recv_x, recv_topk, + num_recv_tokens_per_expert_gpu, expert_start_loc, + output_tensor, m_indices, output_index, + ) + return output_tensor, m_indices, output_index + + def test_basic_contiguous_scatter(self) -> None: + """Each expert gets exactly one token; verify output ordering.""" + num_experts = 4 + hidden_size = 128 + alignment = 128 + recv_x = torch.randn(4, hidden_size, dtype=torch.bfloat16, device=self.device) + recv_topk = torch.tensor([[0], [1], [2], [3]], dtype=torch.int32, device=self.device) + aligned_counts = [alignment] * num_experts # 1 token padded to 128 + + output, m_indices, output_index = self._run_scatter( + recv_x, recv_topk, num_experts, aligned_counts + ) + # Expert 0 token should appear in slot 0..127, expert 1 in 128..255, etc. + self.assertEqual(output.shape[0], alignment * num_experts) + # m_indices for slot 0 should be expert 0 + self.assertEqual(m_indices[0].item(), 0) + self.assertEqual(m_indices[alignment].item(), 1) + + def test_negative_expert_id_skipped(self) -> None: + """expert_id = -1 slots should not be scattered.""" + num_experts = 2 + hidden_size = 64 + alignment = 128 + recv_x = torch.randn(4, hidden_size, dtype=torch.bfloat16, device=self.device) + recv_topk = torch.tensor([[0], [-1], [1], [-1]], dtype=torch.int32, device=self.device) + aligned_counts = [alignment, alignment] + + output, m_indices, output_index = self._run_scatter( + recv_x, recv_topk, num_experts, aligned_counts + ) + # Expert 0 gets token 0, expert 1 gets token 2 + self.assertTrue(torch.equal(output[0], recv_x[0])) + self.assertTrue(torch.equal(output[alignment], recv_x[2])) + + def test_scatter_gather_roundtrip(self) -> None: + """Scatter then gather with uniform weights recovers original.""" + num_experts = 4 + hidden_size = 256 + token_num = 16 + topk = 2 + alignment = 128 + torch.manual_seed(7) + recv_x = torch.randn(token_num, hidden_size, dtype=torch.bfloat16, device=self.device) + recv_topk = torch.randint(0, num_experts, (token_num, topk), dtype=torch.int32, device=self.device) + topk_weights = torch.ones(token_num, topk, dtype=torch.float32, device=self.device) / topk + aligned_counts = [alignment] * num_experts + + output, _, output_index = self._run_scatter( + recv_x, recv_topk, num_experts, aligned_counts + ) + gather_out = torch.zeros_like(recv_x) + ep_gather(output, recv_topk, topk_weights, output_index, gather_out) + self.assertTrue( + torch.allclose(gather_out.float(), recv_x.float(), atol=1e-2), + f"roundtrip max diff={( gather_out.float() - recv_x.float()).abs().max():.4f}", + ) + + def test_large_random(self) -> None: + """Stress: large token count, multiple experts.""" + num_experts = 8 + hidden_size = 1024 + token_num = 512 + topk = 4 + alignment = 128 + torch.manual_seed(99) + recv_x = torch.randn(token_num, hidden_size, dtype=torch.bfloat16, device=self.device) + recv_topk = torch.randint(0, num_experts, (token_num, topk), dtype=torch.int32, device=self.device) + aligned_counts = [alignment] * num_experts + + output, m_indices, _ = self._run_scatter( + recv_x, recv_topk, num_experts, aligned_counts + ) + self.assertEqual(output.shape, (alignment * num_experts, hidden_size)) + self.assertTrue(m_indices.min().item() >= 0) + self.assertTrue(m_indices.max().item() < num_experts) + + +if __name__ == "__main__": + unittest.main() diff --git a/rtp_llm/server/server_args/moe_group_args.py b/rtp_llm/server/server_args/moe_group_args.py index c899f50964..d5889e9d5c 100644 --- a/rtp_llm/server/server_args/moe_group_args.py +++ b/rtp_llm/server/server_args/moe_group_args.py @@ -166,6 +166,7 @@ def init_moe_group_args(parser, moe_config, eplb_config, deep_ep_config): "no_auant_ep_low_latency", "no_auant_cpp", "no_auant_dp_normal", + "no_quant_dp_normal_deepgemm", "fp8_per_block_no_dp_masked", "fp8_per_block_no_dp", "fp8_per_block_ep_low_latency", From 68b0a4426321f1c2d39c8085fe0836313c47c376 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=98=99=E5=A2=A8?= Date: Wed, 17 Jun 2026 11:29:25 +0800 Subject: [PATCH 2/3] test(moe): order-independent bf16 scatter tests + bf16 hybrid executor e2e test_ep_scatter_bf16: the scatter kernels assign output slots with a non-deterministic tl.atomic_add, so row order within an expert is not fixed. Rewrite all checks to be order-independent by following output_index (the authoritative token->slot map the gather stage uses) instead of assuming a token-sequential layout. This replaces the previously order-sensitive torch.equal comparisons that could spuriously fail. Also: - fix the roundtrip tests to use hidden_size % 512 == 0 (ep_gather BLOCK_D=512); - size the contiguous stress test's per-expert capacity from the real routing histogram (bincount, aligned) instead of a fixed count, matching the executor's allocation and avoiding under-allocation. deepgemm_bf16_hybrid_executor: new end-to-end test for the bf16 DeepEP-normal hybrid executor (scatter -> grouped GEMM -> silu_and_mul -> grouped GEMM -> gather with router weight), covering both the masked (small token count) and contiguous (large token count) runtime paths against a plain-torch reference. Tagged open_skip + H20 (requires deep_gemm + SM>=9). Co-Authored-By: Claude Opus 4.8 (1M context) --- .../fused_moe/impl/cuda/executors/test/BUILD | 13 + .../deepgemm_bf16_hybrid_executor_test.py | 225 +++++++++++ .../moe/test/test_ep_scatter_bf16.py | 370 ++++++++++++------ 3 files changed, 484 insertions(+), 124 deletions(-) create mode 100644 rtp_llm/models_py/modules/factory/fused_moe/impl/cuda/executors/test/deepgemm_bf16_hybrid_executor_test.py diff --git a/rtp_llm/models_py/modules/factory/fused_moe/impl/cuda/executors/test/BUILD b/rtp_llm/models_py/modules/factory/fused_moe/impl/cuda/executors/test/BUILD index 219a98acb3..5f06dbf237 100644 --- a/rtp_llm/models_py/modules/factory/fused_moe/impl/cuda/executors/test/BUILD +++ b/rtp_llm/models_py/modules/factory/fused_moe/impl/cuda/executors/test/BUILD @@ -21,6 +21,19 @@ py_test ( exec_properties = {'gpu':'H20', 'gpu_count':'1'}, ) +py_test ( + name = "deepgemm_bf16_hybrid_executor_test", + srcs = ["deepgemm_bf16_hybrid_executor_test.py"], + deps = [ + "//rtp_llm/models_py:modules", + "//rtp_llm:config", + "//rtp_llm:testlib", + ], + env = {"GPU_COUNT": "1"}, + tags = ["open_skip", "H20"], + exec_properties = {'gpu':'H20', 'gpu_count':'1'}, +) + py_test ( name = "deepgemm_masked_executor_sm100_test", srcs = ["deepgemm_masked_executor_test.py", diff --git a/rtp_llm/models_py/modules/factory/fused_moe/impl/cuda/executors/test/deepgemm_bf16_hybrid_executor_test.py b/rtp_llm/models_py/modules/factory/fused_moe/impl/cuda/executors/test/deepgemm_bf16_hybrid_executor_test.py new file mode 100644 index 0000000000..dac42542a8 --- /dev/null +++ b/rtp_llm/models_py/modules/factory/fused_moe/impl/cuda/executors/test/deepgemm_bf16_hybrid_executor_test.py @@ -0,0 +1,225 @@ +"""End-to-end correctness test for DeepGemmBf16HybridExecutor. + +Validates the full bf16 DeepEP-normal expert path (scatter -> deepgemm grouped GEMM +-> silu_and_mul -> deepgemm grouped GEMM -> gather with router weight) against a +plain-torch reference, for BOTH runtime paths: + +- masked (token_num <= masked_max_token_num): 3D [E, alignment, K] layout +- contiguous (token_num > masked_max_token_num): flat [Σ align(ei, 128), K] layout + +Unlike DeepGemmMaskedExecutor (which consumes a pre-grouped 3D payload), the hybrid +executor consumes a FLAT [M, K] DeepEP-normal payload and does the scatter itself, so +this test builds a flat payload and a flat reference rather than reusing +fused_moe_executor_test_util. + +Requires the optional `deep_gemm` package; skips otherwise. Tagged H20 so internal CI +routes it to an SM>=9 worker where deep_gemm is available; open_skip keeps it out of the +open-source public CI lane. +""" + +import unittest +from typing import Dict, Tuple + +import torch + +from rtp_llm.config.model_config import ModelConfig +from rtp_llm.models_py.kernels.cuda.deepgemm_wrapper import has_deep_gemm +from rtp_llm.models_py.modules.factory.fused_moe.defs.config_adapter import ( + MoEConfigAdapter, +) +from rtp_llm.models_py.modules.factory.fused_moe.defs.fused_moe import ( + ExpertForwardPayload, + ExpertTokensMetadata, +) +from rtp_llm.models_py.modules.factory.fused_moe.defs.quant_config import ( + FusedMoEQuantConfig, +) +from rtp_llm.models_py.modules.factory.fused_moe.impl.cuda.executors.deepgemm_bf16_hybrid_executor import ( + DeepGemmBf16HybridExecutor, +) +from rtp_llm.ops import MoeConfig, ParallelismConfig +from rtp_llm.test.utils.numeric_util import calc_diff +from rtp_llm.utils.model_weight import W + + +class DeepGemmBf16HybridExecutorTestBase: + NUM_EXPERTS = 8 + TOP_K = 4 + HIDDEN_SIZE = 2048 + MOE_INTERMEDIATE_SIZE = 768 # N = 2 * 768 = 1536 + MASKED_MAX_TOKEN_NUM = 256 + + @property + def N(self) -> int: + return self.MOE_INTERMEDIATE_SIZE * 2 + + def setUp(self): + if not torch.cuda.is_available(): + self.skipTest("CUDA required") + if not has_deep_gemm(): + self.skipTest("deep_gemm package required") + torch.manual_seed(42) + torch.cuda.manual_seed(42) + + def _generate_config(self) -> MoEConfigAdapter: + model_config = ModelConfig() + model_config.attn_config.head_num = 2 + model_config.attn_config.size_per_head = 128 + model_config.num_layers = 2 + model_config.max_seq_len = 2048 + model_config.vocab_size = 500000 + model_config.expert_num = self.NUM_EXPERTS + model_config.hidden_size = self.HIDDEN_SIZE + model_config.moe_inter_size = self.MOE_INTERMEDIATE_SIZE + model_config.moe_k = self.TOP_K + + parallelism_config = ParallelismConfig() + parallelism_config.world_size = 1 + parallelism_config.dp_size = 1 + parallelism_config.tp_size = 1 + parallelism_config.ep_size = 1 + parallelism_config.dp_rank = 0 + parallelism_config.tp_rank = 0 + parallelism_config.ep_rank = 0 + parallelism_config.world_rank = 0 + parallelism_config.local_world_size = 1 + + moe_config = MoeConfig() + moe_config.masked_max_token_num = self.MASKED_MAX_TOKEN_NUM + return MoEConfigAdapter( + model_config=model_config, + parallelism_config=parallelism_config, + moe_config=moe_config, + ) + + def _make_payload_and_weights( + self, token_num: int + ) -> Tuple[ExpertForwardPayload, Dict[str, torch.Tensor], torch.Tensor]: + device = "cuda" + K = self.HIDDEN_SIZE + N = self.N + num_experts = self.NUM_EXPERTS + + expert_x = ( + torch.rand((token_num, K), device=device, dtype=torch.float32) * 0.1 - 0.05 + ).to(torch.bfloat16) + + # Each token routes to TOP_K *distinct* experts so per-expert load stays <= token_num + # (keeps the masked path's per-expert count within alignment = align(token_num, 128)). + topk_ids = torch.empty( + (token_num, self.TOP_K), device=device, dtype=torch.int32 + ) + for t in range(token_num): + topk_ids[t] = torch.randperm(num_experts, device=device)[: self.TOP_K].to( + torch.int32 + ) + topk_weights = torch.rand( + (token_num, self.TOP_K), device=device, dtype=torch.float32 + ) + + counts = torch.bincount( + topk_ids.flatten().to(torch.int64), minlength=num_experts + ).to(torch.int32) + + payload = ExpertForwardPayload( + expert_x=expert_x, + expert_x_scale=None, + expert_x_origin_dtype=torch.bfloat16, + expert_topk_ids=topk_ids, + expert_topk_weights=topk_weights, + expert_tokens_meta=ExpertTokensMetadata( + expert_num_tokens=counts, + expert_num_tokens_cpu=counts.tolist(), + ), + ) + + weights = { + W.moe_w1: ( + torch.rand((num_experts, N, K), device=device, dtype=torch.float32) * 2 + - 1 + ).to(torch.bfloat16), + W.moe_w2: ( + torch.rand((num_experts, K, N // 2), device=device, dtype=torch.float32) + * 2 + - 1 + ).to(torch.bfloat16), + W.moe_s1: None, + W.moe_s2: None, + } + # Keep a pristine copy of the inputs for the reference (executor disposes tensors). + ref_input = expert_x.clone() + return payload, weights, ref_input + + def _reference( + self, + expert_x: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + weights: Dict[str, torch.Tensor], + ) -> torch.Tensor: + """Plain-torch flat MoE reference, gate/value split matching the deepgemm path.""" + M, K = expert_x.shape + N = self.N + w1 = weights[W.moe_w1] + w2 = weights[W.moe_w2] + out = torch.zeros((M, K), device=expert_x.device, dtype=torch.float32) + for expert in range(self.NUM_EXPERTS): + sel = topk_ids == expert # [M, TOP_K] + if not bool(sel.any()): + continue + tok_idx, k_idx = sel.nonzero(as_tuple=True) + x = expert_x[tok_idx] # [n, K] bf16 + ws1 = x @ w1[expert].transpose(0, 1) # [n, N] + gate = ws1[..., N // 2 :].to(torch.float32) + value = ws1[..., : N // 2].to(torch.float32) + gate = gate * (1.0 / (1.0 + torch.exp(-gate))) + ws2 = (gate * value).to(torch.bfloat16) # [n, N//2] + down = ws2 @ w2[expert].transpose(0, 1) # [n, K] + weight = topk_weights[tok_idx, k_idx].unsqueeze(1).to(torch.float32) + out.index_add_(0, tok_idx, down.to(torch.float32) * weight) + return out.to(torch.bfloat16) + + def _run_path(self, token_num: int, expect_masked: bool) -> None: + config = self._generate_config() + self.assertEqual( + token_num <= config.masked_max_token_num, + expect_masked, + "token_num does not select the intended path", + ) + + payload, weights, ref_input = self._make_payload_and_weights(token_num) + ref_output = self._reference( + ref_input, payload.expert_topk_ids, payload.expert_topk_weights, weights + ) + + executor = DeepGemmBf16HybridExecutor( + config, FusedMoEQuantConfig(quant_dtype=None), weights + ) + combine_payload = executor.execute(payload, "silu", None, None, False, None) + out = combine_payload.fused_expert_output + + self.assertEqual(out.shape, ref_output.shape) + # bf16 grouped GEMM + fp32-accumulated topk-weighted gather vs torch reference; + # looser than the single-expert masked executor test (<0.003). + diff = calc_diff(out, ref_output) + self.assertLess( + diff, 0.01, f"output diff {diff} too large (token_num={token_num})" + ) + + def test_masked_path(self): + # token_num <= masked_max_token_num -> 3D masked layout (decode) + self._run_path(token_num=128, expect_masked=True) + + def test_contiguous_path(self): + # token_num > masked_max_token_num -> flat contiguous layout (prefill) + self._run_path(token_num=512, expect_masked=False) + + +class DeepGemmBf16HybridExecutorTest( + DeepGemmBf16HybridExecutorTestBase, unittest.TestCase +): + pass + + +if __name__ == "__main__": + unittest.main() diff --git a/rtp_llm/models_py/triton_kernels/moe/test/test_ep_scatter_bf16.py b/rtp_llm/models_py/triton_kernels/moe/test/test_ep_scatter_bf16.py index 7e660b852e..d281dd7877 100644 --- a/rtp_llm/models_py/triton_kernels/moe/test/test_ep_scatter_bf16.py +++ b/rtp_llm/models_py/triton_kernels/moe/test/test_ep_scatter_bf16.py @@ -3,6 +3,13 @@ ep_scatter_v2_bf16: flat [M, K] → 3D [E, alignment, K], for masked deepgemm (decode). ep_scatter_bf16: flat [M, K] → flat expert-sorted [all_tokens, K], for contiguous deepgemm (prefill). +Both scatter kernels assign output slots with a non-deterministic tl.atomic_add, so the +row ORDER within an expert is not fixed across runs. All correctness checks here are +therefore order-independent: they follow output_index (the authoritative token->slot map +the gather stage uses) rather than assuming a token-sequential layout. + +ep_gather's BLOCK_D is 512, so any test that gathers must use hidden_size % 512 == 0. + Run with: python -m pytest github-opensource/rtp_llm/models_py/triton_kernels/moe/test/test_ep_scatter_bf16.py -v """ @@ -12,83 +19,98 @@ import torch from rtp_llm.models_py.triton_kernels.moe.ep_kernels import ( - ep_scatter_v2_bf16, - ep_scatter_bf16, ep_gather, + ep_scatter_bf16, + ep_scatter_v2_bf16, ) from rtp_llm.models_py.utils.math import align class TestEpScatterV2Bf16(unittest.TestCase): - """Tests for ep_scatter_v2_bf16 kernel.""" + """Tests for ep_scatter_v2_bf16 — flat [M, K] → 3D [E, alignment, K] (masked).""" def setUp(self) -> None: if not torch.cuda.is_available(): self.skipTest("CUDA required for triton kernel test") self.device = torch.device("cuda") - def _reference_scatter( - self, - recv_x: torch.Tensor, - recv_topk: torch.Tensor, - num_experts: int, - alignment: int, + @staticmethod + def _distinct_topk(token_num, topk, num_experts, device): + """One row per token with topk *distinct* experts (keeps per-expert count <= token_num).""" + ids = torch.empty((token_num, topk), dtype=torch.int32, device=device) + for t in range(token_num): + ids[t] = torch.randperm(num_experts, device=device)[:topk].to(torch.int32) + return ids + + def _verify_v2( + self, recv_x, recv_topk, num_experts, alignment, output_flat, output_index ): - """CPU reference: scatter tokens to [E, alignment, K] based on topk assignments.""" - token_num, hidden_size = recv_x.shape - topk_num = recv_topk.shape[1] - output = torch.zeros( - (num_experts, alignment, hidden_size), - dtype=recv_x.dtype, - device=recv_x.device, - ) - output_index = torch.full_like(recv_topk, -1) - expert_counts = torch.zeros(num_experts, dtype=torch.int32, device=recv_x.device) - - for token_idx in range(token_num): - for topk_idx in range(topk_num): - expert_id = recv_topk[token_idx, topk_idx].item() - if 0 <= expert_id < num_experts: - slot = expert_counts[expert_id].item() - output[expert_id, slot, :] = recv_x[token_idx, :] - output_index[token_idx, topk_idx] = expert_id * alignment + slot - expert_counts[expert_id] += 1 - - return output, output_index + """Order-independent check for the 3D masked layout via output_index. + + Each valid (token, topk) must map to a slot inside its expert's + [e * alignment, (e + 1) * alignment) block, carry the right row content, and + no slot may be reused. + """ + seen: set[int] = set() + topk = recv_topk.shape[1] + for t in range(recv_topk.shape[0]): + for k in range(topk): + expert = int(recv_topk[t, k].item()) + if expert < 0 or expert >= num_experts: + continue + slot = int(output_index[t, k].item()) + lo, hi = expert * alignment, (expert + 1) * alignment + self.assertTrue( + lo <= slot < hi, + f"token {t} topk {k} -> expert {expert}: slot {slot} outside [{lo}, {hi})", + ) + self.assertNotIn( + slot, seen, f"slot {slot} assigned to more than one token" + ) + seen.add(slot) + self.assertTrue( + torch.equal(output_flat[slot], recv_x[t]), + f"row content mismatch at slot {slot} (token {t})", + ) + + def _run_scatter_v2(self, recv_x, recv_topk, num_experts, alignment): + hidden_size = recv_x.shape[1] + output_tensor = torch.zeros( + num_experts * alignment, + hidden_size, + dtype=torch.bfloat16, + device=self.device, + ) + output_index = torch.empty_like(recv_topk) + expert_start_loc = torch.empty( + num_experts, dtype=torch.int32, device=self.device + ) + ep_scatter_v2_bf16( + recv_x, recv_topk, alignment, expert_start_loc, output_tensor, output_index + ) + return output_tensor, output_index def test_basic_scatter(self) -> None: """Basic scatter: each token assigned to one expert.""" num_experts = 4 hidden_size = 256 token_num = 8 - topk = 1 alignment = align(token_num, 128) - recv_x = torch.randn(token_num, hidden_size, dtype=torch.bfloat16, device=self.device) + recv_x = torch.randn( + token_num, hidden_size, dtype=torch.bfloat16, device=self.device + ) recv_topk = torch.tensor( [[0], [1], [2], [3], [0], [1], [2], [3]], dtype=torch.int32, device=self.device, ) - output_tensor = torch.zeros( - num_experts * alignment, hidden_size, dtype=torch.bfloat16, device=self.device - ) - output_index = torch.empty_like(recv_topk) - expert_start_loc = torch.empty(num_experts, dtype=torch.int32, device=self.device) - - ep_scatter_v2_bf16(recv_x, recv_topk, alignment, expert_start_loc, output_tensor, output_index) - - ref_output, ref_index = self._reference_scatter(recv_x, recv_topk, num_experts, alignment) - ref_output_flat = ref_output.view(num_experts * alignment, hidden_size) - - self.assertTrue( - torch.equal(output_tensor, ref_output_flat), - f"scatter output mismatch", + output_tensor, output_index = self._run_scatter_v2( + recv_x, recv_topk, num_experts, alignment ) - self.assertTrue( - torch.equal(output_index, ref_index), - f"index mismatch:\n got={output_index}\n want={ref_index}", + self._verify_v2( + recv_x, recv_topk, num_experts, alignment, output_tensor, output_index ) def test_topk_scatter(self) -> None: @@ -96,34 +118,22 @@ def test_topk_scatter(self) -> None: num_experts = 4 hidden_size = 128 token_num = 4 - topk = 2 alignment = align(token_num, 128) - recv_x = torch.randn(token_num, hidden_size, dtype=torch.bfloat16, device=self.device) + recv_x = torch.randn( + token_num, hidden_size, dtype=torch.bfloat16, device=self.device + ) recv_topk = torch.tensor( [[0, 1], [2, 3], [0, 2], [1, 3]], dtype=torch.int32, device=self.device, ) - output_tensor = torch.zeros( - num_experts * alignment, hidden_size, dtype=torch.bfloat16, device=self.device - ) - output_index = torch.empty_like(recv_topk) - expert_start_loc = torch.empty(num_experts, dtype=torch.int32, device=self.device) - - ep_scatter_v2_bf16(recv_x, recv_topk, alignment, expert_start_loc, output_tensor, output_index) - - ref_output, ref_index = self._reference_scatter(recv_x, recv_topk, num_experts, alignment) - ref_output_flat = ref_output.view(num_experts * alignment, hidden_size) - - self.assertTrue( - torch.equal(output_tensor, ref_output_flat), - "scatter output mismatch for topk=2", + output_tensor, output_index = self._run_scatter_v2( + recv_x, recv_topk, num_experts, alignment ) - self.assertTrue( - torch.equal(output_index, ref_index), - f"index mismatch for topk=2", + self._verify_v2( + recv_x, recv_topk, num_experts, alignment, output_tensor, output_index ) def test_negative_expert_ids_skipped(self) -> None: @@ -131,54 +141,51 @@ def test_negative_expert_ids_skipped(self) -> None: num_experts = 4 hidden_size = 128 token_num = 4 - topk = 2 alignment = align(token_num, 128) - recv_x = torch.randn(token_num, hidden_size, dtype=torch.bfloat16, device=self.device) + recv_x = torch.randn( + token_num, hidden_size, dtype=torch.bfloat16, device=self.device + ) recv_topk = torch.tensor( [[0, -1], [-1, 2], [1, 3], [-1, -1]], dtype=torch.int32, device=self.device, ) - output_tensor = torch.zeros( - num_experts * alignment, hidden_size, dtype=torch.bfloat16, device=self.device + output_tensor, output_index = self._run_scatter_v2( + recv_x, recv_topk, num_experts, alignment ) - output_index = torch.empty_like(recv_topk) - expert_start_loc = torch.empty(num_experts, dtype=torch.int32, device=self.device) - - ep_scatter_v2_bf16(recv_x, recv_topk, alignment, expert_start_loc, output_tensor, output_index) - - # Expert 0 should have token 0, expert 1 should have token 2, - # expert 2 should have token 1, expert 3 should have token 2 + # Each expert here receives exactly one valid token, so slot 0 of each block is + # deterministic regardless of atomic ordering. output_3d = output_tensor.view(num_experts, alignment, hidden_size) self.assertTrue(torch.equal(output_3d[0, 0], recv_x[0])) self.assertTrue(torch.equal(output_3d[2, 0], recv_x[1])) self.assertTrue(torch.equal(output_3d[1, 0], recv_x[2])) self.assertTrue(torch.equal(output_3d[3, 0], recv_x[2])) + self._verify_v2( + recv_x, recv_topk, num_experts, alignment, output_tensor, output_index + ) def test_scatter_gather_roundtrip(self) -> None: """Scatter then gather should recover original (weighted) values.""" num_experts = 4 - hidden_size = 256 + hidden_size = 512 # ep_gather BLOCK_D = 512 token_num = 16 topk = 2 alignment = align(token_num, 128) torch.manual_seed(42) - recv_x = torch.randn(token_num, hidden_size, dtype=torch.bfloat16, device=self.device) - - # Random expert assignments - recv_topk = torch.randint(0, num_experts, (token_num, topk), dtype=torch.int32, device=self.device) - topk_weights = torch.ones(token_num, topk, dtype=torch.float32, device=self.device) / topk - - output_tensor = torch.zeros( - num_experts * alignment, hidden_size, dtype=torch.bfloat16, device=self.device + recv_x = torch.randn( + token_num, hidden_size, dtype=torch.bfloat16, device=self.device + ) + recv_topk = self._distinct_topk(token_num, topk, num_experts, self.device) + topk_weights = ( + torch.ones(token_num, topk, dtype=torch.float32, device=self.device) / topk ) - output_index = torch.empty_like(recv_topk) - expert_start_loc = torch.empty(num_experts, dtype=torch.int32, device=self.device) - ep_scatter_v2_bf16(recv_x, recv_topk, alignment, expert_start_loc, output_tensor, output_index) + output_tensor, output_index = self._run_scatter_v2( + recv_x, recv_topk, num_experts, alignment + ) # Gather back (identity transform: no GEMM, just scatter → gather) gather_out = torch.zeros_like(recv_x) @@ -196,26 +203,20 @@ def test_large_random(self) -> None: hidden_size = 2048 token_num = 256 topk = 4 - alignment = align(token_num, 128) + alignment = align(token_num, 128) # 256 >= max per-expert count torch.manual_seed(123) - recv_x = torch.randn(token_num, hidden_size, dtype=torch.bfloat16, device=self.device) - recv_topk = torch.randint(0, num_experts, (token_num, topk), dtype=torch.int32, device=self.device) - - output_tensor = torch.zeros( - num_experts * alignment, hidden_size, dtype=torch.bfloat16, device=self.device + recv_x = torch.randn( + token_num, hidden_size, dtype=torch.bfloat16, device=self.device ) - output_index = torch.empty_like(recv_topk) - expert_start_loc = torch.empty(num_experts, dtype=torch.int32, device=self.device) - - ep_scatter_v2_bf16(recv_x, recv_topk, alignment, expert_start_loc, output_tensor, output_index) + # distinct experts per token => per-expert count <= token_num <= alignment + recv_topk = self._distinct_topk(token_num, topk, num_experts, self.device) - ref_output, ref_index = self._reference_scatter(recv_x, recv_topk, num_experts, alignment) - ref_output_flat = ref_output.view(num_experts * alignment, hidden_size) - - self.assertTrue( - torch.equal(output_tensor, ref_output_flat), - f"large random scatter mismatch", + output_tensor, output_index = self._run_scatter_v2( + recv_x, recv_topk, num_experts, alignment + ) + self._verify_v2( + recv_x, recv_topk, num_experts, alignment, output_tensor, output_index ) @@ -227,6 +228,70 @@ def setUp(self) -> None: self.skipTest("CUDA required") self.device = torch.device("cuda") + @staticmethod + def _aligned_counts_from_topk(recv_topk, num_experts, alignment): + """Per-expert capacity from the actual routing histogram, aligned up. + + Mirrors the production allocation in DeepGemmBf16HybridExecutor.execute_contiguous: + the scatter kernel assigns slots with an unchecked atomic_add, so the caller MUST + size output_tensor to cover every expert's real (aligned) token count. Hardcoding a + fixed per-expert capacity smaller than the real load causes out-of-bounds writes. + """ + valid = recv_topk[recv_topk >= 0].to(torch.int64).flatten() + counts = torch.bincount(valid, minlength=num_experts) + return [align(int(c), alignment) for c in counts.tolist()] + + def _verify_contiguous( + self, + recv_x, + recv_topk, + num_experts, + aligned_counts, + output, + m_indices, + output_index, + ): + """Order-independent correctness check via output_index. + + The atomic_add slot assignment order is non-deterministic, but every valid + (token, topk) assignment must land in its expert's [start, start + aligned_count) + region, with the right row content and m_indices label, and no slot may be reused. + m_indices must be correct on OCCUPIED rows because the downstream grouped GEMM uses + it to pick each row's expert. + """ + start_loc = [] + acc = 0 + for count in aligned_counts: + start_loc.append(acc) + acc += count + + seen: set[int] = set() + topk = recv_topk.shape[1] + for t in range(recv_topk.shape[0]): + for k in range(topk): + expert = int(recv_topk[t, k].item()) + if expert < 0: + continue + slot = int(output_index[t, k].item()) + lo, hi = start_loc[expert], start_loc[expert] + aligned_counts[expert] + self.assertTrue( + lo <= slot < hi, + f"token {t} topk {k} -> expert {expert}: slot {slot} outside [{lo}, {hi})", + ) + self.assertNotIn( + slot, seen, f"slot {slot} assigned to more than one token" + ) + seen.add(slot) + self.assertTrue( + torch.equal(output[slot], recv_x[t]), + f"row content mismatch at slot {slot} (token {t})", + ) + self.assertEqual( + int(m_indices[slot].item()), + expert, + f"m_indices[{slot}] != expert {expert}", + ) + def _run_scatter(self, recv_x, recv_topk, num_experts, aligned_tokens_per_expert): """Run ep_scatter_bf16 and return (output_tensor, m_indices, output_index).""" all_tokens = sum(aligned_tokens_per_expert) @@ -234,14 +299,22 @@ def _run_scatter(self, recv_x, recv_topk, num_experts, aligned_tokens_per_expert num_recv_tokens_per_expert_gpu = torch.tensor( aligned_tokens_per_expert, dtype=torch.int32, device=self.device ) - expert_start_loc = torch.empty(num_experts, dtype=torch.int32, device=self.device) - output_tensor = torch.zeros(all_tokens, hidden_size, dtype=torch.bfloat16, device=self.device) + expert_start_loc = torch.empty( + num_experts, dtype=torch.int32, device=self.device + ) + output_tensor = torch.zeros( + all_tokens, hidden_size, dtype=torch.bfloat16, device=self.device + ) m_indices = torch.zeros(all_tokens, dtype=torch.int32, device=self.device) output_index = torch.empty_like(recv_topk) ep_scatter_bf16( - recv_x, recv_topk, - num_recv_tokens_per_expert_gpu, expert_start_loc, - output_tensor, m_indices, output_index, + recv_x, + recv_topk, + num_recv_tokens_per_expert_gpu, + expert_start_loc, + output_tensor, + m_indices, + output_index, ) return output_tensor, m_indices, output_index @@ -251,7 +324,9 @@ def test_basic_contiguous_scatter(self) -> None: hidden_size = 128 alignment = 128 recv_x = torch.randn(4, hidden_size, dtype=torch.bfloat16, device=self.device) - recv_topk = torch.tensor([[0], [1], [2], [3]], dtype=torch.int32, device=self.device) + recv_topk = torch.tensor( + [[0], [1], [2], [3]], dtype=torch.int32, device=self.device + ) aligned_counts = [alignment] * num_experts # 1 token padded to 128 output, m_indices, output_index = self._run_scatter( @@ -262,6 +337,15 @@ def test_basic_contiguous_scatter(self) -> None: # m_indices for slot 0 should be expert 0 self.assertEqual(m_indices[0].item(), 0) self.assertEqual(m_indices[alignment].item(), 1) + self._verify_contiguous( + recv_x, + recv_topk, + num_experts, + aligned_counts, + output, + m_indices, + output_index, + ) def test_negative_expert_id_skipped(self) -> None: """expert_id = -1 slots should not be scattered.""" @@ -269,7 +353,9 @@ def test_negative_expert_id_skipped(self) -> None: hidden_size = 64 alignment = 128 recv_x = torch.randn(4, hidden_size, dtype=torch.bfloat16, device=self.device) - recv_topk = torch.tensor([[0], [-1], [1], [-1]], dtype=torch.int32, device=self.device) + recv_topk = torch.tensor( + [[0], [-1], [1], [-1]], dtype=torch.int32, device=self.device + ) aligned_counts = [alignment, alignment] output, m_indices, output_index = self._run_scatter( @@ -278,19 +364,36 @@ def test_negative_expert_id_skipped(self) -> None: # Expert 0 gets token 0, expert 1 gets token 2 self.assertTrue(torch.equal(output[0], recv_x[0])) self.assertTrue(torch.equal(output[alignment], recv_x[2])) + self._verify_contiguous( + recv_x, + recv_topk, + num_experts, + aligned_counts, + output, + m_indices, + output_index, + ) def test_scatter_gather_roundtrip(self) -> None: """Scatter then gather with uniform weights recovers original.""" num_experts = 4 - hidden_size = 256 + hidden_size = 512 # ep_gather BLOCK_D = 512 token_num = 16 topk = 2 alignment = 128 torch.manual_seed(7) - recv_x = torch.randn(token_num, hidden_size, dtype=torch.bfloat16, device=self.device) - recv_topk = torch.randint(0, num_experts, (token_num, topk), dtype=torch.int32, device=self.device) - topk_weights = torch.ones(token_num, topk, dtype=torch.float32, device=self.device) / topk - aligned_counts = [alignment] * num_experts + recv_x = torch.randn( + token_num, hidden_size, dtype=torch.bfloat16, device=self.device + ) + recv_topk = torch.randint( + 0, num_experts, (token_num, topk), dtype=torch.int32, device=self.device + ) + topk_weights = ( + torch.ones(token_num, topk, dtype=torch.float32, device=self.device) / topk + ) + aligned_counts = self._aligned_counts_from_topk( + recv_topk, num_experts, alignment + ) output, _, output_index = self._run_scatter( recv_x, recv_topk, num_experts, aligned_counts @@ -310,16 +413,35 @@ def test_large_random(self) -> None: topk = 4 alignment = 128 torch.manual_seed(99) - recv_x = torch.randn(token_num, hidden_size, dtype=torch.bfloat16, device=self.device) - recv_topk = torch.randint(0, num_experts, (token_num, topk), dtype=torch.int32, device=self.device) - aligned_counts = [alignment] * num_experts + recv_x = torch.randn( + token_num, hidden_size, dtype=torch.bfloat16, device=self.device + ) + recv_topk = torch.randint( + 0, num_experts, (token_num, topk), dtype=torch.int32, device=self.device + ) + # token_num * topk = 2048 assignments over 8 experts (~256 each); a fixed + # [alignment] * num_experts = 1024-slot buffer would be ~2x too small and the + # unchecked atomic_add in the kernel would write out of bounds. Size capacity + # from the real per-expert histogram, exactly as the executor does. + aligned_counts = self._aligned_counts_from_topk( + recv_topk, num_experts, alignment + ) - output, m_indices, _ = self._run_scatter( + output, m_indices, output_index = self._run_scatter( recv_x, recv_topk, num_experts, aligned_counts ) - self.assertEqual(output.shape, (alignment * num_experts, hidden_size)) + self.assertEqual(output.shape, (sum(aligned_counts), hidden_size)) self.assertTrue(m_indices.min().item() >= 0) self.assertTrue(m_indices.max().item() < num_experts) + self._verify_contiguous( + recv_x, + recv_topk, + num_experts, + aligned_counts, + output, + m_indices, + output_index, + ) if __name__ == "__main__": From 2658cff917e3d2132e1b44b9eff15454bf007b7a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=98=99=E5=A2=A8?= Date: Wed, 17 Jun 2026 18:15:47 +0800 Subject: [PATCH 3/3] fix(moe): bf16 deepgemm executor device + empty-rank guard + decoupled bf16 init MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Review fixes on the bf16 DeepEP-Normal deepgemm MoE path. All changes are confined to the opt-in no_quant_dp_normal_deepgemm path, the bf16 deep_gemm wrappers, and tests — the fp8 path and other default/cross-arch paths are unaffected. - DeepGemmBf16HybridExecutor.execute: handle empty rank before dispatch. DeepEP small-batch / skewed routing can leave a rank with token_num == 0; that would otherwise enter the masked path with alignment == 0 and launch 0-grid Triton scatter / 0-size DeepGEMM. Return an empty same-shape [0, K] bf16 output. - DeepGemmBf16HybridExecutor (contiguous path): build the per-expert token-count tensor with .to(device=hidden_states.device, non_blocking=True) instead of .cuda() so it honors the hidden-states device invariant. - deepgemm_wrapper: decouple bf16 symbol resolution from the fp8 path. Previously _ensure_initialized() resolved fp8 AND bf16 symbols together, so an older deep_gemm build missing the bf16 symbols would raise from _ensure_initialized() and break the fp8 wrappers. Now _ensure_initialized() resolves only _FP8_SYMBOLS (raises if missing — fp8 is core), while _ensure_bf16_initialized() resolves _BF16_SYMBOLS independently and tolerantly (missing -> impls stay None, never propagate). bf16 wrappers call _ensure_bf16_initialized(); fp8 wrappers keep _ensure_initialized(). has_deep_gemm_bf16_grouped() reports False (never raises) when the bf16 symbols are unavailable. - deepgemm_wrapper bf16 grouped wrappers: reject a non-default compiled_dims explicitly (NotImplementedError) instead of silently ignoring it; the wrapper does not forward compiled_dims (forwarding perturbs bf16 numerics on this shared path). No current caller passes a non-"nk" value. - CudaNoQuantDpNormalDeepGemmStrategy: fail fast at selection via has_deep_gemm_bf16_grouped(), and gate on the explicit opt-in moe_strategy FIRST with a short-circuit return (ConditionChecker does not stop at the first failed check, so this keeps the probe from running for non-opt-in / "auto" configs). - Tests: empty-rank (token_num==0) executor cases (ep1 + ep2); ep_size>1 executor coverage (rank 0/1, _to_local_expert_ids mapping + masking) vs a per-rank torch reference; strategy selection pos/neg; has_deep_gemm_bf16_grouped no-raise and _ensure_bf16_initialized tolerance; executor test skip uses has_deep_gemm_bf16_grouped() to match the gating. The ep_kernels contiguous padding-row m_indices contract is left to the feature kernel owner (padding output is discarded by the gather; a real fix needs a kernel signature change + a deep_gemm -1 skip contract). Co-Authored-By: Claude Opus 4.8 (1M context) --- .../kernels/cuda/deepgemm_wrapper.py | 101 ++++++++++-- .../deepgemm_bf16_hybrid_executor.py | 10 +- .../deepgemm_bf16_hybrid_executor_test.py | 146 ++++++++++++++---- .../fused_moe/impl/cuda/strategy/no_quant.py | 20 ++- .../fused_moe/tests/test_cuda_strategies.py | 142 +++++++++++++++++ 5 files changed, 373 insertions(+), 46 deletions(-) diff --git a/rtp_llm/models_py/kernels/cuda/deepgemm_wrapper.py b/rtp_llm/models_py/kernels/cuda/deepgemm_wrapper.py index 1f7b859275..fccbea024e 100644 --- a/rtp_llm/models_py/kernels/cuda/deepgemm_wrapper.py +++ b/rtp_llm/models_py/kernels/cuda/deepgemm_wrapper.py @@ -17,6 +17,7 @@ "m_grouped_bf16_gemm_nt_contiguous", "m_grouped_bf16_gemm_nt_masked", "has_deep_gemm", + "has_deep_gemm_bf16_grouped", "is_deep_gemm_e8m0_used", "configure_deep_gemm_num_sms", "maybe_pack_ue8m0_scale", @@ -73,6 +74,31 @@ def has_deep_gemm() -> bool: return available +def has_deep_gemm_bf16_grouped() -> bool: + """Whether the bf16 grouped GEMM kernels are actually resolvable. + + has_deep_gemm() only confirms the package is importable. This additionally + resolves and checks the specific bf16 grouped symbols the bf16 DeepGEMM MoE + path needs (contiguous + masked), so a strategy can fail fast at selection + time rather than deferring to the first execute() call when an older + deep_gemm build lacks them. + + Never raises: symbol resolution failures (old deep_gemm build missing the + bf16 grouped symbols) are reported as "unavailable" (False), so calling this + during strategy enumeration cannot break selection for unrelated configs. + """ + if not has_deep_gemm(): + return False + try: + _ensure_bf16_initialized() + except Exception: + return False + return ( + _m_grouped_bf16_gemm_nt_contiguous_impl is not None + and _m_grouped_bf16_gemm_nt_masked_impl is not None + ) + + @functools.cache def is_deep_gemm_e8m0_used() -> bool: return torch.cuda.get_device_capability()[0] in [10, 12] @@ -142,17 +168,27 @@ def _lazy_init_deep_gemm(symbols: List[str]) -> None: ) +# Core symbols required by the fp8 path. Resolved by _ensure_initialized(); a +# build missing these is broken for fp8 and raising is appropriate. +_FP8_SYMBOLS = [ + "fp8_gemm_nt", + "m_grouped_fp8_gemm_nt_contiguous", + "m_grouped_fp8_gemm_nt_masked", +] + +# Optional bf16 symbols, resolved separately and tolerantly (see +# _ensure_bf16_initialized) so an older deep_gemm build lacking them does NOT +# break the fp8 path's _ensure_initialized() — it only makes the bf16 deepgemm +# MoE strategy unselectable / its wrappers raise _missing_deep_gemm() at use. +_BF16_SYMBOLS = [ + "bf16_gemm_nt", + "m_grouped_bf16_gemm_nt_contiguous", + "m_grouped_bf16_gemm_nt_masked", +] + + def _lazy_init_deep_gemm_once(): - _lazy_init_deep_gemm( - [ - "fp8_gemm_nt", - "m_grouped_fp8_gemm_nt_contiguous", - "m_grouped_fp8_gemm_nt_masked", - "bf16_gemm_nt", - "m_grouped_bf16_gemm_nt_contiguous", - "m_grouped_bf16_gemm_nt_masked", - ] - ) + _lazy_init_deep_gemm(_FP8_SYMBOLS) _symbols_initialized = False @@ -173,6 +209,30 @@ def _ensure_initialized(): _symbols_initialized = True +_bf16_symbols_initialized = False + + +def _ensure_bf16_initialized() -> None: + """Resolve the optional bf16 deep_gemm symbols, independently of the fp8 path. + + Tolerant on purpose: if the deep_gemm build lacks the bf16 symbols, the impls + stay None and we still mark this attempted, so: + - the fp8 path (_ensure_initialized) is never affected; + - bf16 wrappers hit their `is None -> _missing_deep_gemm()` guard at use; + - has_deep_gemm_bf16_grouped() reports False. + """ + global _bf16_symbols_initialized + if _bf16_symbols_initialized: + return + if not has_deep_gemm(): + return # package not present yet; retry on a later call + try: + _lazy_init_deep_gemm(_BF16_SYMBOLS) + except Exception: + pass # missing bf16 symbols -> leave impls None, never propagate + _bf16_symbols_initialized = True + + @triton.jit def pack_ue8m0_kernel_vectorized( scale_ptr, @@ -565,7 +625,7 @@ def bf16_gemm_nt( compiled_dims (str, optional): Compiled dimensions. Defaults to "nk". """ global _bf16_gemm_nt_impl - _ensure_initialized() + _ensure_bf16_initialized() if _bf16_gemm_nt_impl is None: return _missing_deep_gemm() _bf16_gemm_nt_impl(a, b, output, c, compiled_dims) @@ -589,7 +649,16 @@ def m_grouped_bf16_gemm_nt_contiguous( compiled_dims (str, optional): Compiled dimensions. Defaults to "nk". """ global _m_grouped_bf16_gemm_nt_contiguous_impl - _ensure_initialized() + # Only the native "nk" layout is supported. The wrapper does not forward + # compiled_dims to the kernel (forwarding it perturbs bf16 numerics on this + # shared path); reject any non-default value explicitly instead of silently + # ignoring it. + if compiled_dims != "nk": + raise NotImplementedError( + "m_grouped_bf16_gemm_nt_contiguous only supports compiled_dims='nk', " + f"got {compiled_dims!r}" + ) + _ensure_bf16_initialized() if _m_grouped_bf16_gemm_nt_contiguous_impl is None: return _missing_deep_gemm() _m_grouped_bf16_gemm_nt_contiguous_impl( @@ -619,7 +688,13 @@ def m_grouped_bf16_gemm_nt_masked( compiled_dims (str, optional): Compiled dimensions. Defaults to "nk". """ global _m_grouped_bf16_gemm_nt_masked_impl - _ensure_initialized() + # Only the native "nk" layout is supported (see m_grouped_bf16_gemm_nt_contiguous). + if compiled_dims != "nk": + raise NotImplementedError( + "m_grouped_bf16_gemm_nt_masked only supports compiled_dims='nk', " + f"got {compiled_dims!r}" + ) + _ensure_bf16_initialized() if _m_grouped_bf16_gemm_nt_masked_impl is None: return _missing_deep_gemm() _m_grouped_bf16_gemm_nt_masked_impl( diff --git a/rtp_llm/models_py/modules/factory/fused_moe/impl/cuda/executors/deepgemm_bf16_hybrid_executor.py b/rtp_llm/models_py/modules/factory/fused_moe/impl/cuda/executors/deepgemm_bf16_hybrid_executor.py index eb602e5890..cf910ecdf4 100644 --- a/rtp_llm/models_py/modules/factory/fused_moe/impl/cuda/executors/deepgemm_bf16_hybrid_executor.py +++ b/rtp_llm/models_py/modules/factory/fused_moe/impl/cuda/executors/deepgemm_bf16_hybrid_executor.py @@ -156,6 +156,14 @@ def execute( ) token_num = payload.expert_x.shape[0] + # Empty rank: DeepEP small-batch / skewed routing can leave this rank with + # zero tokens. Return an empty same-shape output before dispatch — otherwise + # the masked path (token_num <= masked_max_token_num) would run with + # alignment == 0 and launch 0-grid Triton scatter / 0-size DeepGEMM. + if token_num == 0: + return CombineForwardPayload( + fused_expert_output=torch.empty_like(payload.expert_x) + ) if token_num <= self.masked_max_token_num: return self.execute_masked( payload, activation, expert_map, a2_scale, apply_router_weight_on_input, extra_expert_args @@ -307,7 +315,7 @@ def execute_contiguous( num_recv_tokens_per_expert_gpu = torch.tensor( aligned_tokens, dtype=torch.int32, pin_memory=True, device="cpu" - ).cuda(non_blocking=True) + ).to(device=device, non_blocking=True) expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu) m_indices = torch.empty(all_tokens, device=device, dtype=torch.int32) output_index = torch.empty_like(topk_idx) diff --git a/rtp_llm/models_py/modules/factory/fused_moe/impl/cuda/executors/test/deepgemm_bf16_hybrid_executor_test.py b/rtp_llm/models_py/modules/factory/fused_moe/impl/cuda/executors/test/deepgemm_bf16_hybrid_executor_test.py index dac42542a8..27387a596d 100644 --- a/rtp_llm/models_py/modules/factory/fused_moe/impl/cuda/executors/test/deepgemm_bf16_hybrid_executor_test.py +++ b/rtp_llm/models_py/modules/factory/fused_moe/impl/cuda/executors/test/deepgemm_bf16_hybrid_executor_test.py @@ -7,11 +7,23 @@ - masked (token_num <= masked_max_token_num): 3D [E, alignment, K] layout - contiguous (token_num > masked_max_token_num): flat [Σ align(ei, 128), K] layout +and for BOTH single-rank (ep_size=1) and multi-rank Expert Parallelism (ep_size>1). + Unlike DeepGemmMaskedExecutor (which consumes a pre-grouped 3D payload), the hybrid executor consumes a FLAT [M, K] DeepEP-normal payload and does the scatter itself, so this test builds a flat payload and a flat reference rather than reusing fused_moe_executor_test_util. +EP coverage (ep_size>1) without a real multi-process DeepEP buffer: a true pass through +DeepepNormalRouter.prepare/finalize requires NVSHMEM all-to-all across processes and is +not single-process testable. Instead we simulate the post-dispatch state — the payload's +expert_topk_ids carry GLOBAL expert ids and only this rank's per-local-expert token +counts / weights are provided — so the executor's _to_local_expert_ids mapping (global -> +local id + masking experts outside [start_expert_id, end_expert_id]) is exercised, and +the reference computes ONLY this rank's in-partition contribution (the per-rank partial +that DeepEP combine would later sum across ranks). This is the same simulate-EP pattern +used by deepep_normal_executor_test.py. + Requires the optional `deep_gemm` package; skips otherwise. Tagged H20 so internal CI routes it to an SM>=9 worker where deep_gemm is available; open_skip keeps it out of the open-source public CI lane. @@ -23,7 +35,9 @@ import torch from rtp_llm.config.model_config import ModelConfig -from rtp_llm.models_py.kernels.cuda.deepgemm_wrapper import has_deep_gemm +from rtp_llm.models_py.kernels.cuda.deepgemm_wrapper import ( + has_deep_gemm_bf16_grouped, +) from rtp_llm.models_py.modules.factory.fused_moe.defs.config_adapter import ( MoEConfigAdapter, ) @@ -56,12 +70,12 @@ def N(self) -> int: def setUp(self): if not torch.cuda.is_available(): self.skipTest("CUDA required") - if not has_deep_gemm(): - self.skipTest("deep_gemm package required") + if not has_deep_gemm_bf16_grouped(): + self.skipTest("deep_gemm bf16 grouped GEMM kernels required") torch.manual_seed(42) torch.cuda.manual_seed(42) - def _generate_config(self) -> MoEConfigAdapter: + def _generate_config(self, ep_size: int = 1, ep_rank: int = 0) -> MoEConfigAdapter: model_config = ModelConfig() model_config.attn_config.head_num = 2 model_config.attn_config.size_per_head = 128 @@ -74,15 +88,15 @@ def _generate_config(self) -> MoEConfigAdapter: model_config.moe_k = self.TOP_K parallelism_config = ParallelismConfig() - parallelism_config.world_size = 1 + parallelism_config.world_size = ep_size parallelism_config.dp_size = 1 parallelism_config.tp_size = 1 - parallelism_config.ep_size = 1 + parallelism_config.ep_size = ep_size parallelism_config.dp_rank = 0 parallelism_config.tp_rank = 0 - parallelism_config.ep_rank = 0 - parallelism_config.world_rank = 0 - parallelism_config.local_world_size = 1 + parallelism_config.ep_rank = ep_rank + parallelism_config.world_rank = ep_rank + parallelism_config.local_world_size = ep_size moe_config = MoeConfig() moe_config.masked_max_token_num = self.MASKED_MAX_TOKEN_NUM @@ -93,19 +107,22 @@ def _generate_config(self) -> MoEConfigAdapter: ) def _make_payload_and_weights( - self, token_num: int + self, token_num: int, ep_size: int = 1, ep_rank: int = 0 ) -> Tuple[ExpertForwardPayload, Dict[str, torch.Tensor], torch.Tensor]: device = "cuda" K = self.HIDDEN_SIZE N = self.N num_experts = self.NUM_EXPERTS + num_local_experts = num_experts // ep_size + start_expert_id = ep_rank * num_local_experts expert_x = ( torch.rand((token_num, K), device=device, dtype=torch.float32) * 0.1 - 0.05 ).to(torch.bfloat16) - # Each token routes to TOP_K *distinct* experts so per-expert load stays <= token_num - # (keeps the masked path's per-expert count within alignment = align(token_num, 128)). + # Each token routes to TOP_K *distinct* GLOBAL experts. With ep_size>1 only the + # subset falling in [start_expert_id, start_expert_id + num_local_experts) is owned + # by this rank; the executor masks out the rest (mirrors the real dispatched state). topk_ids = torch.empty( (token_num, self.TOP_K), device=device, dtype=torch.int32 ) @@ -117,9 +134,13 @@ def _make_payload_and_weights( (token_num, self.TOP_K), device=device, dtype=torch.float32 ) - counts = torch.bincount( - topk_ids.flatten().to(torch.int64), minlength=num_experts - ).to(torch.int32) + # Per-LOCAL-expert token counts (length num_local_experts): how many tokens route + # to each expert this rank owns. Distinct top-k => each token contributes <= 1 per + # expert, so .any(dim=1) matches what ep_scatter places. + counts = torch.zeros(num_local_experts, device=device, dtype=torch.int32) + for local_e in range(num_local_experts): + g = start_expert_id + local_e + counts[local_e] = (topk_ids == g).any(dim=1).sum().to(torch.int32) payload = ExpertForwardPayload( expert_x=expert_x, @@ -133,13 +154,20 @@ def _make_payload_and_weights( ), ) + # Weights are per-LOCAL-expert ([num_local_experts, ...]); the executor indexes them + # by local id, and the reference uses weights[local] for global expert start+local. weights = { W.moe_w1: ( - torch.rand((num_experts, N, K), device=device, dtype=torch.float32) * 2 + torch.rand( + (num_local_experts, N, K), device=device, dtype=torch.float32 + ) + * 2 - 1 ).to(torch.bfloat16), W.moe_w2: ( - torch.rand((num_experts, K, N // 2), device=device, dtype=torch.float32) + torch.rand( + (num_local_experts, K, N // 2), device=device, dtype=torch.float32 + ) * 2 - 1 ).to(torch.bfloat16), @@ -156,40 +184,63 @@ def _reference( topk_ids: torch.Tensor, topk_weights: torch.Tensor, weights: Dict[str, torch.Tensor], + ep_size: int = 1, + ep_rank: int = 0, ) -> torch.Tensor: - """Plain-torch flat MoE reference, gate/value split matching the deepgemm path.""" + """Plain-torch flat MoE reference, gate/value split matching the deepgemm path. + + Computes only the experts THIS rank owns (global id in + [start_expert_id, start_expert_id + num_local_experts)); other experts contribute + nothing here (DeepEP combine sums the per-rank partials across ranks). + """ M, K = expert_x.shape N = self.N + num_local_experts = self.NUM_EXPERTS // ep_size + start_expert_id = ep_rank * num_local_experts w1 = weights[W.moe_w1] w2 = weights[W.moe_w2] out = torch.zeros((M, K), device=expert_x.device, dtype=torch.float32) - for expert in range(self.NUM_EXPERTS): - sel = topk_ids == expert # [M, TOP_K] + for local_e in range(num_local_experts): + g = start_expert_id + local_e + sel = topk_ids == g # [M, TOP_K] if not bool(sel.any()): continue tok_idx, k_idx = sel.nonzero(as_tuple=True) x = expert_x[tok_idx] # [n, K] bf16 - ws1 = x @ w1[expert].transpose(0, 1) # [n, N] + ws1 = x @ w1[local_e].transpose(0, 1) # [n, N] gate = ws1[..., N // 2 :].to(torch.float32) value = ws1[..., : N // 2].to(torch.float32) gate = gate * (1.0 / (1.0 + torch.exp(-gate))) ws2 = (gate * value).to(torch.bfloat16) # [n, N//2] - down = ws2 @ w2[expert].transpose(0, 1) # [n, K] + down = ws2 @ w2[local_e].transpose(0, 1) # [n, K] weight = topk_weights[tok_idx, k_idx].unsqueeze(1).to(torch.float32) out.index_add_(0, tok_idx, down.to(torch.float32) * weight) return out.to(torch.bfloat16) - def _run_path(self, token_num: int, expect_masked: bool) -> None: - config = self._generate_config() + def _run_path( + self, + token_num: int, + expect_masked: bool, + ep_size: int = 1, + ep_rank: int = 0, + ) -> None: + config = self._generate_config(ep_size=ep_size, ep_rank=ep_rank) self.assertEqual( token_num <= config.masked_max_token_num, expect_masked, "token_num does not select the intended path", ) - payload, weights, ref_input = self._make_payload_and_weights(token_num) + payload, weights, ref_input = self._make_payload_and_weights( + token_num, ep_size=ep_size, ep_rank=ep_rank + ) ref_output = self._reference( - ref_input, payload.expert_topk_ids, payload.expert_topk_weights, weights + ref_input, + payload.expert_topk_ids, + payload.expert_topk_weights, + weights, + ep_size=ep_size, + ep_rank=ep_rank, ) executor = DeepGemmBf16HybridExecutor( @@ -203,9 +254,13 @@ def _run_path(self, token_num: int, expect_masked: bool) -> None: # looser than the single-expert masked executor test (<0.003). diff = calc_diff(out, ref_output) self.assertLess( - diff, 0.01, f"output diff {diff} too large (token_num={token_num})" + diff, + 0.01, + f"output diff {diff} too large (token_num={token_num}, " + f"ep_size={ep_size}, ep_rank={ep_rank})", ) + # ---- single rank (ep_size=1) ---- def test_masked_path(self): # token_num <= masked_max_token_num -> 3D masked layout (decode) self._run_path(token_num=128, expect_masked=True) @@ -214,6 +269,43 @@ def test_contiguous_path(self): # token_num > masked_max_token_num -> flat contiguous layout (prefill) self._run_path(token_num=512, expect_masked=False) + # ---- multi-rank Expert Parallelism (ep_size=2) ---- + # Exercises _to_local_expert_ids: global->local id mapping + masking experts outside + # this rank's [start_expert_id, end_expert_id] partition. + def test_masked_path_ep2_rank0(self): + self._run_path(token_num=128, expect_masked=True, ep_size=2, ep_rank=0) + + def test_contiguous_path_ep2_rank0(self): + self._run_path(token_num=512, expect_masked=False, ep_size=2, ep_rank=0) + + def test_masked_path_ep2_rank1(self): + # ep_rank=1 -> non-zero start_expert_id (4); validates the local-id offset. + self._run_path(token_num=128, expect_masked=True, ep_size=2, ep_rank=1) + + def test_contiguous_path_ep2_rank1(self): + self._run_path(token_num=512, expect_masked=False, ep_size=2, ep_rank=1) + + # ---- empty rank (DeepEP small-batch / skewed routing -> 0 tokens) ---- + def _run_empty_rank(self, ep_size: int, ep_rank: int) -> None: + config = self._generate_config(ep_size=ep_size, ep_rank=ep_rank) + payload, weights, _ = self._make_payload_and_weights( + token_num=0, ep_size=ep_size, ep_rank=ep_rank + ) + executor = DeepGemmBf16HybridExecutor( + config, FusedMoEQuantConfig(quant_dtype=None), weights + ) + # Must not launch 0-grid scatter / 0-size DeepGEMM; returns an empty, + # same-shape [0, K] output. + out = executor.execute(payload, "silu", None, None, False, None) + self.assertEqual(tuple(out.fused_expert_output.shape), (0, self.HIDDEN_SIZE)) + self.assertEqual(out.fused_expert_output.dtype, torch.bfloat16) + + def test_empty_rank(self): + self._run_empty_rank(ep_size=1, ep_rank=0) + + def test_empty_rank_ep2(self): + self._run_empty_rank(ep_size=2, ep_rank=1) + class DeepGemmBf16HybridExecutorTest( DeepGemmBf16HybridExecutorTestBase, unittest.TestCase diff --git a/rtp_llm/models_py/modules/factory/fused_moe/impl/cuda/strategy/no_quant.py b/rtp_llm/models_py/modules/factory/fused_moe/impl/cuda/strategy/no_quant.py index 81965b7782..61202a178e 100644 --- a/rtp_llm/models_py/modules/factory/fused_moe/impl/cuda/strategy/no_quant.py +++ b/rtp_llm/models_py/modules/factory/fused_moe/impl/cuda/strategy/no_quant.py @@ -108,17 +108,27 @@ class CudaNoQuantDpNormalDeepGemmStrategy(MoeStrategy): @classmethod def check_conditions(cls, checker: Any, config: MoEConfigAdapter) -> None: - from rtp_llm.models_py.kernels.cuda.deepgemm_wrapper import has_deep_gemm + from rtp_llm.models_py.kernels.cuda.deepgemm_wrapper import ( + has_deep_gemm_bf16_grouped, + ) from rtp_llm.models_py.utils.arch import get_sm - resolver = MoeConfigResolver() - quant_method = resolver.get_quant_method(config) - checker.check(quant_method is None) # Opt-in only: must be explicitly requested via --moe_strategy. Not part of # "auto" selection because the bf16 deepgemm path is not yet benchmarked on # CUDA — keeping it out of auto avoids changing the default CUDA MoE path. + # Gate on this FIRST and short-circuit: ConditionChecker does not stop at the + # first failed check, so without this return the bf16-grouped symbol probe + # below would run for every config during strategy enumeration (incl. "auto"). checker.check(config.moe_strategy == "no_quant_dp_normal_deepgemm") - checker.check(has_deep_gemm()) + if config.moe_strategy != "no_quant_dp_normal_deepgemm": + return + + resolver = MoeConfigResolver() + checker.check(resolver.get_quant_method(config) is None) + # Probe the actual bf16 grouped GEMM symbols (not just the deep_gemm package), + # so a build missing them rejects this strategy here instead of failing later + # in the executor's first forward. has_deep_gemm_bf16_grouped() never raises. + checker.check(has_deep_gemm_bf16_grouped()) checker.check(get_sm()[0] >= 9) # executor dispatches masked/contiguous at runtime — incompatible with CUDA Graph replay checker.check(not config.enable_cuda_graph) diff --git a/rtp_llm/models_py/modules/factory/fused_moe/tests/test_cuda_strategies.py b/rtp_llm/models_py/modules/factory/fused_moe/tests/test_cuda_strategies.py index b1d3010fd0..e4a09df710 100644 --- a/rtp_llm/models_py/modules/factory/fused_moe/tests/test_cuda_strategies.py +++ b/rtp_llm/models_py/modules/factory/fused_moe/tests/test_cuda_strategies.py @@ -28,6 +28,7 @@ CudaFp8PerBlockPureCPStrategy, CudaFp8PerBlockPureDPStrategy, CudaFp8PerTensorNoDPStrategy, + CudaNoQuantDpNormalDeepGemmStrategy, CudaW4a8Int4PerChannelNoDPStrategy, ) from rtp_llm.ops import CPRotateMethod, MoeConfig, ParallelismConfig @@ -798,5 +799,146 @@ def test_priority(self) -> None: self.assertEqual(strategy.priority, expected_priority) +class TestCudaNoQuantDpNormalDeepGemmStrategy(unittest.TestCase): + """Test CUDA no-quant DeepEP-Normal bf16 DeepGEMM strategy (opt-in).""" + + def _make_config( + self, + *, + data_type: str = "bf16", + moe_strategy: str = "no_quant_dp_normal_deepgemm", + ep_size: int = 2, + tp_size: int = 1, + dp_size: int = 1, + enable_cuda_graph: bool = False, + ) -> MoEConfigAdapter: + model_config = create_model_config_without_quant() + model_config.data_type = data_type + return create_moe_config_adapter( + model_config=model_config, + parallelism_config=create_parallelism_config( + ep_size=ep_size, tp_size=tp_size, dp_size=dp_size + ), + moe_config=create_moe_config( + use_deepep_low_latency=False, moe_strategy=moe_strategy + ), + enable_cuda_graph=enable_cuda_graph, + ) + + @patch("rtp_llm.models_py.kernels.cuda.deepgemm_wrapper.has_deep_gemm") + @patch( + "rtp_llm.models_py.kernels.cuda.deepgemm_wrapper.has_deep_gemm_bf16_grouped" + ) + @patch("rtp_llm.models_py.utils.arch.get_sm") + @patch("rtp_llm.models_py.distributed.deepep_wrapper.DeepEPWrapper.supported") + def test_can_handle_and_attributes( + self, + mock_supported: Any, + mock_get_sm: Any, + mock_has_grouped: Any, + mock_has_deep_gemm: Any, + ) -> None: + """Positive: explicit opt-in, bf16, ep>1, sm9, no cuda graph -> selected, + and routes to DeepEP-Normal router + bf16 hybrid executor.""" + mock_supported.return_value = True + mock_get_sm.return_value = (9, 0) + mock_has_grouped.return_value = True + mock_has_deep_gemm.return_value = True + + strategy = CudaNoQuantDpNormalDeepGemmStrategy() + self.assertTrue(strategy.can_handle(self._make_config())) + + attrs = strategy.get_attributes() + self.assertEqual(attrs.router_class.__name__, "DeepepNormalRouterNoQuant") + self.assertEqual(attrs.executor_class.__name__, "DeepGemmBf16HybridExecutor") + + @patch("rtp_llm.models_py.kernels.cuda.deepgemm_wrapper.has_deep_gemm") + @patch( + "rtp_llm.models_py.kernels.cuda.deepgemm_wrapper.has_deep_gemm_bf16_grouped" + ) + @patch("rtp_llm.models_py.utils.arch.get_sm") + @patch("rtp_llm.models_py.distributed.deepep_wrapper.DeepEPWrapper.supported") + def test_negative_cases( + self, + mock_supported: Any, + mock_get_sm: Any, + mock_has_grouped: Any, + mock_has_deep_gemm: Any, + ) -> None: + """Negative: fp16 dtype, auto (not opt-in), cuda graph, and missing bf16 + grouped symbols each disqualify the strategy.""" + mock_supported.return_value = True + mock_get_sm.return_value = (9, 0) + mock_has_grouped.return_value = True + mock_has_deep_gemm.return_value = True + + strategy = CudaNoQuantDpNormalDeepGemmStrategy() + + # fp16 (not bf16) + self.assertFalse(strategy.can_handle(self._make_config(data_type="fp16"))) + # not explicitly opted in + self.assertFalse(strategy.can_handle(self._make_config(moe_strategy="auto"))) + # CUDA graph incompatible (runtime masked/contiguous dispatch) + self.assertFalse( + strategy.can_handle(self._make_config(enable_cuda_graph=True)) + ) + # bf16 grouped GEMM symbols unavailable -> fail fast at selection + mock_has_grouped.return_value = False + self.assertFalse(strategy.can_handle(self._make_config())) + + +class TestHasDeepGemmBf16Grouped(unittest.TestCase): + """has_deep_gemm_bf16_grouped() must report unavailability, never raise, so it + is safe to call during strategy enumeration for any config.""" + + @patch("rtp_llm.models_py.kernels.cuda.deepgemm_wrapper.has_deep_gemm") + def test_false_when_package_absent(self, mock_has_deep_gemm: Any) -> None: + from rtp_llm.models_py.kernels.cuda.deepgemm_wrapper import ( + has_deep_gemm_bf16_grouped, + ) + + mock_has_deep_gemm.return_value = False + self.assertFalse(has_deep_gemm_bf16_grouped()) + + @patch( + "rtp_llm.models_py.kernels.cuda.deepgemm_wrapper._ensure_bf16_initialized" + ) + @patch("rtp_llm.models_py.kernels.cuda.deepgemm_wrapper.has_deep_gemm") + def test_false_not_raise_on_symbol_resolution_error( + self, mock_has_deep_gemm: Any, mock_ensure_bf16: Any + ) -> None: + from rtp_llm.models_py.kernels.cuda.deepgemm_wrapper import ( + has_deep_gemm_bf16_grouped, + ) + + mock_has_deep_gemm.return_value = True + mock_ensure_bf16.side_effect = RuntimeError("symbol not found") + # Must swallow the resolution error and return False, not propagate. + self.assertFalse(has_deep_gemm_bf16_grouped()) + + +class TestEnsureBf16Initialized(unittest.TestCase): + """The bf16 symbol init must be decoupled from the fp8 path: missing bf16 + symbols leave the impls None and never raise, so _ensure_initialized() (fp8) + is unaffected.""" + + @patch( + "rtp_llm.models_py.kernels.cuda.deepgemm_wrapper._bf16_symbols_initialized", + False, + ) + @patch("rtp_llm.models_py.kernels.cuda.deepgemm_wrapper._lazy_init_deep_gemm") + @patch("rtp_llm.models_py.kernels.cuda.deepgemm_wrapper.has_deep_gemm") + def test_tolerates_missing_bf16_symbols( + self, mock_has_deep_gemm: Any, mock_lazy_init: Any + ) -> None: + from rtp_llm.models_py.kernels.cuda import deepgemm_wrapper + + mock_has_deep_gemm.return_value = True + mock_lazy_init.side_effect = RuntimeError("bf16 grouped symbol not found") + # Must not propagate the resolution error (would otherwise break the fp8 + # path's _ensure_initialized, which is a separate call). + deepgemm_wrapper._ensure_bf16_initialized() + + if __name__ == "__main__": unittest.main()