66from  torch  import  nn 
77
88from  tensorrt_llm ._mnnvl_utils  import  MnnvlMemory , MnnvlMoe 
9+ from  tensorrt_llm ._torch .distributed .moe_alltoall  import  MoeAlltoAll 
910from  tensorrt_llm ._utils  import  get_sm_version 
1011
1112from  ...custom_ops .trtllm_gen_custom_ops  import  \
@@ -112,11 +113,26 @@ def __init__(
112113        self .alltoall_workspace  =  None 
113114        self .alltoall_prepare_workspace  =  None 
114115        if  self .enable_alltoall :
115-             MnnvlMemory .initialize ()
116-             self .alltoall_workspace  =  MnnvlMoe .get_moe_workspaces (
117-                 model_config .mapping )
118-             self .alltoall_prepare_workspace  =  MnnvlMoe .get_moe_prepare_workspace (
119-                 model_config .mapping )
116+             if  self .moe_alltoall_backend  ==  "MnnvlLatency" :
117+                 MnnvlMemory .initialize ()
118+                 self .alltoall_workspace  =  MnnvlMoe .get_moe_workspaces (
119+                     model_config .mapping )
120+                 self .alltoall_prepare_workspace  =  MnnvlMoe .get_moe_prepare_workspace (
121+                     model_config .mapping )
122+             elif  self .moe_alltoall_backend  ==  "MnnvlThroughput" :
123+                 workspace_mb  =  int (
124+                     os .environ .get ("TRTLLM_MOE_A2A_WORKSPACE_MB" , "512" ))
125+                 self .moe_a2a  =  MoeAlltoAll (
126+                     mapping = self .mapping ,
127+                     max_num_tokens_per_rank = model_config .max_num_tokens ,
128+                     top_k = self .routing_method .experts_per_token ,
129+                     num_experts = self .num_experts ,
130+                     workspace_size_per_rank = workspace_mb  *  1024  *  1024 ,
131+                 )
132+             else :
133+                 raise  ValueError (
134+                     f"Unsupported moe alltoall backend: { self .moe_alltoall_backend }  
135+                 )
120136
121137        self ._weights_created  =  False 
122138        if  not  model_config .skip_create_weights_in_init :
@@ -131,6 +147,12 @@ def enable_alltoall(self):
131147                and  os .environ .get ("TRTLLM_MOE_DISABLE_ALLTOALLV" , "0" ) !=  "1" 
132148                and  MnnvlMemory .supports_mnnvl ())
133149
150+     @cached_property  
151+     def  moe_alltoall_backend (self ):
152+         # "MnnvlLatency" (default) or "MnnvlThroughput" 
153+         return  os .environ .get ("TRTLLM_MOE_ALLTOALL_BACKEND" ,
154+                               "MnnvlLatency" ).strip ().lower ()
155+ 
134156    def  _check_configs (self ):
135157        assert  self .has_deepseek_fp8_block_scales  \
136158            or  self .has_nvfp4  or  self .has_w4a16_mxfp4  or  self .has_w4a8_nvfp4_fp8  \
@@ -298,45 +320,89 @@ def forward_impl(
298320            else :
299321                token_final_scales  =  token_final_scales .to (torch .float32 )
300322
301-             assert  self .alltoall_prepare_workspace  is  not None , "alltoall_prepare_workspace should be initialized" 
302-             alltoall_info , _  =  MnnvlMoe .mnnvl_moe_alltoallv_prepare_without_allgather (
303-                 token_selected_experts ,
304-                 None ,
305-                 self .alltoall_prepare_workspace ,
306-                 max_num_token ,
307-                 self .ep_rank ,
308-                 self .ep_size ,
309-                 self .num_experts ,
310-                 self .num_slots ,
311-                 top_k ,
312-             )
323+             if  self .moe_alltoall_backend  ==  "MnnvlLatency" :
324+                 assert  self .alltoall_prepare_workspace  is  not None , "alltoall_prepare_workspace should be initialized" 
325+                 alltoall_info , _  =  MnnvlMoe .mnnvl_moe_alltoallv_prepare_without_allgather (
326+                     token_selected_experts ,
327+                     None ,
328+                     self .alltoall_prepare_workspace ,
329+                     max_num_token ,
330+                     self .ep_rank ,
331+                     self .ep_size ,
332+                     self .num_experts ,
333+                     self .num_slots ,
334+                     top_k ,
335+                 )
313336
314-             if  x_sf  is  not None :
315-                 x_sf  =  x_sf .view (x_row ,  ceil_div ( x_col ,
316-                                                   self .scaling_vector_size ))
337+                  if  x_sf  is  not None :
338+                      x_sf  =  x_sf .view (x_row ,
339+                                      ceil_div ( x_col ,  self .scaling_vector_size ))
317340
318-             x , x_sf , token_selected_experts , token_final_scales  =  MnnvlMoe .mnnvl_moe_alltoallv (
319-                 [x , x_sf , token_selected_experts , token_final_scales ],
320-                 alltoall_info ,
321-                 self .alltoall_workspace ,
322-                 self .ep_rank ,
323-                 self .ep_size ,
324-             )
341+                  x , x_sf , token_selected_experts , token_final_scales  =  MnnvlMoe .mnnvl_moe_alltoallv (
342+                      [x , x_sf , token_selected_experts , token_final_scales ],
343+                      alltoall_info ,
344+                      self .alltoall_workspace ,
345+                      self .ep_rank ,
346+                      self .ep_size ,
347+                  )
325348
326-             torch .ops .trtllm .memset_expert_ids (
327-                 token_selected_experts ,
328-                 alltoall_info .recv_rank_count_cumsum ,
329-                 max_num_token ,
330-                 top_k ,
331-                 self .num_slots ,
332-                 self .ep_size ,
333-             )
349+                  torch .ops .trtllm .memset_expert_ids (
350+                      token_selected_experts ,
351+                      alltoall_info .recv_rank_count_cumsum ,
352+                      max_num_token ,
353+                      top_k ,
354+                      self .num_slots ,
355+                      self .ep_size ,
356+                  )
334357
335-             if  x_sf  is  not None :
336-                 x_sf  =  x_sf .flatten ()
358+                 if  x_sf  is  not None :
359+                     x_sf  =  x_sf .flatten ()
360+ 
361+                 if  token_final_scales  is  not None :
362+                     token_final_scales  =  token_final_scales .to (torch .bfloat16 )
363+             elif  self .moe_alltoall_backend  ==  "MnnvlThroughput" :
364+                 if  x_sf  is  not None :
365+                     x_sf  =  x_sf .view (x_row ,
366+                                      ceil_div (x_col , self .scaling_vector_size ))
367+ 
368+                 payloads  =  []
369+                 payloads .append (x )
370+                 if  x_sf  is  not None :
371+                     payloads .append (x_sf )
372+                     expert_id_payload_index  =  2 
373+                 else :
374+                     expert_id_payload_index  =  1 
375+                 payloads .append (token_selected_experts )
376+                 payloads .append (token_final_scales )
377+ 
378+                 recv_buffers  =  self .moe_a2a .dispatch (
379+                     token_selected_experts ,
380+                     payloads ,
381+                     invalid_token_expert_id = 
382+                     - 1 ,  # Note Cutlass MoE uses num_experts as invalid token expert id 
383+                     expert_id_payload_index = expert_id_payload_index ,
384+                 )
337385
338-             if  token_final_scales  is  not None :
339-                 token_final_scales  =  token_final_scales .to (torch .bfloat16 )
386+                 if  x_sf  is  not None :
387+                     x_recv , x_sf_recv , token_selected_experts_recv , token_final_scales_recv  =  recv_buffers 
388+                     x_sf  =  x_sf_recv .view (- 1 , x_sf_recv .shape [- 1 ])
389+                 else :
390+                     x_recv , token_selected_experts_recv , token_final_scales_recv  =  recv_buffers 
391+                 x  =  x_recv .view (- 1 , x_recv .shape [- 1 ])
392+                 token_selected_experts  =  token_selected_experts_recv .view (
393+                     - 1 , token_selected_experts_recv .shape [- 1 ])
394+                 token_final_scales  =  token_final_scales_recv .view (
395+                     - 1 , token_final_scales_recv .shape [- 1 ])
396+ 
397+                 if  x_sf  is  not None :
398+                     x_sf  =  x_sf .flatten ()
399+ 
400+                 if  token_final_scales  is  not None :
401+                     token_final_scales  =  token_final_scales .to (torch .bfloat16 )
402+             else :
403+                 raise  ValueError (
404+                     f"Unsupported moe alltoall backend: { self .moe_alltoall_backend }  
405+                 )
340406
341407        elif  run_post_quant_allgather :
342408            if  x_sf  is  not None :
@@ -600,16 +666,28 @@ def forward_impl(
600666            )
601667
602668        # Combine results if using alltoall 
603-         if  self .enable_alltoall  and  alltoall_info  is  not None :
604-             final_hidden_states  =  MnnvlMoe .mnnvl_moe_alltoallv_combine (
605-                 final_hidden_states ,
606-                 alltoall_info ,
607-                 self .alltoall_workspace ,
608-                 ep_rank = self .ep_rank ,
609-                 ep_size = self .ep_size ,
610-                 top_k = top_k ,
611-                 token_count = token_count ,
612-             )
669+         if  self .enable_alltoall :
670+             if  self .moe_alltoall_backend  ==  "MnnvlLatency" :
671+                 if  alltoall_info  is  not None :
672+                     final_hidden_states  =  MnnvlMoe .mnnvl_moe_alltoallv_combine (
673+                         final_hidden_states ,
674+                         alltoall_info ,
675+                         self .alltoall_workspace ,
676+                         ep_rank = self .ep_rank ,
677+                         ep_size = self .ep_size ,
678+                         top_k = top_k ,
679+                         token_count = token_count ,
680+                     )
681+             elif  self .moe_alltoall_backend  ==  "MnnvlThroughput" :
682+                 hidden  =  final_hidden_states .shape [- 1 ]
683+                 payload  =  final_hidden_states .view (
684+                     self .ep_size , self .moe_a2a .max_num_tokens_per_rank , hidden )
685+                 final_hidden_states  =  self .moe_a2a .combine (
686+                     payload , payload_in_workspace = False )
687+             else :
688+                 raise  ValueError (
689+                     f"Unsupported moe alltoall backend: { self .moe_alltoall_backend }  
690+                 )
613691
614692        final_hidden_states  =  self .reducescatter_or_allreduce (
615693            final_hidden_states ,
0 commit comments