2222)
2323from torch .distributed .tensor .parallel import ParallelStyle
2424
25- from torchtitan .tools .logging import logger
2625from torchtitan .tools .utils import _round_up
2726
2827
@@ -90,18 +89,19 @@ class ExpertParallel(ParallelStyle):
9089 a2a_impl (str): The implementation of all-to-all. Default is "default". Options are ["default","mxfp8"].
9190 """
9291
93- def __init__ (self , a2a_impl : str = "default" ):
92+ def __init__ (
93+ self , a2a_dispatch_impl : str = "default" , a2a_combine_impl : str = "default"
94+ ):
9495 super ().__init__ ()
9596 self .input_splits = None
9697 self .output_splits = None
97- self .a2a_func = self ._get_a2a_func (a2a_impl )
98+ self .a2a_dispatch_func = self ._get_a2a_func (a2a_dispatch_impl )
99+ self .a2a_combine_func = self ._get_a2a_func (a2a_combine_impl )
98100
99101 def _get_a2a_func (self , a2a_impl : str ):
100102 if a2a_impl == "default" :
101- logger .info ("Using default all-to-all implementation" )
102103 return all_to_all_single_autograd
103104 elif a2a_impl == "mxfp8" :
104- logger .info ("Using mxfp8 all-to-all implementation" )
105105 from torchao .prototype .moe_training .kernels .mxfp8 .comms import (
106106 to_mxfp8_a2a_dequant ,
107107 )
@@ -143,6 +143,13 @@ def _token_dispatch(self, mod, inputs, device_mesh):
143143 self .input_splits = input_splits .tolist ()
144144 self .output_splits = output_splits .tolist ()
145145
146+ routed_input = self .a2a_dispatch_func (
147+ routed_input ,
148+ self .output_splits ,
149+ self .input_splits ,
150+ device_mesh .get_group (),
151+ )
152+
146153 # NOTE: After this all-to-all, the routed input is put on proper EP rank.
147154 # However, the num_tokens_per_expert_group is not of the final target format
148155 # [#tokens for local expert 0, #tokens for local expert 1, ...]
@@ -152,12 +159,7 @@ def _token_dispatch(self, mod, inputs, device_mesh):
152159 # We need to perform another shuffle to get the correct format -- this is done via the function
153160 # generate_permute_indices in moe.py, which also does padding to make sure the number of tokens
154161 # each expert gets locally is a multiple of ALIGN_SIZE_M.
155- routed_input = self .a2a_func (
156- routed_input ,
157- self .output_splits ,
158- self .input_splits ,
159- device_mesh .get_group (),
160- )
162+
161163 return routed_input , num_tokens_per_expert_group
162164
163165 @staticmethod
@@ -170,7 +172,7 @@ def _partition_fn(name, mod, device_mesh):
170172 # performing all-to-all combine on the output
171173 def _token_combine (self , mod , routed_output , device_mesh ):
172174 # For a2a combine, input splits and output splits are opposite of a2a dispatch.
173- routed_output = self .a2a_func (
175+ routed_output = self .a2a_combine_func (
174176 routed_output ,
175177 self .input_splits ,
176178 self .output_splits ,
0 commit comments