Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 134 additions & 20 deletions rtp_llm/models_py/kernels/cuda/deepgemm_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import functools
import importlib.util
from contextlib import contextmanager
from typing import Any, Callable, Generator, List, NoReturn, Optional, Tuple

import torch
import triton
import triton.language as tl

from rtp_llm.utils.module_util import has_module, resolve_symbol
from rtp_llm.utils.module_util import resolve_symbol

__all__ = [
"fp8_gemm_nt",
Expand All @@ -16,6 +17,7 @@
"m_grouped_bf16_gemm_nt_contiguous",
"m_grouped_bf16_gemm_nt_masked",
"has_deep_gemm",
"has_deep_gemm_bf16_grouped",
"is_deep_gemm_e8m0_used",
"configure_deep_gemm_num_sms",
"maybe_pack_ue8m0_scale",
Expand All @@ -30,13 +32,18 @@
"m_grouped_bf16_gemm_nt_masked": "m_grouped_bf16_gemm_nt_masked",
}

# Legacy/fallback symbol names. resolve_symbol() tries _new_map first and only
# falls back here, so existing fp8 resolution is unchanged. The bf16 entries are
# the real deep_gemm symbols (gemm_bf16_bf16_bf16_nt*); they only make the bf16
# grouped-GEMM path (previously dormant) resolvable and do not affect any caller
# that already resolved via _new_map.
_deep_gemm_impl_old_map = {
"fp8_gemm_nt": "fp8_gemm_nt",
"m_grouped_fp8_gemm_nt_contiguous": "m_grouped_fp8_gemm_nt_contiguous",
"m_grouped_fp8_gemm_nt_masked": "fp8_m_grouped_gemm_nt_masked",
"bf16_gemm_nt": "bf16_gemm_nt",
"m_grouped_bf16_gemm_nt_contiguous": "m_grouped_bf16_gemm_nt_contiguous",
"m_grouped_bf16_gemm_nt_masked": "m_grouped_bf16_gemm_nt_masked",
"bf16_gemm_nt": "gemm_bf16_bf16_bf16_nt",
"m_grouped_bf16_gemm_nt_contiguous": "m_grouped_gemm_bf16_bf16_bf16_nt_contiguous",
"m_grouped_bf16_gemm_nt_masked": "m_grouped_gemm_bf16_bf16_bf16_nt_masked",
}


Expand All @@ -48,10 +55,48 @@
_m_grouped_bf16_gemm_nt_masked_impl: Callable[..., Any] | None = None


@functools.cache
_deep_gemm_available: bool | None = None


def has_deep_gemm() -> bool:
"""Whether the optional `deep_gemm` package is available."""
return has_module("deep_gemm")
"""Whether the optional `deep_gemm` package is available.

Re-checks until first successful detection, then caches True.
This handles late sys.path setup in spawned subprocesses where
deep_gemm may not be importable at module-load time.
"""
global _deep_gemm_available
if _deep_gemm_available is True:
return True
available = importlib.util.find_spec("deep_gemm") is not None
if available:
_deep_gemm_available = True
return available


def has_deep_gemm_bf16_grouped() -> bool:
"""Whether the bf16 grouped GEMM kernels are actually resolvable.

has_deep_gemm() only confirms the package is importable. This additionally
resolves and checks the specific bf16 grouped symbols the bf16 DeepGEMM MoE
path needs (contiguous + masked), so a strategy can fail fast at selection
time rather than deferring to the first execute() call when an older
deep_gemm build lacks them.

Never raises: symbol resolution failures (old deep_gemm build missing the
bf16 grouped symbols) are reported as "unavailable" (False), so calling this
during strategy enumeration cannot break selection for unrelated configs.
"""
if not has_deep_gemm():
return False
try:
_ensure_bf16_initialized()
except Exception:
return False
return (
_m_grouped_bf16_gemm_nt_contiguous_impl is not None
and _m_grouped_bf16_gemm_nt_masked_impl is not None
)


@functools.cache
Expand All @@ -62,6 +107,7 @@ def is_deep_gemm_e8m0_used() -> bool:
@contextmanager
def configure_deep_gemm_num_sms(num_sms: int) -> Generator[None, None, None]:
"""Configure the number of sms for deep gemm."""
_ensure_initialized()
if not has_deep_gemm():
raise RuntimeError(
"DeepGEMM is not available. Please install the `deep_gemm` package to enable DeepGEMM kernels."
Expand Down Expand Up @@ -122,20 +168,69 @@ def _lazy_init_deep_gemm(symbols: List[str]) -> None:
)


# Core symbols required by the fp8 path. Resolved by _ensure_initialized(); a
# build missing these is broken for fp8 and raising is appropriate.
_FP8_SYMBOLS = [
"fp8_gemm_nt",
"m_grouped_fp8_gemm_nt_contiguous",
"m_grouped_fp8_gemm_nt_masked",
]

# Optional bf16 symbols, resolved separately and tolerantly (see
# _ensure_bf16_initialized) so an older deep_gemm build lacking them does NOT
# break the fp8 path's _ensure_initialized() — it only makes the bf16 deepgemm
# MoE strategy unselectable / its wrappers raise _missing_deep_gemm() at use.
_BF16_SYMBOLS = [
"bf16_gemm_nt",
"m_grouped_bf16_gemm_nt_contiguous",
"m_grouped_bf16_gemm_nt_masked",
]


def _lazy_init_deep_gemm_once():
_lazy_init_deep_gemm(
[
"fp8_gemm_nt",
"m_grouped_fp8_gemm_nt_contiguous",
"m_grouped_fp8_gemm_nt_masked",
"bf16_gemm_nt",
"m_grouped_bf16_gemm_nt_contiguous",
"m_grouped_bf16_gemm_nt_masked",
]
)
_lazy_init_deep_gemm(_FP8_SYMBOLS)


