diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index 600c781..83698d4 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -23,6 +23,42 @@ find_package(Python COMPONENTS Interpreter Development.Module REQUIRED) find_package(Torch REQUIRED) find_package(NVSHMEM REQUIRED) +# === NIXL Configuration === +set(PPLX_ENABLE_P2P $ENV{PPLX_ENABLE_P2P}) + +# Multi-channel worker configuration from environment (default 8) +set(PPLX_NIXL_NUM_CHANNELS $ENV{PPLX_NIXL_NUM_CHANNELS}) +if(NOT PPLX_NIXL_NUM_CHANNELS OR PPLX_NIXL_NUM_CHANNELS EQUAL 0) + set(PPLX_NIXL_NUM_CHANNELS 8) +endif() +message(STATUS "PPLX_NIXL_NUM_CHANNELS=${PPLX_NIXL_NUM_CHANNELS}") + +add_library(nixl_interface INTERFACE) + +target_link_libraries(nixl_interface INTERFACE + serdes + stream + ucx_utils + nixl_common + nixl_build + nixl +) + +# Enable P2P/NVLINK support if requested +if(PPLX_ENABLE_P2P) + target_compile_definitions(nixl_interface INTERFACE + PPLX_ENABLE_P2P + PPLX_ATOMIC_SCOPE=__NV_THREAD_SCOPE_SYSTEM) + message(STATUS "P2P/NVLINK support: ENABLED (using system-scope atomics)") +else() + target_compile_definitions(nixl_interface INTERFACE + PPLX_ATOMIC_SCOPE=__NV_THREAD_SCOPE_DEVICE) + message(STATUS "P2P/NVLINK support: DISABLED (using device-scope atomics)") +endif() + +# Multi-channel NIXL configuration (compile-time constant from env) +target_compile_definitions(nixl_interface INTERFACE PPLX_NIXL_NUM_CHANNELS=${PPLX_NIXL_NUM_CHANNELS}) + if(WITH_TESTS) enable_testing() find_package(MPI REQUIRED) @@ -46,8 +82,18 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}) # CUDA-specific compile options function function(set_cuda_compile_options target) + # Base CUDA compile options + set(CUDA_COMPILE_FLAGS "--threads=32" "-O3") + + # Workaround for CUDA 12.9 compilation bug + # See: https://github.com/openucx/ucx/pull/10960 + if(CUDAToolkit_VERSION_MAJOR EQUAL 12 AND CUDAToolkit_VERSION_MINOR EQUAL 9) + message(STATUS "Detected CUDA 12.9: applying _LIBCUDACXX_ATOMIC_UNSAFE_AUTOMATIC_STORAGE workaround") + list(APPEND CUDA_COMPILE_FLAGS "-D_LIBCUDACXX_ATOMIC_UNSAFE_AUTOMATIC_STORAGE") + endif() + target_compile_options(${target} PRIVATE - $<$:--threads=32 -O3>) + $<$:${CUDA_COMPILE_FLAGS}>) endfunction() # === Library targets === @@ -70,6 +116,9 @@ target_link_libraries(pplx_kernels PUBLIC nvshmem::nvshmem_host nvshmem::nvshmem_device ) + +# Always link NIXL (headers/libraries always available) +target_link_libraries(pplx_kernels PUBLIC nixl_interface) set_target_properties(pplx_kernels PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../src/pplx_kernels CUDA_SEPARABLE_COMPILATION ON diff --git a/csrc/all_to_all/CMakeLists.txt b/csrc/all_to_all/CMakeLists.txt index 9cd333f..5d7e966 100644 --- a/csrc/all_to_all/CMakeLists.txt +++ b/csrc/all_to_all/CMakeLists.txt @@ -6,6 +6,7 @@ add_library(all_to_all_common STATIC target_link_libraries(all_to_all_common PUBLIC CUDA::cudart + nixl_interface ) add_library(all_to_all_intranode_lib STATIC @@ -21,6 +22,9 @@ target_link_libraries(all_to_all_intranode_lib INTERFACE nvshmem::nvshmem_host ) target_include_directories(all_to_all_intranode_lib PRIVATE ${NVSHMEM_INCLUDE_DIR}) + +target_link_libraries(all_to_all_intranode_lib PUBLIC nixl_interface) + set_cuda_compile_options(all_to_all_intranode_lib) add_library(all_to_all_internode_lib STATIC @@ -32,10 +36,9 @@ target_link_libraries(all_to_all_internode_lib PUBLIC all_to_all_common CUDA::cudart ) -target_link_libraries(all_to_all_internode_lib INTERFACE - nvshmem::nvshmem_host -) -target_include_directories(all_to_all_internode_lib PRIVATE ${NVSHMEM_INCLUDE_DIR}) + +target_link_libraries(all_to_all_internode_lib PUBLIC nixl_interface) + set_cuda_compile_options(all_to_all_internode_lib) if(WITH_TESTS) @@ -52,6 +55,8 @@ if(WITH_TESTS) MPI::MPI_CXX nvshmem::nvshmem_host ) + + target_link_libraries(test_all_to_all PUBLIC nixl_interface) set_cuda_compile_options(test_all_to_all) add_test(NAME AllToAllTest COMMAND ${MPIEXEC_EXECUTABLE} -np 4 $) @@ -71,4 +76,6 @@ if (WITH_BENCHMARKS) MPI::MPI_CXX nvshmem::nvshmem_host ) + + target_link_libraries(bench_all_to_all PUBLIC nixl_interface) endif() diff --git a/csrc/all_to_all/internode.cpp b/csrc/all_to_all/internode.cpp index b73e547..e9d4e48 100644 --- a/csrc/all_to_all/internode.cpp +++ b/csrc/all_to_all/internode.cpp @@ -1,14 +1,53 @@ -#include - #include +#include #include +#include +#include +#include +#include +#include #include "all_to_all/internode.h" #include "core/utils.h" +#include "core/cuda_utils.h" using namespace pplx; +constexpr size_t SUPER_REGION_ALIGNMENT = 256; +constexpr int MAX_METADATA_CHECK_ATTEMPTS = 10000; +constexpr int MAX_NOTIFICATION_POLL_ATTEMPTS = 10000; +constexpr int RETRY_SLEEP_MS = 10; +constexpr int METADATA_EXCHANGE_INITIAL_SLEEP_MS = 200; +constexpr int CLEANUP_SLEEP_MS = 100; + +static std::unique_ptr createNixlAgent(const std::string& agentName) { + nixlAgentConfig config( + true, + false, + 0, + nixl_thread_sync_t::NIXL_THREAD_SYNC_RW, + 1, + 0, + 100000, + false, + std::chrono::microseconds(100000000) + ); + + return std::make_unique(agentName, config); +} + +static nixl_xfer_dlist_t createSingleDescDlist(void* ptr, size_t size, int deviceId) { + nixl_xfer_dlist_t dlist(VRAM_SEG); + dlist.addDesc(nixlBlobDesc( + reinterpret_cast(ptr), + size, + deviceId, + "" + )); + return dlist; +} + AllToAllInterNode::AllToAllInterNode( size_t maxNumTokens, size_t numExperts, @@ -31,41 +70,76 @@ AllToAllInterNode::AllToAllInterNode( hiddenDimBytes, hiddenDimScaleBytes ), - maxBatchTokens(numLocalExperts * numDPGroups * maxNumTokens) { + maxBatchTokens(numLocalExperts * numDPGroups * maxNumTokens) + , numChannels(PPLX_NIXL_NUM_CHANNELS) + , agent(createNixlAgent(std::to_string(rank))) + , originalDataDlist(VRAM_SEG) + , originalCountersDlist(VRAM_SEG) + , hPeerDataSuperBases(worldSize, nullptr) + , hPeerCountersLocalBases(worldSize, nullptr) +{ // Buffers for token counts. numTokensPerDP = mallocZeroBuffer(numLocalExperts * numDPGroups); - numTokensBuffer = (uint64_t *)nvshmem_malloc(sizeof(uint64_t) * numLocalExperts * numDPGroups); - PPLX_ASSERT(numTokensBuffer != nullptr, "failed to allocate numTokensBuffer"); - cudaMemset(numTokensBuffer, 0, sizeof(uint64_t) * numLocalExperts * numDPGroups); - - numDispatchRecvBuffer = - (uint64_t *)nvshmem_malloc(sizeof(uint64_t) * numLocalExperts * numDPGroups); - PPLX_ASSERT(numDispatchRecvBuffer != nullptr, "failed to allocate numDispatchRecvBuffer"); - cudaMemset(numDispatchRecvBuffer, 0, sizeof(uint64_t) * numLocalExperts * numDPGroups); - - combineSignalBuffer = (uint64_t *)nvshmem_malloc(sizeof(uint64_t) * maxNumTokens); - PPLX_ASSERT(combineSignalBuffer != nullptr, "failed to allocate combineSignalBuffer"); - cudaMemset(combineSignalBuffer, 0, sizeof(uint64_t) * maxNumTokens); - - combineSyncBuffer = (uint64_t *)nvshmem_malloc(sizeof(uint64_t) * worldSize); - PPLX_ASSERT(combineSyncBuffer != nullptr, "failed to allocate combineSyncBuffer"); - cudaMemset(combineSyncBuffer, 0, sizeof(uint64_t) * worldSize); - - // Buffers for dispatch. - const size_t perTokenBytes = - round_up(hiddenDimBytes + hiddenDimScaleBytes + sizeof(uint32_t), 16); - xDispatchIn = (std::byte *)nvshmem_malloc(maxNumTokens * perTokenBytes); - PPLX_ASSERT(xDispatchIn != nullptr, "failed to allocate xDispatchIn"); - xDispatchOut = (std::byte *)nvshmem_malloc(maxBatchTokens * perTokenBytes); - PPLX_ASSERT(xDispatchOut != nullptr, "failed to allocate xDispatchOut"); - - // Buffers for combine. The allocations are a bit wider to accommodate all - // possible data types (primarily float for testing and bfloat16 for prod). - xCombineIn = (std::byte *)nvshmem_malloc(maxBatchTokens * hiddenDim * sizeof(float)); - PPLX_ASSERT(xCombineIn != nullptr, "failed to allocate xCombineIn"); - xCombineOut = (std::byte *)nvshmem_malloc(maxNumTokens * numExperts * hiddenDim * sizeof(float)); - PPLX_ASSERT(xCombineOut != nullptr, "failed to allocate xCombineOut"); + const size_t perTokenBytes = round_up(hiddenDimBytes + hiddenDimScaleBytes + sizeof(uint32_t), 16); + + size_t dispatchInSize = maxNumTokens * perTokenBytes; + size_t dispatchOutSize = maxBatchTokens * perTokenBytes; + size_t combineInSize = maxBatchTokens * hiddenDim * sizeof(float); + size_t combineOutSize = maxNumTokens * numExperts * hiddenDim * sizeof(float); + size_t counter_span = numLocalExperts * numDPGroups * sizeof(uint64_t); + + // Allocate super-regions and wire up NIXL + // Initialize unified memory layout (DATA super-region + counter blocks) + initializeMemoryLayout( + dispatchInSize, + dispatchOutSize, + combineInSize, + combineOutSize, + counter_span, + maxNumTokens, + worldSize); + + // Allocate DATA super-region based on calculated layout + dataSuperBase = mallocZeroBuffer(memLayout.dataSuperTotalSize); + PPLX_ASSERT(dataSuperBase != nullptr, "Failed to allocate DATA super-region"); + + // Point buffer pointers to offsets within super-region + xDispatchIn = dataSuperBase + memLayout.dispatchInOffset; + xDispatchOut = dataSuperBase + memLayout.dispatchOutOffset; + xCombineIn = dataSuperBase + memLayout.combineInOffset; + xCombineOut = dataSuperBase + memLayout.combineOutOffset; + + // Allocate COUNTER blocks based on calculated layout + countersRemoteBase = mallocZeroBuffer(memLayout.counterTotalSize); + PPLX_ASSERT(countersRemoteBase != nullptr, "Failed to allocate REMOTE counter block"); + + countersLocalBase = mallocZeroBuffer(memLayout.counterTotalSize); + PPLX_ASSERT(countersLocalBase != nullptr, "Failed to allocate LOCAL counter block"); + + // Point counter/signal pointers to their offsets for host-side convenience + nixlTokenRemoteCounters = reinterpret_cast( + countersRemoteBase + memLayout.counterTokenOffset); + nixlRecvRemoteCounters = reinterpret_cast( + countersRemoteBase + memLayout.counterRecvOffset); + combineSignalBuffer = reinterpret_cast( + countersRemoteBase + memLayout.counterSignalOffset); + combineSyncBuffer = reinterpret_cast( + countersRemoteBase + memLayout.counterSyncOffset); + + nixlTokenLocalCounters = reinterpret_cast( + countersLocalBase + memLayout.counterTokenOffset); + nixlRecvLocalCounters = reinterpret_cast( + countersLocalBase + memLayout.counterRecvOffset); + combineLocalSignalBuffer = reinterpret_cast( + countersLocalBase + memLayout.counterSignalOffset); + combineLocalSyncBuffer = reinterpret_cast( + countersLocalBase + memLayout.counterSyncOffset); + + initializeNixlAgent(); + wireupNixl(); + createAllNixlRequests(); + setupPeerArrays(); // Buffers for token tracking. sourceIndex = mallocZeroBuffer(maxBatchTokens); @@ -77,15 +151,17 @@ AllToAllInterNode::AllToAllInterNode( } AllToAllInterNode::~AllToAllInterNode() { + CUDACHECK(cudaDeviceSynchronize()); + destroyPeerArrays(); + destroyAllNixlRequests(); + unwireNixl(); + destroyNixlAgent(); + CUDACHECK(cudaFree(numTokensPerDP)); - nvshmem_free(numTokensBuffer); - nvshmem_free(numDispatchRecvBuffer); - nvshmem_free(combineSignalBuffer); - nvshmem_free(combineSyncBuffer); - nvshmem_free(xDispatchIn); - nvshmem_free(xDispatchOut); - nvshmem_free(xCombineIn); - nvshmem_free(xCombineOut); + + CUDACHECK(cudaFree(dataSuperBase)); + CUDACHECK(cudaFree(countersRemoteBase)); + CUDACHECK(cudaFree(countersLocalBase)); CUDACHECK(cudaFree(sourceIndex)); CUDACHECK(cudaFree(sourceExpert)); @@ -94,3 +170,407 @@ AllToAllInterNode::~AllToAllInterNode() { CUDACHECK(cudaFree(sourceToken)); CUDACHECK(cudaFree(tokenIndex)); } + +void AllToAllInterNode::initializeNixlAgent() { + const char* masterPortStr = std::getenv("MASTER_PORT"); + PPLX_ASSERT(masterPortStr != nullptr, "MASTER_PORT environment variable not set"); + + const int baseEnv = std::stoi(masterPortStr); + PPLX_ASSERT(baseEnv >= 0 && baseEnv <= 65535, "MASTER_PORT out of range"); + + PPLX_ASSERT(agent != nullptr, "Failed to create NIXL agent"); + + nixl_mem_list_t mems; + nixl_b_params_t initParams; + nixl_status_t status = agent->getPluginParams("UCX", mems, initParams); + PPLX_ASSERT(status == NIXL_SUCCESS, "Failed to get UCX plugin parameters"); + + initParams["ucx_error_handling_mode"] = "none"; + initParams["num_workers"] = std::to_string(numChannels); + + status = agent->createBackend("UCX", initParams, backend); + PPLX_ASSERT(status == NIXL_SUCCESS && backend != nullptr, "Failed to create UCX backend"); + + int deviceId; + CUDACHECK(cudaGetDevice(&deviceId)); + + // Register super-regions with NIXL agent + nixl_reg_dlist_t dataSuperDlist(VRAM_SEG); + dataSuperDlist.addDesc(nixlBlobDesc( + reinterpret_cast(dataSuperBase), + memLayout.dataSuperTotalSize, + deviceId, + "data_super")); + originalDataDlist = dataSuperDlist; + + status = agent->registerMem(dataSuperDlist); + PPLX_ASSERT(status == NIXL_SUCCESS, "Failed to register DATA super-region"); + + nixl_reg_dlist_t countersRemoteDlist(VRAM_SEG); + countersRemoteDlist.addDesc(nixlBlobDesc( + reinterpret_cast(countersRemoteBase), + memLayout.counterTotalSize, + deviceId, + "counters_remote")); + originalCountersDlist = countersRemoteDlist; + + status = agent->registerMem(countersRemoteDlist); + PPLX_ASSERT(status == NIXL_SUCCESS, "Failed to register REMOTE counter block"); + + // Create extraParams with backend for GPU signal operations + nixl_opt_args_t extraParams; + extraParams.backends.push_back(backend); + + size_t signalSize = 0; + status = agent->getGpuSignalSize(signalSize, &extraParams); + PPLX_ASSERT(status == NIXL_SUCCESS, "Failed to query GPU signal size"); + PPLX_ASSERT(signalSize == sizeof(uint64_t), "GPU signal size is not 8 bytes"); + + status = agent->prepGpuSignal(originalCountersDlist, &extraParams); + PPLX_ASSERT(status == NIXL_SUCCESS, "Failed to prepare GPU signal"); +} + +void AllToAllInterNode::destroyNixlAgent() { + if (!agent) return; + + agent->deregisterMem(originalDataDlist); + agent->deregisterMem(originalCountersDlist); +} + +void AllToAllInterNode::wireupNixl() { + int deviceId; + CUDACHECK(cudaGetDevice(&deviceId)); + + PeerInfo localInfo = { + .deviceId = deviceId, + .rank = static_cast(rank), + .dataBase = dataSuperBase, + .dataSize = memLayout.dataSuperTotalSize, + .countersRemoteBase = countersRemoteBase, + .countersRemoteSize = memLayout.counterTotalSize, + .countersLocalBase = countersLocalBase, + .countersLocalSize = memLayout.counterTotalSize + }; + + #ifdef PPLX_ENABLE_P2P + { + CUDACHECK(cudaIpcGetMemHandle(&localInfo.dataSuperHandle, dataSuperBase)); + CUDACHECK(cudaIpcGetMemHandle(&localInfo.countersLocalHandle, countersLocalBase)); + + std::ifstream bootIdFile("/proc/sys/kernel/random/boot_id"); + if (bootIdFile.is_open()) { + std::string bootId; + std::getline(bootIdFile, bootId); + strncpy(localInfo.bootId, bootId.c_str(), sizeof(localInfo.bootId) - 1); + localInfo.bootId[sizeof(localInfo.bootId) - 1] = '\0'; + } + + struct stat st; + if (stat("/proc/self/ns/ipc", &st) == 0) { + localInfo.ipcNamespaceInode = st.st_ino; + } + } + #endif + + nixl_status_t status = agent->sendLocalMD(); + PPLX_ASSERT(status == NIXL_SUCCESS, "Failed to send local metadata"); + + std::this_thread::sleep_for(std::chrono::milliseconds(METADATA_EXCHANGE_INITIAL_SLEEP_MS)); + + for (unsigned peerRank = 0; peerRank < worldSize; ++peerRank) { + if (peerRank == rank) continue; + + const std::string remoteAgentName = std::to_string(peerRank); + + status = agent->fetchRemoteMD(remoteAgentName); + PPLX_ASSERT(status == NIXL_SUCCESS, "Failed to fetch metadata"); + + nixl_xfer_dlist_t emptyDescs(VRAM_SEG); + int checkAttempts = 0; + while (checkAttempts < MAX_METADATA_CHECK_ATTEMPTS) { + status = agent->checkRemoteMD(remoteAgentName, emptyDescs); + if (status == NIXL_SUCCESS) { + break; + } + std::this_thread::sleep_for(std::chrono::milliseconds(RETRY_SLEEP_MS)); + ++checkAttempts; + } + PPLX_ASSERT(status == NIXL_SUCCESS, "Failed to check metadata"); + } + + peerInfos.resize(worldSize); + peerInfos[rank] = localInfo; + + const std::string wireupPrefix = "pplx_super_region:"; + + for (unsigned targetRank = 0; targetRank < worldSize; ++targetRank) { + if (targetRank == rank) continue; + std::string payload(reinterpret_cast(&localInfo), sizeof(PeerInfo)); + const std::string remoteAgentName = std::to_string(targetRank); + status = agent->genNotif(remoteAgentName, wireupPrefix + payload); + PPLX_ASSERT(status == NIXL_SUCCESS, "Failed to send notification"); + } + + const int expectedNotifs = worldSize > 0 ? static_cast(worldSize) - 1 : 0; + std::vector received(worldSize, false); + int receivedCount = 0; + + for (int attempt = 0; attempt < MAX_NOTIFICATION_POLL_ATTEMPTS && receivedCount < expectedNotifs; ++attempt) { + nixl_notifs_t notifMap; + agent->getNotifs(notifMap); + + for (const auto& [senderName, messages] : notifMap) { + for (const auto& msg : messages) { + PPLX_ASSERT(msg.rfind(wireupPrefix, 0) == 0, "Unexpected NIXL notification"); + + std::string data = msg.substr(wireupPrefix.size()); + PPLX_ASSERT(data.size() == sizeof(PeerInfo), "Received PeerInfo payload of unexpected size"); + + PeerInfo receivedPeer{}; + memcpy(&receivedPeer, data.data(), sizeof(PeerInfo)); + + PPLX_ASSERT(receivedPeer.rank < worldSize, "Received PeerInfo with invalid rank"); + PPLX_ASSERT(receivedPeer.rank != rank, "Received unexpected self PeerInfo notification"); + PPLX_ASSERT(!received[receivedPeer.rank], "Received duplicate PeerInfo"); + + peerInfos[receivedPeer.rank] = receivedPeer; + received[receivedPeer.rank] = true; + ++receivedCount; + } + } + + if (receivedCount < expectedNotifs) + std::this_thread::sleep_for(std::chrono::milliseconds(RETRY_SLEEP_MS)); + } + PPLX_ASSERT(receivedCount >= expectedNotifs, "Missing PeerInfo from some ranks"); +} + +void AllToAllInterNode::unwireNixl() { + if (!agent) return; + + agent->invalidateLocalMD(); + std::this_thread::sleep_for(std::chrono::milliseconds(CLEANUP_SLEEP_MS)); +} + +void AllToAllInterNode::createAllNixlRequests() { + // Unified request architecture: + // - channel_data_reqs_flat[channel * worldSize + dest]: Data write requests + // - channel_signal_reqs_flat[channel * worldSize + dest]: Signal write requests + // Each request covers entire super-region with offset selection at nixlPostXferReq time + + gpuDataReqs.resize(numChannels * worldSize, nullptr); + gpuSignalReqs.resize(numChannels * worldSize, nullptr); + + int localDev = -1; + CUDACHECK(cudaGetDevice(&localDev)); + + nixl_xfer_dlist_t srcDataSuper = createSingleDescDlist( + dataSuperBase, memLayout.dataSuperTotalSize, localDev); + + nixl_xfer_dlist_t dummySrcDlist = createSingleDescDlist( + countersRemoteBase, sizeof(uint64_t), localDev); + + for (int workerIdx = 0; workerIdx < (int)numChannels; ++workerIdx) { + for (unsigned destRank = 0; destRank < worldSize; ++destRank) { + if (destRank == rank) continue; + + const size_t flatIdx = (size_t)workerIdx * worldSize + (size_t)destRank; + const std::string dstName = std::to_string(destRank); + + nixl_opt_args_t workerParams; + workerParams.customParam = "worker_id=" + std::to_string(workerIdx); + + nixl_xfer_dlist_t dstDataSuper = createSingleDescDlist( + peerInfos[destRank].dataBase, + peerInfos[destRank].dataSize, + peerInfos[destRank].deviceId); + + nixlXferReqH* dataCpuReq; + nixl_status_t status = agent->createXferReq( + NIXL_WRITE, srcDataSuper, dstDataSuper, dstName, + dataCpuReq, &workerParams); + + PPLX_ASSERT(status == NIXL_SUCCESS && dataCpuReq != nullptr, "Failed to create unified DATA request"); + + nixlGpuXferReqH dataGpuReq; + status = agent->createGpuXferReq(*dataCpuReq, dataGpuReq); + PPLX_ASSERT(status == NIXL_SUCCESS && dataGpuReq != nullptr, "Failed to create GPU DATA request"); + gpuDataReqs[flatIdx] = dataGpuReq; + cpuXferReqs.push_back(dataCpuReq); + + nixl_xfer_dlist_t dstSignalDlist = createSingleDescDlist( + peerInfos[destRank].countersRemoteBase, + sizeof(uint64_t), + peerInfos[destRank].deviceId); + + nixlXferReqH* signalCpuReq; + status = agent->createXferReq( + NIXL_WRITE, dummySrcDlist, dstSignalDlist, dstName, + signalCpuReq, &workerParams); + + PPLX_ASSERT(status == NIXL_SUCCESS && signalCpuReq != nullptr, "Failed to create unified SIGNAL request"); + + nixlGpuXferReqH signalGpuReq; + status = agent->createGpuXferReq(*signalCpuReq, signalGpuReq); + PPLX_ASSERT(status == NIXL_SUCCESS && signalGpuReq != nullptr, "Failed to create GPU SIGNAL request"); + gpuSignalReqs[flatIdx] = signalGpuReq; + cpuXferReqs.push_back(signalCpuReq); + } + } + + nixlGpuXferReqH* dUnifiedDataReqs = nullptr; + nixlGpuXferReqH* dUnifiedSignalReqs = nullptr; + + size_t unifiedSize = numChannels * worldSize * sizeof(nixlGpuXferReqH); + CUDACHECK(cudaMalloc(&dUnifiedDataReqs, unifiedSize)); + CUDACHECK(cudaMemcpy(dUnifiedDataReqs, gpuDataReqs.data(), + unifiedSize, cudaMemcpyHostToDevice)); + + CUDACHECK(cudaMalloc(&dUnifiedSignalReqs, unifiedSize)); + CUDACHECK(cudaMemcpy(dUnifiedSignalReqs, gpuSignalReqs.data(), + unifiedSize, cudaMemcpyHostToDevice)); + + dChannelDataReqsFlat = dUnifiedDataReqs; + dChannelSignalReqsFlat = dUnifiedSignalReqs; +} + +void AllToAllInterNode::destroyAllNixlRequests() { + if (agent) { + for (auto gpuReq : gpuDataReqs) { + if (!gpuReq) continue; + agent->releaseGpuXferReq(gpuReq); + } + for (auto gpuReq : gpuSignalReqs) { + if (!gpuReq) continue; + agent->releaseGpuXferReq(gpuReq); + } + for (auto cpuReq : cpuXferReqs) { + if (!cpuReq) continue; + agent->releaseXferReq(cpuReq); + } + } + + CUDACHECK(cudaFree(dChannelDataReqsFlat)); + CUDACHECK(cudaFree(dChannelSignalReqsFlat)); +} + +void AllToAllInterNode::setupPeerArrays() { + // Populate self-rank entries in host vectors + hPeerDataSuperBases[rank] = dataSuperBase; + hPeerCountersLocalBases[rank] = countersLocalBase; + + #ifdef PPLX_ENABLE_P2P + if (!peerInfos.empty() && worldSize > 1) { + PPLX_ASSERT(rank < peerInfos.size() && peerInfos[rank].bootId[0] != '\0', + "P2P enabled but PeerInfo lacks super-region IPC handles"); + + for (int peer = 0; peer < worldSize; ++peer) { + if (peer == rank) continue; + + // Check if peer is on same node + bool sameNode = (strcmp(peerInfos[peer].bootId, peerInfos[rank].bootId) == 0 && + peerInfos[peer].ipcNamespaceInode == peerInfos[rank].ipcNamespaceInode); + + if (!sameNode) { + hPeerDataSuperBases[peer] = nullptr; + hPeerCountersLocalBases[peer] = nullptr; + continue; + } + + cudaError_t err = cudaIpcOpenMemHandle(&hPeerDataSuperBases[peer], + peerInfos[peer].dataSuperHandle, + cudaIpcMemLazyEnablePeerAccess); + if (err != cudaSuccess) { + std::cerr << "[P2P] cudaIpcOpenMemHandle(data_super) failed for peer " << peer + << ": " << cudaGetErrorString(err) << std::endl; + hPeerDataSuperBases[peer] = nullptr; + hPeerCountersLocalBases[peer] = nullptr; + continue; + } + + err = cudaIpcOpenMemHandle(&hPeerCountersLocalBases[peer], + peerInfos[peer].countersLocalHandle, + cudaIpcMemLazyEnablePeerAccess); + if (err != cudaSuccess) { + std::cerr << "[P2P] cudaIpcOpenMemHandle(counters_local) failed for peer " << peer + << ": " << cudaGetErrorString(err) << std::endl; + cudaIpcCloseMemHandle(hPeerDataSuperBases[peer]); + hPeerDataSuperBases[peer] = nullptr; + hPeerCountersLocalBases[peer] = nullptr; + continue; + } + } + } + #endif + + CUDACHECK(cudaMalloc(&peerDataBases, sizeof(*peerDataBases) * worldSize)); + CUDACHECK(cudaMalloc(&peerCounterBases, sizeof(*peerCounterBases) * worldSize)); + CUDACHECK(cudaMemcpy(peerDataBases, hPeerDataSuperBases.data(), + sizeof(*peerDataBases) * worldSize, cudaMemcpyHostToDevice)); + CUDACHECK(cudaMemcpy(peerCounterBases, hPeerCountersLocalBases.data(), + sizeof(*peerCounterBases) * worldSize, cudaMemcpyHostToDevice)); +} + +void AllToAllInterNode::destroyPeerArrays() { + #ifdef PPLX_ENABLE_P2P + size_t closed_data = 0, closed_counters = 0; + for (int peer = 0; peer < worldSize; ++peer) { + if (peer == rank) continue; + + if (hPeerDataSuperBases[peer]) { + cudaError_t err = cudaIpcCloseMemHandle(hPeerDataSuperBases[peer]); + if (err != cudaSuccess) { + std::cerr << "[P2P] cleanup close(data_super) peer " << peer << ": " + << cudaGetErrorString(err) << std::endl; + } else { + ++closed_data; + } + } + if (hPeerCountersLocalBases[peer]) { + cudaError_t err = cudaIpcCloseMemHandle(hPeerCountersLocalBases[peer]); + if (err != cudaSuccess) { + std::cerr << "[P2P] cleanup close(counters_local) peer " << peer << ": " + << cudaGetErrorString(err) << std::endl; + } else { + ++closed_counters; + } + } + } + #endif + + CUDACHECK(cudaFree(peerDataBases)); + CUDACHECK(cudaFree(peerCounterBases)); +} + +void AllToAllInterNode::initializeMemoryLayout( + size_t dispatchInSize, + size_t dispatchOutSize, + size_t combineInSize, + size_t combineOutSize, + size_t counterSpan, + size_t maxNumTokens, + size_t worldSize) { + // Initialize DATA super-region offsets + memLayout.dispatchInOffset = 0; + memLayout.dispatchOutOffset = round_up(dispatchInSize, SUPER_REGION_ALIGNMENT); + memLayout.combineInOffset = round_up( + memLayout.dispatchOutOffset + dispatchOutSize, + SUPER_REGION_ALIGNMENT); + memLayout.combineOutOffset = round_up( + memLayout.combineInOffset + combineInSize, + SUPER_REGION_ALIGNMENT); + memLayout.dataSuperTotalSize = memLayout.combineOutOffset + combineOutSize; + + // Initialize COUNTER block offsets (shared layout for both remote and local blocks) + memLayout.counterTokenOffset = 0; + memLayout.counterRecvOffset = round_up( + memLayout.counterTokenOffset + counterSpan, + SUPER_REGION_ALIGNMENT); + memLayout.counterSignalOffset = round_up( + memLayout.counterRecvOffset + counterSpan, + SUPER_REGION_ALIGNMENT); + memLayout.counterSyncOffset = round_up( + memLayout.counterSignalOffset + sizeof(uint64_t) * maxNumTokens, + SUPER_REGION_ALIGNMENT); + memLayout.counterTotalSize = memLayout.counterSyncOffset + sizeof(uint64_t) * worldSize; +} diff --git a/csrc/all_to_all/internode.h b/csrc/all_to_all/internode.h index 28aa939..b284cef 100644 --- a/csrc/all_to_all/internode.h +++ b/csrc/all_to_all/internode.h @@ -2,11 +2,18 @@ #include #include +#include +#include #include +#include +#include #include "all_to_all/all_to_all.h" #include "core/buffer.h" +#include "core/nixl_utils.cuh" +#include + namespace pplx { /// @brief All-to-all broadcast kernel. @@ -122,14 +129,88 @@ class AllToAllInterNode final : public AllToAll { /// @section Pre-allocated symmetric shared memory workspace. uint32_t *numTokensPerDP = nullptr; - uint64_t *numTokensBuffer = nullptr; - uint64_t *numDispatchRecvBuffer = nullptr; uint64_t *combineSignalBuffer = nullptr; uint64_t *combineSyncBuffer = nullptr; + uint64_t *combineLocalSignalBuffer = nullptr; + uint64_t *combineLocalSyncBuffer = nullptr; std::byte *xDispatchIn = nullptr; std::byte *xDispatchOut = nullptr; std::byte *xCombineIn = nullptr; std::byte *xCombineOut = nullptr; + + struct alignas(16) PeerInfo { + int deviceId; + int rank; + + void* dataBase; ///< Base of DATA super-region (dispatch/combine buffers) + size_t dataSize; ///< Size of DATA super-region + void* countersRemoteBase; ///< Base of REMOTE counter block (RDMA targets) + size_t countersRemoteSize; ///< Size of REMOTE counter block + void* countersLocalBase; ///< Base of LOCAL counter block (self/P2P fast-path data) + size_t countersLocalSize; ///< Size of LOCAL counter block + + #ifdef PPLX_ENABLE_P2P + cudaIpcMemHandle_t dataSuperHandle; + cudaIpcMemHandle_t countersLocalHandle; + char bootId[40]; + ino_t ipcNamespaceInode; + #endif + }; + + /// @section NIXL-specific members + const size_t numChannels; + std::vector peerInfos; + + std::unique_ptr agent; + nixlBackendH* backend = nullptr; + + std::vector gpuDataReqs; + std::vector gpuSignalReqs; + std::vector cpuXferReqs; + + nixlGpuXferReqH* dChannelDataReqsFlat = nullptr; + nixlGpuXferReqH* dChannelSignalReqsFlat = nullptr; + + nixl_reg_dlist_t originalDataDlist; + nixl_reg_dlist_t originalCountersDlist; + + uint64_t* nixlTokenRemoteCounters = nullptr; + uint64_t* nixlTokenLocalCounters = nullptr; + uint64_t* nixlRecvRemoteCounters = nullptr; + uint64_t* nixlRecvLocalCounters = nullptr; + + std::byte** peerDataBases = nullptr; + std::byte** peerCounterBases = nullptr; + + MemLayout memLayout; + std::byte* dataSuperBase = nullptr; + std::byte* countersRemoteBase = nullptr; + std::byte* countersLocalBase = nullptr; + + std::vector hPeerDataSuperBases; + std::vector hPeerCountersLocalBases; + + void initializeNixlAgent(); + void destroyNixlAgent(); + + void wireupNixl(); + void unwireNixl(); + + void createAllNixlRequests(); + void destroyAllNixlRequests(); + + void setupPeerArrays(); + void destroyPeerArrays(); + + void initializeMemoryLayout( + size_t dispatchInSize, + size_t dispatchOutSize, + size_t combineInSize, + size_t combineOutSize, + size_t counterSpan, + size_t maxNumTokens, + size_t worldSize); + }; } // namespace pplx diff --git a/csrc/all_to_all/internode_combine.cu b/csrc/all_to_all/internode_combine.cu index 0e02a06..4ae620c 100644 --- a/csrc/all_to_all/internode_combine.cu +++ b/csrc/all_to_all/internode_combine.cu @@ -1,9 +1,9 @@ -#include "core/nvshmem_utils.h" +#include "core/device_utils.cuh" #include "core/utils.h" +#include "core/nixl_utils.cuh" #include "internode.h" #include -#include #include using namespace pplx; @@ -35,21 +35,51 @@ __global__ __launch_bounds__(NUM_WARPS * 32, 1) void combineKernel( const uint32_t *sourceOffset, const uint32_t *sourceGroup, uint64_t *combineSignalBuffer, + uint64_t *combineSignalLocalBuffer, uint64_t *combineSyncBuffer, + uint64_t *combineSyncLocalBuffer, uint32_t &globalTokenIndex, std::byte *xBufferIn, - std::byte *xBufferOut + std::byte *xBufferOut, + const MemLayout memLayout, + std::byte **peerDataBases, + std::byte **peerCounterBases, + size_t numChannels, + const nixlGpuXferReqH* dataReqs, + const nixlGpuXferReqH* signalReqs ) { + const unsigned numLocalExperts = numExperts / worldSize; const size_t stride = hiddenDim * sizeof(T); constexpr unsigned WARP_SIZE = 32; uint32_t warpId = threadIdx.x / WARP_SIZE; + const size_t channelId = select_channel(numChannels); + if (DO_SEND) { const size_t numSendTokens = __ldg(&globalTokenIndex); for (unsigned i = blockIdx.x * blockDim.x + threadIdx.x; i < worldSize; i += gridDim.x * blockDim.x) { - nvshmemx_signal_op(&combineSyncBuffer[rank], 1, NVSHMEM_SIGNAL_SET, i); + const size_t syncCounterOffset = memLayout.counterSyncOffset + counter_byte_offset(rank); + + if (peerCounterBases[i] != nullptr) { + uint64_t one = 1ull; + __nv_atomic_store( + (uint64_t*)(peerCounterBases[i] + syncCounterOffset), + &one, + __NV_ATOMIC_RELEASE, PPLX_ATOMIC_SCOPE); + } else { + size_t reqIdx = channel_request_index(channelId, i, worldSize); + nixl_status_t status = nixlGpuPostSignalXferReq( + signalReqs[reqIdx], + /*signal_desc_index=*/0, + /*signal_inc=*/1, + /*signal_offset=*/syncCounterOffset, + /*channel_id=*/0, + /*is_no_delay=*/true, + /*status=*/nullptr); + PPLX_DEVICE_ASSERT_POST_STATUS(status); + } } // Dispatch the tokens from the expert to the DP groups. @@ -84,10 +114,52 @@ __global__ __launch_bounds__(NUM_WARPS * 32, 1) void combineKernel( for (unsigned i = warpId; i < dpSize; i += NUM_WARPS) { const int dstRank = dp * dpSize + i; const unsigned index = dstExpert * maxNumTokens + source; - std::byte *dstPtr = xBufferOut + index * stride; - nvshmemx_putmem_signal_nbi_warp( - dstPtr, xTokenPtr, stride, &combineSignalBuffer[source], 1, NVSHMEM_SIGNAL_ADD, dstRank - ); + + // Unified offset calculation for both P2P and NIXL paths + const size_t dstTokenOffset = index * stride; + const size_t srcTokenOffset = token * stride; + const size_t dstDataOffset = memLayout.combineOutOffset + dstTokenOffset; + const size_t srcDataOffset = memLayout.combineInOffset + srcTokenOffset; + + const size_t signalCounterOffset = memLayout.counterSignalOffset + counter_byte_offset(source); + + if (peerCounterBases[dstRank] != nullptr) { + device::warp_streaming_copy_and_sync( + peerDataBases[dstRank] + dstDataOffset, + xTokenPtr, stride, threadIdx.x % WARP_SIZE); + if (threadIdx.x % WARP_SIZE == 0) { + __nv_atomic_add( + (uint64_t*)(peerCounterBases[dstRank] + signalCounterOffset), + 1ull, + __NV_ATOMIC_RELEASE, PPLX_ATOMIC_SCOPE); + } + } else { + if (threadIdx.x % WARP_SIZE == 0) { + size_t sizeBytes = stride; + size_t reqIdx = channel_request_index(channelId, dstRank, worldSize); + nixl_status_t stData = nixlGpuPostSingleWriteXferReq( + dataReqs[reqIdx], + /*desc_index=*/0, + /*local_offset=*/srcDataOffset, + /*remote_offset=*/dstDataOffset, + /*size=*/sizeBytes, + /*channel_id=*/0, + /*is_no_delay=*/false, + /*status=*/nullptr); + PPLX_DEVICE_ASSERT_POST_STATUS(stData); + + nixl_status_t stSignal = nixlGpuPostSignalXferReq( + signalReqs[reqIdx], + /*signal_desc_index=*/0, + /*signal_inc=*/1, + /*signal_offset=*/signalCounterOffset, + /*channel_id=*/0, + /*is_no_delay=*/true, + /*status=*/nullptr); + PPLX_DEVICE_ASSERT_POST_STATUS(stSignal); + } + } + __syncwarp(); } } } @@ -102,9 +174,27 @@ __global__ __launch_bounds__(NUM_WARPS * 32, 1) void combineKernel( // Compute the weighed sum of the input tokens. const size_t localNumTokens = boundM ? __ldg(boundM) : m; for (unsigned i = blockIdx.x; i < localNumTokens; i += gridDim.x) { - nvshmem_uint64_wait_until(&combineSignalBuffer[i], NVSHMEM_CMP_EQ, expertsPerToken); + // Split counters: remote (peers) and local-only (self) + uint64_t *remoteSignalCounter = combineSignalBuffer + i; + uint64_t *localOnlyCounter = combineSignalLocalBuffer + i; + + // Wait on sum(remote + local) == expertsPerToken + // Gate busy-wait to the first thread in the block + if (threadIdx.x == 0) { + device::nixl_wait_until_sum_eq( + remoteSignalCounter, + localOnlyCounter, + expertsPerToken + ); + } + __syncthreads(); + + if (threadIdx.x == 0) { + uint64_t zero = 0ull; + pplx::st_flag_release(remoteSignalCounter, 0ull); + __nv_atomic_store(localOnlyCounter, &zero, __NV_ATOMIC_RELEASE, PPLX_ATOMIC_SCOPE); + } __syncthreads(); - combineSignalBuffer[i] = 0; U *dstPtr = outTokens + i * outTokensStrideElem; constexpr unsigned VEC_SIZE = 8; @@ -136,8 +226,17 @@ __global__ __launch_bounds__(NUM_WARPS * 32, 1) void combineKernel( for (unsigned i = blockIdx.x * blockDim.x + threadIdx.x; i < worldSize; i += gridDim.x * blockDim.x) { - nvshmem_uint64_wait_until(&combineSyncBuffer[i], NVSHMEM_CMP_EQ, 1); - combineSyncBuffer[i] = 0; + uint64_t *remoteSyncCounter = combineSyncBuffer + i; + uint64_t *localSyncCounter = combineSyncLocalBuffer + i; + + device::nixl_wait_until_sum_eq( + remoteSyncCounter, + localSyncCounter, + 1); + + uint64_t zero = 0ull; + pplx::st_flag_release(remoteSyncCounter, 0ull); + __nv_atomic_store(localSyncCounter, &zero, __NV_ATOMIC_RELEASE, PPLX_ATOMIC_SCOPE); } if (blockIdx.x == 0 && threadIdx.x == 0) { @@ -169,6 +268,15 @@ void AllToAllInterNode::combine( dim3 dimGrid(numBlocks, 1, 1); dim3 dimBlock(NUM_WARPS * 32, 1, 1); + uint64_t *combineSignalBufferArg = combineSignalBuffer; + uint64_t *combineSyncBufferArg = combineSyncBuffer; + PPLX_ASSERT(combineSignalBuffer != nullptr, "NIXL combine remote signal buffer not initialized"); + PPLX_ASSERT(combineSyncBuffer != nullptr, "NIXL combine remote sync buffer not initialized"); + PPLX_ASSERT(combineLocalSignalBuffer != nullptr, "NIXL combine local signal buffer not initialized"); + PPLX_ASSERT(combineLocalSyncBuffer != nullptr, "NIXL combine local sync buffer not initialized"); + uint64_t *combineSignalLocalBufferArg = combineLocalSignalBuffer; + uint64_t *combineSyncLocalBufferArg = combineLocalSyncBuffer; + void *args[] = { const_cast(&outTokens.data), const_cast(&outTokens.strideElem), @@ -194,11 +302,21 @@ void AllToAllInterNode::combine( &sourceIndex, &sourceOffset, &sourceGroup, - &combineSignalBuffer, - &combineSyncBuffer, + + &combineSignalBufferArg, + &combineSignalLocalBufferArg, + &combineSyncBufferArg, + &combineSyncLocalBufferArg, &tokenIndex, &xCombineIn, - &xCombineOut}; + &xCombineOut, + &memLayout, + &peerDataBases, + &peerCounterBases, + const_cast(&numChannels), + &dChannelDataReqsFlat, + &dChannelSignalReqsFlat, + }; nvtxRangePush("combine"); switch (splitMode) { diff --git a/csrc/all_to_all/internode_dispatch.cu b/csrc/all_to_all/internode_dispatch.cu index d7374e6..36fc4db 100644 --- a/csrc/all_to_all/internode_dispatch.cu +++ b/csrc/all_to_all/internode_dispatch.cu @@ -1,11 +1,11 @@ #include #include -#include #include #include "all_to_all/internode.h" #include "core/device_utils.cuh" #include "core/utils.h" +#include "core/nixl_utils.cuh" using namespace pplx; @@ -47,10 +47,18 @@ __global__ __launch_bounds__(NUM_WARPS * 32, 1) void dispatchKernel( uint32_t *sourceGroup, uint32_t *sourceToken, uint64_t *numTokensBuffer, + uint64_t *numTokensLocalBuffer, uint64_t *numRecvBuffer, + uint64_t *numRecvLocalBuffer, uint32_t &globalTokenIndex, std::byte *xBufferIn, - std::byte *xBufferOut + std::byte *xBufferOut, + const MemLayout memLayout, + std::byte **peerDataBases, + std::byte **peerCounterBases, + size_t numChannels, + const nixlGpuXferReqH* dataReqs, + const nixlGpuXferReqH* signalReqs ) { // Determine the rank, DP rank and per-rank constants. const unsigned numLocalExperts = numExperts / worldSize; @@ -62,6 +70,7 @@ __global__ __launch_bounds__(NUM_WARPS * 32, 1) void dispatchKernel( const unsigned WARP_SIZE = 32; const unsigned warpId = threadIdx.x / WARP_SIZE; const unsigned laneId = threadIdx.x % WARP_SIZE; + const size_t channelId = select_channel(numChannels); // Determine the number of tokens populated which are to be sent. const unsigned numSendTokens = boundM ? __ldg(boundM) : m; @@ -98,10 +107,30 @@ __global__ __launch_bounds__(NUM_WARPS * 32, 1) void dispatchKernel( } unsigned numTokensPerExpert = device::warp_sum(count); - uint64_t *dstCount = &numTokensBuffer[dstLocalExpert * numDPGroups + dpGroup]; if (laneId == 0) { - nvshmemx_signal_op(dstCount, numTokensPerExpert + 1, NVSHMEM_SIGNAL_SET, dstRank); + const uint32_t targetGroup = dstLocalExpert * numDPGroups + dpGroup; + const uint64_t tokenValue = numTokensPerExpert + 1; + const size_t tokenCounterOffset = memLayout.counterTokenOffset + counter_byte_offset(targetGroup); + + if (peerCounterBases[dstRank] != nullptr) { + uint64_t value = tokenValue; + __nv_atomic_store( + (uint64_t*)(peerCounterBases[dstRank] + tokenCounterOffset), + &value, + __NV_ATOMIC_RELEASE, PPLX_ATOMIC_SCOPE); + } else { + size_t reqIdx = channel_request_index(channelId, dstRank, worldSize); + nixl_status_t st = nixlGpuPostSignalXferReq( + signalReqs[reqIdx], + /*signal_desc_index=*/0, + /*signal_inc=*/(uint64_t)tokenValue, + /*signal_offset=*/tokenCounterOffset, + /*channel_id=*/0, + /*is_no_delay=*/true, + /*status=*/nullptr); + PPLX_DEVICE_ASSERT_POST_STATUS(st); + } } } @@ -153,16 +182,50 @@ __global__ __launch_bounds__(NUM_WARPS * 32, 1) void dispatchKernel( const uint32_t group = dstLocalExpert * numDPGroups + dpGroup; const unsigned loc = group * maxNumTokens + index; - std::byte *destPointer = xBufferOut + loc * tokenStride; - nvshmemx_putmem_signal_nbi_warp( - destPointer, - xInPtr, - tokenStride, - &numRecvBuffer[group], - 1, - NVSHMEM_SIGNAL_ADD, - dstRank - ); + // Unified offset calculation for both P2P and NIXL paths + const size_t dstTokenOffset = loc * tokenStride; + const size_t srcTokenOffset = i * tokenStride; + const size_t dstDataOffset = memLayout.dispatchOutOffset + dstTokenOffset; + const size_t srcDataOffset = memLayout.dispatchInOffset + srcTokenOffset; + + const size_t recvCounterOffset = memLayout.counterRecvOffset + counter_byte_offset(group); + + if (peerCounterBases[dstRank] != nullptr) { + device::warp_streaming_copy_and_sync( + peerDataBases[dstRank] + dstDataOffset, + xInPtr, tokenStride, laneId); + if (laneId == 0) { + __nv_atomic_add( + (uint64_t*)(peerCounterBases[dstRank] + recvCounterOffset), + 1ull, + __NV_ATOMIC_RELEASE, PPLX_ATOMIC_SCOPE); + } + } else { + size_t reqIdx = channel_request_index(channelId, dstRank, worldSize); + if (laneId == 0) { + size_t size = tokenStride; + nixl_status_t rd = nixlGpuPostSingleWriteXferReq( + dataReqs[reqIdx], + /*desc_index=*/0, + /*local_offset=*/srcDataOffset, + /*remote_offset=*/dstDataOffset, + /*size=*/size, + /*channel_id=*/0, + /*is_no_delay=*/false, + /*status=*/nullptr); + PPLX_DEVICE_ASSERT_POST_STATUS(rd); + nixl_status_t rs = nixlGpuPostSignalXferReq( + signalReqs[reqIdx], + /*signal_desc_index=*/0, + /*signal_inc=*/1, + /*signal_offset=*/recvCounterOffset, + /*channel_id=*/0, + /*is_no_delay=*/true, + /*status=*/nullptr); + PPLX_DEVICE_ASSERT_POST_STATUS(rs); + } + } + __syncwarp(); } } } @@ -188,13 +251,27 @@ __global__ __launch_bounds__(NUM_WARPS * 32, 1) void dispatchKernel( // Fetch the token count per DP, which is non-zero to indicate receipt. // Afterwards, wait for exactly that many tokens to be sent to us. - nvshmem_uint64_wait_until(&numTokensBuffer[group], NVSHMEM_CMP_NE, 0); - size_t numTokens = numTokensBuffer[group] - 1; - nvshmem_uint64_wait_until(&numRecvBuffer[group], NVSHMEM_CMP_EQ, numTokens); + const uint32_t dp = group % numDPGroups; + const uint32_t counterIdx = expert * numDPGroups + dp; + uint64_t *tokenRemote = numTokensBuffer + counterIdx; + uint64_t *tokenLocal = numTokensLocalBuffer + counterIdx; + uint64_t tokenSum = device::nixl_wait_until_sum_ne_zero(tokenRemote, tokenLocal); + size_t numTokens = tokenSum - 1; + + uint64_t *recvRemote = numRecvBuffer + counterIdx; + uint64_t *recvLocal = numRecvLocalBuffer + counterIdx; + device::nixl_wait_until_sum_eq( + recvRemote, + recvLocal, + numTokens + ); numTokensPerDP[group] = numTokens; - numTokensBuffer[group] = 0; - numRecvBuffer[group] = 0; + uint64_t zero = 0ull; + pplx::st_flag_release(tokenRemote, 0ull); + __nv_atomic_store(tokenLocal, &zero, __NV_ATOMIC_RELEASE, PPLX_ATOMIC_SCOPE); + pplx::st_flag_release(recvRemote, 0ull); + __nv_atomic_store(recvLocal, &zero, __NV_ATOMIC_RELEASE, PPLX_ATOMIC_SCOPE); sharedExpert[group - firstGroup] = atomicAdd(&outNumTokensPerExpert[expert], numTokens); sharedToken[group - firstGroup] = atomicAdd(&globalTokenIndex, numTokens); } @@ -277,6 +354,15 @@ void AllToAllInterNode::dispatch( const size_t sharedMemorySend = sizeof(uint32_t) * numExperts; const size_t sharedMemoryRecv = sizeof(uint32_t) * expertsPerBlock * 2; + PPLX_ASSERT(nixlTokenRemoteCounters != nullptr, "NIXL token remote counters not initialized"); + PPLX_ASSERT(nixlRecvRemoteCounters != nullptr, "NIXL recv remote counters not initialized"); + PPLX_ASSERT(nixlTokenLocalCounters != nullptr, "NIXL token local counters not initialized"); + PPLX_ASSERT(nixlRecvLocalCounters != nullptr, "NIXL recv local counters not initialized"); + uint64_t *numTokensBufferArg = nixlTokenRemoteCounters; + uint64_t *numRecvBufferArg = nixlRecvRemoteCounters; + uint64_t *numTokensLocalBufferArg = nixlTokenLocalCounters; + uint64_t *numRecvLocalBufferArg = nixlRecvLocalCounters; + void *args[] = { const_cast(&outNumTokensPerExpert.data), const_cast(&outNumTokensPerExpert.strideElem), @@ -311,11 +397,19 @@ void AllToAllInterNode::dispatch( &sourceOffset, &sourceGroup, &sourceToken, - &numTokensBuffer, - &numDispatchRecvBuffer, + &numTokensBufferArg, + &numTokensLocalBufferArg, + &numRecvBufferArg, + &numRecvLocalBufferArg, &tokenIndex, &xDispatchIn, &xDispatchOut, + &memLayout, + &peerDataBases, + &peerCounterBases, + const_cast(&numChannels), + &dChannelDataReqsFlat, + &dChannelSignalReqsFlat, }; nvtxRangePush("dispatch"); diff --git a/csrc/core/CMakeLists.txt b/csrc/core/CMakeLists.txt index 821035d..28816b2 100644 --- a/csrc/core/CMakeLists.txt +++ b/csrc/core/CMakeLists.txt @@ -11,4 +11,7 @@ target_link_libraries(core_lib INTERFACE nvshmem::nvshmem_host ) target_include_directories(core_lib PRIVATE ${NVSHMEM_INCLUDE_DIR}) + +target_link_libraries(core_lib PUBLIC nixl_interface) + set_cuda_compile_options(core_lib) diff --git a/csrc/core/atomic.cuh b/csrc/core/atomic.cuh index 6f6e7f6..8e965c0 100644 --- a/csrc/core/atomic.cuh +++ b/csrc/core/atomic.cuh @@ -20,10 +20,20 @@ __forceinline__ __device__ uint32_t ld_flag_acquire(uint32_t *flag_addr) { return flag; } +__forceinline__ __device__ uint64_t ld_flag_acquire(uint64_t *flag_addr) { + uint64_t flag; + asm volatile("ld.acquire.sys.global.u64 %0, [%1];" : "=l"(flag) : "l"(flag_addr)); + return flag; +} + __forceinline__ __device__ void st_flag_release(uint32_t *flag_addr, uint32_t flag) { asm volatile("st.release.sys.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr)); } +__forceinline__ __device__ void st_flag_release(uint64_t *flag_addr, uint64_t flag) { + asm volatile("st.release.sys.global.u64 [%1], %0;" ::"l"(flag), "l"(flag_addr) : "memory"); +} + __forceinline__ __device__ uint32_t add_flag_release(uint32_t *addr, uint32_t val) { uint32_t flag; asm volatile("atom.release.sys.global.add.u32 %0, [%1], %2;" : "=r"(flag) : "l"(addr), "r"(val)); diff --git a/csrc/core/device_utils.cuh b/csrc/core/device_utils.cuh index c60aa87..697456e 100644 --- a/csrc/core/device_utils.cuh +++ b/csrc/core/device_utils.cuh @@ -1,6 +1,10 @@ #pragma once +#include #include "core/common_utils.h" +#include "core/atomic.cuh" +#include +#include #define PPLX_ENABLE_DEVICE_ASSERT 0 @@ -12,8 +16,22 @@ asm("trap;"); \ } \ } while (0) + +// Assert macro for NIXL post operations (accepts SUCCESS or IN_PROG) +#define PPLX_DEVICE_ASSERT_POST_STATUS(status) \ + do { \ + if ((status) != NIXL_SUCCESS && (status) != NIXL_IN_PROG) { \ + printf("PPLX NIXL Post Assert Failed: expected SUCCESS(0) or IN_PROG(1), got=%d at %s:%d\n",\ + (int)(status), __FILE__, __LINE__); \ + asm("trap;"); \ + } \ + } while (0) #else #define PPLX_DEVICE_ASSERT(cond) +#define PPLX_DEVICE_ASSERT_POST_STATUS(status) \ + do { \ + (void)(status); \ + } while (0) #endif namespace pplx { @@ -56,5 +74,73 @@ __forceinline__ __device__ float half_warp_reduce_max(float value) { return value; } +__device__ __forceinline__ int4 ld_nc_global(const int4* ptr) { + int4 ret; + asm volatile("ld.global.nc.L1::no_allocate.L2::256B.v4.s32 {%0, %1, %2, %3}, [%4];" + : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) + : "l"(ptr)); + return ret; +} + +__device__ __forceinline__ void st_na_global(int4* ptr, const int4& value) { + asm volatile("st.global.L1::no_allocate.v4.s32 [%0], {%1, %2, %3, %4};" + :: "l"(ptr), "r"(value.x), "r"(value.y), + "r"(value.z), "r"(value.w) : "memory"); +} + +__device__ __forceinline__ void warp_streaming_copy_and_sync( + void* dst_ptr, const void* src_ptr, size_t size, int lane_id +) { + const int num_int4 = size / sizeof(int4); + int4* dst = reinterpret_cast(dst_ptr); + const int4* src = reinterpret_cast(src_ptr); + + const int warp_stride = 32; + int i = lane_id; + for (; i < num_int4; i += warp_stride * 8) { +#pragma unroll 4 + for (int u = 0; u < 8; ++u) { + int j = i + u * warp_stride; + if (j < num_int4) { + int4 val = pplx::device::ld_nc_global(src + j); + pplx::device::st_na_global(dst + j, val); + } + } + } + __syncwarp(); +} + +// Wait until the sum of a remote-modified counter and a local-only counter equals expected_value. +__device__ inline void nixl_wait_until_sum_eq( + uint64_t* remote_ptr, + uint64_t* local_ptr, + uint64_t expected_value +) { + uint64_t r; + uint64_t l; + + do { + r = pplx::ld_flag_acquire(remote_ptr); + __nv_atomic_load(local_ptr, &l, __NV_ATOMIC_ACQUIRE, PPLX_ATOMIC_SCOPE); + } while (r + l != expected_value); +} + +// Wait until the sum of remote+local counters becomes non-zero, returns the sum. +__device__ inline uint64_t nixl_wait_until_sum_ne_zero( + uint64_t* remote_ptr, + uint64_t* local_ptr +) { + uint64_t r; + uint64_t l; + + do { + r = pplx::ld_flag_acquire(remote_ptr); + __nv_atomic_load(local_ptr, &l, __NV_ATOMIC_ACQUIRE, PPLX_ATOMIC_SCOPE); + } while (r + l == 0); + + return r + l; +} + + } // namespace device } // namespace pplx diff --git a/csrc/core/nixl_utils.cuh b/csrc/core/nixl_utils.cuh new file mode 100644 index 0000000..1c34d97 --- /dev/null +++ b/csrc/core/nixl_utils.cuh @@ -0,0 +1,39 @@ +#pragma once + +#include + +namespace pplx { + +struct MemLayout { + size_t dispatchInOffset; + size_t dispatchOutOffset; + size_t combineInOffset; + size_t combineOutOffset; + size_t dataSuperTotalSize; + + size_t counterTokenOffset; + size_t counterRecvOffset; + size_t counterSignalOffset; + size_t counterSyncOffset; + size_t counterTotalSize; +}; + +// Channel selector: uses block index for deterministic channel assignment +__device__ __forceinline__ size_t select_channel(size_t numChannels) { + #if defined(__CUDA_ARCH__) + return static_cast(blockIdx.x % numChannels); + #else + return 0; + #endif +} + +// Compute flat index for per-channel request arrays +__device__ __forceinline__ size_t channel_request_index(size_t channel_id, int dest_rank, int worldSize) { + return channel_id * worldSize + static_cast(dest_rank); +} + +__device__ __forceinline__ size_t counter_byte_offset(size_t element_index) { + return element_index * sizeof(uint64_t); +} + +} // namespace pplx diff --git a/csrc/core/utils.h b/csrc/core/utils.h index e825e01..aa989ef 100644 --- a/csrc/core/utils.h +++ b/csrc/core/utils.h @@ -3,6 +3,7 @@ #include "core/common_utils.h" #include +#include #define _PPLX_ASSERT_MSG(msg) \ do { \ @@ -22,3 +23,7 @@ #define PPLX_UNREACHABLE(msg) \ _PPLX_ASSERT_MSG(__FILE__ "(" _PPLX_EXPAND_AND_STRINGIFY(__LINE__) "): " msg "\n"); \ __builtin_trap(); + +namespace pplx { + +} // namespace pplx diff --git a/docs/install-driver-and-dependencies.md b/docs/install-driver-and-dependencies.md index 803b631..8e42eaa 100644 --- a/docs/install-driver-and-dependencies.md +++ b/docs/install-driver-and-dependencies.md @@ -6,6 +6,11 @@ Here's a summary of the software and drivers required for running pplx-kernels o |---------------------------|-------------|--------------------------|---------------------| | NVIDIA Driver | Y | Y | Y | | modprobe.d/nvidia.conf | | Y | | +| CUDA Toolkit | Y | Y | Y | +| RDMA-Core | | Y | Y | +| UCX Library | | Y | Y | +| NIXL Library | | Y | Y | +| DOCA (GPUNetIO) | | Y | | | GDRCopy Driver | | Y | Y | | GDRCopy Library | | Y | Y | | NVSHMEM Library | Y | Y | Y | @@ -18,6 +23,10 @@ Here's a summary of the software and drivers required for running pplx-kernels o ## NVIDIA Driver Config +### NIXL + UCX + DOCA Configuration + +For ConnectX systems, internode communication can use NIXL + UCX with DOCA GPUNetIO. The following sections detail the installation and configuration. + To use IBGDA, NVIDIA Driver needs to be configured to allow GPU to initiate communication. ```bash @@ -26,6 +35,127 @@ sudo update-initramfs -u sudo reboot ``` +#### System Packages for NIXL/UCX/DOCA + +Install common build dependencies (Ubuntu 24.04 example): + +```bash +sudo apt-get update +sudo apt-get install -y \ + build-essential cmake gcc g++ ninja-build pkg-config \ + autoconf automake libtool git valgrind \ + linux-headers-$(uname -r) \ + meson python3-pip python3-dev python3-docutils \ + libnl-3-dev libnl-route-3-dev \ + libglib2.0-dev libssl-dev libzip-dev libjson-c-dev \ + libpcap-dev libjsoncpp-dev curl +``` + +#### RDMA-Core + +```bash +git clone https://github.com/linux-rdma/rdma-core.git +cd rdma-core +git checkout v58.0 +mkdir build && cd build +cmake .. -DCMAKE_INSTALL_PREFIX=/opt/rdma +make -j"$(nproc)" +sudo make install +``` + +Set `RDMA_HOME=/opt/rdma`. + +#### DOCA + +Install DOCA with GPUNetIO and set `DOCA_HOME` to its install prefix. Refer to NVIDIA DOCA SDK documentation for obtaining packages and prerequisites. + +Example install layout expectation: +- Headers at `$DOCA_HOME/include` +- Libraries at `$DOCA_HOME/lib/x86_64-linux-gnu` + +#### UCX + +Build UCX with CUDA, GDRCopy, and DOCA support. + +```bash +git clone https://github.com/openucx/ucx.git +cd ucx +./autogen.sh +./contrib/configure-release-mt --prefix=/opt/ucx \ + --with-verbs=/opt/rdma \ + --with-cuda=/usr/local/cuda \ + --with-gdrcopy=/opt/gdrcopy \ + --with-doca=${DOCA_HOME:-no} +make -j"$(nproc)" +sudo make install +``` + +Set `UCX_HOME=/opt/ucx`. + +#### NIXL + +Follow the build instructions in the NIXL repository: [NIXL repository](https://github.com/ai-dynamo/nixl). Example: + +```bash +pip3 install --upgrade meson pybind11 +git clone https://github.com/ai-dynamo/nixl.git +cd nixl +meson setup build -Ducx_path=/opt/ucx -Dprefix=/opt/nixl --buildtype=release +sudo ninja install -C build +``` + +Set `NIXL_HOME=/opt/nixl`. + +#### Runtime Environment + +Set up environment variables (adjust paths to your installation): + +```bash +export CUDA_HOME=/usr/local/cuda +export RDMA_HOME=/opt/rdma +export DOCA_HOME=/opt/doca +export UCX_HOME=/opt/ucx +export NIXL_HOME=/opt/nixl + +export LD_LIBRARY_PATH="$UCX_HOME/lib:$NIXL_HOME/lib:$CUDA_HOME/lib64:${DOCA_HOME:+$DOCA_HOME/lib/x86_64-linux-gnu:}$LD_LIBRARY_PATH" +export LD_PRELOAD="${DOCA_HOME:+$DOCA_HOME/lib/x86_64-linux-gnu/libdoca_common.so:$DOCA_HOME/lib/x86_64-linux-gnu/libdoca_gpunetio.so:$DOCA_HOME/lib/x86_64-linux-gnu/libdoca_verbs.so}" +``` + +#### UCX Network Device Selection + +On some systems UCX may not automatically pick the desired NIC per GPU. Explicitly set `UCX_NET_DEVICES` to map each GPU to the correct NIC and port. + +Shell example (single-node, 1:1 GPU-to-NIC mapping): + +```bash +# Adjust NIC names (mlx5_X) and port (:1) as needed +export UCX_NET_DEVICES="cuda${LOCAL_RANK}-mlx5_${LOCAL_RANK}:1" +``` + +Python example (custom mapping list): + +```python +import os +local_rank = int(os.environ.get("LOCAL_RANK", "0")) +pxb_nics = ["mlx5_0", "mlx5_3", "mlx5_4", "mlx5_5", "mlx5_6", "mlx5_9", "mlx5_10", "mlx5_11"] +os.environ['UCX_NET_DEVICES'] = f'cuda{local_rank}-{pxb_nics[local_rank]}:1' +``` + +Project example (as used in `tests/bench_all_to_all.py`): + +```python +import os +# Map each GPU to a specific mlx5 device and include additional NICs +pxb_nics_eos = ["mlx5_0", "mlx5_3", "mlx5_4", "mlx5_5", "mlx5_6", "mlx5_9", "mlx5_10", "mlx5_11"] +tcp_nics_eos = ',ibp154s0,ibp192s0,ibp206s0,ibp220s0,ibp94s0' +local_rank = int(os.environ.get("LOCAL_RANK", "0")) +os.environ['UCX_NET_DEVICES'] = f'cuda{local_rank}-{pxb_nics_eos[local_rank]}:1' + tcp_nics_eos +``` + +Adjust device names to match your system (`mlx5_*`, `ibp*`, and the `:1` port). + +This configuration aligns pplx-kernels internode usage with NIXL + UCX. Adjust paths and options to match your deployment. + ## GDRCopy GDRCopy is needed for multi-node. diff --git a/tests/bench_all_to_all.py b/tests/bench_all_to_all.py index 970a4d9..ca8666c 100644 --- a/tests/bench_all_to_all.py +++ b/tests/bench_all_to_all.py @@ -33,6 +33,9 @@ def bench_all_to_all( device = pgi.device num_dp = pgi.world_size // dp_size dp_rank = pgi.rank // dp_size + pxb_nics_eos = ["mlx5_0", "mlx5_3", "mlx5_4", "mlx5_5", "mlx5_6", "mlx5_9", "mlx5_10", "mlx5_11"] + tcp_nics_eos = ',ibp154s0,ibp192s0,ibp206s0,ibp220s0,ibp94s0' + os.environ['UCX_NET_DEVICES'] = f'cuda{pgi.local_rank}-{pxb_nics_eos[pgi.local_rank]}:1' + tcp_nics_eos # Generate the same rank data for each DP group rng = torch.Generator()