Skip to content

Commit 751a472

Browse files
use sync based mxfp8 a2a
1 parent 5c38830 commit 751a472

File tree

2 files changed

+61
-160
lines changed

2 files changed

+61
-160
lines changed

torchtitan/distributed/expert_parallel.py

Lines changed: 58 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
)
2323
from torch.distributed.tensor.parallel import ParallelStyle
2424

25+
from torchtitan.tools.logging import logger
2526
from torchtitan.tools.utils import _round_up
2627

2728

@@ -87,20 +88,60 @@ class ExpertParallel(ParallelStyle):
8788
8889
Args:
8990
a2a_impl (str): The implementation of all-to-all. Default is "default". Options are ["default","mxfp8"].
90-
max_tokens_per_ep_rank (int): The maximum number of tokens per expert rank. Only used for "mxfp8".
9191
"""
9292

93-
def __init__(self, a2a_impl: str = "default", max_tokens_per_ep_rank: int = -1):
93+
def __init__(self, a2a_impl: str = "default"):
9494
super().__init__()
9595
self.input_splits = None
9696
self.output_splits = None
97-
self.a2a_impl = a2a_impl
98-
self.max_tokens_per_ep_rank = max_tokens_per_ep_rank
97+
self.a2a_func = self._get_a2a_func(a2a_impl)
98+
99+
def _get_a2a_func(self, a2a_impl: str):
100+
if a2a_impl == "default":
101+
logger.info("Using default all-to-all implementation")
102+
return all_to_all_single_autograd
103+
elif a2a_impl == "mxfp8":
104+
logger.info("Using mxfp8 all-to-all implementation")
105+
from torchao.prototype.moe_training.kernels.mxfp8.comms import (
106+
mxfp8_sync_all_to_all_v,
107+
)
108+
109+
return mxfp8_sync_all_to_all_v
110+
else:
111+
raise ValueError(f"Unknown a2a_impl: {a2a_impl}")
99112

100113
# performing all-to-all dispatch on the input
101114
def _token_dispatch(self, mod, inputs, device_mesh):
102115
# annotate module input placements/sharding with input_layouts
103116
routed_input, num_tokens_per_expert = inputs
117+
ep_size = device_mesh.size(0)
118+
119+
# generate the input splits and output splits for all-to-all
120+
with torch.no_grad():
121+
num_tokens_per_expert_group = all_to_all_single(
122+
num_tokens_per_expert,
123+
None,
124+
None,
125+
group=device_mesh.get_group(),
126+
)
127+
# Need to wait explicitly because it is used by a triton kernel later
128+
# which doesn't realize that AsyncCollectiveTensor needs unwrapping
129+
num_tokens_per_expert_group = torch.ops._c10d_functional.wait_tensor(
130+
num_tokens_per_expert_group
131+
)
132+
input_splits = (
133+
num_tokens_per_expert.view(ep_size, -1)
134+
.sum(dim=1)
135+
.to(torch.device("cpu"), non_blocking=True)
136+
)
137+
# NOTE: this would incur a device-to-host sync
138+
output_splits = (
139+
num_tokens_per_expert_group.view(ep_size, -1)
140+
.sum(dim=1)
141+
.to(torch.device("cpu"), non_blocking=False)
142+
)
143+
self.input_splits = input_splits.tolist()
144+
self.output_splits = output_splits.tolist()
104145

105146
# NOTE: After this all-to-all, the routed input is put on proper EP rank.
106147
# However, the num_tokens_per_expert_group is not of the final target format
@@ -111,25 +152,12 @@ def _token_dispatch(self, mod, inputs, device_mesh):
111152
# We need to perform another shuffle to get the correct format -- this is done via the function
112153
# generate_permute_indices in moe.py, which also does padding to make sure the number of tokens
113154
# each expert gets locally is a multiple of ALIGN_SIZE_M.
114-
if self.a2a_impl == "mxfp8":
115-
(
116-
routed_input,
117-
self.input_splits,
118-
self.output_splits,
119-
num_tokens_per_expert_group,
120-
) = mxfp8_a2a_dispatch(
121-
routed_input,
122-
num_tokens_per_expert,
123-
device_mesh,
124-
self.max_tokens_per_ep_rank,
125-
)
126-
else:
127-
(
128-
routed_input,
129-
self.input_splits,
130-
self.output_splits,
131-
num_tokens_per_expert_group,
132-
) = default_a2a_dispatch(routed_input, num_tokens_per_expert, device_mesh)
155+
routed_input = self.a2a_func(
156+
routed_input,
157+
self.output_splits,
158+
self.input_splits,
159+
device_mesh.get_group(),
160+
)
133161
return routed_input, num_tokens_per_expert_group
134162

135163
@staticmethod
@@ -141,25 +169,13 @@ def _partition_fn(name, mod, device_mesh):
141169

142170
# performing all-to-all combine on the output
143171
def _token_combine(self, mod, routed_output, device_mesh):
144-
if self.a2a_impl == "mxfp8":
145-
from torchao.prototype.moe_training.kernels.mxfp8.comms import (
146-
mxfp8_on_device_all_to_all_v,
147-
)
148-
149-
# For a2a combine, output splits are the input splits, and input splits are the output splits.
150-
routed_output, self.input_splits = mxfp8_on_device_all_to_all_v(
151-
routed_output,
152-
self.output_splits,
153-
self.max_tokens_per_ep_rank,
154-
device_mesh.get_group().group_name,
155-
)
156-
else:
157-
routed_output = all_to_all_single_autograd(
158-
routed_output,
159-
self.input_splits,
160-
self.output_splits,
161-
device_mesh.get_group(),
162-
)
172+
# For a2a combine, input splits and output splits are opposite of a2a dispatch.
173+
routed_output = self.a2a_func(
174+
routed_output,
175+
self.input_splits,
176+
self.output_splits,
177+
device_mesh.get_group(),
178+
)
163179
return routed_output
164180

165181
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
@@ -349,105 +365,3 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
349365
input_fn=self._prepare_inputput_fn,
350366
output_fn=self._prepare_output_fn,
351367
)
352-
353-
354-
def default_a2a_dispatch(
355-
routed_input: torch.Tensor,
356-
num_tokens_per_expert: torch.Tensor,
357-
device_mesh: DeviceMesh,
358-
):
359-
"""
360-
Default implementation of all-to-all dispatch. Incurs device-to-host sync.
361-
362-
Returns:
363-
routed_input: the local tokens after all-to-all dispatch
364-
input_splits: the input splits for all-to-all dispatch
365-
output_splits: the output splits for all-to-all dispatch
366-
num_tokens_per_expert_group: the number of tokens per EP rank after all-to-all dispatch
367-
"""
368-
ep_degree = device_mesh.size(0)
369-
# generate the input splits and output splits for all-to-all
370-
with torch.no_grad():
371-
num_tokens_per_expert_group = all_to_all_single(
372-
num_tokens_per_expert,
373-
None,
374-
None,
375-
group=device_mesh.get_group(),
376-
)
377-
# Need to wait explicitly because it is used by a triton kernel later
378-
# which doesn't realize that AsyncCollectiveTensor needs unwrapping
379-
num_tokens_per_expert_group = torch.ops._c10d_functional.wait_tensor(
380-
num_tokens_per_expert_group
381-
)
382-
input_splits = (
383-
num_tokens_per_expert.view(ep_degree, -1)
384-
.sum(dim=1)
385-
.to(torch.device("cpu"), non_blocking=True)
386-
)
387-
# NOTE: this would incur a device-to-host sync
388-
output_splits = (
389-
num_tokens_per_expert_group.view(ep_degree, -1)
390-
.sum(dim=1)
391-
.to(torch.device("cpu"), non_blocking=False)
392-
)
393-
input_splits_list = input_splits.tolist()
394-
output_splits_list = output_splits.tolist()
395-
396-
# perform all-to-all
397-
routed_input = all_to_all_single_autograd(
398-
routed_input,
399-
output_splits_list,
400-
input_splits_list,
401-
device_mesh.get_group(),
402-
)
403-
return (
404-
routed_input,
405-
input_splits_list,
406-
output_splits_list,
407-
num_tokens_per_expert_group,
408-
)
409-
410-
411-
def mxfp8_a2a_dispatch(
412-
routed_input: torch.Tensor,
413-
num_tokens_per_expert: torch.Tensor,
414-
device_mesh: DeviceMesh,
415-
max_tokens_per_ep_rank: int,
416-
):
417-
"""
418-
Perform on-device all-to-all dispatch with dynamically quantized mxfp8 inputs to save network bandwidth
419-
and avoid device-to-host sync.
420-
421-
Returns:
422-
routed_input: the local tokens after all-to-all dispatch
423-
input_splits: the input splits for all-to-all dispatch
424-
output_splits: the output splits for all-to-all dispatch
425-
num_tokens_per_expert_group: the number of tokens per EP rank after all-to-all dispatch
426-
"""
427-
from torchao.prototype.moe_training.kernels.mxfp8.comms import (
428-
mxfp8_on_device_all_to_all_v,
429-
)
430-
431-
ep_degree = device_mesh.size(0)
432-
input_splits_per_ep_rank = num_tokens_per_expert.view(ep_degree, -1).sum(dim=1)
433-
num_tokens_per_expert_group = all_to_all_single(
434-
num_tokens_per_expert,
435-
None,
436-
None,
437-
group=device_mesh.get_group(),
438-
)
439-
num_tokens_per_expert_group = torch.ops._c10d_functional.wait_tensor(
440-
num_tokens_per_expert_group
441-
)
442-
routed_input, output_splits_per_ep_rank = mxfp8_on_device_all_to_all_v(
443-
routed_input,
444-
input_splits_per_ep_rank,
445-
max_tokens_per_ep_rank,
446-
device_mesh.get_group().group_name,
447-
)
448-
return (
449-
routed_input,
450-
input_splits_per_ep_rank,
451-
output_splits_per_ep_rank,
452-
num_tokens_per_expert_group,
453-
)

torchtitan/experiments/llama4/infra/parallelize.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,6 @@ def parallelize_llama(
9494
)
9595
maybe_enable_async_tp(job_config, world_mesh["tp"])
9696

97-
# Worst case = single expert receives all tokens
98-
# TODO: explore using token dropping to avoid this huge overallocation
99-
max_tokens_per_ep_rank = (
100-
job_config.training.seq_len
101-
* job_config.training.local_batch_size
102-
* model.model_args.moe_args.num_experts
103-
)
10497
if parallel_dims.tp_enabled or parallel_dims.ep_enabled:
10598
apply_moe_ep_tp(
10699
model,
@@ -115,7 +108,6 @@ def parallelize_llama(
115108
),
116109
etp_enabled=parallel_dims.etp_enabled,
117110
a2a_impl=job_config.parallelism.expert_parallel_a2a_impl,
118-
max_tokens_per_ep_rank=max_tokens_per_ep_rank,
119111
)
120112

121113
model_compile_enabled = (
@@ -447,8 +439,7 @@ def apply_moe_ep_tp(
447439
ep_mesh: DeviceMesh | None,
448440
ep_tp_mesh: DeviceMesh | None,
449441
etp_enabled: bool,
450-
a2a_impl: str,
451-
max_tokens_per_ep_rank: int = -1, # Only used for mxfp8 a2a
442+
a2a_impl: str = "default",
452443
):
453444
for transformer_block in model.layers.values():
454445
if not transformer_block.moe_enabled:
@@ -498,17 +489,13 @@ def apply_moe_ep_tp(
498489
elif tp_mesh is None:
499490
experts_mesh = ep_mesh
500491
# input / output sharding on the batch / tokens dim
501-
experts_plan = ExpertParallel(
502-
a2a_impl=a2a_impl, max_tokens_per_ep_rank=max_tokens_per_ep_rank
503-
)
492+
experts_plan = ExpertParallel(a2a_impl=a2a_impl)
504493
elif etp_enabled:
505494
experts_mesh = ep_tp_mesh
506495
experts_plan = ExpertTensorParallel(tp_mesh=tp_mesh, ep_mesh=ep_mesh)
507496
else:
508497
experts_mesh = ep_mesh
509-
experts_plan = ExpertParallel(
510-
a2a_impl=a2a_impl, max_tokens_per_ep_rank=max_tokens_per_ep_rank
511-
)
498+
experts_plan = ExpertParallel(a2a_impl=a2a_impl)
512499

513500
parallelize_module(
514501
module=transformer_block.moe.experts,

0 commit comments

Comments
 (0)