|
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | 6 |
|
7 | 7 |
|
8 | | -import threading |
9 | | -from typing import Callable, Literal, Optional |
| 8 | +from typing import Callable, Literal |
10 | 9 |
|
11 | 10 | import torch |
12 | 11 | import torch.nn as nn |
|
22 | 21 | Shard, |
23 | 22 | ) |
24 | 23 | from torch.distributed.tensor.parallel import ParallelStyle |
25 | | -from torchtitan.tools.utils import _round_up |
26 | | - |
27 | | -class HookCoordinator: |
28 | | - def __init__(self): |
29 | | - # Barrier for 2 threads (forward and backward) to synchronize |
30 | | - # This ensures that we always alternate at executing one compute and one comm op together |
31 | | - self._execution_barrier = threading.Barrier(2) |
32 | | - |
33 | | - self._coordination_enabled = False |
34 | | - self._cycle_count = 0 |
35 | | - self._num_layers = None |
36 | | - |
37 | | - def barrier(self): |
38 | | - """Barrier for 2 threads to synchronize""" |
39 | | - if not self.is_coordination_enabled(): |
40 | | - return |
41 | | - |
42 | | - try: |
43 | | - self._execution_barrier.wait() |
44 | | - except threading.BrokenBarrierError: |
45 | | - pass |
46 | | - |
47 | | - def enable_coordination(self, num_layers: Optional[int] = None): |
48 | | - if num_layers is not None and num_layers > 0: |
49 | | - self._coordination_enabled = True |
50 | | - self._cycle_count = 0 |
51 | | - |
52 | | - # Reset barrier |
53 | | - self._execution_barrier = threading.Barrier(2) |
54 | | - self._num_layers = num_layers |
55 | | - |
56 | | - def disable_coordination(self): |
57 | | - self._coordination_enabled = False |
58 | | - self._cycle_count = 0 |
59 | | - self._execution_barrier.abort() # Break barrier to unblock threads |
60 | | - |
61 | | - def check_should_continue_coordination(self): |
62 | | - if self._num_layers is not None and self._cycle_count >= self._num_layers: |
63 | | - return False |
64 | | - return True |
65 | | - |
66 | | - def is_coordination_enabled(self): |
67 | | - return self._coordination_enabled |
68 | 24 |
|
69 | | - |
70 | | -# Global coordinator |
71 | | -_hook_coordinator = HookCoordinator() |
72 | | - |
73 | | - |
74 | | -class SyncHook(torch.autograd.Function): |
75 | | - @staticmethod |
76 | | - def forward(ctx, x, hook_name=""): |
77 | | - ctx.hook_name = hook_name |
78 | | - # handle edge case for transformer level boundary |
79 | | - if _hook_coordinator._coordination_enabled and hook_name == "D": |
80 | | - _hook_coordinator._cycle_count += 1 |
81 | | - # print(f"[FORWARD] cycle count: {_hook_coordinator._cycle_count}", "=" * 40) |
82 | | - if not _hook_coordinator.check_should_continue_coordination(): |
83 | | - _hook_coordinator.disable_coordination() |
84 | | - return x |
85 | | - |
86 | | - _hook_coordinator.barrier() |
87 | | - return x |
88 | | - |
89 | | - @staticmethod |
90 | | - def backward(ctx, grad_output): |
91 | | - hook_name = ctx.hook_name |
92 | | - |
93 | | - # Edge case, skip initial barrier, all subsequent backward hooks will acquire |
94 | | - if hook_name == "D" and _hook_coordinator._cycle_count == 0: |
95 | | - return grad_output, None |
96 | | - |
97 | | - _hook_coordinator.barrier() |
98 | | - return grad_output, None |
| 25 | +from torchtitan.distributed.pipeline_parallel import SyncHook |
| 26 | +from torchtitan.tools.utils import _round_up |
99 | 27 |
|
100 | 28 |
|
101 | 29 | TOKEN_GROUP_ALIGN_SIZE_M = 8 |
@@ -164,6 +92,7 @@ def _token_dispatch(self, mod, inputs, device_mesh): |
164 | 92 | # annotate module input placements/sharding with input_layouts |
165 | 93 | routed_input, num_tokens_per_expert = inputs |
166 | 94 | ep_size = device_mesh.shape[0] |
| 95 | + |
167 | 96 | # generate the input splits and output splits for all-to-all |
168 | 97 | with torch.no_grad(): |
169 | 98 | num_tokens_per_expert_group = all_to_all_single( |
|
0 commit comments