Skip to content

Commit 28b4dd4

Browse files
committed
Upd
1 parent b26c69d commit 28b4dd4

File tree

3 files changed

+31
-20
lines changed

3 files changed

+31
-20
lines changed

flashinfer/comm/mnnvl.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,7 @@ def supports_mnnvl() -> bool:
547547

548548
class McastDeviceMemory:
549549
"""Python port of McastDeviceMemory from TensorRT-LLM"""
550+
550551
def __init__(
551552
self,
552553
buf_size: int,
@@ -753,7 +754,7 @@ def get_world_size(self) -> int:
753754
"""Get the total number of devices in the group"""
754755
return self.group_size
755756

756-
def _alloc_mn_mcast_mem(self, buf_size: int, comm: Any=MpiComm()):
757+
def _alloc_mn_mcast_mem(self, buf_size: int, comm: Any = None):
757758
"""Allocate multi-node multicast memory using MNNVL"""
758759

759760
# Verify CUDA context
@@ -766,7 +767,8 @@ def _alloc_mn_mcast_mem(self, buf_size: int, comm: Any=MpiComm()):
766767
)
767768
except Exception as e:
768769
print(f"Error checking CUDA context: {e}")
769-
770+
if comm is None:
771+
comm = MpiComm()
770772
# Set up allocation properties
771773
handle_type = cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC
772774

flashinfer/comm/trtllm_mnnvl_ar.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from ..jit import gen_trtllm_mnnvl_comm_module
1717
from ..utils import register_custom_op
18-
from .mnnvl import (McastGPUBuffer, CommBackend)
18+
from .mnnvl import McastGPUBuffer, CommBackend
1919

2020

