Skip to content

Commit a6e46c7

Browse files
committed
barrier working
1 parent 6584aac commit a6e46c7

File tree

2 files changed

+83
-153
lines changed

2 files changed

+83
-153
lines changed

torchtitan/distributed/expert_parallel.py

Lines changed: 53 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -25,101 +25,47 @@
2525

2626
class HookCoordinator:
2727
def __init__(self):
28-
# Only need 2 semaphores - one for forward, one for backward
29-
self._forward_semaphore = threading.Semaphore(0) # Forward waits
30-
self._backward_semaphore = threading.Semaphore(1) # Backward starts first
31-
32-
# Semaphore mapping
33-
self._semaphores = {
34-
'forward': self._forward_semaphore,
35-
'backward': self._backward_semaphore
36-
}
37-
38-
# Cross-signaling mapping (forward signals backward, backward signals forward)
39-
self._signal_targets = {
40-
'forward': self._backward_semaphore,
41-
'backward': self._forward_semaphore
42-
}
28+
# Barrier for 2 threads (forward and backward) to synchronize
29+
# This ensures that we always alternate at executing one compute and one comm op together
30+
self._execution_barrier = threading.Barrier(2)
4331

4432
self._coordination_enabled = False
4533
self._cycle_count = 0
34+
self._num_layers = None
4635

47-
def _acquire_execution(self, direction: str, timeout: float = 20.0):
48-
"""Generic acquire method for both forward and backward"""
49-
if not self._coordination_enabled:
50-
return
51-
52-
direction_upper = direction.upper()
53-
print(f"[{direction_upper}] Attempting acquire {direction} execution")
54-
self._semaphores[direction].acquire(timeout=timeout)
55-
print(f"[{direction_upper}] Acquired {direction} execution")
56-
57-
def _release_execution(self, direction: str, async_tensor: Optional[torch.Tensor] = None):
58-
"""Generic release method for both forward and backward"""
59-
if not self._coordination_enabled:
36+
def barrier(self):
37+
"""Barrier for 2 threads to synchronize"""
38+
if not self.is_coordination_enabled():
6039
return
6140

62-
direction_upper = direction.upper()
63-
64-
# Signal the other direction
65-
other_direction = 'backward' if direction == 'forward' else 'forward'
66-
print(f"[{direction_upper}] Releasing {direction}, signaling {other_direction}")
67-
self._signal_targets[direction].release()
68-
69-
# Forward-specific logic
70-
if direction == 'forward':
71-
self._cycle_count += 1
72-
print(f"cycle count {self._cycle_count}")
73-
self.check_should_enable_coordination()
74-
75-
# Simple wrapper methods
76-
def acquire_forward_execution(self):
77-
self._acquire_execution('forward')
78-
79-
def release_forward_execution(self, async_tensor: Optional[torch.Tensor] = None):
80-
self._release_execution('forward', async_tensor)
81-
82-
def acquire_backward_execution(self):
83-
self._acquire_execution('backward')
84-
85-
def release_backward_execution(self, async_tensor: Optional[torch.Tensor] = None):
86-
self._release_execution('backward', async_tensor)
41+
try:
42+
self._execution_barrier.wait()
43+
print(f"Both threads ready, proceeding")
44+
except threading.BrokenBarrierError:
45+
print(f"Barrier broken - one thread has finished!")
8746

8847
def enable_coordination(self, num_layers: Optional[int] = None):
89-
self._coordination_enabled = True
90-
self._cycle_count = 0
48+
if num_layers is not None and num_layers > 0:
49+
self._coordination_enabled = True
50+
self._cycle_count = 0
9151

92-
# Reset semaphores
93-
self._forward_semaphore = threading.Semaphore(0)
94-
self._backward_semaphore = threading.Semaphore(1)
95-
96-
# Update semaphore references
97-
self._semaphores['forward'] = self._forward_semaphore
98-
self._semaphores['backward'] = self._backward_semaphore
99-
self._signal_targets['forward'] = self._backward_semaphore
100-
self._signal_targets['backward'] = self._forward_semaphore
101-
102-
self._num_layers = num_layers
103-
self.check_should_enable_coordination()
104-
print(f"[COORDINATION] Simplified hook coordination ENABLED with {num_layers} MoE layers")
52+
# Reset barrier
53+
self._execution_barrier = threading.Barrier(2)
54+
55+
self._num_layers = num_layers
56+
print(f"Compute/Comm hook coordination ENABLED with {num_layers} MoE layers")
10557

10658
def disable_coordination(self):
10759
self._coordination_enabled = False
108-
# Release both semaphores to unblock any waiting threads
109-
try:
110-
self._forward_semaphore.release()
111-
self._backward_semaphore.release()
112-
except ValueError:
113-
pass
114-
print("[COORDINATION] Simplified hook coordination DISABLED")
115-
116-
def check_should_enable_coordination(self):
117-
# TODO: better way to determine when to disable coordination
118-
moe_multipler = 4
119-
if self._num_layers is not None and self._cycle_count >= moe_multipler * self._num_layers:
60+
self._cycle_count = 0
61+
self._execution_barrier.abort() # Break barrier to unblock threads
62+
print("[COORDINATION] Compute/Comm hook coordination DISABLED")
63+
64+
def check_should_continue_coordination(self):
65+
if self._num_layers is not None and self._cycle_count >= self._num_layers:
12066
print("[COORDINATION] Reached target number of cycles, disabling coordination")
121-
self.disable_coordination()
122-
return
67+
return False
68+
return True
12369

