Skip to content

Commit 90686db

Browse files
committed
Enable MnnvlThroughput in Trtllm MoE.
Signed-off-by: Bo Li <[email protected]>
1 parent 2c5989e commit 90686db

File tree

1 file changed

+127
-49
lines changed

1 file changed

+127
-49
lines changed

tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py

Lines changed: 127 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from torch import nn
77

88
from tensorrt_llm._mnnvl_utils import MnnvlMemory, MnnvlMoe
9+
from tensorrt_llm._torch.distributed.moe_alltoall import MoeAlltoAll
910
from tensorrt_llm._utils import get_sm_version
1011

1112
from ...custom_ops.trtllm_gen_custom_ops import \
@@ -112,11 +113,26 @@ def __init__(
112113
self.alltoall_workspace = None
113114
self.alltoall_prepare_workspace = None
114115
if self.enable_alltoall:
115-
MnnvlMemory.initialize()
116-
self.alltoall_workspace = MnnvlMoe.get_moe_workspaces(
117-
model_config.mapping)
118-
self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace(
119-
model_config.mapping)
116+
if self.moe_alltoall_backend == "MnnvlLatency":
117+
MnnvlMemory.initialize()
118+
self.alltoall_workspace = MnnvlMoe.get_moe_workspaces(
119+
model_config.mapping)
120+
self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace(
121+
model_config.mapping)
122+
elif self.moe_alltoall_backend == "MnnvlThroughput":
123+
workspace_mb = int(
124+
os.environ.get("TRTLLM_MOE_A2A_WORKSPACE_MB", "512"))
125+
self.moe_a2a = MoeAlltoAll(
126+
mapping=self.mapping,
127+
max_num_tokens_per_rank=model_config.max_num_tokens,
128+
top_k=self.routing_method.experts_per_token,
129+
num_experts=self.num_experts,
130+
workspace_size_per_rank=workspace_mb * 1024 * 1024,
131+
)
132+
else:
133+
raise ValueError(
134+
f"Unsupported moe alltoall backend: {self.moe_alltoall_backend}"
135+
)
120136

121137
self._weights_created = False
122138
if not model_config.skip_create_weights_in_init:
@@ -131,6 +147,12 @@ def enable_alltoall(self):
131147
and os.environ.get("TRTLLM_MOE_DISABLE_ALLTOALLV", "0") != "1"
132148
and MnnvlMemory.supports_mnnvl())
133149

