@@ -226,10 +226,10 @@ def run_moe_a2a_dispatch_single_rank(
226226):
227227 """Worker function for MPI testing."""
228228 comm = MPI .COMM_WORLD
229- rank = comm .Get_rank ()
230- torch .cuda .set_device (rank )
231-
232229 try :
230+ rank = comm .Get_rank ()
231+ torch .cuda .set_device (rank )
232+
233233 mapping = Mapping (
234234 rank = rank ,
235235 tp_size = ep_size ,
@@ -262,79 +262,83 @@ def run_moe_a2a_dispatch_single_rank(
262262 payloads , expert_id_payload_index = make_nvfp4_payloads (
263263 rank_local_tokens , hidden_size , top_k , rank , token_selected_experts
264264 )
265-
266- recv_tensors = moe_a2a .dispatch (
267- token_selected_experts ,
268- payloads ,
269- max_num_tokens ,
270- invalid_token_expert_id = invalid_token_expert_id ,
271- expert_id_payload_index = expert_id_payload_index ,
272- )
273-
274- # Read counters and compact routing tensors from workspace
275- send_counters_offset = moe_a2a .metainfo [
276- MoeAlltoAll ._METAINFO_INDEX ["SEND_COUNTERS_OFFSET_INDEX" ]
277- ].item ()
278- recv_counters_offset = moe_a2a .metainfo [
279- MoeAlltoAll ._METAINFO_INDEX ["RECV_COUNTERS_OFFSET_INDEX" ]
280- ].item ()
281- topk_target_ranks_offset = moe_a2a .metainfo [
282- MoeAlltoAll ._METAINFO_INDEX ["TOPK_TARGET_RANKS_OFFSET_INDEX" ]
283- ].item ()
284- topk_send_indices_offset = moe_a2a .metainfo [
285- MoeAlltoAll ._METAINFO_INDEX ["TOPK_SEND_INDICES_OFFSET_INDEX" ]
286- ].item ()
287-
288- send_counters = (
289- moe_a2a .workspace [
290- rank , send_counters_offset : send_counters_offset + ep_size * 4
291- ]
292- .view (torch .int32 )
293- .cpu ()
294- )
295- recv_counters = (
296- moe_a2a .workspace [
297- rank , recv_counters_offset : recv_counters_offset + ep_size * 4
298- ]
299- .view (torch .int32 )
300- .cpu ()
301- )
302- topk_target_ranks = (
303- moe_a2a .workspace [
304- rank ,
305- topk_target_ranks_offset : topk_target_ranks_offset
306- + max_num_tokens * top_k * 4 ,
307- ]
308- .view (torch .int32 )
309- .view (max_num_tokens , top_k )
310- .cpu ()
311- )
312- topk_send_indices = (
313- moe_a2a .workspace [
314- rank ,
315- topk_send_indices_offset : topk_send_indices_offset
316- + max_num_tokens * top_k * 4 ,
317- ]
318- .view (torch .int32 )
319- .view (max_num_tokens , top_k )
320- .cpu ()
321- )
322-
323- # Return results to be collected (move to CPU for MPI transfer)
324- return (
325- token_selected_experts .cpu (),
326- [p .cpu () for p in payloads ],
327- [rt .cpu () for rt in recv_tensors ],
328- send_counters ,
329- topk_send_indices ,
330- topk_target_ranks ,
331- recv_counters ,
332- expert_id_payload_index ,
333- )
334265 except Exception :
335266 traceback .print_exc ()
267+ comm .allgather (True )
336268 raise
337269
270+ if any (comm .allgather (False )):
271+ raise Exception ("Another rank failed" )
272+
273+ recv_tensors = moe_a2a .dispatch (
274+ token_selected_experts ,
275+ payloads ,
276+ max_num_tokens ,
277+ invalid_token_expert_id = invalid_token_expert_id ,
278+ expert_id_payload_index = expert_id_payload_index ,
279+ )
280+
281+ # Read counters and compact routing tensors from workspace
282+ send_counters_offset = moe_a2a .metainfo [
283+ MoeAlltoAll ._METAINFO_INDEX ["SEND_COUNTERS_OFFSET_INDEX" ]
284+ ].item ()
285+ recv_counters_offset = moe_a2a .metainfo [
286+ MoeAlltoAll ._METAINFO_INDEX ["RECV_COUNTERS_OFFSET_INDEX" ]
287+ ].item ()
288+ topk_target_ranks_offset = moe_a2a .metainfo [
289+ MoeAlltoAll ._METAINFO_INDEX ["TOPK_TARGET_RANKS_OFFSET_INDEX" ]
290+ ].item ()
291+ topk_send_indices_offset = moe_a2a .metainfo [
292+ MoeAlltoAll ._METAINFO_INDEX ["TOPK_SEND_INDICES_OFFSET_INDEX" ]
293+ ].item ()
294+
295+ send_counters = (
296+ moe_a2a .workspace [
297+ rank , send_counters_offset : send_counters_offset + ep_size * 4
298+ ]
299+ .view (torch .int32 )
300+ .cpu ()
301+ )
302+ recv_counters = (
303+ moe_a2a .workspace [
304+ rank , recv_counters_offset : recv_counters_offset + ep_size * 4
305+ ]
306+ .view (torch .int32 )
307+ .cpu ()
308+ )
309+ topk_target_ranks = (
310+ moe_a2a .workspace [
311+ rank ,
312+ topk_target_ranks_offset : topk_target_ranks_offset
313+ + max_num_tokens * top_k * 4 ,
314+ ]
315+ .view (torch .int32 )
316+ .view (max_num_tokens , top_k )
317+ .cpu ()
318+ )
319+ topk_send_indices = (
320+ moe_a2a .workspace [
321+ rank ,
322+ topk_send_indices_offset : topk_send_indices_offset
323+ + max_num_tokens * top_k * 4 ,
324+ ]
325+ .view (torch .int32 )
326+ .view (max_num_tokens , top_k )
327+ .cpu ()
328+ )
329+
330+ # Return results to be collected (move to CPU for MPI transfer)
331+ return (
332+ token_selected_experts .cpu (),
333+ [p .cpu () for p in payloads ],
334+ [rt .cpu () for rt in recv_tensors ],
335+ send_counters ,
336+ topk_send_indices ,
337+ topk_target_ranks ,
338+ recv_counters ,
339+ expert_id_payload_index ,
340+ )
341+
338342
339343def verify_dispatch (
340344 all_token_selected_experts ,
@@ -538,19 +542,19 @@ def test_moe_a2a_dispatch(ep_size, all_num_tokens, top_k):
538542 pytest .skip (f"Test requires exactly { ep_size } ranks" )
539543
540544 try :
541- MnnvlMemory .initialize ()
542- if not MnnvlMemory .supports_mnnvl ():
545+ try :
546+ MnnvlMemory .initialize ()
547+ if not MnnvlMemory .supports_mnnvl ():
548+ pytest .skip ("MNNVL not supported on this system" )
549+ except Exception :
543550 pytest .skip ("MNNVL not supported on this system" )
544- except Exception :
545- pytest .skip ("MNNVL not supported on this system" )
546551
547- hidden_size = 1024
548- num_experts_per_rank = 8
549- workspace_size_per_rank = 512 * 1024 * 1024
550- invalid_token_expert_id = - 1
552+ hidden_size = 1024
553+ num_experts_per_rank = 8
554+ workspace_size_per_rank = 512 * 1024 * 1024
555+ invalid_token_expert_id = - 1
551556
552- # Run dispatch on this rank
553- try :
557+ # Run dispatch on this rank
554558 result = run_moe_a2a_dispatch_single_rank (
555559 ep_size ,
556560 all_num_tokens ,
@@ -562,12 +566,11 @@ def test_moe_a2a_dispatch(ep_size, all_num_tokens, top_k):
562566 )
563567 except Exception as e :
564568 traceback .print_exc ()
565- comm .allgather (e )
569+ comm .allgather (True )
566570 raise e
567571
568- exceptions = comm .allgather (None )
569- if any (exceptions ):
570- raise filter (lambda x : x is not None , exceptions )[0 ]
572+ if any (comm .allgather (False )):
573+ raise Exception ("Another rank failed" )
571574
572575 # Gather results from all ranks
573576 all_results = comm .allgather (result )
@@ -631,19 +634,18 @@ def test_moe_a2a_dispatch_moe_combine(ep_size, all_num_tokens, top_k):
631634 pytest .skip (f"Test requires exactly { ep_size } ranks" )
632635
633636 try :
634- MnnvlMemory .initialize ()
635- if not MnnvlMemory .supports_mnnvl ():
637+ try :
638+ MnnvlMemory .initialize ()
639+ if not MnnvlMemory .supports_mnnvl ():
640+ pytest .skip ("MNNVL not supported on this system" )
641+ except Exception :
636642 pytest .skip ("MNNVL not supported on this system" )
637- except Exception :
638- pytest .skip ("MNNVL not supported on this system" )
639-
640- torch .cuda .set_device (rank )
641643
642- hidden_size = 2880 # gpt-oss
643- num_experts_per_rank = 8
644- workspace_size_per_rank = 512 * 1024 * 1024
644+ torch .cuda .set_device (rank )
645645
646- try :
646+ hidden_size = 2880 # gpt-oss
647+ num_experts_per_rank = 8
648+ workspace_size_per_rank = 512 * 1024 * 1024
647649 mapping = Mapping (
648650 rank = rank ,
649651 moe_ep_size = world_size ,
@@ -700,7 +702,15 @@ def test_moe_a2a_dispatch_moe_combine(ep_size, all_num_tokens, top_k):
700702 num_experts = ep_size * num_experts_per_rank ,
701703 workspace_size_per_rank = workspace_size_per_rank ,
702704 )
705+ except Exception as e :
706+ traceback .print_exc ()
707+ comm .allgather (True )
708+ raise e
703709
710+ if any (comm .allgather (False )):
711+ raise Exception ("Another rank failed" )
712+
713+ try :
704714 # Dispatch
705715 recv_tensors = moe_a2a .dispatch (
706716 token_selected_experts = token_selected_experts ,
@@ -741,12 +751,11 @@ def test_moe_a2a_dispatch_moe_combine(ep_size, all_num_tokens, top_k):
741751 )
742752 except Exception as e :
743753 traceback .print_exc ()
744- comm .allgather (e )
754+ comm .allgather (True )
745755 raise e
746756
747- exceptions = comm .allgather (None )
748- if any (exceptions ):
749- raise filter (lambda x : x is not None , exceptions )[0 ]
757+ if any (comm .allgather (False )):
758+ raise Exception ("Another rank failed" )
750759
751760 try :
752761 # Combine
@@ -762,12 +771,11 @@ def test_moe_a2a_dispatch_moe_combine(ep_size, all_num_tokens, top_k):
762771 )
763772 except Exception as e :
764773 traceback .print_exc ()
765- comm .allgather (e )
774+ comm .allgather (True )
766775 raise e
767776
768- exceptions = comm .allgather (None )
769- if any (exceptions ):
770- raise filter (lambda x : x is not None , exceptions )[0 ]
777+ if any (comm .allgather (False )):
778+ raise Exception ("Another rank failed" )
771779
772780
773781if __name__ == "__main__" :
0 commit comments