feat(moe): add bf16 DeepEP-normal MoE path via DeepGEMM grouped GEMM#1111
feat(moe): add bf16 DeepEP-normal MoE path via DeepGEMM grouped GEMM#1111Tanmo-ai wants to merge 3 commits into
Conversation
Add an opt-in expert path for unquantized (bf16) MoE under DeepEP normal mode that uses DeepGEMM grouped GEMM instead of the Triton fused_moe_kernel. - DeepGemmBf16HybridExecutor: runtime-dispatches between a masked 3D layout (small token count / decode) and a contiguous flat layout (large token count / prefill) for better memory utilization. - ep_scatter_bf16 / ep_scatter_v2_bf16: bf16 variants of the existing fp8 scatter kernels (flat -> contiguous, flat -> 3D masked). - CudaNoQuantDpNormalDeepGemmStrategy: opt-in only, selected via --moe_strategy no_quant_dp_normal_deepgemm, gated on bf16 + has_deep_gemm + SM>=9 + no CUDA graph. It is NOT part of "auto" selection, so the default MoE path on existing CUDA deployments is unchanged. deepgemm_wrapper.py changes are backward-compatible and do not affect existing fp8/bf16 callers: - has_deep_gemm() re-checks until the first successful import (then caches True) instead of caching the first result; for normal processes where deep_gemm is importable at import time it returns True on the first call exactly as before. Needed for spawned subprocesses whose sys.path is set up after module import. - Symbol resolution is deferred from import-time to first use (_ensure_initialized); functionally identical, only lazy. - bf16 grouped-GEMM legacy fallback names corrected to the real deep_gemm symbols (gemm_bf16_bf16_bf16_nt*). resolve_symbol() tries the standard name first, so existing resolution is unchanged; this only makes the previously dormant bf16 path resolvable. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
AI Code Review - PR #1111Status: BLOCKING Summary: P0/0 · P1/1 · P2/1 · P3/0 Blocking IssuesP1
Non-blocking SuggestionsP2
Checklist Violations (5 fail / 104 total)General Principles Checklist
RTP-LLM Checklist
Strengths
|
…r e2e test_ep_scatter_bf16: the scatter kernels assign output slots with a non-deterministic tl.atomic_add, so row order within an expert is not fixed. Rewrite all checks to be order-independent by following output_index (the authoritative token->slot map the gather stage uses) instead of assuming a token-sequential layout. This replaces the previously order-sensitive torch.equal comparisons that could spuriously fail. Also: - fix the roundtrip tests to use hidden_size % 512 == 0 (ep_gather BLOCK_D=512); - size the contiguous stress test's per-expert capacity from the real routing histogram (bincount, aligned) instead of a fixed count, matching the executor's allocation and avoiding under-allocation. deepgemm_bf16_hybrid_executor: new end-to-end test for the bf16 DeepEP-normal hybrid executor (scatter -> grouped GEMM -> silu_and_mul -> grouped GEMM -> gather with router weight), covering both the masked (small token count) and contiguous (large token count) runtime paths against a plain-torch reference. Tagged open_skip + H20 (requires deep_gemm + SM>=9). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
AI Code Review - PR #1111Status: LGTM Summary: P0/0 · P1/0 · P2/3 · P3/0 lgtm ready to ci Non-blocking SuggestionsP2
Checklist Violations (7 fail / 104 total)General Principles Checklist
RTP-LLM Checklist
Strengths
|
e4564ea to
68b0a44
Compare
AI Code Review - PR #1111Status: LGTM Summary: P0/0 · P1/0 · P2/0 · P3/0 lgtm ready to ci Checklist ✅ (104 items passed)Strengths
|
2012875 to
8e0e71d
Compare
AI Code Review - PR #1111Status: LGTM Summary: P0/0 · P1/0 · P2/2 · P3/0 lgtm ready to ci Non-blocking SuggestionsP2
Checklist Violations (2 fail / 56 total)General Principles Checklist
Strengths
|
8e0e71d to
0a39628
Compare
AI Code Review - PR #1111Status: LGTM Summary: P0/0 · P1/0 · P2/3 · P3/0 lgtm ready to ci Non-blocking SuggestionsP2
Checklist Violations (5 fail / 104 total)General Principles Checklist
RTP-LLM Checklist
Strengths
|
0a39628 to
e284541
Compare
AI Code Review - PR #1111Status: BLOCKING Summary: P0/0 · P1/1 · P2/1 · P3/0 Blocking IssuesP1
Non-blocking SuggestionsP2
Checklist Violations (4 fail / 60 total)General Principles Checklist
RTP-LLM Checklist
Strengths
|
e284541 to
cc2025a
Compare
AI Code Review - PR #1111Status: BLOCKING Summary: P0/0 · P1/1 · P2/1 · P3/0 Blocking IssuesP1
Non-blocking SuggestionsP2
Checklist Violations (3 fail / 104 total)General Principles Checklist
RTP-LLM Checklist
Strengths
|
cc2025a to
9ecf9d8
Compare
AI Code Review - PR #1111Status: BLOCKING Summary: P0/0 · P1/1 · P2/0 · P3/0 Blocking IssuesP1
Checklist Violations (5 fail / 104 total)General Principles Checklist
RTP-LLM Checklist
Strengths
|
…d bf16 init Review fixes on the bf16 DeepEP-Normal deepgemm MoE path. All changes are confined to the opt-in no_quant_dp_normal_deepgemm path, the bf16 deep_gemm wrappers, and tests — the fp8 path and other default/cross-arch paths are unaffected. - DeepGemmBf16HybridExecutor.execute: handle empty rank before dispatch. DeepEP small-batch / skewed routing can leave a rank with token_num == 0; that would otherwise enter the masked path with alignment == 0 and launch 0-grid Triton scatter / 0-size DeepGEMM. Return an empty same-shape [0, K] bf16 output. - DeepGemmBf16HybridExecutor (contiguous path): build the per-expert token-count tensor with .to(device=hidden_states.device, non_blocking=True) instead of .cuda() so it honors the hidden-states device invariant. - deepgemm_wrapper: decouple bf16 symbol resolution from the fp8 path. Previously _ensure_initialized() resolved fp8 AND bf16 symbols together, so an older deep_gemm build missing the bf16 symbols would raise from _ensure_initialized() and break the fp8 wrappers. Now _ensure_initialized() resolves only _FP8_SYMBOLS (raises if missing — fp8 is core), while _ensure_bf16_initialized() resolves _BF16_SYMBOLS independently and tolerantly (missing -> impls stay None, never propagate). bf16 wrappers call _ensure_bf16_initialized(); fp8 wrappers keep _ensure_initialized(). has_deep_gemm_bf16_grouped() reports False (never raises) when the bf16 symbols are unavailable. - deepgemm_wrapper bf16 grouped wrappers: reject a non-default compiled_dims explicitly (NotImplementedError) instead of silently ignoring it; the wrapper does not forward compiled_dims (forwarding perturbs bf16 numerics on this shared path). No current caller passes a non-"nk" value. - CudaNoQuantDpNormalDeepGemmStrategy: fail fast at selection via has_deep_gemm_bf16_grouped(), and gate on the explicit opt-in moe_strategy FIRST with a short-circuit return (ConditionChecker does not stop at the first failed check, so this keeps the probe from running for non-opt-in / "auto" configs). - Tests: empty-rank (token_num==0) executor cases (ep1 + ep2); ep_size>1 executor coverage (rank 0/1, _to_local_expert_ids mapping + masking) vs a per-rank torch reference; strategy selection pos/neg; has_deep_gemm_bf16_grouped no-raise and _ensure_bf16_initialized tolerance; executor test skip uses has_deep_gemm_bf16_grouped() to match the gating. The ep_kernels contiguous padding-row m_indices contract is left to the feature kernel owner (padding output is discarded by the gather; a real fix needs a kernel signature change + a deep_gemm -1 skip contract). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
9ecf9d8 to
2658cff
Compare
AI Code Review - PR #1111Status: LGTM Summary: P0/0 · P1/0 · P2/0 · P3/0 lgtm ready to ci Checklist ✅ (56 items passed)Strengths
|
|
internal source has been updated, please review the changes! |
Replace runtime signature introspection with explicit validation: bf16 grouped GEMM only supports compiled_dims='nk', reject others with ValueError. Matches PR alibaba#1111 approach. Remove _has_param helper and all inspect.signature usage — eliminates unintrospectable callable, **kwargs, and positional-vs-keyword edge cases.
Summary
Add an opt-in expert path for unquantized (bf16) MoE under DeepEP normal mode
that uses DeepGEMM grouped GEMM instead of the Triton fused_moe_kernel.
Changes
(small token count / decode) and a contiguous flat layout (large token count /
prefill) for better memory utilization.
scatter kernels (flat → contiguous, flat → 3D masked).
--moe_strategy no_quant_dp_normal_deepgemm, gated on bf16 + has_deep_gemm +SM≥9 + no CUDA graph. It is not part of "auto" selection, so the default
MoE path on existing deployments is unchanged.
deepgemm_wrapper.py changes (backward-compatible)
These changes enable the bf16 grouped-GEMM path and do not affect existing
fp8/bf16 callers:
(
gemm_bf16_bf16_bf16_nt*).resolve_symbol()tries the standard name firstand only falls back, so existing resolution is unchanged — this only makes the
previously dormant bf16 path resolvable. The stale
compiled_dimsargument isdropped from the contiguous/masked bf16 calls to match the actual deep_gemm
signature.
_ensure_initialized) instead ofat import time. Functionally identical, only lazy: the same symbols are
resolved, just on the first actual GEMM call.
True) instead of caching the first result. For normal processes where
deep_gemmis importable at import time it returns True on the first callexactly as before; this only adds resilience when the package becomes
importable slightly later, and does not change existing behavior.
Testing
test_ep_scatter_bf16.pycovers bf16 scatter kernel correctness.