150+
@cached_property
151+
def moe_alltoall_backend(self):
152+
# "MnnvlLatency" (default) or "MnnvlThroughput"
153+
return os.environ.get("TRTLLM_MOE_ALLTOALL_BACKEND",
154+
"MnnvlLatency").strip().lower()
155+
134156
def _check_configs(self):
135157
assert self.has_deepseek_fp8_block_scales \
136158
or self.has_nvfp4 or self.has_w4a16_mxfp4 or self.has_w4a8_nvfp4_fp8 \
@@ -298,45 +320,89 @@ def forward_impl(
298320
else:
299321
token_final_scales = token_final_scales.to(torch.float32)
300322

301-
assert self.alltoall_prepare_workspace is not None, "alltoall_prepare_workspace should be initialized"
302-
alltoall_info, _ = MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather(
303-
token_selected_experts,
304-
None,
305-
self.alltoall_prepare_workspace,
306-
max_num_token,
307-
self.ep_rank,
308-
self.ep_size,
309-
self.num_experts,
310-
self.num_slots,
311-
top_k,
312-
)
323+
if self.moe_alltoall_backend == "MnnvlLatency":
324+
assert self.alltoall_prepare_workspace is not None, "alltoall_prepare_workspace should be initialized"
325+
alltoall_info, _ = MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather(
326+
token_selected_experts,
327+
None,
328+
self.alltoall_prepare_workspace,
329+
max_num_token,
330+
self.ep_rank,
331+
self.ep_size,
332+
self.num_experts,
333+
self.num_slots,
334+
top_k,
335+
)
313336

314-
if x_sf is not None:
315-
x_sf = x_sf.view(x_row, ceil_div(x_col,
316-
self.scaling_vector_size))
337+
if x_sf is not None:
338+
x_sf = x_sf.view(x_row,
339+
ceil_div(x_col, self.scaling_vector_size))
317340

318-
x, x_sf, token_selected_experts, token_final_scales = MnnvlMoe.mnnvl_moe_alltoallv(
319-
[x, x_sf, token_selected_experts, token_final_scales],
320-
alltoall_info,
321-
self.alltoall_workspace,
322-
self.ep_rank,
323-
self.ep_size,
324-
)
341+
x, x_sf, token_selected_experts, token_final_scales = MnnvlMoe.mnnvl_moe_alltoallv(
342+
[x, x_sf, token_selected_experts, token_final_scales],
343+
alltoall_info,
344+
self.alltoall_workspace,
345+
self.ep_rank,
346+
self.ep_size,
347+
)
325348

326-
torch.ops.trtllm.memset_expert_ids(
327-
token_selected_experts,
328-
alltoall_info.recv_rank_count_cumsum,
329-
max_num_token,
330-
top_k,
331-
self.num_slots,
332-
self.ep_size,
333-
)
349+
torch.ops.trtllm.memset_expert_ids(
350+
token_selected_experts,
351+
alltoall_info.recv_rank_count_cumsum,
352+
max_num_token,
353+
top_k,
354+
self.num_slots,
355+
self.ep_size,
356+
)
334357

335-
if x_sf is not None:
336-
x_sf = x_sf.flatten()
358+
if x_sf is not None:
359+
x_sf = x_sf.flatten()
360+
361+
if token_final_scales is not None:
362+
token_final_scales = token_final_scales.to(torch.bfloat16)
363+
elif self.moe_alltoall_backend == "MnnvlThroughput":
364+
if x_sf is not None:
365+
x_sf = x_sf.view(x_row,
366+
ceil_div(x_col, self.scaling_vector_size))
367+
368+
payloads = []
369+
payloads.append(x)
370+
if x_sf is not None:
371+
payloads.append(x_sf)
372+
expert_id_payload_index = 2
373+
else:
374+
expert_id_payload_index = 1
375+
payloads.append(token_selected_experts)
376+
payloads.append(token_final_scales)
377+
378+
recv_buffers = self.moe_a2a.dispatch(
379+
token_selected_experts,
380+
payloads,
381+
invalid_token_expert_id=
382+
-1, # Note Cutlass MoE uses num_experts as invalid token expert id
383+
expert_id_payload_index=expert_id_payload_index,
384+
)
337385

338-
if token_final_scales is not None:
339-
token_final_scales = token_final_scales.to(torch.bfloat16)
386+
if x_sf is not None:
387+
x_recv, x_sf_recv, token_selected_experts_recv, token_final_scales_recv = recv_buffers
388+
x_sf = x_sf_recv.view(-1, x_sf_recv.shape[-1])
389+
else:
390+
x_recv, token_selected_experts_recv, token_final_scales_recv = recv_buffers
391+
x = x_recv.view(-1, x_recv.shape[-1])
392+
token_selected_experts = token_selected_experts_recv.view(
393+
-1, token_selected_experts_recv.shape[-1])
394+
token_final_scales = token_final_scales_recv.view(
395+
-1, token_final_scales_recv.shape[-1])
396+
397+
if x_sf is not None:
398+
x_sf = x_sf.flatten()
399+
400+
if token_final_scales is not None:
401+
token_final_scales = token_final_scales.to(torch.bfloat16)
402+
else:
403+
raise ValueError(
404+
f"Unsupported moe alltoall backend: {self.moe_alltoall_backend}"
405+
)
340406

341407
elif run_post_quant_allgather:
342408
if x_sf is not None:
@@ -600,16 +666,28 @@ def forward_impl(
600666
)
601667

602668
# Combine results if using alltoall
603-
if self.enable_alltoall and alltoall_info is not None:
604-
final_hidden_states = MnnvlMoe.mnnvl_moe_alltoallv_combine(
605-
final_hidden_states,
606-
alltoall_info,
607-
self.alltoall_workspace,
608-
ep_rank=self.ep_rank,
609-
ep_size=self.ep_size,
610-
top_k=top_k,
611-
token_count=token_count,
612-
)
669+
if self.enable_alltoall:
670+
if self.moe_alltoall_backend == "MnnvlLatency":
671+
if alltoall_info is not None:
672+
final_hidden_states = MnnvlMoe.mnnvl_moe_alltoallv_combine(
673+
final_hidden_states,
674+
alltoall_info,
675+
self.alltoall_workspace,
676+
ep_rank=self.ep_rank,
677+
ep_size=self.ep_size,
678+
top_k=top_k,
679+
token_count=token_count,
680+
)
681+
elif self.moe_alltoall_backend == "MnnvlThroughput":
682+
hidden = final_hidden_states.shape[-1]
683+
payload = final_hidden_states.view(
684+
self.ep_size, self.moe_a2a.max_num_tokens_per_rank, hidden)
685+
final_hidden_states = self.moe_a2a.combine(
686+
payload, payload_in_workspace=False)
687+
else:
688+
raise ValueError(
689+
f"Unsupported moe alltoall backend: {self.moe_alltoall_backend}"
690+
)
613691

614692
final_hidden_states = self.reducescatter_or_allreduce(
615693
final_hidden_states,

0 commit comments

Comments
 (0)