-
Notifications
You must be signed in to change notification settings - Fork 559
Enable PP and EP overlap for MoE #1721
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
cce81df
21dcff4
5810c54
6e4ef27
4f8e621
a23ab5b
9e43a67
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,6 +22,7 @@ | |
) | ||
from torch.distributed.tensor.parallel import ParallelStyle | ||
|
||
from torchtitan.distributed.pipeline_parallel import SyncHook | ||
from torchtitan.tools.utils import _round_up | ||
|
||
|
||
|
@@ -157,13 +158,43 @@ def _token_combine(self, mod, routed_output, device_mesh): | |
return routed_output | ||
|
||
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: | ||
return distribute_module( | ||
module, | ||
""" | ||
Hooks are called in the order they are registered: | ||
SyncHookA, _token_dispatch, SyncHookB (pre hooks) | ||
SyncHookC, _token_combine, SyncHookD (post hooks) | ||
""" | ||
inner_wrapped_module = self._wrap_with_inner_hooks(module) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm it seems this clearly has something to do with the order of hooks -- from you comments it looks like hooks that are inserted earlier are also executed earlier, like a queue. If this is the case, what are "inner vs. outer" referring to? |
||
distributed_module = distribute_module( | ||
inner_wrapped_module, | ||
device_mesh, | ||
partition_fn=ExpertParallel._partition_fn, | ||
input_fn=self._token_dispatch, | ||
output_fn=self._token_combine, | ||
) | ||
final_module = self._wrap_with_outer_hooks(distributed_module) | ||
return final_module | ||
|
||
def _wrap_with_inner_hooks(self, module): | ||
def inner_pre_hook(module, input): | ||
return (SyncHook.apply(input[0], "A"),) + input[1:] | ||
|
||
def inner_post_hook(module, input, output): | ||
return SyncHook.apply(output, "C") | ||
|
||
module.register_forward_pre_hook(inner_pre_hook) | ||
module.register_forward_hook(inner_post_hook) | ||
return module | ||
|
||
def _wrap_with_outer_hooks(self, module): | ||
def outer_pre_hook(module, input): | ||
return (SyncHook.apply(input[0], "B"),) + input[1:] | ||
|
||
def outer_post_hook(module, input, output): | ||
return SyncHook.apply(output, "D") | ||
|
||
module.register_forward_pre_hook(outer_pre_hook) | ||
module.register_forward_hook(outer_post_hook) | ||
return module | ||
|
||
|
||
# This class is for dp2ep with TP (without TP we can just use ExpertParallel) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,22 +5,29 @@ | |
# LICENSE file in the root directory of this source tree. | ||
import copy | ||
import os | ||
from typing import Callable | ||
import threading | ||
from typing import Callable, Optional | ||
|
||
import torch | ||
import torch.nn as nn | ||
from torch.distributed.device_mesh import DeviceMesh | ||
from torch.distributed.pipelining import PipelineStage | ||
|
||
from torch.distributed.pipelining.schedules import ( | ||
_Action, | ||
_PipelineContext, | ||
_PipelineSchedule, | ||
_PipelineScheduleRuntime, | ||
_wait_batch_p2p, | ||
get_schedule_class, | ||
OVERLAP_F_B, | ||
PipelineScheduleMulti, | ||
PipelineScheduleSingle, | ||
ScheduleDualPipeV, | ||
ScheduleZBVZeroBubble, | ||
) | ||
from torch.distributed.pipelining.stage import _PipelineStageBase | ||
from torch.profiler import record_function | ||
|
||
from torchtitan.components.loss import rescale_accumulated_loss | ||
from torchtitan.config import JobConfig | ||
|
@@ -91,6 +98,12 @@ def build_pipeline_schedule( | |
f"with {n_microbatches} microbatches and {num_total_stages} stages." | ||
) | ||
|
||
if ( | ||
job_config.parallelism.pipeline_parallel_expert_parallel_overlap | ||
and isinstance(schedule, ScheduleDualPipeV) | ||
): | ||
schedule.register_custom_function(OVERLAP_F_B, overlap_callback) | ||
|
||
if pp_schedule_csv: | ||
assert schedule_class in [ | ||
PipelineScheduleSingle, | ||
|
@@ -357,3 +370,218 @@ def _build_stage_from_modules( | |
models.append(model_chunk) | ||
|
||
return stages, models | ||
|
||
|
||
# TODO: is there a better place to put this? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about putting them into |
||
# Below are optimizations related to pipeline parallelism with expert parallelism | ||
|
||
|
||
class HookCoordinator: | ||
def __init__(self): | ||
# Barrier for 2 threads (forward and backward) to synchronize | ||
# This ensures that we always alternate at executing one compute and one comm op together | ||
self._execution_barrier = threading.Barrier(2) | ||
|
||
self._coordination_enabled = False | ||
self._cycle_count = 0 | ||
self._num_layers = None | ||
|
||
def barrier(self): | ||
"""Barrier for 2 threads to synchronize""" | ||
if not self.is_coordination_enabled(): | ||
return | ||
|
||
try: | ||
self._execution_barrier.wait() | ||
except threading.BrokenBarrierError: | ||
pass | ||
|
||
def enable_coordination(self, num_layers: Optional[int] = None): | ||
if num_layers is not None and num_layers > 0: | ||
self._coordination_enabled = True | ||
self._cycle_count = 0 | ||
|
||
# Reset barrier | ||
self._execution_barrier = threading.Barrier(2) | ||
self._num_layers = num_layers | ||
|
||
def disable_coordination(self): | ||
self._coordination_enabled = False | ||
self._cycle_count = 0 | ||
self._execution_barrier.abort() # Break barrier to unblock threads | ||
|
||
def check_should_continue_coordination(self): | ||
if self._num_layers is not None and self._cycle_count >= self._num_layers: | ||
return False | ||
return True | ||
|
||
def is_coordination_enabled(self): | ||
return self._coordination_enabled | ||
|
||
|
||
# Global coordinator | ||
_hook_coordinator = HookCoordinator() | ||
|
||
|
||
class SyncHook(torch.autograd.Function): | ||
@staticmethod | ||
def forward(ctx, x, hook_name=""): | ||
ctx.hook_name = hook_name | ||
# handle edge case for transformer level boundary | ||
if _hook_coordinator._coordination_enabled and hook_name == "D": | ||
_hook_coordinator._cycle_count += 1 | ||
# print(f"[FORWARD] cycle count: {_hook_coordinator._cycle_count}", "=" * 40) | ||
H-Huang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if not _hook_coordinator.check_should_continue_coordination(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This check is only called in |
||
_hook_coordinator.disable_coordination() | ||
return x | ||
|
||
_hook_coordinator.barrier() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Strictly speaking, the barrier only has effect on the CPU threads, and it only forces the compute and a2a to be dispatched to GPU at the same time. But looking from the GPU perspective, it doesn't guarantee the execution of compute kernels and a2a are actually overlapped. It may work in cases where there happen to have GPU-CPU syncs in the right places in the MoE layer (e.g. token index H2D copy etc). But I suspect it would fail to overlap as we remove those syncs (the community is working toward more efficient no-sync MoE implementations). Theoretically we should use cuda event wait between compute/comm streams, not thread wait. |
||
return x | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
hook_name = ctx.hook_name | ||
|
||
# Edge case, skip initial barrier, all subsequent backward hooks will acquire | ||
if hook_name == "D" and _hook_coordinator._cycle_count == 0: | ||
return grad_output, None | ||
|
||
_hook_coordinator.barrier() | ||
return grad_output, None | ||
|
||
|
||
def _count_moe_modules(model): | ||
"""Count MoE modules directly""" | ||
from torchtitan.models.moe import MoE | ||
|
||
moe_count = 0 | ||
for _, module in model.named_modules(): | ||
if isinstance(module, MoE): | ||
moe_count += 1 | ||
return moe_count | ||
|
||
|
||
def overlap_callback(action: _Action, ctx: _PipelineContext): | ||
""" | ||
Custom callback for OVERLAP_F_B computation that allows expert parallel communication | ||
and pipeline parallel computation to overlap. | ||
""" | ||
print("calling into overlap callback") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please remove or change to logger calls |
||
schedule = ctx.schedule_ref | ||
assert isinstance(schedule, _PipelineScheduleRuntime) | ||
stage_index_to_stage: dict[int, _PipelineStageBase] = { | ||
stage.stage_index: stage for stage in schedule._stages | ||
} | ||
assert action.sub_actions is not None | ||
fwd_action = action.sub_actions[0] | ||
bwd_action = action.sub_actions[1] | ||
|
||
# Get stages | ||
forward_stage_index = fwd_action.stage_index | ||
forward_mb_index = fwd_action.microbatch_index | ||
assert forward_mb_index is not None | ||
backward_stage_index = bwd_action.stage_index | ||
backward_stage = stage_index_to_stage[backward_stage_index] | ||
|
||
# Forward setup | ||
arg_mbs = ctx.arg_mbs | ||
kwarg_mbs = ctx.kwarg_mbs | ||
assert arg_mbs is not None and kwarg_mbs is not None | ||
fwd_recv_ops = schedule.fwd_recv_ops | ||
forward_stage = stage_index_to_stage[forward_stage_index] | ||
forward_is_next_stage_on_this_rank = forward_stage_index + 1 in stage_index_to_stage | ||
forward_is_prev_stage_on_this_rank = forward_stage_index - 1 in stage_index_to_stage | ||
|
||
# Backward setup | ||
backward_is_next_stage_on_this_rank = ( | ||
backward_stage.stage_index + 1 in stage_index_to_stage | ||
) | ||
backward_is_prev_stage_on_this_rank = ( | ||
backward_stage.stage_index - 1 in stage_index_to_stage | ||
) | ||
backward_mb_index = bwd_action.microbatch_index | ||
assert backward_mb_index is not None | ||
bwd_recv_ops = schedule.bwd_recv_ops | ||
|
||
# Fwd receives | ||
if ( | ||
not forward_stage.is_first | ||
# no recv op expected for V-schedule special case (see [Note: V-schedule special case]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. where is "[Note: V-schedule special case]"? |
||
and not forward_is_prev_stage_on_this_rank | ||
): | ||
assert ( | ||
forward_stage_index, | ||
forward_mb_index, | ||
) in fwd_recv_ops, f"Computing {action=} before receiving input" | ||
_wait_batch_p2p(fwd_recv_ops.pop((forward_stage_index, forward_mb_index))) | ||
|
||
# Bwd receives | ||
if ( | ||
not backward_stage.is_last | ||
# no recv op expected for V-schedule special case (see [Note: V-schedule special case]) | ||
and not backward_is_next_stage_on_this_rank | ||
): | ||
assert ( | ||
backward_stage_index, | ||
backward_mb_index, | ||
) in bwd_recv_ops, f"Attempted to run compute {action=} before receiving input" | ||
_wait_batch_p2p(bwd_recv_ops.pop((backward_stage_index, backward_mb_index))) | ||
|
||
# We count num layers in case the stage layers differ | ||
# If they differ than we only want coordination to happen for the min amount of layers | ||
min_num_layers = min( | ||
_count_moe_modules(forward_stage.submod), | ||
_count_moe_modules(backward_stage.submod), | ||
) | ||
# PP computation ======================================================== | ||
_hook_coordinator.enable_coordination(num_layers=min_num_layers) | ||
main_cuda_stream = torch.cuda.current_stream() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we change this to device-neutral calls? We have https://github.com/pytorch/torchtitan/blob/main/torchtitan/tools/utils.py#L33 |
||
|
||
def run_backward(): | ||
# Set the backward thread to use the same stream as forward | ||
torch.cuda.set_stream(main_cuda_stream) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. similar -- can we change it to neutral calls |
||
with record_function( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. always enabling this may hurt perf? |
||
f"backward_stage_{backward_stage_index}_mb_{backward_mb_index}" | ||
): | ||
loss = schedule._maybe_get_loss(backward_stage, backward_mb_index) | ||
schedule.backward_counter[backward_stage_index] += 1 | ||
last_backward = ( | ||
schedule.backward_counter[backward_stage_index] | ||
== schedule._n_microbatches | ||
) | ||
backward_stage.backward_one_chunk( | ||
backward_mb_index, | ||
loss=loss, | ||
full_backward=True, | ||
last_backward=last_backward, | ||
) | ||
grad_scale_factor = schedule._n_microbatches if schedule.scale_grads else 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This may not work well with gradient accumulation. See what we did in #1732 |
||
if last_backward: | ||
backward_stage.scale_grads(grad_scale_factor) | ||
|
||
if backward_is_prev_stage_on_this_rank: | ||
stage_index_to_stage[backward_stage_index - 1].set_local_bwd_input( | ||
backward_stage.get_local_bwd_output(backward_mb_index), | ||
backward_mb_index, | ||
) | ||
|
||
def run_forward(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For my education: |
||
output = forward_stage.forward_one_chunk( | ||
forward_mb_index, | ||
arg_mbs[forward_mb_index], | ||
kwarg_mbs[forward_mb_index], | ||
) | ||
schedule._maybe_compute_loss( | ||
forward_stage, output, ctx.target_mbs, forward_mb_index | ||
) | ||
if forward_is_next_stage_on_this_rank: | ||
stage_index_to_stage[forward_stage_index + 1].set_local_fwd_input( | ||
output, forward_mb_index | ||
) | ||
|
||
# Run forward and backward in parallel | ||
thread = threading.Thread(target=run_backward, daemon=True) | ||
thread.start() | ||
run_forward() | ||
thread.join() | ||
_hook_coordinator.disable_coordination() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -99,8 +99,8 @@ | |
qk_rope_head_dim=64, | ||
v_head_dim=128, | ||
mscale=0.70, | ||
use_flex_attn=True, | ||
attn_mask_type="block_causal", | ||
use_flex_attn=False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is FlexAttention not supported? It sounds unrelated. |
||
# attn_mask_type="block_causal", | ||
), | ||
"236B": DeepSeekV3ModelArgs( | ||
vocab_size=102400, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks beautiful!!
I think to decouple the complexity, we can have another dedicated class
DualPipeExpertParallel
inheriting this class and only override this_apply
function. We can put it also indual_pipe_v.py.
WDYT?