Skip to content

Commit 7cf98e4

Browse files
committed
Enable PP and EP overlap for MoE
1 parent 7c10480 commit 7cf98e4

File tree

7 files changed

+293
-23
lines changed

7 files changed

+293
-23
lines changed

run_train.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ set -ex
1111
# e.g.
1212
# LOG_RANK=0,1 NGPU=4 ./run_train.sh
1313
NGPU=${NGPU:-"8"}
14-
export LOG_RANK=${LOG_RANK:-0}
14+
export LOG_RANK=${LOG_RANK:-0,2}
1515
CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/debug_model.toml"}
1616
TRAIN_FILE=${TRAIN_FILE:-"torchtitan.train"}
1717

torchtitan/config/job_config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,11 @@ class Parallelism:
375375
The global training batch size must be evenly divisible by pipeline_parallel_microbatch_size.
376376
"""
377377

378+
pipeline_parallel_expert_parallel_overlap: bool = True
379+
"""Whether to turn on the optimization to overlap expert parallel and pipeline parallel
380+
communication. This is only effective when the pipeline paralel schedule is DualPipeV and
381+
pipeline_parallel_degree > 1 and expert_parallel_degree > 1."""
382+
378383
context_parallel_degree: int = 1
379384
"""Context parallelism degree. 1 means disabled."""
380385

torchtitan/distributed/expert_parallel.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
)
2020
from torch.distributed.tensor.parallel import ParallelStyle
2121

22+
from torchtitan.distributed.pipeline_parallel import SyncHook
2223
from torchtitan.models.moe.utils import _permute, _unpermute
2324

2425

@@ -145,13 +146,43 @@ def _token_combine(self, mod, routed_output, device_mesh):
145146
return routed_output
146147

147148
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
148-
return distribute_module(
149-
module,
149+
"""
150+
Hooks are called in the order they are registered:
151+
SyncHookA, _token_dispatch, SyncHookB (pre hooks)
152+
SyncHookC, _token_combine, SyncHookD (post hooks)
153+
"""
154+
inner_wrapped_module = self._wrap_with_inner_hooks(module)
155+
distributed_module = distribute_module(
156+
inner_wrapped_module,
150157
device_mesh,
151158
partition_fn=ExpertParallel._partition_fn,
152159
input_fn=self._token_dispatch,
153160
output_fn=self._token_combine,
154161
)
162+
final_module = self._wrap_with_outer_hooks(distributed_module)
163+
return final_module
164+
165+
def _wrap_with_inner_hooks(self, module):
166+
def inner_pre_hook(module, input):
167+
return (SyncHook.apply(input[0], "A"),) + input[1:]
168+
169+
def inner_post_hook(module, input, output):
170+
return SyncHook.apply(output, "C")
171+
172+
module.register_forward_pre_hook(inner_pre_hook)
173+
module.register_forward_hook(inner_post_hook)
174+
return module
175+
176+
def _wrap_with_outer_hooks(self, module):
177+
def outer_pre_hook(module, input):
178+
return (SyncHook.apply(input[0], "B"),) + input[1:]
179+
180+
def outer_post_hook(module, input, output):
181+
return SyncHook.apply(output, "D")
182+
183+
module.register_forward_pre_hook(outer_pre_hook)
184+
module.register_forward_hook(outer_post_hook)
185+
return module
155186

156187

157188
# This class is for dp2ep with TP (without TP we can just use ExpertParallel)

torchtitan/distributed/pipeline_parallel.py

Lines changed: 234 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,29 @@
77

88
import math
99
import os
10-
from typing import Callable
10+
import threading
11+
from typing import Callable, Optional
1112

1213
import torch
1314
import torch.nn as nn
1415
from torch.distributed.device_mesh import DeviceMesh
1516
from torch.distributed.pipelining import PipelineStage
1617

1718
from torch.distributed.pipelining.schedules import (
19+
_Action,
20+
_PipelineContext,
1821
_PipelineSchedule,
1922
_PipelineScheduleRuntime,
23+
_wait_batch_p2p,
2024
get_schedule_class,
25+
OVERLAP_F_B,
2126
PipelineScheduleMulti,
2227
PipelineScheduleSingle,
2328
ScheduleDualPipeV,
2429
ScheduleZBVZeroBubble,
2530
)
31+
from torch.distributed.pipelining.stage import _PipelineStageBase
32+
from torch.profiler import record_function
2633

2734
from torchtitan.components.loss import LossFunction, rescale_accumulated_loss
2835
from torchtitan.config import JobConfig
@@ -37,7 +44,8 @@
3744
"pipeline_module_split",
3845
]
3946

40-
47+
import fbvscode
48+
fbvscode.attach_debugger()
4149
def pipeline_llm(
4250
model: nn.Module,
4351
parallel_dims: ParallelDims,
@@ -209,6 +217,11 @@ def build_pipeline_schedule(
209217
f"with {n_microbatches} microbatches and {num_total_stages} stages."
210218
)
211219

220+
if job_config.parallelism.pipeline_parallel_expert_parallel_overlap and isinstance(
221+
schedule, ScheduleDualPipeV
222+
):
223+
schedule.register_custom_function(OVERLAP_F_B, overlap_callback)
224+
212225
if pp_schedule_csv:
213226
assert schedule_class in [
214227
PipelineScheduleSingle,
@@ -473,3 +486,222 @@ def _get_stage_indices() -> tuple[int]:
473486
models.append(model_chunk)
474487

475488
return stages, models
489+
490+
491+
# TODO: is there a better place to put this?
492+
# Below are optimizations related to pipeline parallelism with expert parallelism
493+
494+
495+
class HookCoordinator:
496+
def __init__(self):
497+
# Barrier for 2 threads (forward and backward) to synchronize
498+
# This ensures that we always alternate at executing one compute and one comm op together
499+
self._execution_barrier = threading.Barrier(2)
500+
501+
self._coordination_enabled = False
502+
self._cycle_count = 0
503+
self._num_layers = None
504+
505+
def barrier(self):
506+
"""Barrier for 2 threads to synchronize"""
507+
if not self.is_coordination_enabled():
508+
return
509+
510+
try:
511+
self._execution_barrier.wait()
512+
except threading.BrokenBarrierError:
513+
pass
514+
515+
def enable_coordination(self, num_layers: Optional[int] = None):
516+
if num_layers is not None and num_layers > 0:
517+
self._coordination_enabled = True
518+
self._cycle_count = 0
519+
520+
# Reset barrier
521+
self._execution_barrier = threading.Barrier(2)
522+
self._num_layers = num_layers
523+
524+
def disable_coordination(self):
525+
self._coordination_enabled = False
526+
self._cycle_count = 0
527+
self._execution_barrier.abort() # Break barrier to unblock threads
528+
529+
def check_should_continue_coordination(self):
530+
if self._num_layers is not None and self._cycle_count >= self._num_layers:
531+
return False
532+
return True
533+
534+
def is_coordination_enabled(self):
535+
return self._coordination_enabled
536+
537+
538+
# Global coordinator
539+
_hook_coordinator = HookCoordinator()
540+
541+
542+
class SyncHook(torch.autograd.Function):
543+
@staticmethod
544+
def forward(ctx, x, hook_name=""):
545+
ctx.hook_name = hook_name
546+
# handle edge case for transformer level boundary
547+
if _hook_coordinator._coordination_enabled and hook_name == "D":
548+
_hook_coordinator._cycle_count += 1
549+
if not _hook_coordinator.check_should_continue_coordination():
550+
_hook_coordinator.disable_coordination()
551+
return x
552+
553+
# print(f"forward {hook_name=} calling barrier")
554+
_hook_coordinator.barrier()
555+
return x
556+
557+
@staticmethod
558+
def backward(ctx, grad_output):
559+
hook_name = ctx.hook_name
560+
561+
# Edge case, skip initial barrier, all subsequent backward hooks will acquire
562+
if hook_name == "D" and _hook_coordinator._cycle_count == 0:
563+
return grad_output, None
564+
565+
# print(f"backward {hook_name=} calling barrier")
566+
_hook_coordinator.barrier()
567+
return grad_output, None
568+
569+
570+
def _count_moe_modules(model):
571+
"""Count MoE modules directly"""
572+
from torchtitan.models.moe import MoE
573+
574+
moe_count = 0
575+
for _, module in model.named_modules():
576+
if isinstance(module, MoE):
577+
moe_count += 1
578+
return moe_count
579+
580+
581+
def overlap_callback(action: _Action, ctx: _PipelineContext):
582+
"""
583+
Custom callback for OVERLAP_F_B computation that allows expert parallel communication
584+
and pipeline parallel computation to overlap.
585+
"""
586+
schedule = ctx.schedule_ref
587+
assert isinstance(schedule, _PipelineScheduleRuntime)
588+
stage_index_to_stage: dict[int, _PipelineStageBase] = {
589+
stage.stage_index: stage for stage in schedule._stages
590+
}
591+
assert action.sub_actions is not None
592+
fwd_action = action.sub_actions[0]
593+
bwd_action = action.sub_actions[1]
594+
595+
# Get stages
596+
forward_stage_index = fwd_action.stage_index
597+
forward_mb_index = fwd_action.microbatch_index
598+
assert forward_mb_index is not None
599+
backward_stage_index = bwd_action.stage_index
600+
backward_stage = stage_index_to_stage[backward_stage_index]
601+
602+
# Forward setup
603+
arg_mbs = ctx.arg_mbs
604+
kwarg_mbs = ctx.kwarg_mbs
605+
assert arg_mbs is not None and kwarg_mbs is not None
606+
fwd_recv_ops = schedule.fwd_recv_ops
607+
forward_stage = stage_index_to_stage[forward_stage_index]
608+
forward_is_next_stage_on_this_rank = forward_stage_index + 1 in stage_index_to_stage
609+
forward_is_prev_stage_on_this_rank = forward_stage_index - 1 in stage_index_to_stage
610+
611+
# Backward setup
612+
backward_is_next_stage_on_this_rank = (
613+
backward_stage.stage_index + 1 in stage_index_to_stage
614+
)
615+
backward_is_prev_stage_on_this_rank = (
616+
backward_stage.stage_index - 1 in stage_index_to_stage
617+
)
618+
backward_mb_index = bwd_action.microbatch_index
619+
assert backward_mb_index is not None
620+
bwd_recv_ops = schedule.bwd_recv_ops
621+
622+
# Fwd receives
623+
if (
624+
not forward_stage.is_first
625+
# no recv op expected for V-schedule special case (see [Note: V-schedule special case])
626+
and not forward_is_prev_stage_on_this_rank
627+
):
628+
assert (
629+
forward_stage_index,
630+
forward_mb_index,
631+
) in fwd_recv_ops, f"Computing {action=} before receiving input"
632+
_wait_batch_p2p(fwd_recv_ops.pop((forward_stage_index, forward_mb_index)))
633+
634+
# Bwd receives
635+
if (
636+
not backward_stage.is_last
637+
# no recv op expected for V-schedule special case (see [Note: V-schedule special case])
638+
and not backward_is_next_stage_on_this_rank
639+
):
640+
assert (
641+
backward_stage_index,
642+
backward_mb_index,
643+
) in bwd_recv_ops, f"Attempted to run compute {action=} before receiving input"
644+
_wait_batch_p2p(bwd_recv_ops.pop((backward_stage_index, backward_mb_index)))
645+
646+
# We count num layers in case the stage layers differ
647+
# If they differ than we only want coordination to happen for the min amount of layers
648+
min_num_layers = min(
649+
_count_moe_modules(forward_stage.submod),
650+
_count_moe_modules(backward_stage.submod),
651+
)
652+
# PP computation ========================================================
653+
_hook_coordinator.enable_coordination(num_layers=min_num_layers)
654+
main_cuda_stream = torch.cuda.current_stream()
655+
656+
# Shared container for exception from backward thread
657+
def run_backward():
658+
schedule._assert_unsharded(backward_stage)
659+
# Set the backward thread to use the same stream as forward
660+
torch.cuda.set_stream(main_cuda_stream)
661+
with record_function(
662+
f"backward_stage_{backward_stage_index}_mb_{backward_mb_index}"
663+
):
664+
loss = schedule._maybe_get_loss(backward_stage, backward_mb_index)
665+
schedule.backward_counter[backward_stage_index] += 1
666+
last_backward = (
667+
schedule.backward_counter[backward_stage_index]
668+
== schedule._n_microbatches
669+
)
670+
backward_stage.backward_one_chunk(
671+
backward_mb_index,
672+
loss=loss,
673+
full_backward=True,
674+
last_backward=last_backward,
675+
)
676+
grad_scale_factor = schedule._n_microbatches if schedule.scale_grads else 1
677+
if last_backward:
678+
backward_stage.scale_grads(grad_scale_factor)
679+
680+
if backward_is_prev_stage_on_this_rank:
681+
stage_index_to_stage[backward_stage_index - 1].set_local_bwd_input(
682+
backward_stage.get_local_bwd_output(backward_mb_index),
683+
backward_mb_index,
684+
)
685+
686+
def run_forward():
687+
schedule._assert_unsharded(forward_stage)
688+
output = forward_stage.forward_one_chunk(
689+
forward_mb_index,
690+
arg_mbs[forward_mb_index],
691+
kwarg_mbs[forward_mb_index],
692+
)
693+
schedule._maybe_compute_loss(
694+
forward_stage, output, ctx.target_mbs, forward_mb_index
695+
)
696+
if forward_is_next_stage_on_this_rank:
697+
stage_index_to_stage[forward_stage_index + 1].set_local_fwd_input(
698+
output, forward_mb_index
699+
)
700+
701+
# Run forward and backward in parallel
702+
thread = threading.Thread(target=run_backward, daemon=True)
703+
thread.start()
704+
run_forward()
705+
thread.join()
706+
707+
_hook_coordinator.disable_coordination()

torchtitan/models/deepseek_v3/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,8 @@
9797
qk_rope_head_dim=64,
9898
v_head_dim=128,
9999
mscale=0.70,
100-
use_flex_attn=True,
101-
attn_mask_type="block_causal",
100+
use_flex_attn=False,
101+
# attn_mask_type="block_causal",
102102
),
103103
"236B": DeepSeekV3ModelArgs(
104104
vocab_size=102400,

0 commit comments

Comments
 (0)