55# LICENSE file in the root directory of this source tree.
66
77
8- from typing import Callable , Literal
8+ from typing import Callable , Literal , Dict
99
1010import torch
1111import torch .nn as nn
2222)
2323from torch .distributed .tensor .parallel import ParallelStyle
2424
25+ import threading
26+ import torch
27+ from typing import Optional
28+ import time
29+
30+ class HookSequenceCoordinator :
31+ """Coordinates hooks based on a predefined sequence"""
32+
33+ def __init__ (self ):
34+ self ._lock = threading .Lock ()
35+ self ._condition = threading .Condition (self ._lock )
36+
37+ # Define your desired execution sequence matching:
38+ # stageB.combine() -> stageA.forward_attention() -> stageB.backward_moe() ->
39+ # stageA.dispatch() -> stageB.dispatch() -> stageA.forward_moe() ->
40+ # stageB.backward_attention() -> stageA.combine()
41+ self ._hook_sequence = [
42+ "combine_D_bwd" ,
43+ "dispatch_A_fwd" ,
44+ "combine_C_bwd" ,
45+ "dispatch_B_fwd" ,
46+ "dispatch_B_bwd" ,
47+ "combine_C_fwd" ,
48+ "dispatch_A_bwd" ,
49+ "combine_D_fwd" ,
50+ ]
51+ # Create a semaphore for each hook in the sequence
52+ self ._semaphores : Dict [str , threading .Semaphore ] = {}
53+ self ._reset_semaphores ()
54+
55+ # Coordination control - disabled by default
56+ self ._coordination_enabled = False
57+ self ._cycle_count = 0
58+
59+ def _reset_semaphores (self ):
60+ """Reset all semaphores - first one gets 1 permit, others get 0"""
61+ self ._semaphores .clear ()
62+ for i , hook_name in enumerate (self ._hook_sequence ):
63+ # First semaphore starts with 1 permit, others start with 0
64+ initial_permits = 1 if i == 0 else 0
65+ self ._semaphores [hook_name ] = threading .Semaphore (initial_permits )
66+
67+ def enable_coordination (self ):
68+ """Enable hook coordination"""
69+ self ._coordination_enabled = True
70+ self ._reset_semaphores () # Reset semaphores when enabling
71+ print ("[COORDINATION] Hook coordination ENABLED" )
72+
73+ def disable_coordination (self ):
74+ """Disable hook coordination"""
75+ self ._coordination_enabled = False
76+ # Release all semaphores so no threads get stuck
77+ for semaphore in self ._semaphores .values ():
78+ try :
79+ semaphore .release ()
80+ except ValueError :
81+ pass # Semaphore was already at max value
82+ print ("[COORDINATION] Hook coordination DISABLED" )
83+
84+ def is_coordination_enabled (self ) -> bool :
85+ """Check if coordination is currently enabled"""
86+ return self ._coordination_enabled
87+
88+ def reset_coordination (self ):
89+ """Reset coordination state (useful between training runs)"""
90+ self ._cycle_count = 0
91+ self ._reset_semaphores ()
92+ print ("[COORDINATION] Hook coordination state RESET" )
93+
94+ def acquire_execution (self , hook_name : str ):
95+ """Acquire execution permission using semaphores"""
96+ # If coordination is disabled, just pass through
97+ if not self ._coordination_enabled :
98+ print (f"[PASSTHROUGH] { hook_name } executing (coordination disabled)" )
99+ return
100+
101+ # Check if hook is in our sequence
102+ if hook_name not in self ._semaphores :
103+ print (f"[WARNING] { hook_name } not in sequence, executing without coordination" )
104+ return
105+
106+ # Acquire the semaphore for this hook (blocks until available)
107+ print (f"[WAITING] { hook_name } waiting for semaphore" )
108+ self ._semaphores [hook_name ].acquire ()
109+ print (f"[EXECUTING] { hook_name } acquired semaphore" )
110+
111+ def release_execution (self , hook_name : str ):
112+ """Release execution and signal next hook"""
113+ # If coordination is disabled, just pass through
114+ if not self ._coordination_enabled :
115+ return
116+
117+ # Check if hook is in our sequence
118+ if hook_name not in self ._semaphores :
119+ return
120+
121+ # Find the next hook in the sequence and release its semaphore
122+ try :
123+ current_index = self ._hook_sequence .index (hook_name )
124+ next_index = (current_index + 1 ) % len (self ._hook_sequence )
125+ next_hook = self ._hook_sequence [next_index ]
126+
127+ print (f"[COMPLETED] { hook_name } completed, signaling { next_hook } " )
128+ self ._semaphores [next_hook ].release ()
129+
130+ # Check if we completed a full cycle
131+ if next_index == 0 :
132+ self ._cycle_count += 1
133+ print (f"[CYCLE] Completed cycle { self ._cycle_count } " )
134+
135+ except ValueError :
136+ print (f"[ERROR] { hook_name } not found in sequence" )
137+
138+ # Global coordinator
139+ _hook_coordinator = HookSequenceCoordinator ()
140+
141+ class SyncHook (torch .autograd .Function ):
142+ """Sync hook that follows a predefined execution sequence"""
143+
144+ @staticmethod
145+ def forward (ctx , x , hook_name ):
146+ ctx .hook_name = hook_name
147+
148+ # Use forward-specific hook name
149+ forward_hook_name = f"{ hook_name } _fwd"
150+ _hook_coordinator .acquire_execution (forward_hook_name )
151+
152+ try :
153+ if _hook_coordinator .is_coordination_enabled ():
154+ print (f"[FORWARD HOOK] { forward_hook_name } (coordinated)" )
155+ else :
156+ print (f"[FORWARD HOOK] { forward_hook_name } (uncoordinated)" )
157+ return x
158+ finally :
159+ _hook_coordinator .release_execution (forward_hook_name )
160+
161+ @staticmethod
162+ def backward (ctx , grad_output ):
163+ hook_name = ctx .hook_name
164+
165+ # Use backward-specific hook name
166+ backward_hook_name = f"{ hook_name } _bwd"
167+ _hook_coordinator .acquire_execution (backward_hook_name )
168+
169+ try :
170+ if _hook_coordinator .is_coordination_enabled ():
171+ print (f"[BACKWARD HOOK] { backward_hook_name } (coordinated)" )
172+ else :
173+ print (f"[BACKWARD HOOK] { backward_hook_name } (uncoordinated)" )
174+ return grad_output , None
175+ finally :
176+ _hook_coordinator .release_execution (backward_hook_name )
177+
178+
25179
26180TOKEN_GROUP_ALIGN_SIZE_M = 8
27181ValidTokenGroupAlignmentSize = Literal [8 , 16 , 32 ]
@@ -77,7 +231,6 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
77231 self ._partition_fn ,
78232 )
79233
80-
81234class ExpertParallel (ParallelStyle ):
82235 def __init__ (self ):
83236 super ().__init__ ()
@@ -90,6 +243,9 @@ def _token_dispatch(self, mod, inputs, device_mesh):
90243 routed_input , num_tokens_per_expert = inputs
91244 ep_size = device_mesh .shape [0 ]
92245
246+ # HOOK: signal ready for sync
247+ routed_input = SyncHook .apply (routed_input , "dispatch_A" )
248+
93249 # generate the input splits and output splits for all-to-all
94250 with torch .no_grad ():
95251 num_tokens_per_expert_group = all_to_all_single (
@@ -135,6 +291,9 @@ def _token_dispatch(self, mod, inputs, device_mesh):
135291 # generate_permute_indices in moe.py, which also does padding to make sure the number of tokens
136292 # each expert gets locally is a multiple of ALIGN_SIZE_M.
137293
294+ # HOOK: signal ready for sync
295+ routed_input = SyncHook .apply (routed_input , "dispatch_B" )
296+
138297 return routed_input , num_tokens_per_expert_group
139298
140299 @staticmethod
@@ -146,12 +305,16 @@ def _partition_fn(name, mod, device_mesh):
146305
147306 # performing all-to-all combine on the output
148307 def _token_combine (self , mod , routed_output , device_mesh ):
308+ # HOOK: signal ready for sync
309+ routed_output = SyncHook .apply (routed_output , "combine_C" )
149310 routed_output = all_to_all_single_autograd (
150311 routed_output ,
151312 self .input_splits ,
152313 self .output_splits ,
153314 device_mesh .get_group (),
154315 )
316+ # HOOK: signal ready for sync
317+ routed_output = SyncHook .apply (routed_output , "combine_D" )
155318 return routed_output
156319
157320 def _apply (self , module : nn .Module , device_mesh : DeviceMesh ) -> nn .Module :
0 commit comments