2121
def mpi_barrier():
@@ -122,9 +122,10 @@ def trtllm_mnnvl_rmsnorm(
122122

123123

124124
def get_allreduce_mnnvl_workspace(
125-
mapping: Mapping, dtype: torch.dtype,
126-
buffer_size_in_bytes: Optional[int] = None,
125+
mapping: Mapping,
126+
dtype: torch.dtype,
127127
comm: Optional[CommBackend] = None,
128+
buffer_size_in_bytes: Optional[int] = None,
128129
) -> Tuple[McastGPUBuffer, torch.Tensor, int]:
129130
"""Get workspace buffers needed for multi-node NVLink all-reduce operation.
130131
@@ -140,8 +141,8 @@ def get_allreduce_mnnvl_workspace(
140141
Args:
141142
mapping: Tensor parallel mapping configuration containing rank info
142143
dtype: Data type of the tensors being reduced
143-
buffer_size_in_bytes: Optional buffer size. Practically, assign this to 3 * 2 * dtype.itemsize * hidden_dim * max_tokens
144144
comm: Optional communication backend for multi-node synchronization
145+
buffer_size_in_bytes: Optional buffer size. Practically, assign this to 3 * 2 * dtype.itemsize * hidden_dim * max_tokens
145146
146147
Returns:
147148
Tuple containing:

tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,16 @@
66
import pytest
77
import torch
88
import torch.distributed as dist
9-
from mpi4py import MPI # Added MPI import
109

1110
import flashinfer.comm.trtllm_mnnvl_ar as trtllm_mnnvl_ar
1211
from flashinfer.comm.mapping import Mapping
13-
14-
# Use flashinfer.norm.rmsnorm as reference implementation.
15-
from flashinfer.norm import rmsnorm
1612
from flashinfer.comm.mnnvl import CommBackend as CommBackend
1713

1814
import pynvml
1915

2016
pynvml.nvmlInit()
2117

18+
2219
class CustomCommunicator(CommBackend):
2320
def __init__(self, group):
2421
self._group = group
@@ -59,7 +56,7 @@ def bcast(self, data, root: int = 0):
5956
# broadcast_object_list mutates obj_list in-place
6057
dist.broadcast_object_list(obj_list, src=root, group=self._group)
6158
return obj_list[0]
62-
59+
6360
def barrier(self):
6461
"""
6562
Synchronize all ranks in this communicator.
@@ -69,6 +66,7 @@ def barrier(self):
6966
def Split(self, color: int, key: int) -> "CustomCommunicator":
7067
return self
7168

69+
7270
def get_open_port() -> int:
7371
try:
7472
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
@@ -78,7 +76,8 @@ def get_open_port() -> int:
7876
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
7977
s.bind(("::1", 0))
8078
return s.getsockname()[1]
81-
79+
80+
8281
def multi_process_parallel(
8382
world_size: int, dtype: torch.dtype, test_target: Any, target_args: tuple = ()
8483
) -> None:
@@ -98,6 +97,7 @@ def multi_process_parallel(
9897
f"Process {i} failed with exit code {procs[i].exitcode}"
9998
)
10099

100+
101101
@torch.inference_mode()
102102
def row_linear_residual_norm_forward(
103103
x: torch.Tensor,
@@ -182,6 +182,7 @@ def func(
182182
atol=0.15,
183183
)
184184

185+
185186
def _run_mnnvl_ar(world_size, rank, dtype, distributed_init_port, seq_len, hidden_size):
186187
# Set CUDA device based on rank
187188
device = torch.device(f"cuda:{rank}")
@@ -223,8 +224,11 @@ def _run_mnnvl_ar(world_size, rank, dtype, distributed_init_port, seq_len, hidde
223224
# Get workspace buffers using MPI rank - allocate once per seq_lens list and reuse within the list
224225
# This workspace is sized for the maximum expected sequence length and can be reused within each list
225226
# Each parameterized list gets its own fresh workspace allocation
227+
explicit_workspace_bytes = 3 * 2 * dtype.itemsize * hidden_size * seq_len
226228
mcast_buffer_mnnvl, buffer_flags_mnnvl, max_num_elements_mnnvl = (
227-
trtllm_mnnvl_ar.get_allreduce_mnnvl_workspace(mapping, dtype, comm)
229+
trtllm_mnnvl_ar.get_allreduce_mnnvl_workspace(
230+
mapping, dtype, comm, explicit_workspace_bytes
231+
)
228232
)
229233

230234
multicast_ptr = mcast_buffer_mnnvl.get_multicast_ptr()
@@ -282,16 +286,16 @@ def _run_mnnvl_ar(world_size, rank, dtype, distributed_init_port, seq_len, hidde
282286
# Synchronize before next test
283287
comm.barrier()
284288

285-
print(
286-
f"PASSED[rank={rank}]: seq_len={seq_len}, dtype={dtype}"
287-
)
289+
print(f"PASSED[rank={rank}]: seq_len={seq_len}, dtype={dtype}")
288290

289291
except Exception as e:
290292
rank_failed = True
291-
failure_message = f"FAILED[rank={rank}]: seq_lens={seq_len}, dtype={dtype} failed: {e}"
293+
failure_message = (
294+
f"FAILED[rank={rank}]: seq_lens={seq_len}, dtype={dtype} failed: {e}"
295+
)
292296
print(failure_message)
293297
# Gather failure status from all ranks
294-
all_failures = MPI.COMM_WORLD.allgather(rank_failed)
298+
all_failures = comm.allgather(rank_failed)
295299

296300
# If any rank failed, fail the test
297301
if any(all_failures):
@@ -302,7 +306,7 @@ def _run_mnnvl_ar(world_size, rank, dtype, distributed_init_port, seq_len, hidde
302306
# Fail the test on all ranks
303307
pytest.fail(f"Test failed on ranks {failed_ranks}")
304308
comm.barrier()
305-
309+
306310
finally:
307311
# Ensure cleanup happens for this list's workspace
308312
if "mcast_buffer_mnnvl" in locals():
@@ -311,10 +315,14 @@ def _run_mnnvl_ar(world_size, rank, dtype, distributed_init_port, seq_len, hidde
311315
# Final synchronization and check for failures across all ranks
312316
comm.barrier()
313317

318+
314319
"""Main test function that runs on each MPI rank"""
320+
321+
315322
@pytest.mark.parametrize("world_size", [2, 4])
316323
def test_mnnvl_allreduce_custom_communicator(
317-
monkeypatch, world_size,
324+
monkeypatch,
325+
world_size,
318326
):
319327
monkeypatch.setenv("TRTLLM_FORCE_MNNVL_AR", "1") # force multi-node allreduce.
320328
seq_len = 24

0 commit comments

Comments
 (0)