5
5
# LICENSE file in the root directory of this source tree.
6
6
7
7
8
- from typing import Callable , Literal
8
+ from typing import Callable , Literal , Dict
9
9
10
10
import torch
11
11
import torch .nn as nn
22
22
)
23
23
from torch .distributed .tensor .parallel import ParallelStyle
24
24
25
+ import threading
26
+ import torch
27
+ from typing import Optional
28
+ import time
29
+
30
+ class SimplifiedHookCoordinator :
31
+ """
32
+ TODO: this hangs because FWD is doing dispatch, and BWD is doing combine
33
+ the two communications are conflicting??
34
+ """
35
+
36
+ """Alternating forward/backward coordination with just 2 semaphores"""
37
+
38
+ def __init__ (self ):
39
+ self ._lock = threading .Lock ()
40
+
41
+ # Only need 2 semaphores - one for forward, one for backward
42
+ self ._forward_semaphore = threading .Semaphore (0 ) # Forward waits
43
+ self ._backward_semaphore = threading .Semaphore (1 ) # Backward starts first
44
+
45
+ # CUDA event tracking
46
+ self ._forward_cuda_event = torch .cuda .Event ()
47
+ self ._backward_cuda_event = torch .cuda .Event ()
48
+ self ._forward_event_recorded = False
49
+ self ._backward_event_recorded = False
50
+
51
+ # Store AsyncCollectiveTensors from previous operations
52
+ self ._stored_forward_async_tensor = None
53
+ self ._stored_backward_async_tensor = None
54
+
55
+ self ._coordination_enabled = False
56
+ self ._cycle_count = 0
57
+
58
+ def is_coordination_enabled (self ) -> bool :
59
+ """Check if coordination is currently enabled"""
60
+ return self ._coordination_enabled
61
+
62
+ def enable_coordination (self , num_layers : Optional [int ] = None ):
63
+ self ._coordination_enabled = True
64
+ self ._cycle_count = 0
65
+
66
+ # Reset semaphores
67
+ self ._forward_semaphore = threading .Semaphore (0 )
68
+ self ._backward_semaphore = threading .Semaphore (1 )
69
+
70
+ # Reset CUDA events
71
+ self ._forward_cuda_event = torch .cuda .Event ()
72
+ self ._backward_cuda_event = torch .cuda .Event ()
73
+ self ._forward_event_recorded = False
74
+ self ._backward_event_recorded = False
75
+
76
+ # num layers
77
+ self ._num_layers = num_layers
78
+
79
+ print ("[COORDINATION] Simplified hook coordination with CUDA events ENABLED" )
80
+
81
+ def disable_coordination (self ):
82
+ self ._coordination_enabled = False
83
+ # Release both semaphores to unblock any waiting threads
84
+ try :
85
+ self ._forward_semaphore .release ()
86
+ self ._backward_semaphore .release ()
87
+ except ValueError :
88
+ pass
89
+ print ("[COORDINATION] Simplified hook coordination DISABLED" )
90
+
91
+ def acquire_forward_execution (self ):
92
+ if not self ._coordination_enabled :
93
+ return
94
+
95
+ print ("[FORWARD] Attempting acquire forward execution" )
96
+ self ._forward_semaphore .acquire ()
97
+
98
+ # 2. Wait for PREVIOUS FORWARD CUDA operations to complete
99
+ if self ._forward_event_recorded :
100
+ print ("[FORWARD] Waiting for previous FORWARD CUDA operations to complete..." )
101
+ self ._forward_cuda_event .wait ()
102
+ print ("[FORWARD] Previous FORWARD CUDA operations completed" )
103
+
104
+ # Wait for forward's own previously stored AsyncCollectiveTensor
105
+ if self ._stored_forward_async_tensor is not None :
106
+ from torch .distributed ._functional_collectives import AsyncCollectiveTensor
107
+ if isinstance (self ._stored_forward_async_tensor , AsyncCollectiveTensor ):
108
+ print ("[FORWARD] Waiting for forward's own previous AsyncCollectiveTensor..." )
109
+ torch .ops ._c10d_functional .wait_tensor (self ._stored_forward_async_tensor )
110
+ print ("[FORWARD] Forward's previous AsyncCollectiveTensor completed" )
111
+ self ._stored_forward_async_tensor = None # Clear after waiting
112
+
113
+ print ("[FORWARD] Acquired forward execution" )
114
+
115
+ def release_forward_execution (self , async_tensor : Optional [torch .Tensor ] = None ):
116
+ if not self ._coordination_enabled :
117
+ return
118
+
119
+ # 1. Record CUDA event for current forward operations
120
+ current_stream = torch .cuda .current_stream ()
121
+ self ._forward_cuda_event = torch .cuda .Event () # Create new event
122
+ self ._forward_cuda_event .record (current_stream )
123
+ self ._forward_event_recorded = True
124
+ print ("[FORWARD] Recorded forward CUDA completion event" )
125
+
126
+ # Store forward's AsyncCollectiveTensor for forward's own future use
127
+ self ._stored_forward_async_tensor = async_tensor
128
+ if async_tensor is not None :
129
+ from torch .distributed ._functional_collectives import AsyncCollectiveTensor
130
+ if isinstance (async_tensor , AsyncCollectiveTensor ):
131
+ print ("[FORWARD] Stored forward AsyncCollectiveTensor for forward's future use" )
132
+
133
+ print ("[FORWARD] Releasing forward, signaling backward" )
134
+ self ._backward_semaphore .release () # Signal backward can start
135
+
136
+ self ._cycle_count += 1
137
+ print (f"cycle count { self ._cycle_count } " )
138
+ # TODO: better way to determine when to disable coordination
139
+ moe_multipler = 4
140
+ if self ._num_layers is not None and self ._cycle_count >= moe_multipler * self ._num_layers :
141
+ print ("[COORDINATION] Reached target number of cycles, disabling coordination" )
142
+ self .disable_coordination ()
143
+ return # Exit early since coordination is now disabled
144
+
145
+ def acquire_backward_execution (self ):
146
+ if not self ._coordination_enabled :
147
+ return
148
+
149
+ print ("[BACKWARD] Attempting acquire backward execution" )
150
+ self ._backward_semaphore .acquire ()
151
+
152
+ # # 2. Wait for PREVIOUS BACKWARD CUDA operations to complete
153
+ if self ._backward_event_recorded :
154
+ print ("[BACKWARD] Waiting for previous BACKWARD CUDA operations to complete..." )
155
+ self ._backward_cuda_event .wait ()
156
+ print ("[BACKWARD] Previous BACKWARD CUDA operations completed" )
157
+
158
+ # Wait for backward's own previously stored AsyncCollectiveTensor
159
+ if self ._stored_backward_async_tensor is not None :
160
+ from torch .distributed ._functional_collectives import AsyncCollectiveTensor
161
+ if isinstance (self ._stored_backward_async_tensor , AsyncCollectiveTensor ):
162
+ print ("[BACKWARD] Waiting for backward's own previous AsyncCollectiveTensor..." )
163
+ torch .ops ._c10d_functional .wait_tensor (self ._stored_backward_async_tensor )
164
+ print ("[BACKWARD] Backward's previous AsyncCollectiveTensor completed" )
165
+ self ._stored_backward_async_tensor = None # Clear after waiting
166
+
167
+ print ("[BACKWARD] Acquired backward execution" )
168
+
169
+ def release_backward_execution (self , async_tensor : Optional [torch .Tensor ] = None ):
170
+ if not self ._coordination_enabled :
171
+ return
172
+
173
+ # 1. Record CUDA event for current backward operations
174
+ current_stream = torch .cuda .current_stream ()
175
+ self ._backward_cuda_event = torch .cuda .Event () # Create new event
176
+ self ._backward_cuda_event .record (current_stream )
177
+ self ._backward_event_recorded = True
178
+ print ("[BACKWARD] Recorded backward CUDA completion event" )
179
+
180
+ # Store backward's AsyncCollectiveTensor for backward's own future use
181
+ self ._stored_backward_async_tensor = async_tensor
182
+ if async_tensor is not None :
183
+ from torch .distributed ._functional_collectives import AsyncCollectiveTensor
184
+ if isinstance (async_tensor , AsyncCollectiveTensor ):
185
+ print ("[BACKWARD] Stored backward AsyncCollectiveTensor for backward's future use" )
186
+
187
+ print ("[BACKWARD] Releasing backward, signaling next forward" )
188
+ self ._forward_semaphore .release () # Signal next forward can start
189
+ # self._cycle_count += 1
190
+ # print(f"[CYCLE] Completed cycle {self._cycle_count}")
191
+
192
+ # Global coordinator
193
+ _hook_coordinator = SimplifiedHookCoordinator ()
194
+
195
+ class SyncHook (torch .autograd .Function ):
196
+ @staticmethod
197
+ def forward (ctx , x , hook_name ):
198
+ ctx .hook_name = hook_name
199
+
200
+ _hook_coordinator .acquire_forward_execution ()
201
+
202
+ try :
203
+ if _hook_coordinator .is_coordination_enabled ():
204
+ print (f"[FORWARD] { hook_name } _fwd" )
205
+ return x
206
+ finally :
207
+ _hook_coordinator .release_forward_execution (x )
208
+
209
+ @staticmethod
210
+ def backward (ctx , grad_output ):
211
+ hook_name = ctx .hook_name
212
+
213
+ _hook_coordinator .acquire_backward_execution ()
214
+
215
+ try :
216
+ if _hook_coordinator .is_coordination_enabled ():
217
+ print (f"[BACKWARD] { hook_name } _bwd" )
218
+ return grad_output , None
219
+ finally :
220
+ _hook_coordinator .release_backward_execution (grad_output )
25
221
26
222
TOKEN_GROUP_ALIGN_SIZE_M = 8
27
223
ValidTokenGroupAlignmentSize = Literal [8 , 16 , 32 ]
@@ -109,11 +305,15 @@ def _token_dispatch(self, mod, inputs, device_mesh):
109
305
.to (torch .device ("cpu" ), non_blocking = True )
110
306
)
111
307
# NOTE: this would incur a device-to-host sync
308
+ # CPU-GPU sync here!!!
309
+ # start_time = time.time()
112
310
output_splits = (
113
311
num_tokens_per_expert_group .view (ep_size , - 1 )
114
312
.sum (dim = 1 )
115
313
.to (torch .device ("cpu" ), non_blocking = False )
116
314
)
315
+ # sync_time = time.time() - start_time
316
+ # print(f"CPU-GPU sync took {sync_time:.4f}s")
117
317
self .input_splits = input_splits .tolist ()
118
318
self .output_splits = output_splits .tolist ()
119
319
@@ -125,6 +325,11 @@ def _token_dispatch(self, mod, inputs, device_mesh):
125
325
device_mesh .get_group (),
126
326
)
127
327
328
+ # TODO: FIX NEEDING THIS???
329
+ routed_input = torch .ops ._c10d_functional .wait_tensor (
330
+ routed_input
331
+ )
332
+
128
333
# NOTE: After this all-to-all, the routed input is put on proper EP rank.
129
334
# However, the num_tokens_per_expert_group is not of the final target format
130
335
# [#tokens for local expert 0, #tokens for local expert 1, ...]
@@ -152,16 +357,48 @@ def _token_combine(self, mod, routed_output, device_mesh):
152
357
self .output_splits ,
153
358
device_mesh .get_group (),
154
359
)
360
+ # TODO: FIX NEEDING THIS???
361
+ # CRITICAL: Wait for AsyncCollectiveTensor BEFORE coordination
362
+ from torch .distributed ._functional_collectives import AsyncCollectiveTensor
363
+ if isinstance (routed_output , AsyncCollectiveTensor ):
364
+ routed_output = torch .ops ._c10d_functional .wait_tensor (routed_output )
365
+
155
366
return routed_output
156
367
157
368
def _apply (self , module : nn .Module , device_mesh : DeviceMesh ) -> nn .Module :
158
- return distribute_module (
159
- module ,
369
+ """
370
+ hooks are called in the order they are registered:
371
+ A, dispatch, B (pre hooks)
372
+ C, combine, D (post hooks)
373
+ """
374
+ inner_wrapped_module = self ._wrap_with_inner_hooks (module )
375
+ distributed_module = distribute_module (
376
+ inner_wrapped_module ,
160
377
device_mesh ,
161
378
partition_fn = ExpertParallel ._partition_fn ,
162
379
input_fn = self ._token_dispatch ,
163
380
output_fn = self ._token_combine ,
164
381
)
382
+ final_module = self ._wrap_with_outer_hooks (distributed_module )
383
+ return final_module
384
+
385
+ def _wrap_with_inner_hooks (self , module ):
386
+ def inner_pre_hook (module , input ):
387
+ return (SyncHook .apply (input [0 ], "dispatch_A" ),) + input [1 :]
388
+ def inner_post_hook (module , input , output ):
389
+ return SyncHook .apply (output , "combine_C" )
390
+ module .register_forward_pre_hook (inner_pre_hook )
391
+ module .register_forward_hook (inner_post_hook )
392
+ return module
393
+
394
+ def _wrap_with_outer_hooks (self , module ):
395
+ def outer_pre_hook (module , input ):
396
+ return (SyncHook .apply (input [0 ], "dispatch_B" ),) + input [1 :]
397
+ def outer_post_hook (module , input , output ):
398
+ return SyncHook .apply (output , "combine_D" )
399
+ module .register_forward_pre_hook (outer_pre_hook )
400
+ module .register_forward_hook (outer_post_hook )
401
+ return module
165
402
166
403
167
404
# This class is for dp2ep with TP (without TP we can just use ExpertParallel)
0 commit comments