2525
2626class  HookCoordinator :
2727    def  __init__ (self ):
28-         # Only need 2 semaphores - one for forward, one for backward 
29-         self ._forward_semaphore  =  threading .Semaphore (0 )   # Forward waits 
30-         self ._backward_semaphore  =  threading .Semaphore (1 )  # Backward starts first 
31-         
32-         # Semaphore mapping 
33-         self ._semaphores  =  {
34-             'forward' : self ._forward_semaphore ,
35-             'backward' : self ._backward_semaphore 
36-         }
37-         
38-         # Cross-signaling mapping (forward signals backward, backward signals forward) 
39-         self ._signal_targets  =  {
40-             'forward' : self ._backward_semaphore ,
41-             'backward' : self ._forward_semaphore 
42-         }
28+         # Barrier for 2 threads (forward and backward) to synchronize 
29+         # This ensures that we always alternate at executing one compute and one comm op together 
30+         self ._execution_barrier  =  threading .Barrier (2 )
4331
4432        self ._coordination_enabled  =  False 
4533        self ._cycle_count  =  0 
34+         self ._num_layers  =  None 
4635
47-     def  _acquire_execution (self , direction : str , timeout : float  =  20.0 ):
48-         """Generic acquire method for both forward and backward""" 
49-         if  not  self ._coordination_enabled :
50-             return 
51-         
52-         direction_upper  =  direction .upper ()
53-         print (f"[{ direction_upper } { direction }  )
54-         self ._semaphores [direction ].acquire (timeout = timeout )
55-         print (f"[{ direction_upper } { direction }  )
56- 
57-     def  _release_execution (self , direction : str , async_tensor : Optional [torch .Tensor ] =  None ):
58-         """Generic release method for both forward and backward""" 
59-         if  not  self ._coordination_enabled :
36+     def  barrier (self ):
37+         """Barrier for 2 threads to synchronize""" 
38+         if  not  self .is_coordination_enabled ():
6039            return 
6140
62-         direction_upper  =  direction .upper ()
63-         
64-         # Signal the other direction 
65-         other_direction  =  'backward'  if  direction  ==  'forward'  else  'forward' 
66-         print (f"[{ direction_upper } { direction } { other_direction }  )
67-         self ._signal_targets [direction ].release ()
68- 
69-         # Forward-specific logic 
70-         if  direction  ==  'forward' :
71-             self ._cycle_count  +=  1 
72-             print (f"cycle count { self ._cycle_count }  )
73-             self .check_should_enable_coordination ()
74- 
75-     # Simple wrapper methods 
76-     def  acquire_forward_execution (self ):
77-         self ._acquire_execution ('forward' )
78-         
79-     def  release_forward_execution (self , async_tensor : Optional [torch .Tensor ] =  None ):
80-         self ._release_execution ('forward' , async_tensor )
81-         
82-     def  acquire_backward_execution (self ):
83-         self ._acquire_execution ('backward' )
84-         
85-     def  release_backward_execution (self , async_tensor : Optional [torch .Tensor ] =  None ):
86-         self ._release_execution ('backward' , async_tensor )
41+         try :
42+             self ._execution_barrier .wait ()
43+             print (f"Both threads ready, proceeding" )
44+         except  threading .BrokenBarrierError :
45+             print (f"Barrier broken - one thread has finished!" )
8746
8847    def  enable_coordination (self , num_layers : Optional [int ] =  None ):
89-         self ._coordination_enabled  =  True 
90-         self ._cycle_count  =  0 
48+         if  num_layers  is  not None  and  num_layers  >  0 :
49+             self ._coordination_enabled  =  True 
50+             self ._cycle_count  =  0 
9151
92-         # Reset semaphores 
93-         self ._forward_semaphore  =  threading .Semaphore (0 )
94-         self ._backward_semaphore  =  threading .Semaphore (1 )
95-         
96-         # Update semaphore references 
97-         self ._semaphores ['forward' ] =  self ._forward_semaphore 
98-         self ._semaphores ['backward' ] =  self ._backward_semaphore 
99-         self ._signal_targets ['forward' ] =  self ._backward_semaphore 
100-         self ._signal_targets ['backward' ] =  self ._forward_semaphore 
101- 
102-         self ._num_layers  =  num_layers 
103-         self .check_should_enable_coordination ()
104-         print (f"[COORDINATION] Simplified hook coordination ENABLED with { num_layers }  )
52+             # Reset barrier 
53+             self ._execution_barrier  =  threading .Barrier (2 )
54+ 
55+             self ._num_layers  =  num_layers 
56+             print (f"Compute/Comm hook coordination ENABLED with { num_layers }  )
10557
10658    def  disable_coordination (self ):
10759        self ._coordination_enabled  =  False 
108-         # Release both semaphores to unblock any waiting threads 
109-         try :
110-             self ._forward_semaphore .release ()
111-             self ._backward_semaphore .release ()
112-         except  ValueError :
113-             pass 
114-         print ("[COORDINATION] Simplified hook coordination DISABLED" )
115- 
116-     def  check_should_enable_coordination (self ):
117-         # TODO: better way to determine when to disable coordination 
118-         moe_multipler  =  4 
119-         if  self ._num_layers  is  not None  and  self ._cycle_count  >=  moe_multipler  *  self ._num_layers :
60+         self ._cycle_count  =  0 
61+         self ._execution_barrier .abort ()  # Break barrier to unblock threads 
62+         print ("[COORDINATION] Compute/Comm hook coordination DISABLED" )
63+ 
64+     def  check_should_continue_coordination (self ):
65+         if  self ._num_layers  is  not None  and  self ._cycle_count  >=  self ._num_layers :
12066            print ("[COORDINATION] Reached target number of cycles, disabling coordination" )
121-             self . disable_coordination () 
122-              return 
67+             return   False 
68+         return   True 
12369
12470    def  is_coordination_enabled (self ):
12571        return  self ._coordination_enabled 
@@ -129,41 +75,34 @@ def is_coordination_enabled(self):
12975
13076class  SyncHook (torch .autograd .Function ):
13177    @staticmethod  
132-     def  forward (ctx , x , hook_name ):
78+     def  forward (ctx , x , hook_name = "" ):
13379        ctx .hook_name  =  hook_name 
134-         _hook_coordinator .acquire_forward_execution ()
135- 
136-         
137-         try :
138-             if  _hook_coordinator .is_coordination_enabled ():
139-                 if  hook_name  ==  "dispatch_A" :
140-                     # TODO: is this right? 
141-                     print ("Calling torch.cuda.synchronize() from dispatch_A" )
142-                     # This does GPU-CPU sync so we need to wait explicitly before starting 
143-                     torch .cuda .synchronize ()
144-                 print (f"[FORWARD] { hook_name }  )
145-             return  x 
146-         finally :
147-             _hook_coordinator .release_forward_execution (x )
80+         # handle edge case for transformer level boundary 
81+         if  _hook_coordinator ._coordination_enabled  and  hook_name  ==  "D" :
82+             _hook_coordinator ._cycle_count  +=  1 
83+             print (f"[FORWARD] cycle count: { _hook_coordinator ._cycle_count }  , "="  *  40 )
84+             if  not  _hook_coordinator .check_should_continue_coordination ():
85+                 _hook_coordinator .disable_coordination ()
86+                 return  x 
87+ 
88+         _hook_coordinator .barrier ()
89+ 
90+         if  _hook_coordinator .is_coordination_enabled ():
91+             print (f"[FORWARD] finished { hook_name }  )
92+         return  x 
14893
14994    @staticmethod  
15095    def  backward (ctx , grad_output ):
15196        hook_name  =  ctx .hook_name 
152-         _hook_coordinator .acquire_backward_execution ()
153-         
15497
155-         try :
156-             if  _hook_coordinator .is_coordination_enabled ():
157-                 if  hook_name  ==  "dispatch_B" :
158-                     # TODO: is this right? 
159-                     print ("Calling torch.cuda.synchronize() from dispatch_B" )
160-                     # This does GPU-CPU sync so we need to wait explicitly before starting 
161-                     torch .cuda .synchronize ()
162-                 print (f"[BACKWARD] { hook_name }  )
163-                 # grad_output.record_stream(torch.cuda.current_stream()) 
98+         # Edge case, skip initial barrier, all subsequent backward hooks will acquire 
99+         if  hook_name  ==  "D"  and  _hook_coordinator ._cycle_count  ==  0 :
164100            return  grad_output , None 
165-         finally :
166-             _hook_coordinator .release_backward_execution (grad_output )
101+ 
102+         _hook_coordinator .barrier ()
103+         if  _hook_coordinator .is_coordination_enabled ():
104+             print (f"[BACKWARD] finished { hook_name }  )
105+         return  grad_output , None 
167106
168107TOKEN_GROUP_ALIGN_SIZE_M  =  8 
169108ValidTokenGroupAlignmentSize  =  Literal [8 , 16 , 32 ]
@@ -231,17 +170,6 @@ def _token_dispatch(self, mod, inputs, device_mesh):
231170        # annotate module input placements/sharding with input_layouts 
232171        routed_input , num_tokens_per_expert  =  inputs 
233172        ep_size  =  device_mesh .shape [0 ]
234- 
235-         # TODO: what is causing the IMAs??? 
236-         if  not  torch .isfinite (routed_input ).all ():
237-             raise  RuntimeError (f"routed_input contains non-finite values: { routed_input }  )
238-         
239-         if  not  torch .isfinite (num_tokens_per_expert ).all ():
240-             raise  RuntimeError (f"num_tokens_per_expert contains non-finite values: { num_tokens_per_expert }  )
241-         
242-         if  routed_input .shape [0 ] >  1000000 :  # Reasonable limit 
243-             raise  RuntimeError (f"routed_input suspiciously large: { routed_input .shape }  )
244- 
245173        # generate the input splits and output splits for all-to-all 
246174        with  torch .no_grad ():
247175            num_tokens_per_expert_group  =  all_to_all_single (
@@ -325,18 +253,18 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
325253
326254    def  _wrap_with_inner_hooks (self , module ):
327255        def  inner_pre_hook (module , input ):
328-             return  (SyncHook .apply (input [0 ], "dispatch_A " ),) +  input [1 :]
256+             return  (SyncHook .apply (input [0 ], "A " ),) +  input [1 :]
329257        def  inner_post_hook (module , input , output ):
330-             return  SyncHook .apply (output , "combine_C " )
258+             return  SyncHook .apply (output , "C " )
331259        module .register_forward_pre_hook (inner_pre_hook )
332260        module .register_forward_hook (inner_post_hook )
333261        return  module 
334262
335263    def  _wrap_with_outer_hooks (self , module ):
336264        def  outer_pre_hook (module , input ):
337-             return  (SyncHook .apply (input [0 ], "dispatch_B " ),) +  input [1 :]
265+             return  (SyncHook .apply (input [0 ], "B " ),) +  input [1 :]
338266        def  outer_post_hook (module , input , output ):
339-             return  SyncHook .apply (output , "combine_D " )
267+             return  SyncHook .apply (output , "D " )
340268        module .register_forward_pre_hook (outer_pre_hook )
341269        module .register_forward_hook (outer_post_hook )
342270        return  module 
0 commit comments