Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,11 @@ class Parallelism:
The global training batch size must be evenly divisible by pipeline_parallel_microbatch_size.
"""

pipeline_parallel_expert_parallel_overlap: bool = True
"""Whether to turn on the optimization to overlap expert parallel and pipeline parallel
communication. This is only effective when the pipeline paralel schedule is DualPipeV and
pipeline_parallel_degree > 1 and expert_parallel_degree > 1."""

context_parallel_degree: int = 1
"""Context parallelism degree. 1 means disabled."""

Expand Down
35 changes: 33 additions & 2 deletions torchtitan/distributed/expert_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)
from torch.distributed.tensor.parallel import ParallelStyle

from torchtitan.distributed.pipeline_parallel import SyncHook
from torchtitan.tools.utils import _round_up


Expand Down Expand Up @@ -157,13 +158,43 @@ def _token_combine(self, mod, routed_output, device_mesh):
return routed_output

def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
return distribute_module(
module,
"""
Hooks are called in the order they are registered:
SyncHookA, _token_dispatch, SyncHookB (pre hooks)
SyncHookC, _token_combine, SyncHookD (post hooks)
Comment on lines +162 to +164
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks beautiful!!

I think to decouple the complexity, we can have another dedicated class DualPipeExpertParallel inheriting this class and only override this _apply function. We can put it also in dual_pipe_v.py.

WDYT?

