diff --git a/.azure-pipelines/multi-nodes-test.yml b/.azure-pipelines/multi-nodes-test.yml index 8aab0d1a3..daa9aeb3e 100644 --- a/.azure-pipelines/multi-nodes-test.yml +++ b/.azure-pipelines/multi-nodes-test.yml @@ -30,6 +30,7 @@ jobs: mkdir build && cd build MPI_HOME=/usr/local/mpi /tmp/cmake-3.26.4-linux-x86_64/bin/cmake -DCMAKE_BUILD_TYPE=Release .. make -j + make pylib-copy workingDirectory: '$(System.DefaultWorkingDirectory)' - task: DownloadSecureFile@1 @@ -104,6 +105,25 @@ jobs: -O $SSH_OPTION -o output 'sudo docker exec -t mscclpp-test bash /root/mscclpp/run_tests.sh mp-ut' kill $CHILD_PID + - task: Bash@3 + name: RunMultiNodePythonTests + displayName: Run multi-nodes python tests + inputs: + targetType: 'inline' + script: | + set -e + HOSTFILE=$(System.DefaultWorkingDirectory)/test/mscclpp-test/deploy/hostfile + SSH_OPTION="StrictHostKeyChecking=no" + KeyFilePath=${SSHKEYFILE_SECUREFILEPATH} + rm -rf output/* + mkdir -p output + touch output/mscclpp-it-000000 + tail -f output/mscclpp-it-000000 & + CHILD_PID=$! + parallel-ssh -t 0 -H mscclpp-it-000000 -l azureuser -x "-i ${KeyFilePath}" \ + -O $SSH_OPTION -o output 'sudo docker exec -t mscclpp-test bash /root/mscclpp/run_tests.sh pytests' + kill $CHILD_PID + - task: AzureCLI@2 name: StopVMSS displayName: Deallocate VMSS diff --git a/.azure-pipelines/ut.yml b/.azure-pipelines/ut.yml index b31ad8ad0..31b8091cd 100644 --- a/.azure-pipelines/ut.yml +++ b/.azure-pipelines/ut.yml @@ -70,3 +70,20 @@ jobs: mpirun -tag-output -np 4 ./build/test/mp_unit_tests mpirun -tag-output -np 8 ./build/test/mp_unit_tests workingDirectory: '$(System.DefaultWorkingDirectory)' + + - task: Bash@3 + name: PyTests + displayName: Run pytests + inputs: + targetType: 'inline' + script: | + set -e + export PATH=/usr/local/mpi/bin:$PATH + cd build && make pylib-copy + if [[ '$(containerImage)' == *'cuda11'* ]]; then + pip3 install -r ../python/test/requirements_cu11.txt + else + pip3 install -r ../python/test/requirements_cu12.txt + fi + mpirun -tag-output -np 8 ~/.local/bin/pytest ../python/test/test_mscclpp.py -x + workingDirectory: '$(System.DefaultWorkingDirectory)' diff --git a/.gitignore b/.gitignore index 072a32fb2..af2117f72 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ dist/ __pycache__ .*.swp .idea/ +*.so diff --git a/include/mscclpp/core.hpp b/include/mscclpp/core.hpp index d3af3103a..1e9e6abd8 100644 --- a/include/mscclpp/core.hpp +++ b/include/mscclpp/core.hpp @@ -326,7 +326,7 @@ class RegisteredMemory { /// Get a pointer to the memory block. /// /// @return A pointer to the memory block. - void* data(); + void* data() const; /// Get the size of the memory block. /// diff --git a/include/mscclpp/poll.hpp b/include/mscclpp/poll.hpp index cb32a9743..0bc09a949 100644 --- a/include/mscclpp/poll.hpp +++ b/include/mscclpp/poll.hpp @@ -6,6 +6,8 @@ #ifdef __CUDACC__ +#include + extern __device__ void __assert_fail(const char *__assertion, const char *__file, unsigned int __line, const char *__function) __THROW; diff --git a/include/mscclpp/sm_channel.hpp b/include/mscclpp/sm_channel.hpp index 947eea21d..a1d1daf2b 100644 --- a/include/mscclpp/sm_channel.hpp +++ b/include/mscclpp/sm_channel.hpp @@ -15,8 +15,8 @@ namespace mscclpp { struct SmChannel { private: std::shared_ptr semaphore_; + RegisteredMemory dst_; void* src_; - void* dst_; void* getPacketBuffer_; public: diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 22fd318e9..7776be62c 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -1,14 +1,17 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED) -include(FetchContent) -FetchContent_Declare(nanobind GIT_REPOSITORY https://github.com/wjakob/nanobind.git GIT_TAG v1.4.0) -FetchContent_MakeAvailable(nanobind) +add_subdirectory(mscclpp) +add_subdirectory(test) + +add_custom_target(pylib-copy) +add_custom_command(TARGET pylib-copy POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different + ${CMAKE_CURRENT_BINARY_DIR}/mscclpp/_mscclpp.cpython-38-x86_64-linux-gnu.so + ${CMAKE_CURRENT_SOURCE_DIR}/mscclpp + COMMAND ${CMAKE_COMMAND} -E copy_if_different + ${CMAKE_CURRENT_BINARY_DIR}/test/_ext.cpython-38-x86_64-linux-gnu.so + ${CMAKE_CURRENT_SOURCE_DIR}/test/_cpp + COMMAND ${CMAKE_COMMAND} -E echo "Copy python libraries" +) -file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS *.cpp) -nanobind_add_module(mscclpp_py ${SOURCES}) -set_target_properties(mscclpp_py PROPERTIES OUTPUT_NAME _mscclpp) -target_link_libraries(mscclpp_py PRIVATE mscclpp_static) -target_include_directories(mscclpp_py PRIVATE ${CUDAToolkit_INCLUDE_DIRS}) -install(TARGETS mscclpp_py LIBRARY DESTINATION .) diff --git a/python/mscclpp/CMakeLists.txt b/python/mscclpp/CMakeLists.txt new file mode 100644 index 000000000..22fd318e9 --- /dev/null +++ b/python/mscclpp/CMakeLists.txt @@ -0,0 +1,14 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED) +include(FetchContent) +FetchContent_Declare(nanobind GIT_REPOSITORY https://github.com/wjakob/nanobind.git GIT_TAG v1.4.0) +FetchContent_MakeAvailable(nanobind) + +file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS *.cpp) +nanobind_add_module(mscclpp_py ${SOURCES}) +set_target_properties(mscclpp_py PROPERTIES OUTPUT_NAME _mscclpp) +target_link_libraries(mscclpp_py PRIVATE mscclpp_static) +target_include_directories(mscclpp_py PRIVATE ${CUDAToolkit_INCLUDE_DIRS}) +install(TARGETS mscclpp_py LIBRARY DESTINATION .) diff --git a/python/core_py.cpp b/python/mscclpp/core_py.cpp similarity index 100% rename from python/core_py.cpp rename to python/mscclpp/core_py.cpp diff --git a/python/error_py.cpp b/python/mscclpp/error_py.cpp similarity index 100% rename from python/error_py.cpp rename to python/mscclpp/error_py.cpp diff --git a/python/fifo_py.cpp b/python/mscclpp/fifo_py.cpp similarity index 100% rename from python/fifo_py.cpp rename to python/mscclpp/fifo_py.cpp diff --git a/python/numa_py.cpp b/python/mscclpp/numa_py.cpp similarity index 100% rename from python/numa_py.cpp rename to python/mscclpp/numa_py.cpp diff --git a/python/proxy_channel_py.cpp b/python/mscclpp/proxy_channel_py.cpp similarity index 100% rename from python/proxy_channel_py.cpp rename to python/mscclpp/proxy_channel_py.cpp diff --git a/python/semaphore_py.cpp b/python/mscclpp/semaphore_py.cpp similarity index 100% rename from python/semaphore_py.cpp rename to python/mscclpp/semaphore_py.cpp diff --git a/python/sm_channel_py.cpp b/python/mscclpp/sm_channel_py.cpp similarity index 100% rename from python/sm_channel_py.cpp rename to python/mscclpp/sm_channel_py.cpp diff --git a/python/utils_py.cpp b/python/mscclpp/utils_py.cpp similarity index 100% rename from python/utils_py.cpp rename to python/mscclpp/utils_py.cpp diff --git a/python/test/CMakeLists.txt b/python/test/CMakeLists.txt new file mode 100644 index 000000000..356934536 --- /dev/null +++ b/python/test/CMakeLists.txt @@ -0,0 +1,13 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED) +include(FetchContent) +FetchContent_Declare(nanobind GIT_REPOSITORY https://github.com/wjakob/nanobind.git GIT_TAG v1.4.0) +FetchContent_MakeAvailable(nanobind) + +file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS *.cpp) +nanobind_add_module(mscclpp_py_test ${SOURCES}) +set_target_properties(mscclpp_py_test PROPERTIES OUTPUT_NAME _ext) +target_link_libraries(mscclpp_py_test PRIVATE mscclpp_static) +target_include_directories(mscclpp_py_test PRIVATE ${CUDAToolkit_INCLUDE_DIRS}) diff --git a/python/test/__init__.py b/python/test/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/test/_cpp/__init__.py b/python/test/_cpp/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/test/_cpp/proxy_test.cpp b/python/test/_cpp/proxy_test.cpp new file mode 100644 index 000000000..e44f0f6f5 --- /dev/null +++ b/python/test/_cpp/proxy_test.cpp @@ -0,0 +1,83 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace nb = nanobind; + +class MyProxyService { + private: + int deviceNumaNode_; + int my_rank_, nranks_, dataSize_; + std::vector> connections_; + std::vector> allRegMem_; + std::vector> semaphores_; + mscclpp::Proxy proxy_; + + public: + MyProxyService(int my_rank, int nranks, int dataSize, std::vector> conns, + std::vector> allRegMem, + std::vector> semaphores) + : my_rank_(my_rank), + nranks_(nranks), + dataSize_(dataSize), + connections_(conns), + allRegMem_(allRegMem), + semaphores_(semaphores), + proxy_([&](mscclpp::ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); }, [&]() { bindThread(); }) { + int cudaDevice; + cudaGetDevice(&cudaDevice); + deviceNumaNode_ = mscclpp::getDeviceNumaNode(cudaDevice); + } + + void bindThread() { + if (deviceNumaNode_ >= 0) { + mscclpp::numaBind(deviceNumaNode_); + } + } + + mscclpp::ProxyHandlerResult handleTrigger(mscclpp::ProxyTrigger) { + int dataSizePerRank = dataSize_ / nranks_; + for (int r = 1; r < nranks_; ++r) { + int nghr = (my_rank_ + r) % nranks_; + connections_[nghr]->write(*allRegMem_[nghr], my_rank_ * (uint64_t)dataSizePerRank, *allRegMem_[my_rank_], + my_rank_ * (uint64_t)dataSizePerRank, dataSizePerRank); + semaphores_[nghr]->signal(); + connections_[nghr]->flush(); + } + return mscclpp::ProxyHandlerResult::FlushFifoTailAndContinue; + } + + void start() { proxy_.start(); } + + void stop() { proxy_.stop(); } + + mscclpp::FifoDeviceHandle fifoDeviceHandle() { return proxy_.fifo().deviceHandle(); } +}; + +void init_mscclpp_proxy_test_module(nb::module_ &m) { + nb::class_(m, "MyProxyService") + .def(nb::init>, + std::vector>, + std::vector>>(), + nb::arg("rank"), nb::arg("nranks"), nb::arg("data_size"), nb::arg("conn_vec"), nb::arg("reg_mem_vec"), + nb::arg("h2d_sem_vec")) + .def("fifo_device_handle", &MyProxyService::fifoDeviceHandle) + .def("start", &MyProxyService::start) + .def("stop", &MyProxyService::stop); +} + +NB_MODULE(_ext, m) { init_mscclpp_proxy_test_module(m); } diff --git a/python/test/d2d_semaphore_test.cu b/python/test/d2d_semaphore_test.cu new file mode 100644 index 000000000..04b945e3d --- /dev/null +++ b/python/test/d2d_semaphore_test.cu @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include + +// be careful about using semaphore[my_rank] as it is an invalid semaphore and it is there just for simplicity of +// indexing +extern "C" __global__ void __launch_bounds__(1024, 1) + d2d_semaphore(mscclpp::SmDevice2DeviceSemaphoreDeviceHandle* semaphores, int my_rank, int nranks) { + int tid = threadIdx.x; + if (tid < nranks && tid != my_rank) { + semaphores[tid].signal(); + semaphores[tid].wait(); + } +} diff --git a/python/test/fifo_test.cu b/python/test/fifo_test.cu new file mode 100644 index 000000000..e3e39c79a --- /dev/null +++ b/python/test/fifo_test.cu @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include + +#include "mscclpp/fifo_device.hpp" + +extern "C" __global__ void __launch_bounds__(1024, 1) fifo(mscclpp::FifoDeviceHandle fifo) { + mscclpp::ProxyTrigger trigger; + trigger.fst = 123; + fifo.push(trigger); +} diff --git a/python/test/h2d_semaphore_test.cu b/python/test/h2d_semaphore_test.cu new file mode 100644 index 000000000..b68d6d762 --- /dev/null +++ b/python/test/h2d_semaphore_test.cu @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include + +// be careful about using semaphore[my_rank] as it is an invalid semaphore and it is there just for simplicity of +// indexing +extern "C" __global__ void __launch_bounds__(1024, 1) + h2d_semaphore(mscclpp::Host2DeviceSemaphoreDeviceHandle* semaphores, int my_rank, int nranks) { + int tid = threadIdx.x; + if (tid < nranks && tid != my_rank) semaphores[tid].wait(); +} diff --git a/python/test/mscclpp_group.py b/python/test/mscclpp_group.py new file mode 100644 index 000000000..2412854ab --- /dev/null +++ b/python/test/mscclpp_group.py @@ -0,0 +1,154 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations +import logging +from typing import Type + +import cupy as cp +from mscclpp import ( + Communicator, + Connection, + Host2DeviceSemaphore, + Host2HostSemaphore, + ProxyService, + RegisteredMemory, + SimpleProxyChannel, + SmChannel, + SmDevice2DeviceSemaphore, + TcpBootstrap, + Transport, + TransportFlags, +) +import numpy as np + +from .mscclpp_mpi import MpiGroup + +logger = logging.getLogger(__name__) + + +class MscclppGroup: + def __init__(self, mpi_group: MpiGroup, interfaceIpPortTrio=""): + self.bootstrap = TcpBootstrap.create(mpi_group.comm.rank, mpi_group.comm.size) + if interfaceIpPortTrio == "": + uniq_id = None + if mpi_group.comm.rank == 0: + # similar to NCCL's unique id + uniq_id = self.bootstrap.create_unique_id() + uniq_id_global = mpi_group.comm.bcast(uniq_id, 0) + self.bootstrap.initialize(uniq_id_global) + else: + # use this instead + self.bootstrap.initialize(interfaceIpPortTrio) + self.communicator = Communicator(self.bootstrap) + self.my_rank = self.bootstrap.get_rank() + self.nranks = self.bootstrap.get_n_ranks() + + def barrier(self): + self.bootstrap.barrier() + + def send(self, tensor: np.ndarray, peer: int, tag: int): + self.bootstrap.send(tensor.ctypes.data, tensor.size * tensor.itemsize, peer, tag) + + def recv(self, tensor: np.ndarray, peer: int, tag: int): + self.bootstrap.recv(tensor.ctypes.data, tensor.size * tensor.itemsize, peer, tag) + + def my_ib_device(self, local_rank: int) -> Transport: + if local_rank == 0: + return Transport.IB0 + if local_rank == 1: + return Transport.IB1 + if local_rank == 2: + return Transport.IB2 + if local_rank == 3: + return Transport.IB3 + if local_rank == 4: + return Transport.IB4 + if local_rank == 5: + return Transport.IB5 + if local_rank == 6: + return Transport.IB6 + if local_rank == 7: + return Transport.IB7 + else: + assert False # only 8 IBs are supported + + def make_connection(self, remote_ranks: list[int], transport: Transport) -> dict[int, Connection]: + connections = {} + for rank in remote_ranks: + connections[rank] = self.communicator.connect_on_setup(rank, 0, transport) + self.communicator.setup() + return connections + + def register_tensor_with_connections( + self, tensor: Type[cp.ndarray] or Type[np.ndarray], connections: dict[int, Connection] + ) -> dict[int, RegisteredMemory]: + transport_flags = TransportFlags() + for rank in connections: + transport_flags |= connections[rank].transport() + data_ptr = tensor.data.ptr if isinstance(tensor, cp.ndarray) else tensor.ctypes.data + local_reg_memory = self.communicator.register_memory(data_ptr, tensor.size * tensor.itemsize, transport_flags) + all_registered_memories = {} + all_registered_memories[self.my_rank] = local_reg_memory + future_memories = {} + for rank in connections: + self.communicator.send_memory_on_setup(local_reg_memory, rank, 0) + future_memories[rank] = self.communicator.recv_memory_on_setup(rank, 0) + self.communicator.setup() + for rank in connections: + all_registered_memories[rank] = future_memories[rank].get() + return all_registered_memories + + def make_semaphore( + self, + connections: dict[int, Connection], + semaphore_type: Type[Host2HostSemaphore] or Type[Host2DeviceSemaphore] or Type[SmDevice2DeviceSemaphore], + ) -> dict[int, Host2HostSemaphore]: + semaphores = {} + for rank in connections: + semaphores[rank] = semaphore_type(self.communicator, connections[rank]) + self.communicator.setup() + return semaphores + + def make_sm_channels(self, tensor: cp.ndarray, connections: dict[int, Connection]) -> dict[int, SmChannel]: + semaphores = self.make_semaphore(connections, SmDevice2DeviceSemaphore) + registered_memories = self.register_tensor_with_connections(tensor, connections) + channels = {} + for rank in connections: + channels[rank] = SmChannel(semaphores[rank], registered_memories[rank], tensor.data.ptr) + return channels + + def make_sm_channels_with_packet( + self, tensor: cp.ndarray, packetTensor: cp.ndarray, connections: dict[int, Connection] + ) -> dict[int, SmChannel]: + semaphores = self.make_semaphore(connections, SmDevice2DeviceSemaphore) + registered_memories = self.register_tensor_with_connections(packetTensor, connections) + channels = {} + for rank in connections: + channels[rank] = SmChannel( + semaphores[rank], + registered_memories[rank], + tensor.data.ptr, + packetTensor.data.ptr, + ) + return channels + + def make_proxy_channels_with_packet( + self, proxy_service: ProxyService, tensor: cp.ndarray, connections: dict[int, Connection] + ) -> dict[int, SmChannel]: + semaphores = self.make_semaphore(connections, Host2DeviceSemaphore) + registered_memories = self.register_tensor_with_connections(tensor, connections) + memory_ids = {} + semaphore_ids = {} + for rank in registered_memories: + memory_ids[rank] = proxy_service.add_memory(registered_memories[rank]) + for rank in semaphores: + semaphore_ids[rank] = proxy_service.add_semaphore(semaphores[rank]) + channels = {} + for rank in semaphores: + channels[rank] = SimpleProxyChannel( + proxy_service.proxy_channel(semaphore_ids[rank]), + memory_ids[rank], + memory_ids[self.my_rank], + ) + return channels diff --git a/python/test/mscclpp_mpi.py b/python/test/mscclpp_mpi.py new file mode 100644 index 000000000..1f37eb9c6 --- /dev/null +++ b/python/test/mscclpp_mpi.py @@ -0,0 +1,72 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import atexit +import logging + +import cupy as cp +import mpi4py + +mpi4py.rc.initialize = False +mpi4py.rc.finalize = False + +from mpi4py import MPI +import pytest + +N_GPUS_PER_NODE = 8 + +logging.basicConfig(level=logging.INFO) + + +def init_mpi(): + if not MPI.Is_initialized(): + MPI.Init() + shm_comm = MPI.COMM_WORLD.Split_type(MPI.COMM_TYPE_SHARED, 0, MPI.INFO_NULL) + N_GPUS_PER_NODE = shm_comm.size + shm_comm.Free() + cp.cuda.Device(MPI.COMM_WORLD.rank % N_GPUS_PER_NODE).use() + + +# Define a function to finalize MPI +def finalize_mpi(): + if MPI.Is_initialized(): + MPI.Finalize() + + +# Register the function to be called on exit +atexit.register(finalize_mpi) + + +class MpiGroup: + def __init__(self, ranks: list): + world_group = MPI.COMM_WORLD.group + group = world_group.Incl(ranks) + self.comm = MPI.COMM_WORLD.Create(group) + + +@pytest.fixture +def mpi_group(request: pytest.FixtureRequest): + MPI.COMM_WORLD.barrier() + if request.param is None: + pytest.skip(f"Skip for rank {MPI.COMM_WORLD.rank}") + yield request.param + + +def parametrize_mpi_groups(*tuples: tuple): + def decorator(func): + mpi_groups = [] + for group_size in list(tuples): + if MPI.COMM_WORLD.size < group_size: + logging.warning(f"MPI.COMM_WORLD.size < {group_size}, skip") + continue + mpi_group = MpiGroup(list(range(group_size))) + if mpi_group.comm == MPI.COMM_NULL: + mpi_groups.append(None) + else: + mpi_groups.append(mpi_group) + return pytest.mark.parametrize("mpi_group", mpi_groups, indirect=True)(func) + + return decorator + + +init_mpi() diff --git a/python/test/proxy_test.cu b/python/test/proxy_test.cu new file mode 100644 index 000000000..78b932c94 --- /dev/null +++ b/python/test/proxy_test.cu @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include +#include + +extern "C" __global__ void __launch_bounds__(1024, 1) proxy(int my_rank, int nranks, mscclpp::FifoDeviceHandle fifo, + mscclpp::Host2DeviceSemaphoreDeviceHandle* semaphores) { + int tid = threadIdx.x; + if (tid == 0) { + mscclpp::ProxyTrigger trigger; + trigger.fst = 123; + trigger.snd = 0; + uint64_t currentFifoHead = fifo.push(trigger); + // wait for the work to be done in cpu side + fifo.sync(currentFifoHead); + } + __syncthreads(); + if (tid < nranks && tid != my_rank) { + semaphores[tid].wait(); + } +} diff --git a/python/test/requirements_cu11.txt b/python/test/requirements_cu11.txt new file mode 100644 index 000000000..2b79ab977 --- /dev/null +++ b/python/test/requirements_cu11.txt @@ -0,0 +1,6 @@ +cuda-python==12.1.0 +mpi4py==3.1.4 +netifaces==0.11.0 +numpy==1.22.2 +pytest==7.2.2 +cupy-cuda11x diff --git a/python/test/requirements_cu12.txt b/python/test/requirements_cu12.txt new file mode 100644 index 000000000..0061438d2 --- /dev/null +++ b/python/test/requirements_cu12.txt @@ -0,0 +1,6 @@ +cuda-python==12.1.0 +mpi4py==3.1.4 +netifaces==0.11.0 +numpy==1.22.2 +pytest==7.2.2 +cupy-cuda12x diff --git a/python/test/simple_proxy_channel_test.cu b/python/test/simple_proxy_channel_test.cu new file mode 100644 index 000000000..51b5f3472 --- /dev/null +++ b/python/test/simple_proxy_channel_test.cu @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include +#include + +// be careful about using channels[my_rank] as it is inavlie and it is there just for simplicity of indexing +extern "C" __global__ void __launch_bounds__(1024, 1) + simple_proxy_channel(mscclpp::SimpleProxyChannelDeviceHandle* channels, int my_rank, int nranks, int* data, + int* scratch, int num_elements, int use_packet) { + int tid = threadIdx.x; + int nthreads = blockDim.x; + uint64_t size_per_rank = (num_elements * sizeof(int)) / nranks; + uint64_t my_offset = size_per_rank * my_rank; + int nthreads_per_rank = nthreads / nranks; + int my_nghr = tid / nthreads_per_rank; + uint64_t my_nghr_offset = size_per_rank * my_nghr; + __syncthreads(); + int flag = 123; + if (use_packet) { + mscclpp::putPackets(scratch, 2 * my_offset, data, my_offset, size_per_rank, tid, nthreads, flag); + __syncthreads(); + if (tid < nranks && tid != my_rank) { + channels[tid].put(2 * my_offset, 2 * my_offset, 2 * size_per_rank); + } + if (my_nghr != my_rank && my_nghr < nranks) + mscclpp::getPackets(data, my_nghr_offset, scratch, 2 * my_nghr_offset, size_per_rank, tid % nthreads_per_rank, + nthreads_per_rank, flag); + } else { + if (tid < nranks && tid != my_rank) { + channels[tid].putWithSignalAndFlush(my_offset, my_offset, size_per_rank); + channels[tid].wait(); + } + } +} diff --git a/python/test/sm_channel_test.cu b/python/test/sm_channel_test.cu new file mode 100644 index 000000000..7902687b6 --- /dev/null +++ b/python/test/sm_channel_test.cu @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include + +// be careful about using channels[my_rank] as it is inavlie and it is there just for simplicity of indexing +extern "C" __global__ void __launch_bounds__(1024, 1) + sm_channel(mscclpp::SmChannelDeviceHandle* channels, int my_rank, int nranks, int num_elements, int use_packet) { + int tid = threadIdx.x; + int bid = blockIdx.x; + uint64_t size_per_rank = (num_elements * sizeof(int)) / nranks; + uint64_t my_offset = size_per_rank * my_rank; + uint64_t my_nghr_offset = size_per_rank * bid; + int flag = 123; + if (bid < nranks && bid != my_rank) { + if (use_packet) { + channels[bid].putPackets(2 * my_offset, my_offset, size_per_rank, tid, blockDim.x, flag); + channels[bid].getPackets(my_nghr_offset, 2 * my_nghr_offset, size_per_rank, tid, blockDim.x, flag); + } else { + channels[bid].put(my_offset, my_offset, size_per_rank, tid, blockDim.x); + __syncthreads(); + if (!use_packet && tid == 0) { + channels[bid].signal(); + channels[bid].wait(); + } + } + } +} diff --git a/python/test/test_mscclpp.py b/python/test/test_mscclpp.py new file mode 100644 index 000000000..0be3b2126 --- /dev/null +++ b/python/test/test_mscclpp.py @@ -0,0 +1,478 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from concurrent.futures import ThreadPoolExecutor +import time + +import cupy as cp +import numpy as np +import netifaces as ni +import pytest + +from mscclpp import ( + Fifo, + Host2DeviceSemaphore, + Host2HostSemaphore, + ProxyService, + SmDevice2DeviceSemaphore, + Transport, +) +from ._cpp import _ext +from .mscclpp_group import MscclppGroup +from .mscclpp_mpi import MpiGroup, parametrize_mpi_groups, mpi_group +from .utils import KernelBuilder, pack + +ethernet_interface_name = "eth0" + + +def all_ranks_on_the_same_node(mpi_group: MpiGroup): + if (ethernet_interface_name in ni.interfaces()) is False: + pytest.skip(f"{ethernet_interface_name} is not an interface to use on this node") + my_ip = ni.ifaddresses(ethernet_interface_name)[ni.AF_INET][0]["addr"] + root_ip = mpi_group.comm.bcast(my_ip, 0) + last_rank_ip = mpi_group.comm.bcast(my_ip, mpi_group.comm.size - 1) + return last_rank_ip == root_ip + + +@parametrize_mpi_groups(2, 4, 8, 16) +@pytest.mark.parametrize("ifIpPortTrio", ["eth0:localhost:50000", ethernet_interface_name, ""]) +def test_group_with_ip(mpi_group: MpiGroup, ifIpPortTrio: str): + if (ethernet_interface_name in ni.interfaces()) is False: + pytest.skip(f"{ethernet_interface_name} is not an interface to use on this node") + my_ip = ni.ifaddresses(ethernet_interface_name)[ni.AF_INET][0]["addr"] + root_ip = mpi_group.comm.bcast(my_ip, 0) + if ifIpPortTrio == ethernet_interface_name: + ifIpPortTrio += ":" + root_ip + ":50000" # some random port + + if all_ranks_on_the_same_node(mpi_group) is False and "localhost" in ifIpPortTrio: + # ranks are on different nodes + pytest.skip("this case is not supported as localhost will be different for different nodes") + + group = MscclppGroup(mpi_group, ifIpPortTrio) + + nelem = 1024 + memory = np.zeros(nelem, dtype=np.int32) + nelemPerRank = nelem // group.nranks + memory[(nelemPerRank * group.my_rank) : (nelemPerRank * (group.my_rank + 1))] = group.my_rank + 1 + memory_expected = np.zeros_like(memory) + for rank in range(group.nranks): + memory_expected[(nelemPerRank * rank) : (nelemPerRank * (rank + 1))] = rank + 1 + + for rank in range(group.nranks): + if rank == group.my_rank: + continue + group.send( + memory[(nelemPerRank * group.my_rank) : (nelemPerRank * (group.my_rank + 1))], + rank, + 0, + ) + for rank in range(group.nranks): + if rank == group.my_rank: + continue + group.recv(memory[(nelemPerRank * rank) : (nelemPerRank * (rank + 1))], rank, 0) + + assert np.array_equal(memory, memory_expected) + + +def create_and_connect(mpi_group: MpiGroup, transport: str): + if transport == "NVLink" and all_ranks_on_the_same_node(mpi_group) is False: + pytest.skip("cannot use nvlink for cross node") + group = MscclppGroup(mpi_group) + + remote_nghrs = list(range(mpi_group.comm.size)) + remote_nghrs.remove(mpi_group.comm.rank) + if transport == "NVLink": + tran = Transport.CudaIpc + elif transport == "IB": + tran = group.my_ib_device(group.my_rank % 8) + else: + assert False + connections = group.make_connection(remote_nghrs, tran) + return group, connections + + +@parametrize_mpi_groups(2, 4, 8, 16) +@pytest.mark.parametrize("transport", ["IB", "NVLink"]) +def test_group_with_connections(mpi_group: MpiGroup, transport: str): + create_and_connect(mpi_group, transport) + + +@parametrize_mpi_groups(2, 4, 8, 16) +@pytest.mark.parametrize("transport", ["IB", "NVLink"]) +@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]]) +def test_connection_write(mpi_group: MpiGroup, transport: Transport, nelem: int): + group, connections = create_and_connect(mpi_group, transport) + memory = cp.zeros(nelem, dtype=cp.int32) + nelemPerRank = nelem // group.nranks + sizePerRank = nelemPerRank * memory.itemsize + memory[(nelemPerRank * group.my_rank) : (nelemPerRank * (group.my_rank + 1))] = group.my_rank + 1 + memory_expected = cp.zeros_like(memory) + for rank in range(group.nranks): + memory_expected[(nelemPerRank * rank) : (nelemPerRank * (rank + 1))] = rank + 1 + group.barrier() + all_reg_memories = group.register_tensor_with_connections(memory, connections) + for rank in connections: + connections[rank].write( + all_reg_memories[rank], + sizePerRank * group.my_rank, + all_reg_memories[group.my_rank], + sizePerRank * group.my_rank, + sizePerRank, + ) + poll_for = 100 + for i in range(poll_for): + all_correct = cp.array_equal(memory, memory_expected) + if all_correct: + break + time.sleep(0.1) + for conn in connections: + connections[conn].flush() + cp.cuda.runtime.deviceSynchronize() + group.barrier() + assert all_correct + + +@parametrize_mpi_groups(2, 4, 8, 16) +@pytest.mark.parametrize("transport", ["IB", "NVLink"]) +@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]]) +@pytest.mark.parametrize("device", ["cuda", "cpu"]) +def test_connection_write_and_signal(mpi_group: MpiGroup, transport: Transport, nelem: int, device: str): + # this test starts with a random tensor on rank 0 and rotates it all the way through all ranks + # and finally, comes back to rank 0 to make sure it matches all the original values + + if device == "cpu" and transport == "NVLink": + pytest.skip("nvlink doesn't work with host allocated memory") + group, connections = create_and_connect(mpi_group, transport) + xp = cp if device == "cuda" else np + if group.my_rank == 0: + memory = xp.random.randn(nelem) + memory = memory.astype(xp.float32) + memory_expected = memory.copy() + else: + memory = xp.zeros(nelem, dtype=xp.float32) + + signal_memory = xp.zeros(1, dtype=xp.int64) + all_reg_memories = group.register_tensor_with_connections(memory, connections) + all_signal_memories = group.register_tensor_with_connections(signal_memory, connections) + + next_rank = (group.my_rank + 1) % group.nranks + bufferSize = nelem * memory.itemsize + dummy_memory_on_cpu = np.zeros(1, dtype=np.int64) + + signal_val = 123 + if group.my_rank != 0: + while signal_memory[0] != signal_val: + time.sleep(0.1) + connections[next_rank].write(all_reg_memories[next_rank], 0, all_reg_memories[group.my_rank], 0, bufferSize) + connections[next_rank].flush() + if group.my_rank == 0: + memory[:] = 0 + connections[next_rank].update_and_sync( + all_signal_memories[next_rank], 0, dummy_memory_on_cpu.ctypes.data, signal_val + ) + all_correct = False + if group.my_rank == 0: + while signal_memory[0] != signal_val: + time.sleep(0.1) + all_correct = cp.array_equal(memory, memory_expected) + group.barrier() + all_correct = mpi_group.comm.bcast(all_correct, 0) + assert all_correct + + +@parametrize_mpi_groups(2, 4, 8, 16) +def test_h2h_semaphores(mpi_group: MpiGroup): + group, connections = create_and_connect(mpi_group, "IB") + + semaphores = group.make_semaphore(connections, Host2HostSemaphore) + for rank in connections: + semaphores[rank].signal() + + for rank in connections: + semaphores[rank].wait() + group.barrier() + + +class MscclppKernel: + def __init__( + self, + test_name, + my_rank=None, + nranks=None, + semaphore_or_channels=None, + tensor=None, + use_packet=False, + scratch=None, + fifo=None, + ): + if test_name == "h2d_semaphore": + self._kernel = KernelBuilder( + file="h2d_semaphore_test.cu", + kernel_name="h2d_semaphore", + ).get_compiled_kernel() + self.nblocks = 1 + self.nthreads = nranks + elif test_name == "d2d_semaphore": + self._kernel = KernelBuilder( + file="d2d_semaphore_test.cu", + kernel_name="d2d_semaphore", + ).get_compiled_kernel() + self.nblocks = 1 + self.nthreads = nranks + elif test_name == "sm_channel": + self._kernel = KernelBuilder( + file="sm_channel_test.cu", + kernel_name="sm_channel", + ).get_compiled_kernel() + self.nblocks = nranks + self.nthreads = 1024 + elif test_name == "fifo": + self._kernel = KernelBuilder( + file="fifo_test.cu", + kernel_name="fifo", + ).get_compiled_kernel() + self.nblocks = 1 + self.nthreads = 1 + elif test_name == "proxy": + self._kernel = KernelBuilder( + file="proxy_test.cu", + kernel_name="proxy", + ).get_compiled_kernel() + self.nblocks = 1 + self.nthreads = nranks + elif test_name == "simple_proxy_channel": + self._kernel = KernelBuilder( + file="simple_proxy_channel_test.cu", + kernel_name="simple_proxy_channel", + ).get_compiled_kernel() + self.nblocks = 1 + self.nthreads = 1024 + else: + assert False + + self.params = b"" + if test_name in ["h2d_semaphore", "d2d_semaphore", "sm_channel", "simple_proxy_channel"]: + first_arg = next(iter(semaphore_or_channels.values())) + size_of_semaphore_or_channels = len(first_arg.device_handle().raw) + device_handles = [] + for rank in range(nranks): + if rank == my_rank: + device_handles.append( + bytes(size_of_semaphore_or_channels) + ) # just zeros for semaphores that do not exist + else: + device_handles.append(semaphore_or_channels[rank].device_handle().raw) + # keep a reference to the device handles so that they don't get garbage collected + self._d_semaphore_or_channels = cp.asarray(memoryview(b"".join(device_handles)), dtype=cp.uint8) + self.params += pack(self._d_semaphore_or_channels, my_rank, nranks) + if test_name == "sm_channel": + self.params += pack(tensor.size, use_packet) + if test_name == "simple_proxy_channel": + self.params += pack(tensor, scratch, tensor.size, use_packet) + elif test_name == "fifo": + self.params = fifo.device_handle().raw + elif test_name == "proxy": + semaphore_device_handles = [semaphore.device_handle().raw for semaphore in semaphore_or_channels] + self._d_semaphore_or_channels = cp.asarray(memoryview(b"".join(semaphore_device_handles)), dtype=cp.uint8) + self.params = pack(my_rank, nranks) + fifo.raw + pack(self._d_semaphore_or_channels) + + def __call__(self): + return self._kernel.launch_kernel(self.params, self.nblocks, self.nthreads, 0, None) + + +@parametrize_mpi_groups(2, 4, 8, 16) +@pytest.mark.parametrize("transport", ["NVLink", "IB"]) +def test_h2d_semaphores(mpi_group: MpiGroup, transport: str): + def signal(semaphores): + for rank in semaphores: + semaphores[rank].signal() + + group, connections = create_and_connect(mpi_group, transport) + + semaphores = group.make_semaphore(connections, Host2DeviceSemaphore) + kernel = MscclppKernel("h2d_semaphore", group.my_rank, group.nranks, semaphores) + kernel() + + # workaround: use a separate thread to to let cudaMemcpyAsync run concurrently with the kernel + with ThreadPoolExecutor(max_workers=1) as executor: + executor.submit(signal, semaphores) + + cp.cuda.runtime.deviceSynchronize() + group.barrier() + + +@parametrize_mpi_groups(2, 4, 8, 16) +def test_d2d_semaphores(mpi_group: MpiGroup): + group, connections = create_and_connect(mpi_group, "NVLink") + + semaphores = group.make_semaphore(connections, SmDevice2DeviceSemaphore) + group.barrier() + kernel = MscclppKernel("d2d_semaphore", group.my_rank, group.nranks, semaphores) + kernel() + cp.cuda.runtime.deviceSynchronize() + group.barrier() + + +@parametrize_mpi_groups(2, 4, 8, 16) +@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]]) +@pytest.mark.parametrize("use_packet", [False, True]) +def test_sm_channels(mpi_group: MpiGroup, nelem: int, use_packet: bool): + group, connections = create_and_connect(mpi_group, "NVLink") + + memory = cp.zeros(nelem, dtype=cp.int32) + if use_packet: + scratch = cp.zeros(nelem * 2, dtype=cp.int32) + else: + scratch = None + nelemPerRank = nelem // group.nranks + nelemPerRank * memory.itemsize + memory[(nelemPerRank * group.my_rank) : (nelemPerRank * (group.my_rank + 1))] = group.my_rank + 1 + memory_expected = cp.zeros_like(memory) + for rank in range(group.nranks): + memory_expected[(nelemPerRank * rank) : (nelemPerRank * (rank + 1))] = rank + 1 + + if use_packet: + channels = group.make_sm_channels_with_packet(memory, scratch, connections) + else: + channels = group.make_sm_channels(memory, connections) + kernel = MscclppKernel("sm_channel", group.my_rank, group.nranks, channels, memory, use_packet, scratch) + + group.barrier() + kernel() + cp.cuda.runtime.deviceSynchronize() + group.barrier() + assert cp.array_equal(memory, memory_expected) + + +@parametrize_mpi_groups(2, 4, 8, 16) +def test_fifo( + mpi_group: MpiGroup, +): + fifo = Fifo() + kernel = MscclppKernel("fifo", fifo=fifo) + + kernel() + poll_for = 100 + for _ in range(poll_for): + trigger = fifo.poll() + if trigger.fst == 123: + return + time.sleep(0.1) + assert False + + +@parametrize_mpi_groups(2, 4, 8, 16) +@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]]) +@pytest.mark.parametrize("transport", ["IB", "NVLink"]) +def test_proxy( + mpi_group: MpiGroup, + nelem: int, + transport: str, +): + group, connections = create_and_connect(mpi_group, transport) + + memory = cp.zeros( + nelem, + dtype=cp.int32, + ) + nelemPerRank = nelem // group.nranks + nelemPerRank * memory.itemsize + memory[(nelemPerRank * group.my_rank) : (nelemPerRank * (group.my_rank + 1))] = group.my_rank + 1 + memory_expected = cp.zeros_like(memory) + for rank in range(group.nranks): + memory_expected[(nelemPerRank * rank) : (nelemPerRank * (rank + 1))] = rank + 1 + group.barrier() + all_reg_memories = group.register_tensor_with_connections(memory, connections) + + semaphores = group.make_semaphore(connections, Host2DeviceSemaphore) + + list_conn = [] + list_sem = [] + list_reg_mem = [] + first_conn = next(iter(connections.values())) + first_sem = next(iter(semaphores.values())) + for rank in range(group.nranks): + if rank in connections: + list_conn.append(connections[rank]) + list_sem.append(semaphores[rank]) + else: + list_conn.append(first_conn) # just for simplicity of indexing + list_sem.append(first_sem) + + list_reg_mem.append(all_reg_memories[rank]) + + proxy = _ext.MyProxyService( + group.my_rank, + group.nranks, + nelem * memory.itemsize, + list_conn, + list_reg_mem, + list_sem, + ) + + fifo_device_handle = proxy.fifo_device_handle() + + kernel = MscclppKernel( + "proxy", + my_rank=group.my_rank, + nranks=group.nranks, + semaphore_or_channels=list_sem, + fifo=fifo_device_handle, + ) + proxy.start() + group.barrier() + kernel() + cp.cuda.runtime.deviceSynchronize() + proxy.stop() + group.barrier() + assert cp.array_equal(memory, memory_expected) + + +@parametrize_mpi_groups(2, 4, 8, 16) +@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]]) +@pytest.mark.parametrize("transport", ["NVLink", "IB"]) +@pytest.mark.parametrize("use_packet", [False, True]) +def test_simple_proxy_channel( + mpi_group: MpiGroup, + nelem: int, + transport: str, + use_packet: bool, +): + group, connections = create_and_connect(mpi_group, transport) + + memory = cp.zeros(nelem, dtype=cp.int32) + if use_packet: + scratch = cp.zeros(nelem * 2, dtype=cp.int32) + else: + scratch = cp.zeros(1, dtype=cp.int32) # just so that we can pass a valid ptr + nelemPerRank = nelem // group.nranks + nelemPerRank * memory.itemsize + memory[(nelemPerRank * group.my_rank) : (nelemPerRank * (group.my_rank + 1))] = group.my_rank + 1 + memory_expected = cp.zeros_like(memory) + for rank in range(group.nranks): + memory_expected[(nelemPerRank * rank) : (nelemPerRank * (rank + 1))] = rank + 1 + group.barrier() + + proxy_service = ProxyService() + if use_packet: + memory_to_register = scratch + else: + memory_to_register = memory + simple_channels = group.make_proxy_channels_with_packet(proxy_service, memory_to_register, connections) + + kernel = MscclppKernel( + "simple_proxy_channel", + my_rank=group.my_rank, + nranks=group.nranks, + semaphore_or_channels=simple_channels, + tensor=memory, + use_packet=use_packet, + scratch=scratch, + ) + proxy_service.start_proxy() + group.barrier() + kernel() + cp.cuda.runtime.deviceSynchronize() + proxy_service.stop_proxy() + group.barrier() + assert cp.array_equal(memory, memory_expected) diff --git a/python/test/utils.py b/python/test/utils.py new file mode 100644 index 000000000..d32c74421 --- /dev/null +++ b/python/test/utils.py @@ -0,0 +1,122 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import ctypes +import os +import struct +import subprocess +import tempfile +from typing import Type + +from cuda import cuda, nvrtc, cudart +import cupy as cp +import numpy as np + + +def _check_cuda_errors(result): + if result[0].value: + raise RuntimeError(f"CUDA error code={result[0].value}({_cuda_get_error(result[0])})") + if len(result) == 1: + return None + elif len(result) == 2: + return result[1] + else: + return result[1:] + + +def _cuda_get_error(error): + if isinstance(error, cuda.CUresult): + err, name = cuda.cuGetErrorName(error) + return name if err == cuda.CUresult.CUDA_SUCCESS else "" + elif isinstance(error, cudart.cudaError_t): + return cudart.cudaGetErrorName(error)[1] + elif isinstance(error, nvrtc.nvrtcResult): + return nvrtc.nvrtcGetErrorString(error)[1] + else: + raise RuntimeError("Unknown error type: {}".format(error)) + + +class Kernel: + def __init__(self, ptx: bytes, kernel_name: str, device_id: int): + self._context = _check_cuda_errors(cuda.cuCtxGetCurrent()) + assert self._context is not None + self._module = _check_cuda_errors(cuda.cuModuleLoadData(ptx)) + self._kernel = _check_cuda_errors(cuda.cuModuleGetFunction(self._module, kernel_name.encode())) + + def launch_kernel( + self, + params: bytes, + nblocks: int, + nthreads: int, + shared: int, + stream: Type[cuda.CUstream] or Type[cudart.cudaStream_t], + ): + buffer = (ctypes.c_byte * len(params)).from_buffer_copy(params) + buffer_size = ctypes.c_size_t(len(params)) + config = np.array( + [ + cuda.CU_LAUNCH_PARAM_BUFFER_POINTER, + ctypes.addressof(buffer), + cuda.CU_LAUNCH_PARAM_BUFFER_SIZE, + ctypes.addressof(buffer_size), + cuda.CU_LAUNCH_PARAM_END, + ], + dtype=np.uint64, + ) + _check_cuda_errors( + cuda.cuLaunchKernel(self._kernel, nblocks, 1, 1, nthreads, 1, 1, shared, stream, 0, config.ctypes.data) + ) + + def __del__(self): + cuda.cuModuleUnload(self._module) + + +class KernelBuilder: + def __init__(self, file: str, kernel_name: str): + self._tempdir = tempfile.TemporaryDirectory() + self._current_file_dir = os.path.dirname(os.path.abspath(__file__)) + device_id = cp.cuda.Device().id + ptx = self._compile_cuda(os.path.join(self._current_file_dir, file), f"{kernel_name}.ptx", device_id) + self._kernel = Kernel(ptx, kernel_name, device_id) + + def _compile_cuda(self, source_file, output_file, device_id, std_version="c++17"): + include_dir = os.path.join(self._current_file_dir, "../../include") + major = _check_cuda_errors( + cudart.cudaDeviceGetAttribute(cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMajor, device_id) + ) + minor = _check_cuda_errors( + cudart.cudaDeviceGetAttribute(cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMinor, device_id) + ) + command = ( + f"nvcc -std={std_version} -ptx -Xcompiler -Wall,-Wextra -I{include_dir} {source_file} " + f"--gpu-architecture=compute_{major}{minor} --gpu-code=sm_{major}{minor},compute_{major}{minor} -o {self._tempdir.name}/{output_file}" + ) + try: + subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + with open(f"{self._tempdir.name}/{output_file}", "rb") as f: + return f.read() + except subprocess.CalledProcessError as e: + raise RuntimeError("Compilation failed:", e.stderr.decode(), command) + + def get_compiled_kernel(self): + return self._kernel + + def __del__(self): + self._tempdir.cleanup() + + +def pack(*args): + res = b"" + for arg in list(args): + if isinstance(arg, int): + res += struct.pack("i", arg) + elif isinstance(arg, np.ndarray): + res += struct.pack("P", arg.ctypes.data) + elif isinstance(arg, cp.ndarray): + res += struct.pack("P", arg.data.ptr) + # use int to represent bool, which can avoid CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES error + elif isinstance(arg, bool): + res += struct.pack("i", arg) + else: + raise RuntimeError(f"Unsupported type: {type(arg)}") + return res diff --git a/src/include/registered_memory.hpp b/src/include/registered_memory.hpp index 779cd7965..627960a8e 100644 --- a/src/include/registered_memory.hpp +++ b/src/include/registered_memory.hpp @@ -35,12 +35,16 @@ struct RegisteredMemory::Impl { void* data; size_t size; int rank; + bool isRemote; uint64_t hostHash; TransportFlags transports; std::vector transportInfos; Impl(void* data, size_t size, int rank, TransportFlags transports, Communicator::Impl& commImpl); + /// Constructs a RegisteredMemory::Impl from a vector of data. The constructor should only be used for the remote + /// memory. Impl(const std::vector& data); + ~Impl(); const TransportInfo& getTransportInfo(Transport transport) const; }; diff --git a/src/registered_memory.cc b/src/registered_memory.cc index bb1ae3563..39a5ebb6f 100644 --- a/src/registered_memory.cc +++ b/src/registered_memory.cc @@ -15,7 +15,12 @@ namespace mscclpp { RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags transports, Communicator::Impl& commImpl) - : data(data), size(size), rank(rank), hostHash(commImpl.rankToHash_.at(rank)), transports(transports) { + : data(data), + size(size), + rank(rank), + isRemote(false), + hostHash(commImpl.rankToHash_.at(rank)), + transports(transports) { if (transports.has(Transport::CudaIpc)) { TransportInfo transportInfo; transportInfo.transport = Transport::CudaIpc; @@ -56,7 +61,7 @@ MSCCLPP_API_CPP RegisteredMemory::RegisteredMemory(std::shared_ptr pimpl) MSCCLPP_API_CPP RegisteredMemory::~RegisteredMemory() = default; -MSCCLPP_API_CPP void* RegisteredMemory::data() { return pimpl->data; } +MSCCLPP_API_CPP void* RegisteredMemory::data() const { return pimpl->data; } MSCCLPP_API_CPP size_t RegisteredMemory::size() { return pimpl->size; } @@ -142,6 +147,20 @@ RegisteredMemory::Impl::Impl(const std::vector& serialization) { INFO(MSCCLPP_P2P, "Opened CUDA IPC handle at pointer %p", data); } } + this->isRemote = true; +} + +RegisteredMemory::Impl::~Impl() { + uint64_t localHostHash = getHostHash(); + if (this->isRemote && localHostHash == this->hostHash && transports.has(Transport::CudaIpc)) { + void* base = static_cast(data) - getTransportInfo(Transport::CudaIpc).cudaIpcOffsetFromBase; + cudaError_t err = cudaIpcCloseMemHandle(base); + if (err != cudaSuccess) { + WARN("Failed to close cuda IPC handle: %s", cudaGetErrorString(err)); + } + INFO(MSCCLPP_P2P, "Closed CUDA IPC handle at pointer %p", base); + data = nullptr; + } } const TransportInfo& RegisteredMemory::Impl::getTransportInfo(Transport transport) const { diff --git a/src/sm_channel.cc b/src/sm_channel.cc index f1b2e2bc8..a148595bf 100644 --- a/src/sm_channel.cc +++ b/src/sm_channel.cc @@ -10,16 +10,17 @@ namespace mscclpp { MSCCLPP_API_CPP SmChannel::SmChannel(std::shared_ptr semaphore, RegisteredMemory dst, void* src, void* getPacketBuffer) - : semaphore_(semaphore), src_(src), getPacketBuffer_(getPacketBuffer) { + : semaphore_(semaphore), dst_(dst), src_(src), getPacketBuffer_(getPacketBuffer) { if (!dst.transports().has(Transport::CudaIpc)) { throw Error("SmChannel: dst must be registered with CudaIpc", ErrorCode::InvalidUsage); } - dst_ = dst.data(); } MSCCLPP_API_CPP SmChannel::DeviceHandle SmChannel::deviceHandle() const { - return DeviceHandle{ - .semaphore_ = semaphore_->deviceHandle(), .src_ = src_, .dst_ = dst_, .getPacketBuffer_ = getPacketBuffer_}; + return DeviceHandle{.semaphore_ = semaphore_->deviceHandle(), + .src_ = src_, + .dst_ = dst_.data(), + .getPacketBuffer_ = getPacketBuffer_}; } } // namespace mscclpp diff --git a/test/deploy/deploy.sh b/test/deploy/deploy.sh index 51c86d7b4..248c09c43 100644 --- a/test/deploy/deploy.sh +++ b/test/deploy/deploy.sh @@ -2,6 +2,8 @@ set -e KeyFilePath=${SSHKEYFILE_SECUREFILEPATH} SRC_DIR="${SYSTEM_DEFAULTWORKINGDIRECTORY}/build" +SRC_INCLUDE_DIR="${SYSTEM_DEFAULTWORKINGDIRECTORY}/include" +PYTHON_SRC_DIR="${SYSTEM_DEFAULTWORKINGDIRECTORY}/python" DST_DIR="/tmp/mscclpp" HOSTFILE="${SYSTEM_DEFAULTWORKINGDIRECTORY}/test/deploy/hostfile" DEPLOY_DIR="${SYSTEM_DEFAULTWORKINGDIRECTORY}/test/deploy" @@ -25,6 +27,8 @@ set -e parallel-ssh -i -t 0 -h ${HOSTFILE} -x "-i ${KeyFilePath}" -O $SSH_OPTION "rm -rf ${DST_DIR}" parallel-ssh -i -t 0 -h ${HOSTFILE} -x "-i ${KeyFilePath}" -O $SSH_OPTION "mkdir -p ${DST_DIR}" parallel-scp -t 0 -r -h ${HOSTFILE} -x "-i ${KeyFilePath}" -O $SSH_OPTION ${SRC_DIR} ${DST_DIR} +parallel-scp -t 0 -r -h ${HOSTFILE} -x "-i ${KeyFilePath}" -O $SSH_OPTION ${PYTHON_SRC_DIR} ${DST_DIR} +parallel-scp -t 0 -r -h ${HOSTFILE} -x "-i ${KeyFilePath}" -O $SSH_OPTION ${SRC_INCLUDE_DIR} ${DST_DIR} parallel-scp -t 0 -h ${HOSTFILE} -x "-i ${KeyFilePath}" -O $SSH_OPTION sshkey ${DST_DIR} parallel-scp -t 0 -h ${HOSTFILE} -x "-i ${KeyFilePath}" -O $SSH_OPTION sshkey.pub ${DST_DIR} diff --git a/test/deploy/pytest.sh b/test/deploy/pytest.sh new file mode 100644 index 000000000..26fec2e4f --- /dev/null +++ b/test/deploy/pytest.sh @@ -0,0 +1,7 @@ +#!/bin/bash +if [[ $OMPI_COMM_WORLD_RANK == 0 ]] +then + pytest /root/mscclpp/python/test/test_mscclpp.py -x -v +else + pytest /root/mscclpp/python/test/test_mscclpp.py -x 2>&1 >/dev/null +fi diff --git a/test/deploy/run_tests.sh b/test/deploy/run_tests.sh index d8c5bf386..3f1b431fc 100644 --- a/test/deploy/run_tests.sh +++ b/test/deploy/run_tests.sh @@ -60,6 +60,14 @@ function run_mp_ut() -npernode 8 /root/mscclpp/build/test/mp_unit_tests -ip_port mscclpp-it-000000:20003 } +function run_pytests() +{ + echo "==================Run python tests================================" + /usr/local/mpi/bin/mpirun -allow-run-as-root -tag-output -np 16 --bind-to numa \ + -hostfile /root/mscclpp/hostfile_mpi -x MSCCLPP_DEBUG=WARN -x LD_LIBRARY_PATH=/root/mscclpp/build:$LD_LIBRARY_PATH \ + -npernode 8 bash /root/mscclpp/pytest.sh +} + if [ $# -lt 1 ]; then echo "Usage: $0 " exit 1 @@ -74,6 +82,10 @@ case $test_name in echo "==================Run mp-ut on 2 nodes================================" run_mp_ut ;; + pytests) + echo "==================Run python tests================================" + run_pytests + ;; *) echo "Unknown test name: $test_name" exit 1 diff --git a/test/deploy/setup.sh b/test/deploy/setup.sh index 9be3b10c6..2b2c7f7e8 100644 --- a/test/deploy/setup.sh +++ b/test/deploy/setup.sh @@ -13,5 +13,11 @@ for i in $(seq 0 $(( $(nvidia-smi -L | wc -l) - 1 ))); do nvidia-smi -ac $(nvidia-smi --query-gpu=clocks.max.memory,clocks.max.sm --format=csv,noheader,nounits -i $i | sed 's/\ //') -i $i done +if [[ "${CUDA_VERSION}" == *"11."* ]]; then + pip3 install -r /root/mscclpp/python/test/requirements_cu11.txt +else + pip3 install -r /root/mscclpp/python/test/requirements_cu12.txt +fi + mkdir -p /var/run/sshd /usr/sbin/sshd -p 22345