File tree Expand file tree Collapse file tree 2 files changed +10
-9
lines changed
tensorrt_llm/_torch/modules/fused_moe Expand file tree Collapse file tree 2 files changed +10
-9
lines changed Original file line number Diff line number Diff line change @@ -468,16 +468,17 @@ def forward_impl(
468468 else :
469469 num_rows = x .shape [0 ]
470470
471- # in case of num_rows is larger than max_chunk_size, we need to split the input into multiple chunks
472- num_chunks = (num_rows + self .moe_max_num_tokens -
473- 1 ) // self .moe_max_num_tokens
474-
475471 if use_dp_padding :
476472 all_rank_num_tokens_padded = [max (all_rank_num_tokens )
477473 ] * len (all_rank_num_tokens )
474+ num_rows = sum (all_rank_num_tokens_padded )
478475 else :
479476 all_rank_num_tokens_padded = all_rank_num_tokens
480477
478+ # in case of num_rows is larger than max_chunk_size, we need to split the input into multiple chunks
479+ num_chunks = (num_rows + self .moe_max_num_tokens -
480+ 1 ) // self .moe_max_num_tokens
481+
481482 if num_chunks == 1 :
482483 outputs = self .forward_chunk (
483484 x ,
Original file line number Diff line number Diff line change @@ -771,16 +771,16 @@ def forward_impl(
771771
772772 all_rank_max_num_tokens = max (all_rank_num_tokens )
773773
774- # in case of num_rows is larger than max_chunk_size, we need to split the input into multiple chunks
775- num_chunks = self .calculate_num_chunks (all_rank_num_tokens )
776- use_all_to_all = self .can_use_alltoall (all_rank_num_tokens ,
777- all_rank_max_num_tokens )
778-
779774 if use_dp_padding :
780775 all_rank_num_tokens_padded = [all_rank_max_num_tokens
781776 ] * len (all_rank_num_tokens )
782777 else :
783778 all_rank_num_tokens_padded = all_rank_num_tokens
779+
780+ # in case of num_rows is larger than max_chunk_size, we need to split the input into multiple chunks
781+ num_chunks = self .calculate_num_chunks (all_rank_num_tokens_padded )
782+ use_all_to_all = self .can_use_alltoall (all_rank_num_tokens_padded ,
783+ all_rank_max_num_tokens )
784784 if num_chunks == 1 :
785785 is_first_call = self .repeat_idx == 0
786786 is_last_call = self .repeat_idx == self .repeat_count - 1
You can’t perform that action at this time.
0 commit comments