Skip to content

Commit 808e556

Browse files
authored
[None][fix] : Fix OOM issue when dp padding is enabled (#8052)
Signed-off-by: peaceh <[email protected]>
1 parent 84aa3c9 commit 808e556

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -468,16 +468,17 @@ def forward_impl(
468468
else:
469469
num_rows = x.shape[0]
470470

471-
# in case of num_rows is larger than max_chunk_size, we need to split the input into multiple chunks
472-
num_chunks = (num_rows + self.moe_max_num_tokens -
473-
1) // self.moe_max_num_tokens
474-
475471
if use_dp_padding:
476472
all_rank_num_tokens_padded = [max(all_rank_num_tokens)
477473
] * len(all_rank_num_tokens)
474+
num_rows = sum(all_rank_num_tokens_padded)
478475
else:
479476
all_rank_num_tokens_padded = all_rank_num_tokens
480477

478+
# in case of num_rows is larger than max_chunk_size, we need to split the input into multiple chunks
479+
num_chunks = (num_rows + self.moe_max_num_tokens -
480+
1) // self.moe_max_num_tokens
481+
481482
if num_chunks == 1:
482483
outputs = self.forward_chunk(
483484
x,

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -771,16 +771,16 @@ def forward_impl(
771771

772772
all_rank_max_num_tokens = max(all_rank_num_tokens)
773773

774-
# in case of num_rows is larger than max_chunk_size, we need to split the input into multiple chunks
775-
num_chunks = self.calculate_num_chunks(all_rank_num_tokens)
776-
use_all_to_all = self.can_use_alltoall(all_rank_num_tokens,
777-
all_rank_max_num_tokens)
778-
779774
if use_dp_padding:
780775
all_rank_num_tokens_padded = [all_rank_max_num_tokens
781776
] * len(all_rank_num_tokens)
782777
else:
783778
all_rank_num_tokens_padded = all_rank_num_tokens
779+
780+
# in case of num_rows is larger than max_chunk_size, we need to split the input into multiple chunks
781+
num_chunks = self.calculate_num_chunks(all_rank_num_tokens_padded)
782+
use_all_to_all = self.can_use_alltoall(all_rank_num_tokens_padded,
783+
all_rank_max_num_tokens)
784784
if num_chunks == 1:
785785
is_first_call = self.repeat_idx == 0
786786
is_last_call = self.repeat_idx == self.repeat_count - 1

0 commit comments

Comments
 (0)