"""
inner_wrapped_module = self._wrap_with_inner_hooks(module)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm it seems this clearly has something to do with the order of hooks -- from you comments it looks like hooks that are inserted earlier are also executed earlier, like a queue.

If this is the case, what are "inner vs. outer" referring to?

distributed_module = distribute_module(
inner_wrapped_module,
device_mesh,
partition_fn=ExpertParallel._partition_fn,
input_fn=self._token_dispatch,
output_fn=self._token_combine,
)
final_module = self._wrap_with_outer_hooks(distributed_module)
return final_module

def _wrap_with_inner_hooks(self, module):
def inner_pre_hook(module, input):
return (SyncHook.apply(input[0], "A"),) + input[1:]

def inner_post_hook(module, input, output):
return SyncHook.apply(output, "C")

module.register_forward_pre_hook(inner_pre_hook)
module.register_forward_hook(inner_post_hook)
return module

def _wrap_with_outer_hooks(self, module):
def outer_pre_hook(module, input):
return (SyncHook.apply(input[0], "B"),) + input[1:]

def outer_post_hook(module, input, output):
return SyncHook.apply(output, "D")

module.register_forward_pre_hook(outer_pre_hook)
module.register_forward_hook(outer_post_hook)
return module


# This class is for dp2ep with TP (without TP we can just use ExpertParallel)
Expand Down
230 changes: 229 additions & 1 deletion torchtitan/distributed/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,29 @@
# LICENSE file in the root directory of this source tree.
import copy
import os
from typing import Callable
import threading
from typing import Callable, Optional

import torch
import torch.nn as nn
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.pipelining import PipelineStage

from torch.distributed.pipelining.schedules import (
_Action,
_PipelineContext,
_PipelineSchedule,
_PipelineScheduleRuntime,
_wait_batch_p2p,
get_schedule_class,
OVERLAP_F_B,
PipelineScheduleMulti,
PipelineScheduleSingle,
ScheduleDualPipeV,
ScheduleZBVZeroBubble,
)
from torch.distributed.pipelining.stage import _PipelineStageBase
from torch.profiler import record_function

from torchtitan.components.loss import rescale_accumulated_loss
from torchtitan.config import JobConfig
Expand Down Expand Up @@ -91,6 +98,12 @@ def build_pipeline_schedule(
f"with {n_microbatches} microbatches and {num_total_stages} stages."
)

if (
job_config.parallelism.pipeline_parallel_expert_parallel_overlap
and isinstance(schedule, ScheduleDualPipeV)
):
schedule.register_custom_function(OVERLAP_F_B, overlap_callback)

if pp_schedule_csv:
assert schedule_class in [
PipelineScheduleSingle,
Expand Down Expand Up @@ -357,3 +370,218 @@ def _build_stage_from_modules(
models.append(model_chunk)

return stages, models


# TODO: is there a better place to put this?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about putting them into distributed/dual_pipe_v.py?

# Below are optimizations related to pipeline parallelism with expert parallelism


class HookCoordinator:
def __init__(self):
# Barrier for 2 threads (forward and backward) to synchronize
# This ensures that we always alternate at executing one compute and one comm op together
self._execution_barrier = threading.Barrier(2)

self._coordination_enabled = False
self._cycle_count = 0
self._num_layers = None

def barrier(self):
"""Barrier for 2 threads to synchronize"""
if not self.is_coordination_enabled():
return

try:
self._execution_barrier.wait()
except threading.BrokenBarrierError:
pass

def enable_coordination(self, num_layers: Optional[int] = None):
if num_layers is not None and num_layers > 0:
self._coordination_enabled = True
self._cycle_count = 0

# Reset barrier
self._execution_barrier = threading.Barrier(2)
self._num_layers = num_layers

def disable_coordination(self):
self._coordination_enabled = False
self._cycle_count = 0
self._execution_barrier.abort() # Break barrier to unblock threads

def check_should_continue_coordination(self):
if self._num_layers is not None and self._cycle_count >= self._num_layers:
return False
return True

def is_coordination_enabled(self):
return self._coordination_enabled


# Global coordinator
_hook_coordinator = HookCoordinator()


class SyncHook(torch.autograd.Function):
@staticmethod
def forward(ctx, x, hook_name=""):
ctx.hook_name = hook_name
# handle edge case for transformer level boundary
if _hook_coordinator._coordination_enabled and hook_name == "D":
_hook_coordinator._cycle_count += 1
# print(f"[FORWARD] cycle count: {_hook_coordinator._cycle_count}", "=" * 40)
if not _hook_coordinator.check_should_continue_coordination():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check is only called in SyncHook.forward. Is it safe if for a particular overlap_f_b call, the backward stage has more layers than the forward stage?

_hook_coordinator.disable_coordination()
return x

_hook_coordinator.barrier()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strictly speaking, the barrier only has effect on the CPU threads, and it only forces the compute and a2a to be dispatched to GPU at the same time. But looking from the GPU perspective, it doesn't guarantee the execution of compute kernels and a2a are actually overlapped.

It may work in cases where there happen to have GPU-CPU syncs in the right places in the MoE layer (e.g. token index H2D copy etc). But I suspect it would fail to overlap as we remove those syncs (the community is working toward more efficient no-sync MoE implementations).

Theoretically we should use cuda event wait between compute/comm streams, not thread wait.

return x

@staticmethod
def backward(ctx, grad_output):
hook_name = ctx.hook_name

# Edge case, skip initial barrier, all subsequent backward hooks will acquire
if hook_name == "D" and _hook_coordinator._cycle_count == 0:
return grad_output, None

_hook_coordinator.barrier()
return grad_output, None


def _count_moe_modules(model):
"""Count MoE modules directly"""
from torchtitan.models.moe import MoE

moe_count = 0
for _, module in model.named_modules():
if isinstance(module, MoE):
moe_count += 1
return moe_count


def overlap_callback(action: _Action, ctx: _PipelineContext):
"""
Custom callback for OVERLAP_F_B computation that allows expert parallel communication
and pipeline parallel computation to overlap.
"""
print("calling into overlap callback")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove or change to logger calls

schedule = ctx.schedule_ref
assert isinstance(schedule, _PipelineScheduleRuntime)
stage_index_to_stage: dict[int, _PipelineStageBase] = {
stage.stage_index: stage for stage in schedule._stages
}
assert action.sub_actions is not None
fwd_action = action.sub_actions[0]
bwd_action = action.sub_actions[1]

# Get stages
forward_stage_index = fwd_action.stage_index
forward_mb_index = fwd_action.microbatch_index
assert forward_mb_index is not None
backward_stage_index = bwd_action.stage_index
backward_stage = stage_index_to_stage[backward_stage_index]

# Forward setup
arg_mbs = ctx.arg_mbs
kwarg_mbs = ctx.kwarg_mbs
assert arg_mbs is not None and kwarg_mbs is not None
fwd_recv_ops = schedule.fwd_recv_ops
forward_stage = stage_index_to_stage[forward_stage_index]
forward_is_next_stage_on_this_rank = forward_stage_index + 1 in stage_index_to_stage
forward_is_prev_stage_on_this_rank = forward_stage_index - 1 in stage_index_to_stage

# Backward setup
backward_is_next_stage_on_this_rank = (
backward_stage.stage_index + 1 in stage_index_to_stage
)
backward_is_prev_stage_on_this_rank = (
backward_stage.stage_index - 1 in stage_index_to_stage
)
backward_mb_index = bwd_action.microbatch_index
assert backward_mb_index is not None
bwd_recv_ops = schedule.bwd_recv_ops

# Fwd receives
if (
not forward_stage.is_first
# no recv op expected for V-schedule special case (see [Note: V-schedule special case])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is "[Note: V-schedule special case]"?

and not forward_is_prev_stage_on_this_rank
):
assert (
forward_stage_index,
forward_mb_index,
) in fwd_recv_ops, f"Computing {action=} before receiving input"
_wait_batch_p2p(fwd_recv_ops.pop((forward_stage_index, forward_mb_index)))

# Bwd receives
if (
not backward_stage.is_last
# no recv op expected for V-schedule special case (see [Note: V-schedule special case])
and not backward_is_next_stage_on_this_rank
):
assert (
backward_stage_index,
backward_mb_index,
) in bwd_recv_ops, f"Attempted to run compute {action=} before receiving input"
_wait_batch_p2p(bwd_recv_ops.pop((backward_stage_index, backward_mb_index)))

# We count num layers in case the stage layers differ
# If they differ than we only want coordination to happen for the min amount of layers
min_num_layers = min(
_count_moe_modules(forward_stage.submod),
_count_moe_modules(backward_stage.submod),
)
# PP computation ========================================================
_hook_coordinator.enable_coordination(num_layers=min_num_layers)
main_cuda_stream = torch.cuda.current_stream()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we change this to device-neutral calls? We have https://github.com/pytorch/torchtitan/blob/main/torchtitan/tools/utils.py#L33


def run_backward():
# Set the backward thread to use the same stream as forward
torch.cuda.set_stream(main_cuda_stream)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar -- can we change it to neutral calls

with record_function(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

always enabling this may hurt perf?

f"backward_stage_{backward_stage_index}_mb_{backward_mb_index}"
):
loss = schedule._maybe_get_loss(backward_stage, backward_mb_index)
schedule.backward_counter[backward_stage_index] += 1
last_backward = (
schedule.backward_counter[backward_stage_index]
== schedule._n_microbatches
)
backward_stage.backward_one_chunk(
backward_mb_index,
loss=loss,
full_backward=True,
last_backward=last_backward,
)
grad_scale_factor = schedule._n_microbatches if schedule.scale_grads else 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may not work well with gradient accumulation. See what we did in #1732

if last_backward:
backward_stage.scale_grads(grad_scale_factor)

if backward_is_prev_stage_on_this_rank:
stage_index_to_stage[backward_stage_index - 1].set_local_bwd_input(
backward_stage.get_local_bwd_output(backward_mb_index),
backward_mb_index,
)

def run_forward():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my education:
The run_forward() and run_backward() functions look general and not tied to DualPipe. Do we not have such functions in pytorch pipelining code?

output = forward_stage.forward_one_chunk(
forward_mb_index,
arg_mbs[forward_mb_index],
kwarg_mbs[forward_mb_index],
)
schedule._maybe_compute_loss(
forward_stage, output, ctx.target_mbs, forward_mb_index
)
if forward_is_next_stage_on_this_rank:
stage_index_to_stage[forward_stage_index + 1].set_local_fwd_input(
output, forward_mb_index
)

# Run forward and backward in parallel
thread = threading.Thread(target=run_backward, daemon=True)
thread.start()
run_forward()
thread.join()
_hook_coordinator.disable_coordination()
4 changes: 2 additions & 2 deletions torchtitan/models/deepseek_v3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@
qk_rope_head_dim=64,
v_head_dim=128,
mscale=0.70,
use_flex_attn=True,
attn_mask_type="block_causal",
use_flex_attn=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is FlexAttention not supported? It sounds unrelated.

# attn_mask_type="block_causal",
),
"236B": DeepSeekV3ModelArgs(
vocab_size=102400,
Expand Down
21 changes: 11 additions & 10 deletions torchtitan/models/deepseek_v3/train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ description = "DeepSeek-V3 debug training"
print_args = false

[profiling]
enable_profiling = false
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 10
profile_freq = 1
enable_memory_snapshot = false
save_memory_snapshot_folder = "memory_snapshot"

Expand All @@ -30,28 +30,29 @@ lr = 8e-4
eps = 1e-8

[lr_scheduler]
warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps
warmup_steps = 0 # lr scheduler warm up, normally 20% of the train steps
decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps
decay_type = "linear"
min_lr_factor = 0.0

[training]
local_batch_size = 8
seq_len = 2048
local_batch_size = 4
seq_len = 4
max_norm = 1.0 # grad norm clipping
steps = 10
steps = 6
dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)
# dataset = "c4"

[parallelism]
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
fsdp_reshard_after_forward = "default" # default / never / always
tensor_parallel_degree = 1
enable_async_tensor_parallel = false
pipeline_parallel_degree = 1
pipeline_parallel_schedule = "1F1B"
pipeline_parallel_degree = 2
expert_parallel_degree = 2
context_parallel_degree = 1
expert_parallel_degree = 1
pipeline_parallel_schedule = "DualPipeV"
expert_tensor_parallel_degree = 1

[checkpoint]
Expand All @@ -63,7 +64,7 @@ export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = "selective" # ["none", "selective", "full"]
mode = "none" # ["none", "selective", "full"]
selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy

[compile]
Expand Down
Loading
Loading