2222)
2323from torch .distributed .tensor .parallel import ParallelStyle
2424
25+ from torchtitan .tools .logging import logger
2526from torchtitan .tools .utils import _round_up
2627
2728
@@ -87,20 +88,60 @@ class ExpertParallel(ParallelStyle):
8788
8889 Args:
8990 a2a_impl (str): The implementation of all-to-all. Default is "default". Options are ["default","mxfp8"].
90- max_tokens_per_ep_rank (int): The maximum number of tokens per expert rank. Only used for "mxfp8".
9191 """
9292
93- def __init__ (self , a2a_impl : str = "default" , max_tokens_per_ep_rank : int = - 1 ):
93+ def __init__ (self , a2a_impl : str = "default" ):
9494 super ().__init__ ()
9595 self .input_splits = None
9696 self .output_splits = None
97- self .a2a_impl = a2a_impl
98- self .max_tokens_per_ep_rank = max_tokens_per_ep_rank
97+ self .a2a_func = self ._get_a2a_func (a2a_impl )
98+
99+ def _get_a2a_func (self , a2a_impl : str ):
100+ if a2a_impl == "default" :
101+ logger .info ("Using default all-to-all implementation" )
102+ return all_to_all_single_autograd
103+ elif a2a_impl == "mxfp8" :
104+ logger .info ("Using mxfp8 all-to-all implementation" )
105+ from torchao .prototype .moe_training .kernels .mxfp8 .comms import (
106+ mxfp8_sync_all_to_all_v ,
107+ )
108+
109+ return mxfp8_sync_all_to_all_v
110+ else :
111+ raise ValueError (f"Unknown a2a_impl: { a2a_impl } " )
99112
100113 # performing all-to-all dispatch on the input
101114 def _token_dispatch (self , mod , inputs , device_mesh ):
102115 # annotate module input placements/sharding with input_layouts
103116 routed_input , num_tokens_per_expert = inputs
117+ ep_size = device_mesh .size (0 )
118+
119+ # generate the input splits and output splits for all-to-all
120+ with torch .no_grad ():
121+ num_tokens_per_expert_group = all_to_all_single (
122+ num_tokens_per_expert ,
123+ None ,
124+ None ,
125+ group = device_mesh .get_group (),
126+ )
127+ # Need to wait explicitly because it is used by a triton kernel later
128+ # which doesn't realize that AsyncCollectiveTensor needs unwrapping
129+ num_tokens_per_expert_group = torch .ops ._c10d_functional .wait_tensor (
130+ num_tokens_per_expert_group
131+ )
132+ input_splits = (
133+ num_tokens_per_expert .view (ep_size , - 1 )
134+ .sum (dim = 1 )
135+ .to (torch .device ("cpu" ), non_blocking = True )
136+ )
137+ # NOTE: this would incur a device-to-host sync
138+ output_splits = (
139+ num_tokens_per_expert_group .view (ep_size , - 1 )
140+ .sum (dim = 1 )
141+ .to (torch .device ("cpu" ), non_blocking = False )
142+ )
143+ self .input_splits = input_splits .tolist ()
144+ self .output_splits = output_splits .tolist ()
104145
105146 # NOTE: After this all-to-all, the routed input is put on proper EP rank.
106147 # However, the num_tokens_per_expert_group is not of the final target format
@@ -111,25 +152,12 @@ def _token_dispatch(self, mod, inputs, device_mesh):
111152 # We need to perform another shuffle to get the correct format -- this is done via the function
112153 # generate_permute_indices in moe.py, which also does padding to make sure the number of tokens
113154 # each expert gets locally is a multiple of ALIGN_SIZE_M.
114- if self .a2a_impl == "mxfp8" :
115- (
116- routed_input ,
117- self .input_splits ,
118- self .output_splits ,
119- num_tokens_per_expert_group ,
120- ) = mxfp8_a2a_dispatch (
121- routed_input ,
122- num_tokens_per_expert ,
123- device_mesh ,
124- self .max_tokens_per_ep_rank ,
125- )
126- else :
127- (
128- routed_input ,
129- self .input_splits ,
130- self .output_splits ,
131- num_tokens_per_expert_group ,
132- ) = default_a2a_dispatch (routed_input , num_tokens_per_expert , device_mesh )
155+ routed_input = self .a2a_func (
156+ routed_input ,
157+ self .output_splits ,
158+ self .input_splits ,
159+ device_mesh .get_group (),
160+ )
133161 return routed_input , num_tokens_per_expert_group
134162
135163 @staticmethod
@@ -141,25 +169,13 @@ def _partition_fn(name, mod, device_mesh):
141169
142170 # performing all-to-all combine on the output
143171 def _token_combine (self , mod , routed_output , device_mesh ):
144- if self .a2a_impl == "mxfp8" :
145- from torchao .prototype .moe_training .kernels .mxfp8 .comms import (
146- mxfp8_on_device_all_to_all_v ,
147- )
148-
149- # For a2a combine, output splits are the input splits, and input splits are the output splits.
150- routed_output , self .input_splits = mxfp8_on_device_all_to_all_v (
151- routed_output ,
152- self .output_splits ,
153- self .max_tokens_per_ep_rank ,
154- device_mesh .get_group ().group_name ,
155- )
156- else :
157- routed_output = all_to_all_single_autograd (
158- routed_output ,
159- self .input_splits ,
160- self .output_splits ,
161- device_mesh .get_group (),
162- )
172+ # For a2a combine, input splits and output splits are opposite of a2a dispatch.
173+ routed_output = self .a2a_func (
174+ routed_output ,
175+ self .input_splits ,
176+ self .output_splits ,
177+ device_mesh .get_group (),
178+ )
163179 return routed_output
164180
165181 def _apply (self , module : nn .Module , device_mesh : DeviceMesh ) -> nn .Module :
@@ -349,105 +365,3 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
349365 input_fn = self ._prepare_inputput_fn ,
350366 output_fn = self ._prepare_output_fn ,
351367 )
352-
353-
354- def default_a2a_dispatch (
355- routed_input : torch .Tensor ,
356- num_tokens_per_expert : torch .Tensor ,
357- device_mesh : DeviceMesh ,
358- ):
359- """
360- Default implementation of all-to-all dispatch. Incurs device-to-host sync.
361-
362- Returns:
363- routed_input: the local tokens after all-to-all dispatch
364- input_splits: the input splits for all-to-all dispatch
365- output_splits: the output splits for all-to-all dispatch
366- num_tokens_per_expert_group: the number of tokens per EP rank after all-to-all dispatch
367- """
368- ep_degree = device_mesh .size (0 )
369- # generate the input splits and output splits for all-to-all
370- with torch .no_grad ():
371- num_tokens_per_expert_group = all_to_all_single (
372- num_tokens_per_expert ,
373- None ,
374- None ,
375- group = device_mesh .get_group (),
376- )
377- # Need to wait explicitly because it is used by a triton kernel later
378- # which doesn't realize that AsyncCollectiveTensor needs unwrapping
379- num_tokens_per_expert_group = torch .ops ._c10d_functional .wait_tensor (
380- num_tokens_per_expert_group
381- )
382- input_splits = (
383- num_tokens_per_expert .view (ep_degree , - 1 )
384- .sum (dim = 1 )
385- .to (torch .device ("cpu" ), non_blocking = True )
386- )
387- # NOTE: this would incur a device-to-host sync
388- output_splits = (
389- num_tokens_per_expert_group .view (ep_degree , - 1 )
390- .sum (dim = 1 )
391- .to (torch .device ("cpu" ), non_blocking = False )
392- )
393- input_splits_list = input_splits .tolist ()
394- output_splits_list = output_splits .tolist ()
395-
396- # perform all-to-all
397- routed_input = all_to_all_single_autograd (
398- routed_input ,
399- output_splits_list ,
400- input_splits_list ,
401- device_mesh .get_group (),
402- )
403- return (
404- routed_input ,
405- input_splits_list ,
406- output_splits_list ,
407- num_tokens_per_expert_group ,
408- )
409-
410-
411- def mxfp8_a2a_dispatch (
412- routed_input : torch .Tensor ,
413- num_tokens_per_expert : torch .Tensor ,
414- device_mesh : DeviceMesh ,
415- max_tokens_per_ep_rank : int ,
416- ):
417- """
418- Perform on-device all-to-all dispatch with dynamically quantized mxfp8 inputs to save network bandwidth
419- and avoid device-to-host sync.
420-
421- Returns:
422- routed_input: the local tokens after all-to-all dispatch
423- input_splits: the input splits for all-to-all dispatch
424- output_splits: the output splits for all-to-all dispatch
425- num_tokens_per_expert_group: the number of tokens per EP rank after all-to-all dispatch
426- """
427- from torchao .prototype .moe_training .kernels .mxfp8 .comms import (
428- mxfp8_on_device_all_to_all_v ,
429- )
430-
431- ep_degree = device_mesh .size (0 )
432- input_splits_per_ep_rank = num_tokens_per_expert .view (ep_degree , - 1 ).sum (dim = 1 )
433- num_tokens_per_expert_group = all_to_all_single (
434- num_tokens_per_expert ,
435- None ,
436- None ,
437- group = device_mesh .get_group (),
438- )
439- num_tokens_per_expert_group = torch .ops ._c10d_functional .wait_tensor (
440- num_tokens_per_expert_group
441- )
442- routed_input , output_splits_per_ep_rank = mxfp8_on_device_all_to_all_v (
443- routed_input ,
444- input_splits_per_ep_rank ,
445- max_tokens_per_ep_rank ,
446- device_mesh .get_group ().group_name ,
447- )
448- return (
449- routed_input ,
450- input_splits_per_ep_rank ,
451- output_splits_per_ep_rank ,
452- num_tokens_per_expert_group ,
453- )
0 commit comments