Skip to content

Commit 6e4ef27

Browse files
committed
contain changes into pipeline_parallel.py
1 parent 5810c54 commit 6e4ef27

File tree

5 files changed

+268
-285
lines changed

5 files changed

+268
-285
lines changed

torchtitan/config/job_config.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,11 @@ class Parallelism:
365365
The global training batch size must be evenly divisible by pipeline_parallel_microbatch_size.
366366
"""
367367

368+
pipeline_parallel_expert_parallel_overlap: bool = True
369+
"""Whether to turn on the optimization to overlap expert parallel and pipeline parallel
370+
communication. This is only effective when the pipeline paralel schedule is DualPipeV and
371+
pipeline_parallel_degree > 1 and expert_parallel_degree > 1."""
372+
368373
context_parallel_degree: int = 1
369374
"""Context parallelism degree. 1 means disabled."""
370375

@@ -693,7 +698,7 @@ class Comm:
693698
init_timeout_seconds: int = 300
694699
"""Timeout for communication operations, during initialization and first train step."""
695700

696-
train_timeout_seconds: int = 30
701+
train_timeout_seconds: int = 100
697702
"""
698703
Timeout for communication operations after the first train step --
699704
usually a tighter bound than during initialization.

torchtitan/distributed/expert_parallel.py

Lines changed: 4 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77

8-
import threading
9-
from typing import Callable, Literal, Optional
8+
from typing import Callable, Literal
109

1110
import torch
1211
import torch.nn as nn
@@ -22,80 +21,9 @@
2221
Shard,
2322
)
2423
from torch.distributed.tensor.parallel import ParallelStyle
25-
from torchtitan.tools.utils import _round_up
26-
27-
class HookCoordinator:
28-
def __init__(self):
29-
# Barrier for 2 threads (forward and backward) to synchronize
30-
# This ensures that we always alternate at executing one compute and one comm op together
31-
self._execution_barrier = threading.Barrier(2)
32-
33-
self._coordination_enabled = False
34-
self._cycle_count = 0
35-
self._num_layers = None
36-
37-
def barrier(self):
38-
"""Barrier for 2 threads to synchronize"""
39-
if not self.is_coordination_enabled():
40-
return
41-
42-
try:
43-
self._execution_barrier.wait()
44-
except threading.BrokenBarrierError:
45-
pass
46-
47-
def enable_coordination(self, num_layers: Optional[int] = None):
48-
if num_layers is not None and num_layers > 0:
49-
self._coordination_enabled = True
50-
self._cycle_count = 0
51-
52-
# Reset barrier
53-
self._execution_barrier = threading.Barrier(2)
54-
self._num_layers = num_layers
55-
56-
def disable_coordination(self):
57-
self._coordination_enabled = False
58-
self._cycle_count = 0
59-
self._execution_barrier.abort() # Break barrier to unblock threads
60-
61-
def check_should_continue_coordination(self):
62-
if self._num_layers is not None and self._cycle_count >= self._num_layers:
63-
return False
64-
return True
65-
66-
def is_coordination_enabled(self):
67-
return self._coordination_enabled
6824

69-
70-
# Global coordinator
71-
_hook_coordinator = HookCoordinator()
72-
73-
74-
class SyncHook(torch.autograd.Function):
75-
@staticmethod
76-
def forward(ctx, x, hook_name=""):
77-
ctx.hook_name = hook_name
78-
# handle edge case for transformer level boundary
79-
if _hook_coordinator._coordination_enabled and hook_name == "D":
80-
_hook_coordinator._cycle_count += 1
81-
# print(f"[FORWARD] cycle count: {_hook_coordinator._cycle_count}", "=" * 40)
82-
if not _hook_coordinator.check_should_continue_coordination():
83-
_hook_coordinator.disable_coordination()
84-
return x
85-
86-
_hook_coordinator.barrier()
87-
return x
88-
89-
@staticmethod
90-
def backward(ctx, grad_output):
91-
hook_name = ctx.hook_name
92-
93-
# Edge case, skip initial barrier, all subsequent backward hooks will acquire
94-
if hook_name == "D" and _hook_coordinator._cycle_count == 0:
95-
return grad_output, None
96-
97-
_hook_coordinator.barrier()
98-
return grad_output, None
25+
from torchtitan.distributed.pipeline_parallel import SyncHook
26+
from torchtitan.tools.utils import _round_up
9927

10028

10129
TOKEN_GROUP_ALIGN_SIZE_M = 8
@@ -164,6 +92,7 @@ def _token_dispatch(self, mod, inputs, device_mesh):
16492
# annotate module input placements/sharding with input_layouts
16593
routed_input, num_tokens_per_expert = inputs
16694
ep_size = device_mesh.shape[0]
95+
16796
# generate the input splits and output splits for all-to-all
16897
with torch.no_grad():
16998
num_tokens_per_expert_group = all_to_all_single(

0 commit comments

Comments
 (0)