12470
def is_coordination_enabled(self):
12571
return self._coordination_enabled
@@ -129,41 +75,34 @@ def is_coordination_enabled(self):
12975

13076
class SyncHook(torch.autograd.Function):
13177
@staticmethod
132-
def forward(ctx, x, hook_name):
78+
def forward(ctx, x, hook_name=""):
13379
ctx.hook_name = hook_name
134-
_hook_coordinator.acquire_forward_execution()
135-
136-
137-
try:
138-
if _hook_coordinator.is_coordination_enabled():
139-
if hook_name == "dispatch_A":
140-
# TODO: is this right?
141-
print("Calling torch.cuda.synchronize() from dispatch_A")
142-
# This does GPU-CPU sync so we need to wait explicitly before starting
143-
torch.cuda.synchronize()
144-
print(f"[FORWARD] {hook_name}_fwd")
145-
return x
146-
finally:
147-
_hook_coordinator.release_forward_execution(x)
80+
# handle edge case for transformer level boundary
81+
if _hook_coordinator._coordination_enabled and hook_name == "D":
82+
_hook_coordinator._cycle_count += 1
83+
print(f"[FORWARD] cycle count: {_hook_coordinator._cycle_count}", "=" * 40)
84+
if not _hook_coordinator.check_should_continue_coordination():
85+
_hook_coordinator.disable_coordination()
86+
return x
87+
88+
_hook_coordinator.barrier()
89+
90+
if _hook_coordinator.is_coordination_enabled():
91+
print(f"[FORWARD] finished {hook_name}_fwd")
92+
return x
14893

14994
@staticmethod
15095
def backward(ctx, grad_output):
15196
hook_name = ctx.hook_name
152-
_hook_coordinator.acquire_backward_execution()
153-
15497

155-
try:
156-
if _hook_coordinator.is_coordination_enabled():
157-
if hook_name == "dispatch_B":
158-
# TODO: is this right?
159-
print("Calling torch.cuda.synchronize() from dispatch_B")
160-
# This does GPU-CPU sync so we need to wait explicitly before starting
161-
torch.cuda.synchronize()
162-
print(f"[BACKWARD] {hook_name}_bwd")
163-
# grad_output.record_stream(torch.cuda.current_stream())
98+
# Edge case, skip initial barrier, all subsequent backward hooks will acquire
99+
if hook_name == "D" and _hook_coordinator._cycle_count == 0:
164100
return grad_output, None
165-
finally:
166-
_hook_coordinator.release_backward_execution(grad_output)
101+
102+
_hook_coordinator.barrier()
103+
if _hook_coordinator.is_coordination_enabled():
104+
print(f"[BACKWARD] finished {hook_name}_bwd")
105+
return grad_output, None
167106