_symbols_initialized = False


def _ensure_initialized():
"""Resolve deep_gemm symbols on first actual use (not at import time).

Retries until deep_gemm becomes available on sys.path, which handles
spawned subprocesses where path setup happens after module import.
"""
global _symbols_initialized
if _symbols_initialized:
return
if not has_deep_gemm():
return
_lazy_init_deep_gemm_once()
_symbols_initialized = True

_lazy_init_deep_gemm_once()

_bf16_symbols_initialized = False


def _ensure_bf16_initialized() -> None:
"""Resolve the optional bf16 deep_gemm symbols, independently of the fp8 path.

Tolerant on purpose: if the deep_gemm build lacks the bf16 symbols, the impls
stay None and we still mark this attempted, so:
- the fp8 path (_ensure_initialized) is never affected;
- bf16 wrappers hit their `is None -> _missing_deep_gemm()` guard at use;
- has_deep_gemm_bf16_grouped() reports False.
"""
global _bf16_symbols_initialized
if _bf16_symbols_initialized:
return
if not has_deep_gemm():
return # package not present yet; retry on a later call
try:
_lazy_init_deep_gemm(_BF16_SYMBOLS)
except Exception:
pass # missing bf16 symbols -> leave impls None, never propagate
_bf16_symbols_initialized = True


@triton.jit
Expand Down Expand Up @@ -387,6 +482,7 @@ def fp8_gemm_nt(
None
"""
global _fp8_gemm_nt_impl
_ensure_initialized()
if _fp8_gemm_nt_impl is None:
return _missing_deep_gemm()
_fp8_gemm_nt_impl(
Expand Down Expand Up @@ -424,6 +520,7 @@ def m_grouped_fp8_gemm_nt_contiguous(
"""

global _m_grouped_fp8_gemm_nt_contiguous_impl
_ensure_initialized()
if _m_grouped_fp8_gemm_nt_contiguous_impl is None:
return _missing_deep_gemm()
_m_grouped_fp8_gemm_nt_contiguous_impl(
Expand Down Expand Up @@ -487,6 +584,7 @@ def m_grouped_fp8_gemm_nt_masked(
Defaults to None, which will be set to False if E8M0 scale is used, otherwise True.
"""
global _m_grouped_fp8_gemm_nt_masked_impl
_ensure_initialized()
if _m_grouped_fp8_gemm_nt_masked_impl is None:
return _missing_deep_gemm()

Expand Down Expand Up @@ -527,6 +625,7 @@ def bf16_gemm_nt(
compiled_dims (str, optional): Compiled dimensions. Defaults to "nk".
"""
global _bf16_gemm_nt_impl
_ensure_bf16_initialized()
if _bf16_gemm_nt_impl is None:
return _missing_deep_gemm()
_bf16_gemm_nt_impl(a, b, output, c, compiled_dims)
Expand All @@ -550,14 +649,23 @@ def m_grouped_bf16_gemm_nt_contiguous(
compiled_dims (str, optional): Compiled dimensions. Defaults to "nk".
"""
global _m_grouped_bf16_gemm_nt_contiguous_impl
# Only the native "nk" layout is supported. The wrapper does not forward
# compiled_dims to the kernel (forwarding it perturbs bf16 numerics on this
# shared path); reject any non-default value explicitly instead of silently
# ignoring it.
if compiled_dims != "nk":
raise NotImplementedError(
"m_grouped_bf16_gemm_nt_contiguous only supports compiled_dims='nk', "
f"got {compiled_dims!r}"
)
_ensure_bf16_initialized()
if _m_grouped_bf16_gemm_nt_contiguous_impl is None:
return _missing_deep_gemm()
_m_grouped_bf16_gemm_nt_contiguous_impl(
a,
b,
output,
m_indices,
compiled_dims,
)


Expand All @@ -580,6 +688,13 @@ def m_grouped_bf16_gemm_nt_masked(
compiled_dims (str, optional): Compiled dimensions. Defaults to "nk".
"""
global _m_grouped_bf16_gemm_nt_masked_impl
# Only the native "nk" layout is supported (see m_grouped_bf16_gemm_nt_contiguous).
if compiled_dims != "nk":
raise NotImplementedError(
"m_grouped_bf16_gemm_nt_masked only supports compiled_dims='nk', "
f"got {compiled_dims!r}"
)
_ensure_bf16_initialized()
if _m_grouped_bf16_gemm_nt_masked_impl is None:
return _missing_deep_gemm()
_m_grouped_bf16_gemm_nt_masked_impl(
Expand All @@ -588,5 +703,4 @@ def m_grouped_bf16_gemm_nt_masked(
output,
masked_m,
expected_m,
compiled_dims,
)
2 changes: 2 additions & 0 deletions rtp_llm/models_py/modules/factory/fused_moe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
CudaFp8PerTensorEpNormalStrategy,
CudaFp8PerTensorNoDPStrategy,
CudaNoQuantCppStrategy,
CudaNoQuantDpNormalDeepGemmStrategy,
CudaNoQuantDpNormalStrategy,
CudaNoQuantEpLowLatencyStrategy,
CudaW4a8Int4PerChannelEpLowLatencyStrategy,
Expand All @@ -95,6 +96,7 @@
registry.register(CudaFp8PerBlockNoDPStrategy())
registry.register(CudaFp8PerTensorNoDPStrategy())
registry.register(CudaNoQuantEpLowLatencyStrategy())
registry.register(CudaNoQuantDpNormalDeepGemmStrategy())
registry.register(CudaNoQuantDpNormalStrategy())
registry.register(CudaNoQuantCppStrategy())
registry.register(BatchedTritonStrategy())
Expand Down
Loading
Loading