66import pytest
77import torch
88import torch .distributed as dist
9- from mpi4py import MPI # Added MPI import
109
1110import flashinfer .comm .trtllm_mnnvl_ar as trtllm_mnnvl_ar
1211from flashinfer .comm .mapping import Mapping
13-
14- # Use flashinfer.norm.rmsnorm as reference implementation.
15- from flashinfer .norm import rmsnorm
1612from flashinfer .comm .mnnvl import CommBackend as CommBackend
1713
1814import pynvml
1915
2016pynvml .nvmlInit ()
2117
18+
2219class CustomCommunicator (CommBackend ):
2320 def __init__ (self , group ):
2421 self ._group = group
@@ -59,7 +56,7 @@ def bcast(self, data, root: int = 0):
5956 # broadcast_object_list mutates obj_list in-place
6057 dist .broadcast_object_list (obj_list , src = root , group = self ._group )
6158 return obj_list [0 ]
62-
59+
6360 def barrier (self ):
6461 """
6562 Synchronize all ranks in this communicator.
@@ -69,6 +66,7 @@ def barrier(self):
6966 def Split (self , color : int , key : int ) -> "CustomCommunicator" :
7067 return self
7168
69+
7270def get_open_port () -> int :
7371 try :
7472 with socket .socket (socket .AF_INET , socket .SOCK_STREAM ) as s :
@@ -78,7 +76,8 @@ def get_open_port() -> int:
7876 with socket .socket (socket .AF_INET6 , socket .SOCK_STREAM ) as s :
7977 s .bind (("::1" , 0 ))
8078 return s .getsockname ()[1 ]
81-
79+
80+
8281def multi_process_parallel (
8382 world_size : int , dtype : torch .dtype , test_target : Any , target_args : tuple = ()
8483) -> None :
@@ -98,6 +97,7 @@ def multi_process_parallel(
9897 f"Process { i } failed with exit code { procs [i ].exitcode } "
9998 )
10099
100+
101101@torch .inference_mode ()
102102def row_linear_residual_norm_forward (
103103 x : torch .Tensor ,
@@ -182,6 +182,7 @@ def func(
182182 atol = 0.15 ,
183183 )
184184
185+
185186def _run_mnnvl_ar (world_size , rank , dtype , distributed_init_port , seq_len , hidden_size ):
186187 # Set CUDA device based on rank
187188 device = torch .device (f"cuda:{ rank } " )
@@ -223,8 +224,11 @@ def _run_mnnvl_ar(world_size, rank, dtype, distributed_init_port, seq_len, hidde
223224 # Get workspace buffers using MPI rank - allocate once per seq_lens list and reuse within the list
224225 # This workspace is sized for the maximum expected sequence length and can be reused within each list
225226 # Each parameterized list gets its own fresh workspace allocation
227+ explicit_workspace_bytes = 3 * 2 * dtype .itemsize * hidden_size * seq_len
226228 mcast_buffer_mnnvl , buffer_flags_mnnvl , max_num_elements_mnnvl = (
227- trtllm_mnnvl_ar .get_allreduce_mnnvl_workspace (mapping , dtype , comm )
229+ trtllm_mnnvl_ar .get_allreduce_mnnvl_workspace (
230+ mapping , dtype , comm , explicit_workspace_bytes
231+ )
228232 )
229233
230234 multicast_ptr = mcast_buffer_mnnvl .get_multicast_ptr ()
@@ -282,16 +286,16 @@ def _run_mnnvl_ar(world_size, rank, dtype, distributed_init_port, seq_len, hidde
282286 # Synchronize before next test
283287 comm .barrier ()
284288
285- print (
286- f"PASSED[rank={ rank } ]: seq_len={ seq_len } , dtype={ dtype } "
287- )
289+ print (f"PASSED[rank={ rank } ]: seq_len={ seq_len } , dtype={ dtype } " )
288290
289291 except Exception as e :
290292 rank_failed = True
291- failure_message = f"FAILED[rank={ rank } ]: seq_lens={ seq_len } , dtype={ dtype } failed: { e } "
293+ failure_message = (
294+ f"FAILED[rank={ rank } ]: seq_lens={ seq_len } , dtype={ dtype } failed: { e } "
295+ )
292296 print (failure_message )
293297 # Gather failure status from all ranks
294- all_failures = MPI . COMM_WORLD .allgather (rank_failed )
298+ all_failures = comm .allgather (rank_failed )
295299
296300 # If any rank failed, fail the test
297301 if any (all_failures ):
@@ -302,7 +306,7 @@ def _run_mnnvl_ar(world_size, rank, dtype, distributed_init_port, seq_len, hidde
302306 # Fail the test on all ranks
303307 pytest .fail (f"Test failed on ranks { failed_ranks } " )
304308 comm .barrier ()
305-
309+
306310 finally :
307311 # Ensure cleanup happens for this list's workspace
308312 if "mcast_buffer_mnnvl" in locals ():
@@ -311,10 +315,14 @@ def _run_mnnvl_ar(world_size, rank, dtype, distributed_init_port, seq_len, hidde
311315 # Final synchronization and check for failures across all ranks
312316 comm .barrier ()
313317
318+
314319"""Main test function that runs on each MPI rank"""
320+
321+
315322@pytest .mark .parametrize ("world_size" , [2 , 4 ])
316323def test_mnnvl_allreduce_custom_communicator (
317- monkeypatch , world_size ,
324+ monkeypatch ,
325+ world_size ,
318326):
319327 monkeypatch .setenv ("TRTLLM_FORCE_MNNVL_AR" , "1" ) # force multi-node allreduce.
320328 seq_len = 24
0 commit comments