Skip to content

Commit bd82a2b

Browse files
committed
More exit handling for rank failures
1 parent 2bac65a commit bd82a2b

File tree

1 file changed

+112
-104
lines changed

1 file changed

+112
-104
lines changed

tests/comm/test_mnnvl_a2a.py

Lines changed: 112 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -226,10 +226,10 @@ def run_moe_a2a_dispatch_single_rank(
226226
):
227227
"""Worker function for MPI testing."""
228228
comm = MPI.COMM_WORLD
229-
rank = comm.Get_rank()
230-
torch.cuda.set_device(rank)
231-
232229
try:
230+
rank = comm.Get_rank()
231+
torch.cuda.set_device(rank)
232+
233233
mapping = Mapping(
234234
rank=rank,
235235
tp_size=ep_size,
@@ -262,79 +262,83 @@ def run_moe_a2a_dispatch_single_rank(
262262
payloads, expert_id_payload_index = make_nvfp4_payloads(
263263
rank_local_tokens, hidden_size, top_k, rank, token_selected_experts
264264
)
265-
266-
recv_tensors = moe_a2a.dispatch(
267-
token_selected_experts,
268-
payloads,
269-
max_num_tokens,
270-
invalid_token_expert_id=invalid_token_expert_id,
271-
expert_id_payload_index=expert_id_payload_index,
272-
)
273-
274-
# Read counters and compact routing tensors from workspace
275-
send_counters_offset = moe_a2a.metainfo[
276-
MoeAlltoAll._METAINFO_INDEX["SEND_COUNTERS_OFFSET_INDEX"]
277-
].item()
278-
recv_counters_offset = moe_a2a.metainfo[
279-
MoeAlltoAll._METAINFO_INDEX["RECV_COUNTERS_OFFSET_INDEX"]
280-
].item()
281-
topk_target_ranks_offset = moe_a2a.metainfo[
282-
MoeAlltoAll._METAINFO_INDEX["TOPK_TARGET_RANKS_OFFSET_INDEX"]
283-
].item()
284-
topk_send_indices_offset = moe_a2a.metainfo[
285-
MoeAlltoAll._METAINFO_INDEX["TOPK_SEND_INDICES_OFFSET_INDEX"]
286-
].item()
287-
288-
send_counters = (
289-
moe_a2a.workspace[
290-
rank, send_counters_offset : send_counters_offset + ep_size * 4
291-
]
292-
.view(torch.int32)
293-
.cpu()
294-
)
295-
recv_counters = (
296-
moe_a2a.workspace[
297-
rank, recv_counters_offset : recv_counters_offset + ep_size * 4
298-
]
299-
.view(torch.int32)
300-
.cpu()
301-
)
302-
topk_target_ranks = (
303-
moe_a2a.workspace[
304-
rank,
305-
topk_target_ranks_offset : topk_target_ranks_offset
306-
+ max_num_tokens * top_k * 4,
307-
]
308-
.view(torch.int32)
309-
.view(max_num_tokens, top_k)
310-
.cpu()
311-
)
312-
topk_send_indices = (
313-
moe_a2a.workspace[
314-
rank,
315-
topk_send_indices_offset : topk_send_indices_offset
316-
+ max_num_tokens * top_k * 4,
317-
]
318-
.view(torch.int32)
319-
.view(max_num_tokens, top_k)
320-
.cpu()
321-
)
322-
323-
# Return results to be collected (move to CPU for MPI transfer)
324-
return (
325-
token_selected_experts.cpu(),
326-
[p.cpu() for p in payloads],
327-
[rt.cpu() for rt in recv_tensors],
328-
send_counters,
329-
topk_send_indices,
330-
topk_target_ranks,
331-
recv_counters,
332-
expert_id_payload_index,
333-
)
334265
except Exception:
335266
traceback.print_exc()
267+
comm.allgather(True)
336268
raise
337269

270+
if any(comm.allgather(False)):
271+
raise Exception("Another rank failed")
272+
273+
recv_tensors = moe_a2a.dispatch(
274+
token_selected_experts,
275+
payloads,
276+
max_num_tokens,
277+
invalid_token_expert_id=invalid_token_expert_id,
278+
expert_id_payload_index=expert_id_payload_index,
279+
)
280+
281+
# Read counters and compact routing tensors from workspace
282+
send_counters_offset = moe_a2a.metainfo[
283+
MoeAlltoAll._METAINFO_INDEX["SEND_COUNTERS_OFFSET_INDEX"]
284+
].item()
285+
recv_counters_offset = moe_a2a.metainfo[
286+
MoeAlltoAll._METAINFO_INDEX["RECV_COUNTERS_OFFSET_INDEX"]
287+
].item()
288+
topk_target_ranks_offset = moe_a2a.metainfo[
289+
MoeAlltoAll._METAINFO_INDEX["TOPK_TARGET_RANKS_OFFSET_INDEX"]
290+
].item()
291+
topk_send_indices_offset = moe_a2a.metainfo[
292+
MoeAlltoAll._METAINFO_INDEX["TOPK_SEND_INDICES_OFFSET_INDEX"]
293+
].item()
294+
295+
send_counters = (
296+
moe_a2a.workspace[
297+
rank, send_counters_offset : send_counters_offset + ep_size * 4
298+
]
299+
.view(torch.int32)
300+
.cpu()
301+
)
302+
recv_counters = (
303+
moe_a2a.workspace[
304+
rank, recv_counters_offset : recv_counters_offset + ep_size * 4
305+
]
306+
.view(torch.int32)
307+
.cpu()
308+
)
309+
topk_target_ranks = (
310+
moe_a2a.workspace[
311+
rank,
312+
topk_target_ranks_offset : topk_target_ranks_offset
313+
+ max_num_tokens * top_k * 4,
314+
]
315+
.view(torch.int32)
316+
.view(max_num_tokens, top_k)
317+
.cpu()
318+
)
319+
topk_send_indices = (
320+
moe_a2a.workspace[
321+
rank,
322+
topk_send_indices_offset : topk_send_indices_offset
323+
+ max_num_tokens * top_k * 4,
324+
]
325+
.view(torch.int32)
326+
.view(max_num_tokens, top_k)
327+
.cpu()
328+
)
329+
330+
# Return results to be collected (move to CPU for MPI transfer)
331+
return (
332+
token_selected_experts.cpu(),
333+
[p.cpu() for p in payloads],
334+
[rt.cpu() for rt in recv_tensors],
335+
send_counters,
336+
topk_send_indices,
337+
topk_target_ranks,
338+
recv_counters,
339+
expert_id_payload_index,
340+
)
341+
338342

339343
def verify_dispatch(
340344
all_token_selected_experts,
@@ -538,19 +542,19 @@ def test_moe_a2a_dispatch(ep_size, all_num_tokens, top_k):
538542
pytest.skip(f"Test requires exactly {ep_size} ranks")
539543

540544
try:
541-
MnnvlMemory.initialize()
542-
if not MnnvlMemory.supports_mnnvl():
545+
try:
546+
MnnvlMemory.initialize()
547+
if not MnnvlMemory.supports_mnnvl():
548+
pytest.skip("MNNVL not supported on this system")
549+
except Exception:
543550
pytest.skip("MNNVL not supported on this system")
544-
except Exception:
545-
pytest.skip("MNNVL not supported on this system")
546551

547-
hidden_size = 1024
548-
num_experts_per_rank = 8
549-
workspace_size_per_rank = 512 * 1024 * 1024
550-
invalid_token_expert_id = -1
552+
hidden_size = 1024
553+
num_experts_per_rank = 8
554+
workspace_size_per_rank = 512 * 1024 * 1024
555+
invalid_token_expert_id = -1
551556

552-
# Run dispatch on this rank
553-
try:
557+
# Run dispatch on this rank
554558
result = run_moe_a2a_dispatch_single_rank(
555559
ep_size,
556560
all_num_tokens,
@@ -562,12 +566,11 @@ def test_moe_a2a_dispatch(ep_size, all_num_tokens, top_k):
562566
)
563567
except Exception as e:
564568
traceback.print_exc()
565-
comm.allgather(e)
569+
comm.allgather(True)
566570
raise e
567571

568-
exceptions = comm.allgather(None)
569-
if any(exceptions):
570-
raise filter(lambda x: x is not None, exceptions)[0]
572+
if any(comm.allgather(False)):
573+
raise Exception("Another rank failed")
571574

572575
# Gather results from all ranks
573576
all_results = comm.allgather(result)
@@ -631,19 +634,18 @@ def test_moe_a2a_dispatch_moe_combine(ep_size, all_num_tokens, top_k):
631634
pytest.skip(f"Test requires exactly {ep_size} ranks")
632635

633636
try:
634-
MnnvlMemory.initialize()
635-
if not MnnvlMemory.supports_mnnvl():
637+
try:
638+
MnnvlMemory.initialize()
639+
if not MnnvlMemory.supports_mnnvl():
640+
pytest.skip("MNNVL not supported on this system")
641+
except Exception:
636642
pytest.skip("MNNVL not supported on this system")
637-
except Exception:
638-
pytest.skip("MNNVL not supported on this system")
639-
640-
torch.cuda.set_device(rank)
641643

642-
hidden_size = 2880 # gpt-oss
643-
num_experts_per_rank = 8
644-
workspace_size_per_rank = 512 * 1024 * 1024
644+
torch.cuda.set_device(rank)
645645

646-
try:
646+
hidden_size = 2880 # gpt-oss
647+
num_experts_per_rank = 8
648+
workspace_size_per_rank = 512 * 1024 * 1024
647649
mapping = Mapping(
648650
rank=rank,
649651
moe_ep_size=world_size,
@@ -700,7 +702,15 @@ def test_moe_a2a_dispatch_moe_combine(ep_size, all_num_tokens, top_k):
700702
num_experts=ep_size * num_experts_per_rank,
701703
workspace_size_per_rank=workspace_size_per_rank,
702704
)
705+
except Exception as e:
706+
traceback.print_exc()
707+
comm.allgather(True)
708+
raise e
703709

710+
if any(comm.allgather(False)):
711+
raise Exception("Another rank failed")
712+
713+
try:
704714
# Dispatch
705715
recv_tensors = moe_a2a.dispatch(
706716
token_selected_experts=token_selected_experts,
@@ -741,12 +751,11 @@ def test_moe_a2a_dispatch_moe_combine(ep_size, all_num_tokens, top_k):
741751
)
742752
except Exception as e:
743753
traceback.print_exc()
744-
comm.allgather(e)
754+
comm.allgather(True)
745755
raise e
746756

747-
exceptions = comm.allgather(None)
748-
if any(exceptions):
749-
raise filter(lambda x: x is not None, exceptions)[0]
757+
if any(comm.allgather(False)):
758+
raise Exception("Another rank failed")
750759

751760
try:
752761
# Combine
@@ -762,12 +771,11 @@ def test_moe_a2a_dispatch_moe_combine(ep_size, all_num_tokens, top_k):
762771
)
763772
except Exception as e:
764773
traceback.print_exc()
765-
comm.allgather(e)
774+
comm.allgather(True)
766775
raise e
767776

768-
exceptions = comm.allgather(None)
769-
if any(exceptions):
770-
raise filter(lambda x: x is not None, exceptions)[0]
777+
if any(comm.allgather(False)):
778+
raise Exception("Another rank failed")
771779

772780

773781
if __name__ == "__main__":

0 commit comments

Comments
 (0)