@@ -430,15 +430,14 @@ def is_post_quant_all2all_supported(self):
430430 return False
431431
432432 def forward_chunk (
433- self ,
434- x : Union [torch .Tensor , Fp4QuantizedTensor ],
435- router_logits : torch .Tensor ,
436- use_all_to_all : bool ,
437- output_dtype : Optional [torch .dtype ] = None ,
438- all_rank_num_tokens : Optional [List [int ]] = None ,
439- use_dp_padding : Optional [bool ] = None ,
440- repeating_info : Tuple = (True , True ),
441- alltoall_result_do_sum : bool = True ,
433+ self ,
434+ x : Union [torch .Tensor , Fp4QuantizedTensor ],
435+ router_logits : torch .Tensor ,
436+ use_all_to_all : bool ,
437+ output_dtype : Optional [torch .dtype ] = None ,
438+ all_rank_num_tokens : Optional [List [int ]] = None ,
439+ use_dp_padding : Optional [bool ] = None ,
440+ repeating_info : Tuple = (True , True ),
442441 ) -> torch .Tensor :
443442 all_rank_max_num_tokens = max (all_rank_num_tokens )
444443 if isinstance (x , Fp4QuantizedTensor ):
@@ -453,7 +452,7 @@ def forward_chunk(
453452 self .layer_load_balancer .start_wait_gpu_stage ()
454453
455454 if not use_all_to_all or self .alltoall_method_type != AlltoallMethodType .MNNVL :
456- alltoall_result_do_sum = True
455+ pass
457456
458457 weight_dtype = self .w3_w1_weight .dtype
459458
@@ -720,8 +719,7 @@ def forward_chunk(
720719 if self .enable_dummy_allreduce :
721720 self .dummy_allreduce ()
722721 final_hidden_states = self .alltoall_combine (
723- final_hidden_states , alltoall_info , token_count ,
724- alltoall_result_do_sum )
722+ final_hidden_states , alltoall_info , token_count )
725723 elif self .alltoall_method_type == AlltoallMethodType .DeepEP :
726724 final_hidden_states = self .unpad_tensors (
727725 padded , final_hidden_states )
@@ -766,7 +764,6 @@ def forward_impl(
766764 output_dtype : Optional [torch .dtype ] = None ,
767765 all_rank_num_tokens : Optional [List [int ]] = None ,
768766 use_dp_padding : Optional [bool ] = None ,
769- alltoall_result_do_sum : bool = True ,
770767 ** kwargs ,
771768 ) -> torch .Tensor :
772769 assert all_rank_num_tokens is not None
@@ -794,8 +791,7 @@ def forward_impl(
794791 output_dtype ,
795792 all_rank_num_tokens = all_rank_num_tokens_padded ,
796793 use_dp_padding = use_dp_padding ,
797- repeating_info = (is_first_call , is_last_call ),
798- alltoall_result_do_sum = alltoall_result_do_sum )
794+ repeating_info = (is_first_call , is_last_call ))
799795 outputs = self .reducescatter_or_allreduce (
800796 outputs ,
801797 use_all_to_all ,
@@ -853,8 +849,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
853849 all_rank_num_tokens = all_rank_num_tokens_list [
854850 idx_chunk ],
855851 use_dp_padding = use_dp_padding ,
856- repeating_info = (is_first_call , is_last_call ),
857- alltoall_result_do_sum = alltoall_result_do_sum )
852+ repeating_info = (is_first_call , is_last_call ))
858853 if idx_chunk > 0 :
859854 outputs_list [- 1 ] = self .reducescatter_or_allreduce (
860855 outputs_list [- 1 ],
@@ -870,8 +865,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
870865 all_rank_num_tokens = all_rank_num_tokens_list [
871866 idx_chunk ],
872867 use_dp_padding = use_dp_padding ,
873- repeating_info = (is_first_call , is_last_call ),
874- alltoall_result_do_sum = alltoall_result_do_sum )
868+ repeating_info = (is_first_call , is_last_call ))
875869 with torch .cuda .stream (self .aux_stream ):
876870 outputs_list [- 1 ] = self .reducescatter_or_allreduce (
877871 outputs_list [- 1 ],
@@ -885,8 +879,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
885879 router_logits ,
886880 use_all_to_all ,
887881 all_rank_num_tokens = all_rank_num_tokens_list [idx_chunk ],
888- repeating_info = (is_first_call , is_last_call ),
889- alltoall_result_do_sum = alltoall_result_do_sum )
882+ repeating_info = (is_first_call , is_last_call ))
890883
891884 outputs_list .append (outputs )
892885 if not use_all_to_all :
@@ -942,8 +935,7 @@ def alltoall_dispatch(self, x: torch.Tensor, x_sf: Optional[torch.Tensor],
942935 return x , x_sf , token_selected_slots , token_final_scales
943936
944937 def alltoall_combine (self , final_hidden_states : torch .Tensor ,
945- alltoall_info : MoEAlltoallInfo , token_count : int ,
946- alltoall_result_do_sum : bool ):
938+ alltoall_info : MoEAlltoallInfo , token_count : int ):
947939 top_k = self .routing_method .experts_per_token
948940 if isinstance (final_hidden_states , list ):
949941 final_hidden_states = final_hidden_states [0 ]
@@ -956,7 +948,7 @@ def alltoall_combine(self, final_hidden_states: torch.Tensor,
956948 top_k = top_k ,
957949 token_count = token_count ,
958950 use_low_precision_combine = self .use_low_precision_combine ,
959- do_reduce = alltoall_result_do_sum )
951+ do_reduce = False )
960952
961953 return final_hidden_states
962954
0 commit comments