Skip to content

Commit c29fa82

Browse files
committed
Enable PP and EP overlap for MoE
1 parent a8899e4 commit c29fa82

File tree

9 files changed

+332
-20
lines changed

9 files changed

+332
-20
lines changed

run_train.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ set -ex
1111
# e.g.
1212
# LOG_RANK=0,1 NGPU=4 ./run_train.sh
1313
NGPU=${NGPU:-"8"}
14-
export LOG_RANK=${LOG_RANK:-0}
14+
export LOG_RANK=${LOG_RANK:-0,2}
1515
CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/debug_model.toml"}
1616
TRAIN_FILE=${TRAIN_FILE:-"torchtitan.train"}
1717

torchtitan/config/job_config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,14 @@ class Parallelism:
375375
The global training batch size must be evenly divisible by pipeline_parallel_microbatch_size.
376376
"""
377377

378+
pipeline_parallel_expert_parallel_overlap: bool = True
379+
"""Whether to turn on the optimization to overlap expert parallel and pipeline parallel
380+
communication. This is only effective when the pipeline parallel schedule is DualPipeV and
381+
pipeline_parallel_degree > 1 and expert_parallel_degree > 1.
382+
383+
TODO: Does not support activation_checkpoint, set mode="none"
384+
"""
385+
378386
context_parallel_degree: int = 1
379387
"""Context parallelism degree. 1 means disabled."""
380388

Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import threading
7+
from typing import Optional
8+
9+
import torch
10+
import torch.nn as nn
11+
12+
from torch.distributed.pipelining.schedules import (
13+
_Action,
14+
_PipelineContext,
15+
_PipelineScheduleRuntime,
16+
_wait_batch_p2p,
17+
)
18+
from torch.distributed.pipelining.stage import _PipelineStageBase
19+
from torch.distributed.tensor import DeviceMesh, distribute_module
20+
from torch.profiler import record_function
21+
22+
from torchtitan.distributed.expert_parallel import ExpertParallel
23+
24+
from torchtitan.tools.utils import get_device_info
25+
26+
"""
27+
Below are optimizations related to pipeline parallelism with expert parallelism
28+
"""
29+
30+
31+
class DualPipeExpertParallel(ExpertParallel):
32+
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
33+
"""
34+
The execution order is:
35+
A -> dispatch -> B -> module -> C -> combine -> D
36+
37+
Hooks are called in the order they are registered:
38+
SyncHookA, _token_dispatch, SyncHookB (pre hooks)
39+
SyncHookC, _token_combine, SyncHookD (post hooks)
40+
"""
41+
inner_wrapped_module = self._wrap_with_pre_comm_hooks(module)
42+
distributed_module = distribute_module(
43+
inner_wrapped_module,
44+
device_mesh,
45+
partition_fn=ExpertParallel._partition_fn,
46+
input_fn=self._token_dispatch,
47+
output_fn=self._token_combine,
48+
)
49+
final_module = self._wrap_with_post_comm_hooks(distributed_module)
50+
return final_module
51+
52+
def _wrap_with_pre_comm_hooks(self, module):
53+
def inner_pre_hook(module, input):
54+
return (SyncHook.apply(input[0], "A"),) + input[1:]
55+
56+
def inner_post_hook(module, input, output):
57+
return SyncHook.apply(output, "C")
58+
59+
module.register_forward_pre_hook(inner_pre_hook)
60+
module.register_forward_hook(inner_post_hook)
61+
return module
62+
63+
def _wrap_with_post_comm_hooks(self, module):
64+
def outer_pre_hook(module, input):
65+
return (SyncHook.apply(input[0], "B"),) + input[1:]
66+
67+
def outer_post_hook(module, input, output):
68+
return SyncHook.apply(output, "D")
69+
70+
module.register_forward_pre_hook(outer_pre_hook)
71+
module.register_forward_hook(outer_post_hook)
72+
return module
73+
74+
75+
class HookCoordinator:
76+
def __init__(self):
77+
# Barrier for 2 threads (forward and backward) to synchronize
78+
# This ensures that we always alternate at executing one compute and one comm op together
79+
self._execution_barrier = threading.Barrier(2)
80+
81+
self._coordination_enabled = False
82+
self._cycle_count = 0
83+
self._num_layers = None
84+
85+
def barrier(self):
86+
"""Barrier for 2 threads to synchronize"""
87+
if not self.is_coordination_enabled():
88+
return
89+
90+
try:
91+
self._execution_barrier.wait()
92+
except threading.BrokenBarrierError:
93+
pass
94+
95+
def enable_coordination(self, num_layers: Optional[int] = None):
96+
if num_layers is not None and num_layers > 0:
97+
self._coordination_enabled = True
98+
self._cycle_count = 0
99+
100+
# Reset barrier
101+
self._execution_barrier = threading.Barrier(2)
102+
self._num_layers = num_layers
103+
104+
def disable_coordination(self):
105+
self._coordination_enabled = False
106+
self._cycle_count = 0
107+
self._execution_barrier.abort() # Break barrier to unblock threads
108+
109+
def check_should_continue_coordination(self):
110+
if self._num_layers is not None and self._cycle_count >= self._num_layers:
111+
return False
112+
return True
113+
114+
def is_coordination_enabled(self):
115+
return self._coordination_enabled
116+
117+
118+
# Global coordinator
119+
_hook_coordinator = HookCoordinator()
120+
121+
122+
class SyncHook(torch.autograd.Function):
123+
@staticmethod
124+
def forward(ctx, x, hook_name=""):
125+
ctx.hook_name = hook_name
126+
# handle edge case for transformer level boundary
127+
if _hook_coordinator._coordination_enabled and hook_name == "D":
128+
_hook_coordinator._cycle_count += 1
129+
if not _hook_coordinator.check_should_continue_coordination():
130+
_hook_coordinator.disable_coordination()
131+
return x
132+
133+
_hook_coordinator.barrier()
134+
return x
135+
136+
@staticmethod
137+
def backward(ctx, grad_output):
138+
hook_name = ctx.hook_name
139+
140+
# Edge case, skip initial barrier, all subsequent backward hooks will acquire
141+
if hook_name == "D" and _hook_coordinator._cycle_count == 0:
142+
return grad_output, None
143+
144+
_hook_coordinator.barrier()
145+
return grad_output, None
146+
147+
148+
def _count_moe_modules(model):
149+
"""Count MoE modules directly"""
150+
from torchtitan.models.moe import MoE
151+
152+
moe_count = 0
153+
for _, module in model.named_modules():
154+
if isinstance(module, MoE):
155+
moe_count += 1
156+
return moe_count
157+
158+
159+
# import fbvscode
160+
# fbvscode.attach_debugger()
161+
162+
device_type, device_module = get_device_info()
163+
164+
165+
def overlap_callback(action: _Action, ctx: _PipelineContext):
166+
"""
167+
Custom callback for OVERLAP_F_B computation that allows expert parallel communication
168+
and pipeline parallel computation to overlap.
169+
"""
170+
schedule = ctx.schedule_ref
171+
assert isinstance(schedule, _PipelineScheduleRuntime)
172+
stage_index_to_stage: dict[int, _PipelineStageBase] = {
173+
stage.stage_index: stage for stage in schedule._stages
174+
}
175+
assert action.sub_actions is not None
176+
fwd_action = action.sub_actions[0]
177+
bwd_action = action.sub_actions[1]
178+
179+
# Get stages
180+
forward_stage_index = fwd_action.stage_index
181+
forward_mb_index = fwd_action.microbatch_index
182+
assert forward_mb_index is not None
183+
backward_stage_index = bwd_action.stage_index
184+
backward_stage = stage_index_to_stage[backward_stage_index]
185+
186+
# Forward setup
187+
arg_mbs = ctx.arg_mbs
188+
kwarg_mbs = ctx.kwarg_mbs
189+
assert arg_mbs is not None and kwarg_mbs is not None
190+
fwd_recv_ops = schedule.fwd_recv_ops
191+
forward_stage = stage_index_to_stage[forward_stage_index]
192+
forward_is_next_stage_on_this_rank = forward_stage_index + 1 in stage_index_to_stage
193+
forward_is_prev_stage_on_this_rank = forward_stage_index - 1 in stage_index_to_stage
194+
195+
# Backward setup
196+
backward_is_next_stage_on_this_rank = (
197+
backward_stage.stage_index + 1 in stage_index_to_stage
198+
)
199+
backward_is_prev_stage_on_this_rank = (
200+
backward_stage.stage_index - 1 in stage_index_to_stage
201+
)
202+
backward_mb_index = bwd_action.microbatch_index
203+
assert backward_mb_index is not None
204+
bwd_recv_ops = schedule.bwd_recv_ops
205+
206+
# Fwd receives
207+
if (
208+
not forward_stage.is_first
209+
# no recv op expected for V-schedule special case
210+
and not forward_is_prev_stage_on_this_rank
211+
):
212+
assert (
213+
forward_stage_index,
214+
forward_mb_index,
215+
) in fwd_recv_ops, f"Computing {action=} before receiving input"
216+
_wait_batch_p2p(fwd_recv_ops.pop((forward_stage_index, forward_mb_index)))
217+
218+
# Bwd receives
219+
if (
220+
not backward_stage.is_last
221+
# no recv op expected for V-schedule special case
222+
and not backward_is_next_stage_on_this_rank
223+
):
224+
assert (
225+
backward_stage_index,
226+
backward_mb_index,
227+
) in bwd_recv_ops, f"Attempted to run compute {action=} before receiving input"
228+
_wait_batch_p2p(bwd_recv_ops.pop((backward_stage_index, backward_mb_index)))
229+
230+
# We count num layers in case the stage layers differ
231+
# If they differ than we only want coordination to happen for the min amount of layers
232+
min_num_layers = min(
233+
_count_moe_modules(forward_stage.submod),
234+
_count_moe_modules(backward_stage.submod),
235+
)
236+
# PP computation ========================================================
237+
_hook_coordinator.enable_coordination(num_layers=min_num_layers)
238+
main_stream = torch.accelerator.current_stream(device_module)
239+
240+
# Shared container for exception from backward thread
241+
def run_backward():
242+
schedule._assert_unsharded(backward_stage)
243+
# Set the backward thread to use the same stream as forward
244+
device_module.set_stream(main_stream)
245+
with record_function(
246+
f"backward_stage_{backward_stage_index}_mb_{backward_mb_index}"
247+
):
248+
loss = schedule._maybe_get_loss(backward_stage, backward_mb_index)
249+
schedule.backward_counter[backward_stage_index] += 1
250+
last_backward = (
251+
schedule.backward_counter[backward_stage_index]
252+
== schedule._n_microbatches
253+
)
254+
backward_stage.backward_one_chunk(
255+
backward_mb_index,
256+
loss=loss,
257+
full_backward=True,
258+
last_backward=last_backward,
259+
)
260+
261+
if backward_is_prev_stage_on_this_rank:
262+
stage_index_to_stage[backward_stage_index - 1].set_local_bwd_input(
263+
backward_stage.get_local_bwd_output(backward_mb_index),
264+
backward_mb_index,
265+
)
266+
267+
def run_forward():
268+
schedule._assert_unsharded(forward_stage)
269+
output = forward_stage.forward_one_chunk(
270+
forward_mb_index,
271+
arg_mbs[forward_mb_index],
272+
kwarg_mbs[forward_mb_index],
273+
)
274+
schedule._maybe_compute_loss(
275+
forward_stage, output, ctx.target_mbs, forward_mb_index
276+
)
277+
if forward_is_next_stage_on_this_rank:
278+
stage_index_to_stage[forward_stage_index + 1].set_local_fwd_input(
279+
output, forward_mb_index
280+
)
281+
282+
# Run forward and backward in parallel
283+
thread = threading.Thread(target=run_backward, daemon=True)
284+
thread.start()
285+
run_forward()
286+
thread.join()
287+
288+
_hook_coordinator.disable_coordination()

torchtitan/distributed/pipeline_parallel.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
_PipelineSchedule,
1919
_PipelineScheduleRuntime,
2020
get_schedule_class,
21+
OVERLAP_F_B,
2122
PipelineScheduleMulti,
2223
PipelineScheduleSingle,
2324
ScheduleDualPipeV,
@@ -27,6 +28,7 @@
2728
from torchtitan.components.loss import LossFunction, rescale_accumulated_loss
2829
from torchtitan.config import JobConfig
2930
from torchtitan.distributed import ParallelDims
31+
from torchtitan.distributed.dual_pipe_v import overlap_callback
3032
from torchtitan.protocols.train_spec import BaseModelArgs, ParallelizeFunction
3133
from torchtitan.tools.logging import logger
3234

@@ -209,6 +211,11 @@ def build_pipeline_schedule(
209211
f"with {n_microbatches} microbatches and {num_total_stages} stages."
210212
)
211213

214+
if job_config.parallelism.pipeline_parallel_expert_parallel_overlap and isinstance(
215+
schedule, ScheduleDualPipeV
216+
):
217+
schedule.register_custom_function(OVERLAP_F_B, overlap_callback)
218+
212219
if pp_schedule_csv:
213220
assert schedule_class in [
214221
PipelineScheduleSingle,

torchtitan/models/deepseek_v3/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,8 @@
9797
qk_rope_head_dim=64,
9898
v_head_dim=128,
9999
mscale=0.70,
100-
use_flex_attn=True,
101-
attn_mask_type="block_causal",
100+
use_flex_attn=False,
101+
# attn_mask_type="block_causal",
102102
),
103103
"236B": DeepSeekV3ModelArgs(
104104
vocab_size=102400,

torchtitan/models/deepseek_v3/infra/parallelize.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,9 @@ def parallelize_deepseekv3(
101101
else None
102102
),
103103
etp_enabled=parallel_dims.etp_enabled,
104+
dual_pipe_v=job_config.parallelism.pipeline_parallel_expert_parallel_overlap
105+
and job_config.parallelism.pipeline_parallel_schedule.lower()
106+
== "dualpipev",
104107
)
105108

106109
model_compile_enabled = (

0 commit comments

Comments
 (0)