Skip to content

Commit 227c288

Browse files
[TRTLLM-8827] [feat] Enable low precision alltoall for Cutlass and TRTLLMGen backends (#8675)
Signed-off-by: Kaiyu Xie <[email protected]> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
1 parent 00161b3 commit 227c288

File tree

5 files changed

+138
-53
lines changed

5 files changed

+138
-53
lines changed

tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py

Lines changed: 67 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@
66

77
from tensorrt_llm._mnnvl_utils import MnnvlMemory, MnnvlMoe
88
from tensorrt_llm._torch.distributed.moe_alltoall import MoeAlltoAll
9+
from tensorrt_llm.logger import logger
910

1011
from ...distributed import allgather
1112
from ...model_config import ModelConfig
1213
from ...utils import AuxStreamType, EventType, Fp4QuantizedTensor, ceil_div
13-
from .interface import MoE
14+
from .interface import AlltoallMethodType, MoE
1415

1516
# isort: off
1617
from .quantization import (
@@ -140,28 +141,44 @@ def __init__(
140141
self.has_been_profiled_min_latency = False
141142

142143
# TODO: AlltoAll code is largely duplicated with WideEPMoE. Consider refactor and reuse in the future.
144+
self.alltoall_method_type = self.select_alltoall_method_type()
145+
logger.info_once(
146+
f"{self.__class__.__name__} selects alltoall_method_type {self.alltoall_method_type!r}",
147+
key="alltoall_method_type")
143148
self.alltoall_workspace = None
144149
self.alltoall_prepare_workspace = None
150+
self.use_low_precision_combine = False
145151
if self.enable_alltoall:
146-
if self.moe_alltoall_backend == "mnnvllatency":
147-
MnnvlMemory.initialize()
148-
self.alltoall_workspace = MnnvlMoe.get_moe_workspaces(
149-
model_config.mapping)
150-
self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace(
151-
model_config.mapping)
152-
elif self.moe_alltoall_backend == "mnnvlthroughput":
153-
workspace_mb = int(
154-
os.environ.get("TRTLLM_MOE_A2A_WORKSPACE_MB", "512"))
155-
self.moe_a2a = MoeAlltoAll(
156-
mapping=self.mapping,
157-
max_num_tokens_per_rank=model_config.max_num_tokens,
158-
top_k=self.routing_method.experts_per_token,
159-
num_experts=self.num_experts,
160-
workspace_size_per_rank=workspace_mb * 1024 * 1024,
152+
self.use_low_precision_combine = model_config.use_low_precision_moe_combine
153+
154+
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
155+
if self.moe_alltoall_backend == "mnnvllatency":
156+
MnnvlMemory.initialize()
157+
self.alltoall_workspace = MnnvlMoe.get_moe_workspaces(
158+
model_config.mapping)
159+
self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace(
160+
model_config.mapping)
161+
elif self.moe_alltoall_backend == "mnnvlthroughput":
162+
workspace_mb = int(
163+
os.environ.get("TRTLLM_MOE_A2A_WORKSPACE_MB", "512"))
164+
self.moe_a2a = MoeAlltoAll(
165+
mapping=self.mapping,
166+
max_num_tokens_per_rank=model_config.max_num_tokens,
167+
top_k=self.routing_method.experts_per_token,
168+
num_experts=self.num_experts,
169+
workspace_size_per_rank=workspace_mb * 1024 * 1024,
170+
)
171+
else:
172+
raise ValueError(
173+
f"Unsupported moe alltoall backend: {self.moe_alltoall_backend}"
174+
)
175+
elif self.alltoall_method_type == AlltoallMethodType.DeepEP or self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
176+
raise NotImplementedError(
177+
"DeepEP and DeepEPLowLatency are not supported for CutlassFusedMoE yet"
161178
)
162179
else:
163-
raise ValueError(
164-
f"Unsupported moe alltoall backend: {self.moe_alltoall_backend}"
180+
raise NotImplementedError(
181+
f"Not available alltoall method type: {self.alltoall_method_type!r}"
165182
)
166183

167184
# If True, the router weight will be multiplied on the input rather than at the end of FC2
@@ -204,13 +221,38 @@ def has_int8_woq_per_channel(self):
204221
return self.quant_config.layer_quant_mode.is_int8_weight_only(
205222
) and not self.quant_config.layer_quant_mode.has_per_group_scaling()
206223

224+
def select_alltoall_method_type(self) -> AlltoallMethodType:
225+
all2all_method_type = os.environ.get("TRTLLM_FORCE_ALLTOALL_METHOD")
226+
if all2all_method_type is not None:
227+
if AlltoallMethodType[all2all_method_type] in [
228+
AlltoallMethodType.DeepEP,
229+
AlltoallMethodType.DeepEPLowLatency
230+
]:
231+
raise NotImplementedError(
232+
"DeepEP and DeepEPLowLatency are not supported for CutlassFusedMoE yet"
233+
)
234+
return AlltoallMethodType[all2all_method_type]
235+
236+
if not self.mapping.enable_attention_dp:
237+
return AlltoallMethodType.NotEnabled
238+
239+
if self.mapping.tp_size == 1:
240+
return AlltoallMethodType.NotEnabled
241+
242+
if os.environ.get("TRTLLM_MOE_DISABLE_ALLTOALLV", "0") == "1":
243+
return AlltoallMethodType.NotEnabled
244+
245+
if not (self.mapping.moe_ep_size > self.routing_method.experts_per_token
246+
and MnnvlMemory.supports_mnnvl()):
247+
return AlltoallMethodType.NotEnabled
248+
249+
return AlltoallMethodType.MNNVL
250+
207251
@cached_property
208252
def enable_alltoall(self):
209-
return (self.mapping.moe_ep_size > self.routing_method.experts_per_token
210-
and self.mapping.enable_attention_dp
211-
and self.mapping.tp_size > 1
212-
and os.environ.get("TRTLLM_MOE_DISABLE_ALLTOALLV", "0") != "1"
213-
and MnnvlMemory.supports_mnnvl())
253+
""" enable_alltoall (bool): whether to enable alltoall instead of allgather/reducescatter
254+
"""
255+
return self.alltoall_method_type != AlltoallMethodType.NotEnabled
214256

215257
@cached_property
216258
def moe_alltoall_backend(self):
@@ -510,6 +552,8 @@ def forward_chunk(
510552
ep_rank=self.ep_rank,
511553
ep_size=self.ep_size,
512554
top_k=top_k,
555+
use_low_precision_combine=self.
556+
use_low_precision_combine,
513557
token_count=token_count)
514558
elif self.moe_alltoall_backend == "mnnvlthroughput":
515559
hidden = final_hidden_states.shape[-1]

tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py

Lines changed: 55 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,14 @@
77

88
from tensorrt_llm._mnnvl_utils import MnnvlMemory, MnnvlMoe
99
from tensorrt_llm._utils import get_sm_version
10+
from tensorrt_llm.logger import logger
1011

1112
from ...custom_ops.trtllm_gen_custom_ops import \
1213
fp4_block_scale_fake_output_without_finalize
1314
from ...distributed import allgather
1415
from ...model_config import ModelConfig
1516
from ...utils import Fp4QuantizedTensor, ceil_div
16-
from .interface import MoE, MoEWeightLoadingMode
17+
from .interface import AlltoallMethodType, MoE, MoEWeightLoadingMode
1718
from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethod,
1819
NVFP4TRTLLMGenFusedMoEMethod,
1920
W4A8MXFP4FP8TRTLLMGenFusedMoEMethod,
@@ -109,27 +110,68 @@ def __init__(
109110
assert len(
110111
self.initial_local_expert_ids) == self.expert_size_per_partition
111112

113+
# TODO: AlltoAll code is largely duplicated with WideEPMoE. Consider refactor and reuse in the future.
114+
self.alltoall_method_type = self.select_alltoall_method_type()
115+
logger.info_once(
116+
f"{self.__class__.__name__} selects alltoall_method_type {self.alltoall_method_type!r}",
117+
key="alltoall_method_type")
112118
self.alltoall_workspace = None
113119
self.alltoall_prepare_workspace = None
120+
self.use_low_precision_combine = False
114121
if self.enable_alltoall:
115-
MnnvlMemory.initialize()
116-
self.alltoall_workspace = MnnvlMoe.get_moe_workspaces(
117-
model_config.mapping)
118-
self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace(
119-
model_config.mapping)
122+
self.use_low_precision_combine = model_config.use_low_precision_moe_combine
123+
124+
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
125+
MnnvlMemory.initialize()
126+
self.alltoall_workspace = MnnvlMoe.get_moe_workspaces(
127+
model_config.mapping)
128+
self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace(
129+
model_config.mapping)
130+
elif self.alltoall_method_type == AlltoallMethodType.DeepEP or self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
131+
raise NotImplementedError(
132+
"DeepEP and DeepEPLowLatency are not supported for TRTLLMGenFusedMoE yet"
133+
)
134+
else:
135+
raise NotImplementedError(
136+
f"Not available alltoall method type: {self.alltoall_method_type!r}"
137+
)
120138

121139
self._weights_created = False
122140
if not model_config.skip_create_weights_in_init:
123141
self.create_weights()
124142

143+
def select_alltoall_method_type(self) -> AlltoallMethodType:
144+
all2all_method_type = os.environ.get("TRTLLM_FORCE_ALLTOALL_METHOD")
145+
if all2all_method_type is not None:
146+
if AlltoallMethodType[all2all_method_type] in [
147+
AlltoallMethodType.DeepEP,
148+
AlltoallMethodType.DeepEPLowLatency
149+
]:
150+
raise NotImplementedError(
151+
"DeepEP and DeepEPLowLatency are not supported for CutlassFusedMoE yet"
152+
)
153+
return AlltoallMethodType[all2all_method_type]
154+
155+
if not self.mapping.enable_attention_dp:
156+
return AlltoallMethodType.NotEnabled
157+
158+
if self.mapping.tp_size == 1:
159+
return AlltoallMethodType.NotEnabled
160+
161+
if os.environ.get("TRTLLM_MOE_DISABLE_ALLTOALLV", "0") == "1":
162+
return AlltoallMethodType.NotEnabled
163+
164+
if not (self.mapping.moe_ep_size > self.routing_method.experts_per_token
165+
and MnnvlMemory.supports_mnnvl()):
166+
return AlltoallMethodType.NotEnabled
167+
168+
return AlltoallMethodType.MNNVL
169+
125170
@cached_property
126171
def enable_alltoall(self):
127-
mapping = self.mapping
128-
routing_experts = self.routing_method.experts_per_token
129-
return (mapping.moe_ep_size > routing_experts
130-
and mapping.enable_attention_dp and mapping.tp_size > 1
131-
and os.environ.get("TRTLLM_MOE_DISABLE_ALLTOALLV", "0") != "1"
132-
and MnnvlMemory.supports_mnnvl())
172+
""" enable_alltoall (bool): whether to enable alltoall instead of allgather/reducescatter
173+
"""
174+
return self.alltoall_method_type != AlltoallMethodType.NotEnabled
133175

134176
def _check_configs(self):
135177
assert self.has_deepseek_fp8_block_scales \
@@ -608,6 +650,7 @@ def forward_impl(
608650
ep_rank=self.ep_rank,
609651
ep_size=self.ep_size,
610652
top_k=top_k,
653+
use_low_precision_combine=self.use_low_precision_combine,
611654
token_count=token_count,
612655
)
613656

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import os
2-
from enum import IntEnum
32
from typing import Dict, List, Optional, Tuple, Union
43

54
import torch
@@ -15,7 +14,7 @@
1514
from ...model_config import ModelConfig
1615
from ...utils import AuxStreamType, EventType, Fp4QuantizedTensor
1716
from .deep_ep_utils import buffer_pool, deep_ep_installed
18-
from .interface import MoE
17+
from .interface import AlltoallMethodType, MoE
1918
from .moe_load_balancer import get_moe_load_balancer
2019
from .ops import MoEOp, MoEOpSelector
2120
from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethod,
@@ -26,18 +25,6 @@
2625
from .routing import BaseMoeRoutingMethod
2726

2827

29-
# The type of alltoall method
30-
class AlltoallMethodType(IntEnum):
31-
# Not available
32-
NotEnabled = 0
33-
# MNNVL
34-
MNNVL = 1
35-
# DeepEP intranode or internode: CUDA Graphs are supported, IBGDA is required by internode
36-
DeepEP = 2
37-
# DeepEP low latency: CUDA Graphs are supported, IBGDA is required
38-
DeepEPLowLatency = 3
39-
40-
4128
class WideEPMoE(MoE):
4229
"""
4330
Fused Mixture of Experts (MoE) Layer with for wide EP.

tensorrt_llm/_torch/modules/fused_moe/interface.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import weakref
22
from abc import abstractmethod
3-
from enum import Enum
3+
from enum import Enum, IntEnum
44
from typing import Dict, List, Optional, Union, final
55

66
import torch
@@ -22,6 +22,18 @@ class MoEWeightLoadingMode(Enum):
2222
W4A8_CUSTOM = 2
2323

2424

25+
# The type of alltoall method
26+
class AlltoallMethodType(IntEnum):
27+
# Not available
28+
NotEnabled = 0
29+
# MNNVL
30+
MNNVL = 1
31+
# DeepEP intranode or internode: CUDA Graphs are supported, IBGDA is required by internode
32+
DeepEP = 2
33+
# DeepEP low latency: CUDA Graphs are supported, IBGDA is required
34+
DeepEPLowLatency = 3
35+
36+
2537
def extract_extra_attrs(layer_idx: str):
2638
extra_attrs = get_model_extra_attrs()
2739
assert extra_attrs is not None, "Model extra attrs are not set"

tests/unittest/_torch/modules/test_fused_moe.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,8 @@
2525
CuteDslFusedMoE
2626
from tensorrt_llm._torch.modules.fused_moe.fused_moe_deepgemm import \
2727
DeepGemmFusedMoE
28-
from tensorrt_llm._torch.modules.fused_moe.fused_moe_wide_ep import \
29-
AlltoallMethodType
30-
from tensorrt_llm._torch.modules.fused_moe.interface import MoEWeightLoadingMode
28+
from tensorrt_llm._torch.modules.fused_moe.interface import (
29+
AlltoallMethodType, MoEWeightLoadingMode)
3130

3231
# isort and yapf will fight against each other here, so we disable isort
3332
# isort: off

0 commit comments

Comments
 (0)