168107
TOKEN_GROUP_ALIGN_SIZE_M = 8
169108
ValidTokenGroupAlignmentSize = Literal[8, 16, 32]
@@ -231,17 +170,6 @@ def _token_dispatch(self, mod, inputs, device_mesh):
231170
# annotate module input placements/sharding with input_layouts
232171
routed_input, num_tokens_per_expert = inputs
233172
ep_size = device_mesh.shape[0]
234-
235-
# TODO: what is causing the IMAs???
236-
if not torch.isfinite(routed_input).all():
237-
raise RuntimeError(f"routed_input contains non-finite values: {routed_input}")
238-
239-
if not torch.isfinite(num_tokens_per_expert).all():
240-
raise RuntimeError(f"num_tokens_per_expert contains non-finite values: {num_tokens_per_expert}")
241-
242-
if routed_input.shape[0] > 1000000: # Reasonable limit
243-
raise RuntimeError(f"routed_input suspiciously large: {routed_input.shape}")
244-
245173
# generate the input splits and output splits for all-to-all
246174
with torch.no_grad():
247175
num_tokens_per_expert_group = all_to_all_single(
@@ -325,18 +253,18 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
325253

326254
def _wrap_with_inner_hooks(self, module):
327255
def inner_pre_hook(module, input):
328-
return (SyncHook.apply(input[0], "dispatch_A"),) + input[1:]
256+
return (SyncHook.apply(input[0], "A"),) + input[1:]
329257
def inner_post_hook(module, input, output):
330-
return SyncHook.apply(output, "combine_C")
258+
return SyncHook.apply(output, "C")
331259
module.register_forward_pre_hook(inner_pre_hook)
332260
module.register_forward_hook(inner_post_hook)
333261
return module
334262

335263
def _wrap_with_outer_hooks(self, module):
336264
def outer_pre_hook(module, input):
337-
return (SyncHook.apply(input[0], "dispatch_B"),) + input[1:]
265+
return (SyncHook.apply(input[0], "B"),) + input[1:]
338266
def outer_post_hook(module, input, output):
339-
return SyncHook.apply(output, "combine_D")
267+
return SyncHook.apply(output, "D")
340268
module.register_forward_pre_hook(outer_pre_hook)
341269
module.register_forward_hook(outer_post_hook)
342270
return module

torchtitan/train.py

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import torch
1414
from torch.distributed.elastic.multiprocessing.errors import record
15+
from torch.profiler import record_function
1516

1617
import torchtitan.protocols.train_spec as train_spec_module
1718
from torchtitan.components.checkpoint import CheckpointManager
@@ -663,7 +664,6 @@ def _count_moe_modules(model):
663664
return moe_count
664665

665666
def overlap_callback(action: _Action, ctx: _PipelineContext):
666-
print("overlap_callback begin", "=" * 80, torch.distributed.get_rank())
667667
"""Custom callback for OVERLAP_F_B computation that mimics the original implementation."""
668668
schedule = ctx.schedule_ref
669669
assert isinstance(schedule, _PipelineScheduleRuntime)
@@ -700,6 +700,7 @@ def overlap_callback(action: _Action, ctx: _PipelineContext):
700700
assert backward_mb_index is not None
701701
bwd_recv_ops = schedule.bwd_recv_ops
702702

703+
print(f"overlap_callback begin {forward_stage_index}:{forward_mb_index}, {backward_stage_index}:{backward_mb_index}", "=" * 80, torch.distributed.get_rank())
703704
# PP communication ========================================================
704705

705706
# Fwd receives
@@ -744,26 +745,27 @@ def run_backward():
744745
torch.cuda.set_stream(main_cuda_stream)
745746
print(f"BACKWARD {backward_stage_index} {torch.cuda.current_stream()}")
746747
# Backward ========================================================
747-
loss = schedule._maybe_get_loss(backward_stage, backward_mb_index)
748-
schedule.backward_counter[backward_stage_index] += 1
749-
last_backward = (
750-
schedule.backward_counter[backward_stage_index] == schedule._n_microbatches
751-
)
752-
backward_stage.backward_one_chunk(
753-
backward_mb_index,
754-
loss=loss,
755-
full_backward=True,
756-
last_backward=last_backward,
757-
)
758-
grad_scale_factor = schedule._n_microbatches if schedule.scale_grads else 1
759-
if last_backward:
760-
backward_stage.scale_grads(grad_scale_factor)
761-
762-
if backward_is_prev_stage_on_this_rank:
763-
stage_index_to_stage[backward_stage_index - 1].set_local_bwd_input(
764-
backward_stage.get_local_bwd_output(backward_mb_index),
748+
with record_function(f"backward_stage_{backward_stage_index}_mb_{backward_mb_index}"):
749+
loss = schedule._maybe_get_loss(backward_stage, backward_mb_index)
750+
schedule.backward_counter[backward_stage_index] += 1
751+
last_backward = (
752+
schedule.backward_counter[backward_stage_index] == schedule._n_microbatches
753+
)
754+
backward_stage.backward_one_chunk(
765755
backward_mb_index,
756+
loss=loss,
757+
full_backward=True,
758+
last_backward=last_backward,
766759
)
760+
grad_scale_factor = schedule._n_microbatches if schedule.scale_grads else 1
761+
if last_backward:
762+
backward_stage.scale_grads(grad_scale_factor)
763+
764+
if backward_is_prev_stage_on_this_rank:
765+
stage_index_to_stage[backward_stage_index - 1].set_local_bwd_input(
766+
backward_stage.get_local_bwd_output(backward_mb_index),
767+
backward_mb_index,
768+
)
767769

768770

769771
# Forward ========================================================
@@ -783,25 +785,25 @@ def run_forward():
783785
)
784786

785787
# Run forward and backward in parallel
786-
if _hook_coordinator.is_coordination_enabled():
787-
thread = threading.Thread(target=run_backward, daemon=True)
788-
thread.start()
789-
run_forward()
790-
thread.join()
788+
# if _hook_coordinator.is_coordination_enabled():
789+
thread = threading.Thread(target=run_backward, daemon=True)
790+
thread.start()
791+
run_forward()
792+
thread.join()
791793
# with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
792794
# forward_future = executor.submit(run_forward)
793795
# backward_future = executor.submit(run_backward)
794796

795797
# # Wait for both to complete simultaneously
796798
# done, not_done = concurrent.futures.wait([forward_future, backward_future])
797799
# output = forward_future.result()
798-
else:
799-
run_forward()
800-
run_backward()
800+
# else:
801+
# run_forward()
802+
# run_backward()
801803

802804
_hook_coordinator.disable_coordination()
803805
forward_backward_overlapped()
804-
print("overlap_callback end", "=" * 80)
806+
print(f"overlap_callback end {forward_stage_index}:{forward_mb_index}, {backward_stage_index}:{backward_mb_index}", "=" * 80)
805807

806808
import fbvscode
807809
fbvscode.attach_debugger()

0 commit comments

Comments
 (0)