diff --git a/cpp/tensorrt_llm/common/customAllReduceUtils.h b/cpp/tensorrt_llm/common/customAllReduceUtils.h index 0a6c2d9d327..9a466512e4d 100644 --- a/cpp/tensorrt_llm/common/customAllReduceUtils.h +++ b/cpp/tensorrt_llm/common/customAllReduceUtils.h @@ -81,7 +81,6 @@ inline AllReduceStrategyType SelectStrategyLP(size_t seq_len, size_t hidden_size { return AllReduceStrategyType::ONESHOT; } - return AllReduceStrategyType::NCCL; } // use 1D vector to store the best strategy instead of a map for each sm version @@ -143,7 +142,7 @@ inline AllReduceStrategyType selectStrategyLookUpTable( sm_version = 100; } - // Check if the entry is out of bounds, otherwise return NCCL as fallback + // Check if the entry is out of bounds, otherwise return NCCL_SYMMETRIC as fallback if (AllReduceBestStrategyTable.find(sm_version) == AllReduceBestStrategyTable.end() || tp_index >= AllReduceBestStrategyTable.at(sm_version).size() || fusion_op_index >= AllReduceBestStrategyTable.at(sm_version).at(tp_index).size() @@ -151,7 +150,7 @@ inline AllReduceStrategyType selectStrategyLookUpTable( || num_token_index >= AllReduceBestStrategyTable.at(sm_version).at(tp_index).at(fusion_op_index).at(hidden_size_index).size()) { - return AllReduceStrategyType::NCCL; + return AllReduceStrategyType::NCCL_SYMMETRIC; } return static_cast( diff --git a/cpp/tensorrt_llm/common/ncclUtils.cpp b/cpp/tensorrt_llm/common/ncclUtils.cpp new file mode 100644 index 00000000000..76406fd8066 --- /dev/null +++ b/cpp/tensorrt_llm/common/ncclUtils.cpp @@ -0,0 +1,585 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/common/ncclUtils.h" + +#if ENABLE_MULTI_DEVICE + +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/common/logger.h" +#include +#include + +namespace tensorrt_llm::common::nccl_util +{ + +//============================================================================== +// NcclCommResourceManager Implementation +//============================================================================== + +NcclCommResourceManager& NcclCommResourceManager::getInstance() noexcept +{ + static NcclCommResourceManager instance; + return instance; +} + +void NcclCommResourceManager::registerResource(ncclComm_t comm, ResourceCleanupFunc cleanup, char const* debugName) +{ + if (!comm) + { + TLLM_LOG_WARNING("[NCCLUtil] Attempted to register resource for null NCCL comm"); + return; + } + + std::lock_guard lock(mMutex); + auto& resources = mCommResources[comm]; + resources.emplace_back(std::move(cleanup), debugName ? debugName : "unnamed"); + + TLLM_LOG_TRACE("[NCCLUtil] Registered resource '%s' for NCCL comm %p (total: %zu)", + debugName ? debugName : "unnamed", static_cast(comm), resources.size()); +} + +void NcclCommResourceManager::cleanupResources(ncclComm_t comm) noexcept +{ + if (!comm) + { + return; + } + + std::vector resourcesToClean; + + { + std::lock_guard lock(mMutex); + auto it = mCommResources.find(comm); + if (it == mCommResources.end()) + { + // Nothing registered for this comm, nothing to clean up + return; + } + + // Move resources out (preserves order) and remove from map + resourcesToClean = std::move(it->second); + mCommResources.erase(it); + + TLLM_LOG_TRACE( + "[NCCLUtil] Cleaning up %zu resources for NCCL comm %p", resourcesToClean.size(), static_cast(comm)); + } + + // Clean up outside the lock to avoid deadlocks if cleanup functions try to access the manager + // Order is preserved: resources are cleaned up in registration order + for (auto& [cleanup, name] : resourcesToClean) + { + try + { + TLLM_LOG_TRACE( + "[NCCLUtil] Cleaning up resource '%s' for NCCL comm %p", name.c_str(), static_cast(comm)); + cleanup(); + } + catch (std::exception const& e) + { + TLLM_LOG_ERROR("[NCCLUtil] Exception during cleanup of resource '%s' for NCCL comm %p: %s", name.c_str(), + static_cast(comm), e.what()); + } + catch (...) + { + TLLM_LOG_ERROR("[NCCLUtil] Unknown exception during cleanup of resource '%s' for NCCL comm %p", + name.c_str(), static_cast(comm)); + } + } +} + +bool NcclCommResourceManager::hasResources(ncclComm_t comm) const noexcept +{ + std::lock_guard lock(mMutex); + return mCommResources.find(comm) != mCommResources.end(); +} + +size_t NcclCommResourceManager::getResourceCount(ncclComm_t comm) const noexcept +{ + std::lock_guard lock(mMutex); + auto it = mCommResources.find(comm); + return it != mCommResources.end() ? it->second.size() : 0; +} + +//============================================================================== +// NCCLHelper Implementation +//============================================================================== + +NCCLHelper& NCCLHelper::getInstance() +{ + static NCCLHelper instance; + return instance; +} + +NCCLHelper::NCCLHelper() + : mLibraryHandle(nullptr) + , mNCCLCommWindowRegister(nullptr) + , mNCCLMemAlloc(nullptr) + , mIsLoaded(false) +{ + loadNCCLLibrary(); +} + +NCCLHelper::~NCCLHelper() +{ + if (mLibraryHandle) + { +#ifdef _WIN32 + FreeLibrary(mLibraryHandle); +#else + dlclose(mLibraryHandle); +#endif + mLibraryHandle = nullptr; + } +} + +void NCCLHelper::loadNCCLLibrary() +{ + try + { +#ifdef _WIN32 + char const* libraryNames[] = {"nccl.dll"}; +#else + char const* libraryNames[] = {"libnccl.so"}; +#endif + + for (auto const* name : libraryNames) + { + mLibraryHandle = loadLibraryHandle(name); + if (mLibraryHandle) + { + TLLM_LOG_INFO("Successfully loaded NCCL library: %s", name); + break; + } + } + + if (!mLibraryHandle) + { + TLLM_LOG_WARNING("Failed to load NCCL library"); + return; + } + + // Load the required symbols + mNCCLCommWindowRegister + = reinterpret_cast(getSymbolAddress(mLibraryHandle, "ncclCommWindowRegister")); + + mNCCLMemAlloc = reinterpret_cast(getSymbolAddress(mLibraryHandle, "ncclMemAlloc")); + + if (mNCCLCommWindowRegister == nullptr) + { + TLLM_LOG_WARNING("Failed to load ncclCommWindowRegister symbol, NCCL symmetric will not be supported."); + } + + if (mNCCLMemAlloc == nullptr) + { + TLLM_LOG_WARNING("Failed to load ncclMemAlloc symbol, NCCL symmetric will not be supported."); + } + + if (mNCCLCommWindowRegister != nullptr && mNCCLMemAlloc != nullptr) + { + mIsLoaded = true; + } + else + { + TLLM_LOG_WARNING( + "Failed to load required NCCL symbols (both ncclCommWindowRegister and ncclMemAlloc are required)"); + } + } + catch (std::exception const& e) + { + TLLM_LOG_WARNING("Exception while loading NCCL library: %s", e.what()); + } +} + +void* NCCLHelper::loadLibraryHandle(char const* libName) +{ +#ifdef _WIN32 + return LoadLibraryA(libName); +#else + return dlopen(libName, RTLD_LAZY | RTLD_GLOBAL); +#endif +} + +void* NCCLHelper::getSymbolAddress(void* handle, char const* symbolName) +{ + if (!handle) + { + return nullptr; + } + +#ifdef _WIN32 + return GetProcAddress(static_cast(handle), symbolName); +#else + return dlsym(handle, symbolName); +#endif +} + +NCCLHelper::ncclCommWindowRegisterFunc NCCLHelper::getNCCLCommWindowRegister() +{ + return mNCCLCommWindowRegister; +} + +NCCLHelper::ncclMemAllocFunc NCCLHelper::getNCCLMemAlloc() +{ + return mNCCLMemAlloc; +} + +bool NCCLHelper::isLoaded() const +{ + return mIsLoaded; +} + +//============================================================================== +// NCCLWindowAllocator Implementation +//============================================================================== + +NCCLWindowAllocator& NCCLWindowAllocator::getInstance() +{ + static NCCLWindowAllocator instance; + return instance; +} + +NCCLWindowBuffer NCCLWindowAllocator::requestBuffer(ncclComm_t comm, size_t size) +{ + TLLM_CHECK_WITH_INFO(comm != nullptr, "NCCL communicator cannot be null"); + TLLM_CHECK_WITH_INFO(size > 0, "Buffer size must be greater than 0"); + + std::lock_guard lock(mMutex); + + // Register cleanup callback for this communicator if not already registered + // This is cheap even if no buffers exist yet - cleanup will just return early + registerBufferCleanup(comm); + + // Check if we have an available buffer of at least the requested size for this communicator + // Use best-fit: find the smallest buffer that's >= requested size + auto& commBuffers = mBufferPool[comm]; + auto bestFit = commBuffers.end(); + size_t bestFitSize = std::numeric_limits::max(); + + for (auto it = commBuffers.begin(); it != commBuffers.end(); ++it) + { + if (!it->inUse && it->buffer.size >= size && it->buffer.size < bestFitSize) + { + bestFit = it; + bestFitSize = it->buffer.size; + } + } + + if (bestFit != commBuffers.end()) + { + bestFit->inUse = true; + TLLM_LOG_TRACE( + "[NCCLUtil] Reusing NCCL window buffer for comm %p: handle=%d, ptr=%p, size=%zu (requested: %zu)", + static_cast(comm), bestFit->buffer.handle, bestFit->buffer.ptr, bestFit->buffer.size, size); + return bestFit->buffer; + } + + // No available buffer found, allocate a new one + TLLM_LOG_TRACE( + "[NCCLUtil] Allocating new NCCL window buffer for comm %p, size=%zu", static_cast(comm), size); + int handle = static_cast(commBuffers.size()); + NCCLWindowBuffer buffer = allocateAndRegisterBuffer(comm, size, handle); + commBuffers.push_back({buffer, true}); + + return buffer; +} + +NCCLWindowBuffer NCCLWindowAllocator::searchBuffer(ncclComm_t comm, void* ptr) const +{ + if (!comm || !ptr) + { + return NCCLWindowBuffer(); + } + + std::lock_guard lock(mMutex); + return searchBufferLocked(comm, ptr); +} + +void NCCLWindowAllocator::releaseBuffer(ncclComm_t comm, void* ptr) +{ + if (!comm || !ptr) + { + return; + } + + std::lock_guard lock(mMutex); + auto commIt = mBufferPool.find(comm); + if (commIt == mBufferPool.end()) + { + TLLM_LOG_WARNING( + "[NCCLUtil] Attempted to release buffer %p for unknown comm %p", ptr, static_cast(comm)); + return; + } + + for (auto& entry : commIt->second) + { + if (entry.buffer.ptr == ptr) + { + entry.inUse = false; + TLLM_LOG_TRACE("[NCCLUtil] Released NCCL window buffer for comm %p: ptr=%p", static_cast(comm), ptr); + return; + } + } + + TLLM_LOG_WARNING("[NCCLUtil] Attempted to release unknown buffer %p for comm %p", ptr, static_cast(comm)); +} + +ncclWindow_t NCCLWindowAllocator::getWindow(ncclComm_t comm, void* ptr) const +{ + std::lock_guard lock(mMutex); + NCCLWindowBuffer buffer = searchBufferLocked(comm, ptr); + return buffer.isValid() ? buffer.window : nullptr; +} + +size_t NCCLWindowAllocator::getSize(ncclComm_t comm, void* ptr) const +{ + std::lock_guard lock(mMutex); + NCCLWindowBuffer buffer = searchBufferLocked(comm, ptr); + return buffer.isValid() ? buffer.size : 0; +} + +NCCLWindowBuffer NCCLWindowAllocator::getBufferInfo(ncclComm_t comm, void* ptr) const +{ + std::lock_guard lock(mMutex); + return searchBufferLocked(comm, ptr); +} + +size_t NCCLWindowAllocator::getBufferCount(ncclComm_t comm) const +{ + std::lock_guard lock(mMutex); + auto commIt = mBufferPool.find(comm); + return commIt != mBufferPool.end() ? commIt->second.size() : 0; +} + +size_t NCCLWindowAllocator::getBufferInUseCount(ncclComm_t comm) const +{ + std::lock_guard lock(mMutex); + auto commIt = mBufferPool.find(comm); + if (commIt == mBufferPool.end()) + { + return 0; + } + + size_t count = 0; + for (auto const& entry : commIt->second) + { + if (entry.inUse) + { + ++count; + } + } + return count; +} + +bool NCCLWindowAllocator::isCommValid(ncclComm_t comm) const noexcept +{ + // Simply check for null - all non-null comms are valid + // We don't track cleaned-up comms because NCCL can reuse memory addresses, + // making pointer-based tracking unreliable. New comms will be registered when used. + return comm != nullptr; +} + +NCCLWindowBuffer NCCLWindowAllocator::allocateAndRegisterBuffer(ncclComm_t comm, size_t size, int handle) +{ + NCCLWindowBuffer buffer; + buffer.handle = handle; + + // Get NCCL helper for dynamic symbol loading + auto& ncclHelper = NCCLHelper::getInstance(); + if (!ncclHelper.isLoaded()) + { + TLLM_THROW("NCCL library could not be loaded for dynamic symbol access"); + } + + auto ncclMemAllocFunc = ncclHelper.getNCCLMemAlloc(); + auto ncclCommWindowRegisterFunc = ncclHelper.getNCCLCommWindowRegister(); + + // Defensive checks: both function pointers must be non-null + if (ncclMemAllocFunc == nullptr) + { + TLLM_THROW("ncclMemAlloc function pointer is null, cannot allocate NCCL window buffer"); + } + + if (ncclCommWindowRegisterFunc == nullptr) + { + TLLM_THROW("ncclCommWindowRegister function pointer is null, cannot register NCCL window buffer"); + } + + // Allocate device memory using ncclMemAlloc + ncclResult_t allocResult = ncclMemAllocFunc(&buffer.ptr, size); + if (allocResult != ncclSuccess) + { + TLLM_THROW("ncclMemAlloc failed with error: %d", allocResult); + } + buffer.size = size; + + // Register the buffer with NCCL as a window + ncclResult_t regResult + = ncclCommWindowRegisterFunc(comm, buffer.ptr, size, &buffer.window, NCCL_WIN_COLL_SYMMETRIC); + if (regResult != ncclSuccess) + { + ncclMemFree(buffer.ptr); + TLLM_THROW("ncclCommWindowRegister failed with error: %d", regResult); + } + + TLLM_LOG_TRACE("[NCCLUtil] Allocated and registered NCCL window buffer: handle=%d, ptr=%p, size=%zu, window=%p", + handle, buffer.ptr, size, static_cast(buffer.window)); + + return buffer; +} + +NCCLWindowBuffer NCCLWindowAllocator::searchBufferLocked(ncclComm_t comm, void* ptr) const +{ + auto commIt = mBufferPool.find(comm); + if (commIt == mBufferPool.end()) + { + return NCCLWindowBuffer(); + } + + for (auto const& entry : commIt->second) + { + if (entry.buffer.ptr == ptr) + { + return entry.buffer; + } + } + + return NCCLWindowBuffer(); +} + +void NCCLWindowAllocator::registerBufferCleanup(ncclComm_t comm) +{ + // Don't register if already registered + if (mRegisteredComms.find(comm) != mRegisteredComms.end()) + { + return; + } + + mRegisteredComms.insert(comm); + + // Register cleanup with the resource manager + NcclCommResourceManager::getInstance().registerResource( + comm, [this, comm]() { this->cleanupBuffersForComm(comm); }, "NCCLWindowAllocator"); +} + +void NCCLWindowAllocator::cleanupBuffersForComm(ncclComm_t comm) noexcept +{ + if (!comm) + { + return; + } + + // Synchronize CUDA to ensure all operations using these buffers are complete + // before we deregister windows and free memory + cudaError_t cudaErr = cudaDeviceSynchronize(); + if (cudaErr != cudaSuccess) + { + TLLM_LOG_WARNING("[NCCLUtil] cudaDeviceSynchronize failed with error: %d before cleanup for comm %p", cudaErr, + static_cast(comm)); + // Continue anyway - the sync failure might be from a previous error + } + + std::lock_guard lock(mMutex); + + // Check if we've already cleaned up this communicator + if (mRegisteredComms.find(comm) == mRegisteredComms.end()) + { + // Already cleaned up or never registered + return; + } + + auto commIt = mBufferPool.find(comm); + if (commIt == mBufferPool.end()) + { + // No buffers to clean up, but mark as cleaned + mRegisteredComms.erase(comm); + return; + } + + TLLM_LOG_TRACE( + "[NCCLUtil] Cleaning up %zu NCCL window buffers for comm %p", commIt->second.size(), static_cast(comm)); + + // Check for buffers still in use - this shouldn't happen if cleanup is called properly, + // but we log a warning if it does + size_t inUseCount = 0; + for (auto const& entry : commIt->second) + { + if (entry.inUse) + { + ++inUseCount; + } + } + if (inUseCount > 0) + { + TLLM_LOG_WARNING( + "[NCCLUtil] Cleaning up %zu buffers still marked as in-use for comm %p. " + "This may indicate buffers weren't properly released before cleanup.", + inUseCount, static_cast(comm)); + } + + for (auto& entry : commIt->second) + { + if (entry.buffer.isValid()) + { + // Deregister the window - the communicator is still valid at this point + // (cleanup happens before ncclCommDestroy), but we need to be careful + // if buffers are still in use by active operations + if (entry.buffer.window && comm) + { + // Note: Even if buffer is marked inUse, we must deregister since + // the communicator is being destroyed. The communicator is valid, + // but we should handle potential errors gracefully. + ncclResult_t result = ncclCommWindowDeregister(comm, entry.buffer.window); + if (result != ncclSuccess) + { + TLLM_LOG_WARNING( + "[NCCLUtil] ncclCommWindowDeregister failed with error: %d for comm %p, " + "window %p (buffer inUse: %d)", + result, static_cast(comm), static_cast(entry.buffer.window), entry.inUse); + } + } + + // Free device memory using ncclMemFree + // This should be safe even if deregister failed + if (entry.buffer.ptr) + { + try + { + ncclResult_t ncclResult = ncclMemFree(entry.buffer.ptr); + if (ncclResult != ncclSuccess) + { + TLLM_LOG_WARNING("[NCCLUtil] ncclMemFree failed with error: %d", ncclResult); + } + } + catch (...) + { + TLLM_LOG_ERROR("[NCCLUtil] Exception during ncclMemFree for ptr %p", entry.buffer.ptr); + } + } + + TLLM_LOG_TRACE( + "[NCCLUtil] Freed NCCL window buffer: ptr=%p, size=%zu", entry.buffer.ptr, entry.buffer.size); + } + } + + mBufferPool.erase(commIt); + mRegisteredComms.erase(comm); +} + +} // namespace tensorrt_llm::common::nccl_util + +#endif // ENABLE_MULTI_DEVICE diff --git a/cpp/tensorrt_llm/common/ncclUtils.h b/cpp/tensorrt_llm/common/ncclUtils.h new file mode 100644 index 00000000000..d128741e0a5 --- /dev/null +++ b/cpp/tensorrt_llm/common/ncclUtils.h @@ -0,0 +1,397 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/common/logger.h" + +#if ENABLE_MULTI_DEVICE +#include +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if ENABLE_MULTI_DEVICE + +#ifdef _WIN32 +#include +#else +#include +#endif + +namespace tensorrt_llm::common::nccl_util +{ + +//============================================================================== +// NCCL Helper - Dynamic Library Loading +//============================================================================== + +// Helper class for dynamically loading NCCL symbols (ncclMemAlloc, ncclCommWindowRegister) +// This allows the code to work with NCCL libraries that may or may not have these symbols +class NCCLHelper +{ +public: + static NCCLHelper& getInstance(); + + // Dynamic loading function type definition + using ncclCommWindowRegisterFunc = ncclResult_t (*)(ncclComm_t, void*, size_t, ncclWindow_t*, int); + using ncclMemAllocFunc = ncclResult_t (*)(void**, size_t); + + // Get function pointer for ncclCommWindowRegister + ncclCommWindowRegisterFunc getNCCLCommWindowRegister(); + + // Get function pointer for ncclMemAlloc + ncclMemAllocFunc getNCCLMemAlloc(); + + // Check if NCCL library is successfully loaded + bool isLoaded() const; + + NCCLHelper(NCCLHelper const&) = delete; + NCCLHelper& operator=(NCCLHelper const&) = delete; + NCCLHelper(NCCLHelper&&) = delete; + NCCLHelper& operator=(NCCLHelper&&) = delete; + +private: + NCCLHelper(); + ~NCCLHelper(); + + void loadNCCLLibrary(); + void* loadLibraryHandle(char const* libName); + void* getSymbolAddress(void* handle, char const* symbolName); + +#ifdef _WIN32 + HMODULE mLibraryHandle; +#else + void* mLibraryHandle; +#endif + + ncclCommWindowRegisterFunc mNCCLCommWindowRegister; + ncclMemAllocFunc mNCCLMemAlloc; + bool mIsLoaded; +}; + +//============================================================================== +// NCCL Resource Management +//============================================================================== + +// Resource cleanup function type. Called before the NCCL communicator is destroyed. +using ResourceCleanupFunc = std::function; + +// Manages resources associated with NCCL communicators. Thread-safe singleton that maintains +// a pool of resources per NCCL comm. Resources are automatically cleaned up when the +// communicator is destroyed. +class NcclCommResourceManager +{ +public: + static NcclCommResourceManager& getInstance() noexcept; + + // Register a resource cleanup function for a specific NCCL communicator. + // The cleanup function will be called before ncclCommDestroy. + // Thread-safe: Uses global mutex to serialize all operations. + void registerResource(ncclComm_t comm, ResourceCleanupFunc cleanup, char const* debugName = nullptr); + + // Cleanup all resources associated with a communicator. Called automatically by + // the shared_ptr deleter before ncclCommDestroy. + // Thread-safe: Uses global mutex to serialize cleanup operations. + // Order-preserving: Resources are cleaned up in registration order. + void cleanupResources(ncclComm_t comm) noexcept; + + // Check if a communicator has registered resources. + bool hasResources(ncclComm_t comm) const noexcept; + + // Get the number of resources registered for a communicator. + size_t getResourceCount(ncclComm_t comm) const noexcept; + + NcclCommResourceManager(NcclCommResourceManager const&) = delete; + NcclCommResourceManager& operator=(NcclCommResourceManager const&) = delete; + NcclCommResourceManager(NcclCommResourceManager&&) = delete; + NcclCommResourceManager& operator=(NcclCommResourceManager&&) = delete; + +private: + NcclCommResourceManager() = default; + ~NcclCommResourceManager() = default; + + using ResourceEntry = std::pair; + + mutable std::mutex mMutex; + std::unordered_map> mCommResources; +}; + +// RAII helper to register a resource with a NCCL communicator. +// Automatically registers cleanup function on construction. +template +class NcclCommResource +{ +public: + NcclCommResource(ncclComm_t comm, ResourceType&& resource, std::function cleanup, + char const* debugName = nullptr) + : mComm(comm) + , mResource(std::forward(resource)) + , mCleanup(std::move(cleanup)) + , mRegistered(true) + { + // Register with the manager + NcclCommResourceManager::getInstance().registerResource( + comm, + [this]() + { + if (mCleanup) + { + mCleanup(mResource); + } + }, + debugName); + } + + ResourceType& get() + { + return mResource; + } + + ResourceType const& get() const + { + return mResource; + } + + NcclCommResource(NcclCommResource const&) = delete; + NcclCommResource& operator=(NcclCommResource const&) = delete; + NcclCommResource(NcclCommResource&&) = delete; + NcclCommResource& operator=(NcclCommResource&&) = delete; + +private: + ncclComm_t mComm; + ResourceType mResource; + std::function mCleanup; + bool mRegistered; +}; + +//============================================================================== +// NCCL Window Buffer Allocation +//============================================================================== + +// Represents a buffer with an associated NCCL window +struct NCCLWindowBuffer +{ + void* ptr; // Device pointer (same as UBBuffer.addr) + int handle; // Buffer handle/index (for compatibility with UB interface) + size_t size; // Size in bytes + ncclWindow_t window; // NCCL window handle + + NCCLWindowBuffer(void* p = nullptr, int h = -1, size_t s = 0, ncclWindow_t w = nullptr) + : ptr(p) + , handle(h) + , size(s) + , window(w) + { + } + + [[nodiscard]] bool isValid() const + { + return ptr != nullptr && handle >= 0 && size > 0 && window != nullptr; + } + + [[nodiscard]] bool invalid() const + { + return !isValid(); + } + + // Alias for compatibility with UBBuffer interface + void* addr() const + { + return ptr; + } +}; + +// Manages NCCL window-registered buffers with pooling and automatic cleanup. +// Buffers are tied to the lifetime of their associated NCCL communicator. +class NCCLWindowAllocator +{ +public: + static NCCLWindowAllocator& getInstance(); + + // Request a buffer for the given communicator and size. + // If an unused buffer of at least the requested size exists for this communicator, it will be reused. + // Uses best-fit strategy: selects the smallest available buffer that meets the size requirement. + // Otherwise, a new buffer is allocated and registered. + NCCLWindowBuffer requestBuffer(ncclComm_t comm, size_t size); + + // Search for a buffer by pointer. Returns an invalid buffer if not found. + // This matches the UBManager.search_buffer() interface. + NCCLWindowBuffer searchBuffer(ncclComm_t comm, void* ptr) const; + + // Release a buffer back to the pool for potential reuse + void releaseBuffer(ncclComm_t comm, void* ptr); + + // Get the window handle for a specific buffer pointer + ncclWindow_t getWindow(ncclComm_t comm, void* ptr) const; + + // Get the size of a specific buffer pointer + size_t getSize(ncclComm_t comm, void* ptr) const; + + // Get buffer info by pointer + NCCLWindowBuffer getBufferInfo(ncclComm_t comm, void* ptr) const; + + // Get the number of buffers allocated for a communicator + size_t getBufferCount(ncclComm_t comm) const; + + // Get the number of buffers in use for a communicator + size_t getBufferInUseCount(ncclComm_t comm) const; + + // Check if a communicator is valid (non-null) + // Note: We don't track cleaned-up comms because NCCL can reuse memory addresses. + // All non-null comms are considered valid and will be registered when first used. + bool isCommValid(ncclComm_t comm) const noexcept; + + NCCLWindowAllocator(NCCLWindowAllocator const&) = delete; + NCCLWindowAllocator& operator=(NCCLWindowAllocator const&) = delete; + NCCLWindowAllocator(NCCLWindowAllocator&&) = delete; + NCCLWindowAllocator& operator=(NCCLWindowAllocator&&) = delete; + +private: + NCCLWindowAllocator() = default; + ~NCCLWindowAllocator() = default; + + // Allocate a new buffer and register it with NCCL as a window + NCCLWindowBuffer allocateAndRegisterBuffer(ncclComm_t comm, size_t size, int handle); + + // Search for a buffer by pointer (assumes mMutex is already locked) + NCCLWindowBuffer searchBufferLocked(ncclComm_t comm, void* ptr) const; + + // Register cleanup function for all buffers associated with a communicator + void registerBufferCleanup(ncclComm_t comm); + + // Cleanup all buffers for a specific communicator + void cleanupBuffersForComm(ncclComm_t comm) noexcept; + + struct BufferEntry + { + NCCLWindowBuffer buffer; + bool inUse; + }; + + mutable std::mutex mMutex; + std::unordered_map> mBufferPool; + std::unordered_set mRegisteredComms; +}; + +// RAII wrapper for NCCL window buffers +class ScopedNCCLWindowBuffer +{ +public: + ScopedNCCLWindowBuffer(ncclComm_t comm, size_t size) + : mComm(comm) + , mBuffer(NCCLWindowAllocator::getInstance().requestBuffer(comm, size)) + { + } + + ~ScopedNCCLWindowBuffer() + { + if (mBuffer.isValid()) + { + NCCLWindowAllocator::getInstance().releaseBuffer(mComm, mBuffer.ptr); + } + } + + void* getPtr() const + { + return mBuffer.ptr; + } + + size_t getSize() const + { + return mBuffer.size; + } + + ncclWindow_t getWindow() const + { + return mBuffer.window; + } + + NCCLWindowBuffer const& getBuffer() const + { + return mBuffer; + } + + ScopedNCCLWindowBuffer(ScopedNCCLWindowBuffer const&) = delete; + ScopedNCCLWindowBuffer& operator=(ScopedNCCLWindowBuffer const&) = delete; + ScopedNCCLWindowBuffer(ScopedNCCLWindowBuffer&&) = delete; + ScopedNCCLWindowBuffer& operator=(ScopedNCCLWindowBuffer&&) = delete; + +private: + ncclComm_t mComm; + NCCLWindowBuffer mBuffer; +}; + +// Creates a PyTorch tensor backed by an NCCL window buffer. +// The tensor will automatically release the buffer back to the pool when destroyed. +// This is analogous to torch_ext::create_userbuffers_tensor() but for NCCLWindowAllocator. +inline std::pair createNCCLWindowTensor( + ncclComm_t comm, at::IntArrayRef shape, torch::ScalarType dtype) +{ + // Calculate buffer size + int64_t buffer_size + = std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies()) * torch::elementSize(dtype); + + // Calculate strides + std::vector strides_vec(shape.size()); + if (!shape.empty()) + { + strides_vec[shape.size() - 1] = 1; + for (int64_t i = static_cast(shape.size()) - 1; i >= 1; --i) + { + strides_vec[i - 1] = strides_vec[i] * shape[i]; + } + } + + // Request buffer from allocator + auto& allocator = NCCLWindowAllocator::getInstance(); + auto buffer = allocator.requestBuffer(comm, buffer_size); + + // Defensive validation: ensure buffer is valid before proceeding + if (!buffer.isValid()) + { + std::ostringstream oss; + oss << "Failed to allocate NCCL window buffer: invalid buffer returned from requestBuffer " + << "(comm=" << static_cast(comm) << ", buffer_size=" << buffer_size << ")"; + throw std::runtime_error(oss.str()); + } + + // Create custom deleter that releases the buffer + auto deleter = [comm, ptr = buffer.ptr](void*) { NCCLWindowAllocator::getInstance().releaseBuffer(comm, ptr); }; + + // Create tensor from the buffer + auto tensor = torch::from_blob(buffer.ptr, shape, strides_vec, deleter, torch::dtype(dtype).device(torch::kCUDA)); + + return std::make_pair(tensor, buffer); +} + +} // namespace tensorrt_llm::common::nccl_util + +#endif // ENABLE_MULTI_DEVICE diff --git a/cpp/tensorrt_llm/common/opUtils.cpp b/cpp/tensorrt_llm/common/opUtils.cpp index 736cd1c48d0..72d966e43d7 100644 --- a/cpp/tensorrt_llm/common/opUtils.cpp +++ b/cpp/tensorrt_llm/common/opUtils.cpp @@ -16,6 +16,7 @@ */ #include "tensorrt_llm/common/opUtils.h" +#include "tensorrt_llm/common/ncclUtils.h" #include "tensorrt_llm/runtime/utils/mpiTags.h" #include "tensorrt_llm/runtime/utils/mpiUtils.h" @@ -112,7 +113,29 @@ std::shared_ptr getComm(std::set const& group) std::shared_ptr ncclComm(new ncclComm_t, [](ncclComm_t* comm) { - ncclCommDestroy(*comm); + if (!comm) + { + return; + } + + // STEP 1: Clean up resources and destroy NCCL communicator if it's valid + if (*comm) + { + // Clean up all registered resources FIRST + tensorrt_llm::common::nccl_util::NcclCommResourceManager::getInstance().cleanupResources(*comm); + + // Now destroy the NCCL communicator + ncclResult_t result = ncclCommDestroy(*comm); + if (result != ncclSuccess) + { + TLLM_LOG_WARNING("ncclCommDestroy failed with error: %d", result); + } + + // Clear the communicator value before freeing the pointer + *comm = nullptr; + } + + // STEP 2: Always free the pointer memory (regardless of whether *comm was valid) delete comm; }); #if defined(_WIN32) diff --git a/cpp/tensorrt_llm/kernels/userbuffers/ub_allocator.cpp b/cpp/tensorrt_llm/kernels/userbuffers/ub_allocator.cpp index e0f2d5cce2e..2e3e6dde664 100644 --- a/cpp/tensorrt_llm/kernels/userbuffers/ub_allocator.cpp +++ b/cpp/tensorrt_llm/kernels/userbuffers/ub_allocator.cpp @@ -22,16 +22,8 @@ namespace tensorrt_llm::runtime::ub { UserBufferAllocator& UserBufferAllocator::Instance() { - if (use_nccl_symmetric) - { - static NCCLUserBufferAllocator _; - return _; - } - else - { - static UserBufferAllocator _; - return _; - } + static UserBufferAllocator _; + return _; } void UserBufferAllocator::initialize(tensorrt_llm::runtime::WorldConfig const& worldConfig) @@ -83,167 +75,4 @@ communicator* UserBufferAllocator::comm() return mUbComm; } -void NCCLUserBufferAllocator::initialize(tensorrt_llm::runtime::WorldConfig const& worldConfig) -{ - if (!isInitialized()) - { - TLLM_LOG_INFO("Initializing NCCLUserBufferAllocator"); - std::set group; - for (int i = 0; i < worldConfig.getSize(); i++) - { - group.insert(i); - } - mComm = getComm(group); - mIsInitialized = true; - } -} - -UBBuffer NCCLUserBufferAllocator::registerUBBuffer(size_t bytes) -{ - TLLM_CHECK(isInitialized()); - UBBuffer ub_buffer; - - auto& ncclHelper = getNCCLHelper(); - if (!ncclHelper.isLoaded()) - { - TLLM_THROW("NCCL library could not be loaded for dynamic symbol access"); - } - - auto ncclMemAllocFunc = ncclHelper.getNCCLMemAlloc(); - auto ncclCommWindowRegisterFunc = ncclHelper.getNCCLCommWindowRegister(); - - NCCLCHECK(ncclMemAllocFunc(&ub_buffer.addr, bytes)); - NCCLCHECK(ncclCommWindowRegisterFunc((*mComm), ub_buffer.addr, bytes, &ub_buffer.window, NCCL_WIN_COLL_SYMMETRIC)); - ub_buffer.handle = 5; - ub_buffer.size = bytes; - return ub_buffer; -} - -// Static member definitions -std::unique_ptr NCCLUserBufferAllocator::mNCCLHelper = nullptr; - -NCCLHelper& NCCLUserBufferAllocator::getNCCLHelper() -{ - if (!mNCCLHelper) - { - mNCCLHelper = std::make_unique(); - } - return *mNCCLHelper; -} - -// NCCLHelper implementation -NCCLHelper::NCCLHelper() - : mLibraryHandle(nullptr) - , mNCCLCommWindowRegister(nullptr) - , mNCCLMemAlloc(nullptr) - , mIsLoaded(false) -{ - loadNCCLLibrary(); -} - -NCCLHelper::~NCCLHelper() -{ - if (mLibraryHandle) - { -#ifdef _WIN32 - FreeLibrary(mLibraryHandle); -#else - dlclose(mLibraryHandle); -#endif - mLibraryHandle = nullptr; - } -} - -void NCCLHelper::loadNCCLLibrary() -{ - try - { -#ifdef _WIN32 - char const* libraryNames[] = {"nccl.dll"}; -#else - char const* libraryNames[] = {"libnccl.so"}; -#endif - - for (int i = 0; libraryNames[i] != nullptr; ++i) - { - mLibraryHandle = loadLibraryHandle(libraryNames[i]); - if (mLibraryHandle) - { - TLLM_LOG_INFO("Successfully loaded NCCL library: %s", libraryNames[i]); - break; - } - } - - if (!mLibraryHandle) - { - TLLM_LOG_WARNING("Failed to load NCCL library"); - return; - } - - // Load the required symbols - mNCCLCommWindowRegister - = reinterpret_cast(getSymbolAddress(mLibraryHandle, "ncclCommWindowRegister")); - - mNCCLMemAlloc = reinterpret_cast(getSymbolAddress(mLibraryHandle, "ncclMemAlloc")); - - if (mNCCLCommWindowRegister == nullptr) - { - TLLM_LOG_WARNING("Failed to load ncclCommWindowRegister symbol, NCCL symmetric will not be supported."); - } - - if (mNCCLMemAlloc) - { - mIsLoaded = true; - } - else - { - TLLM_LOG_WARNING("Failed to load required NCCL symbols"); - } - } - catch (std::exception const& e) - { - TLLM_LOG_WARNING("Exception while loading NCCL library: %s", e.what()); - } -} - -void* NCCLHelper::loadLibraryHandle(char const* libName) -{ -#ifdef _WIN32 - return LoadLibraryA(libName); -#else - return dlopen(libName, RTLD_LAZY | RTLD_GLOBAL); -#endif -} - -void* NCCLHelper::getSymbolAddress(void* handle, char const* symbolName) -{ - if (!handle) - { - return nullptr; - } - -#ifdef _WIN32 - return GetProcAddress(static_cast(handle), symbolName); -#else - return dlsym(handle, symbolName); -#endif -} - -NCCLHelper::ncclCommWindowRegisterFunc NCCLHelper::getNCCLCommWindowRegister() -{ - return mNCCLCommWindowRegister; -} - -NCCLHelper::ncclMemAllocFunc NCCLHelper::getNCCLMemAlloc() -{ - return mNCCLMemAlloc; -} - -bool NCCLHelper::isLoaded() const -{ - return mIsLoaded; -} - -bool UserBufferAllocator::use_nccl_symmetric = false; - }; // namespace tensorrt_llm::runtime::ub diff --git a/cpp/tensorrt_llm/kernels/userbuffers/ub_allocator.h b/cpp/tensorrt_llm/kernels/userbuffers/ub_allocator.h index 4cc91497054..05a4b6dd4e7 100644 --- a/cpp/tensorrt_llm/kernels/userbuffers/ub_allocator.h +++ b/cpp/tensorrt_llm/kernels/userbuffers/ub_allocator.h @@ -19,11 +19,6 @@ #if ENABLE_MULTI_DEVICE #include "nccl.h" #include "userbuffers.h" -#ifdef _WIN32 -#include -#else -#include -#endif #else using ncclWindow_t = void*; #endif @@ -69,8 +64,6 @@ class UserBufferAllocator communicator* comm(); virtual UBBuffer registerUBBuffer(size_t bytes); - static bool use_nccl_symmetric; - private: communicator* mUbComm; @@ -80,55 +73,6 @@ class UserBufferAllocator tensorrt_llm::runtime::WorldConfig mWorldConfig; }; -class NCCLHelper -{ -public: - NCCLHelper(); - ~NCCLHelper(); - - // Dynamic loading function type definition - using ncclCommWindowRegisterFunc = ncclResult_t (*)(ncclComm_t, void*, size_t, ncclWindow_t*, int); - using ncclMemAllocFunc = ncclResult_t (*)(void**, size_t); - - // Get function pointer for ncclCommWindowRegister - ncclCommWindowRegisterFunc getNCCLCommWindowRegister(); - - // Get function pointer for ncclMemAlloc - ncclMemAllocFunc getNCCLMemAlloc(); - - // Check if NCCL library is successfully loaded - bool isLoaded() const; - -private: - void loadNCCLLibrary(); - void* loadLibraryHandle(char const* libName); - void* getSymbolAddress(void* handle, char const* symbolName); - -#ifdef _WIN32 - HMODULE mLibraryHandle; -#else - void* mLibraryHandle; -#endif - - ncclCommWindowRegisterFunc mNCCLCommWindowRegister; - ncclMemAllocFunc mNCCLMemAlloc; - bool mIsLoaded; -}; - -class NCCLUserBufferAllocator : public UserBufferAllocator -{ -public: - void initialize(tensorrt_llm::runtime::WorldConfig const& world_config) override; - UBBuffer registerUBBuffer(size_t bytes) override; - - // Get shared NCCLHelper instance - static NCCLHelper& getNCCLHelper(); - -private: - std::shared_ptr mComm; - static std::unique_ptr mNCCLHelper; -}; - #else using communicator = void; #endif diff --git a/cpp/tensorrt_llm/kernels/userbuffers/userbuffersManager.cpp b/cpp/tensorrt_llm/kernels/userbuffers/userbuffersManager.cpp index a1fcd3c01fb..df2a549b8dc 100644 --- a/cpp/tensorrt_llm/kernels/userbuffers/userbuffersManager.cpp +++ b/cpp/tensorrt_llm/kernels/userbuffers/userbuffersManager.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ #include "userbuffersManager.h" +#include "tensorrt_llm/common/logger.h" namespace tensorrt_llm::runtime::ub { @@ -29,14 +30,11 @@ UserBuffersManager& UserBuffersManager::get_instance() return allocator; } -void UserBuffersManager::initialize(int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank, - int64_t gpus_per_node, int64_t buffer_size, bool use_nccl_symmetric) +void UserBuffersManager::initialize( + int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank, int64_t gpus_per_node, int64_t buffer_size) { std::lock_guard lock(mutex_); tensorrt_llm::runtime::WorldConfig world_config(tp_size, pp_size, cp_size, rank, gpus_per_node); -#if ENABLE_MULTI_DEVICE - UserBufferAllocator::Instance().use_nccl_symmetric = use_nccl_symmetric; -#endif tensorrt_llm::runtime::ub::ub_initialize(world_config); TLLM_CHECK(tensorrt_llm::runtime::ub::ub_is_initialized()); buffer_size_ = buffer_size; @@ -98,11 +96,10 @@ tensorrt_llm::runtime::ub::communicator* UserBuffersManager::comm() return tensorrt_llm::runtime::ub::ub_comm(); } -void initialize_userbuffers_manager(int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank, - int64_t gpus_per_node, int64_t buffer_size, bool use_nccl_symmetric) +void initialize_userbuffers_manager( + int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank, int64_t gpus_per_node, int64_t buffer_size) { - UserBuffersManager::get_instance().initialize( - tp_size, pp_size, cp_size, rank, gpus_per_node, buffer_size, use_nccl_symmetric); + UserBuffersManager::get_instance().initialize(tp_size, pp_size, cp_size, rank, gpus_per_node, buffer_size); } } // namespace tensorrt_llm::runtime::ub diff --git a/cpp/tensorrt_llm/kernels/userbuffers/userbuffersManager.h b/cpp/tensorrt_llm/kernels/userbuffers/userbuffersManager.h index 1b34f8e8a17..7ec39db602c 100644 --- a/cpp/tensorrt_llm/kernels/userbuffers/userbuffersManager.h +++ b/cpp/tensorrt_llm/kernels/userbuffers/userbuffersManager.h @@ -46,9 +46,8 @@ class UserBuffersManager //! @param gpus_per_node The number of GPUs per node. //! @param buffer_size The size of the buffer to allocate. All buffers allocated by this manager will have this //! size. - //! @param use_nccl_symmetric Whether to use NCCL symmetric communication. - void initialize(int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank, int64_t gpus_per_node, - int64_t buffer_size, bool use_nccl_symmetric); + void initialize( + int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank, int64_t gpus_per_node, int64_t buffer_size); //! @brief Create a UB tensor from the given shape, strides and data type. The function will choose available UB //! buffer or create a new one if no available buffer is found. @@ -76,7 +75,7 @@ class UserBuffersManager int64_t buffer_size_; }; -void initialize_userbuffers_manager(int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank, - int64_t gpus_per_node, int64_t buffer_size, bool use_nccl_symmetric); +void initialize_userbuffers_manager( + int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank, int64_t gpus_per_node, int64_t buffer_size); } // namespace tensorrt_llm::runtime::ub diff --git a/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.cpp b/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.cpp index 4241cf8d859..112364400dd 100644 --- a/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.cpp +++ b/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.cpp @@ -137,13 +137,12 @@ bool AllreducePlugin::supportsFormatCombination( int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept { int base_inputs = 0; - if (mStrategy == AllReduceStrategyType::NCCL || mStrategy == AllReduceStrategyType::UB) + switch (mStrategy) { - base_inputs = 1; - } - else - { - base_inputs = 2; + case AllReduceStrategyType::NCCL: + case AllReduceStrategyType::UB: + case AllReduceStrategyType::NCCL_SYMMETRIC: base_inputs = 1; break; + default: base_inputs = 2; break; } int fusion_op_extra_inputs = 0; int scale_idx = 0; @@ -169,9 +168,15 @@ bool AllreducePlugin::supportsFormatCombination( TLLM_CHECK(nbInputs == (base_inputs + fusion_op_extra_inputs)); - if (mStrategy != AllReduceStrategyType::NCCL && mStrategy != AllReduceStrategyType::UB && pos == 1) + if (pos == 1) { - return (inOut[pos].type == nvinfer1::DataType::kINT64) && (inOut[pos].format == TensorFormat::kLINEAR); + switch (mStrategy) + { + case AllReduceStrategyType::NCCL: + case AllReduceStrategyType::UB: + case AllReduceStrategyType::NCCL_SYMMETRIC: break; + default: return (inOut[pos].type == nvinfer1::DataType::kINT64) && (inOut[pos].format == TensorFormat::kLINEAR); + } } if (mStrategy == AllReduceStrategyType::UB) { @@ -222,25 +227,26 @@ AllReduceStrategyType AllreducePlugin::selectImplementation( { if (!isAuto) { - TLLM_LOG_INFO("Since Peer to Peer not supported, fallback to AllReduceStrategy: NCCL"); + TLLM_LOG_INFO("Since Peer to Peer not supported, fallback to AllReduceStrategy: NCCL_SYMMETRIC"); } else if (forceDeterministic) { TLLM_LOG_WARNING( - "Since Peer to Peer not supported, fallback to AllReduceStrategy: NCCL. NCCL might produce " + "Since Peer to Peer not supported, fallback to AllReduceStrategy: NCCL_SYMMETRIC. NCCL_SYMMETRIC might " + "produce " "non-deterministic results."); } - return AllReduceStrategyType::NCCL; + return AllReduceStrategyType::NCCL_SYMMETRIC; } if (isAuto && !mIsNVLINKSupported && !forceDeterministic) { - return AllReduceStrategyType::NCCL; + return AllReduceStrategyType::NCCL_SYMMETRIC; } auto const maxWorkspaceSize = utils::customAllReduceUtils::getMaxRequiredWorkspaceSize(worldSize); - AllReduceStrategyType strat = AllReduceStrategyType::NCCL; + AllReduceStrategyType strat = AllReduceStrategyType::NCCL_SYMMETRIC; auto const messageSizeBytes = messageSize * common::getDTypeSize(type); if (messageSizeBytes <= maxWorkspaceSize) @@ -268,7 +274,7 @@ AllReduceStrategyType AllreducePlugin::selectImplementation( } else { - strat = AllReduceStrategyType::NCCL; + strat = AllReduceStrategyType::NCCL_SYMMETRIC; } } else @@ -279,7 +285,7 @@ AllReduceStrategyType AllreducePlugin::selectImplementation( } else { - strat = AllReduceStrategyType::NCCL; + strat = AllReduceStrategyType::NCCL_SYMMETRIC; } } @@ -287,30 +293,31 @@ AllReduceStrategyType AllreducePlugin::selectImplementation( { if (!isAuto) { - TLLM_LOG_WARNING("Since not aligned, fallback to AllReduceStrategy: NCCL"); + TLLM_LOG_WARNING("Since not aligned, fallback to AllReduceStrategy: NCCL_SYMMETRIC"); } else if (forceDeterministic) { TLLM_LOG_WARNING( - "Since not aligned, fallback to AllReduceStrategy: NCCL. NCCL might produce " + "Since not aligned, fallback to AllReduceStrategy: NCCL_SYMMETRIC. NCCL_SYMMETRIC might produce " "non-deterministic results."); } - strat = AllReduceStrategyType::NCCL; + strat = AllReduceStrategyType::NCCL_SYMMETRIC; } } else { if (!isAuto) { - TLLM_LOG_WARNING("Since messageSize > maxWorkspace, fallback to AllReduceStrategy: NCCL"); + TLLM_LOG_WARNING("Since messageSize > maxWorkspace, fallback to AllReduceStrategy: NCCL_SYMMETRIC"); } else if (forceDeterministic) { TLLM_LOG_WARNING( - "Since messageSize > maxWorkspace, fallback to AllReduceStrategy: NCCL. NCCL might produce " + "Since messageSize > maxWorkspace, fallback to AllReduceStrategy: NCCL_SYMMETRIC. NCCL_SYMMETRIC might " + "produce " "non-deterministic results."); } - strat = AllReduceStrategyType::NCCL; + strat = AllReduceStrategyType::NCCL_SYMMETRIC; } return strat; @@ -337,6 +344,10 @@ int AllreducePlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfe { runtimeStrategy = AllReduceStrategyType::NCCL; } + else if (mStrategy == AllReduceStrategyType::NCCL_SYMMETRIC) + { + runtimeStrategy = AllReduceStrategyType::NCCL_SYMMETRIC; + } else if (mStrategy == AllReduceStrategyType::UB) { runtimeStrategy = AllReduceStrategyType::UB; @@ -355,6 +366,11 @@ int AllreducePlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfe TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: NCCL", rank); break; } + case AllReduceStrategyType::NCCL_SYMMETRIC: + { + TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: NCCL_SYMMETRIC", rank); + break; + } case AllReduceStrategyType::ONESHOT: { TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: ONESHOT", rank); @@ -373,14 +389,14 @@ int AllreducePlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfe default: break; } - if (runtimeStrategy == AllReduceStrategyType::NCCL) + if (runtimeStrategy == AllReduceStrategyType::NCCL || runtimeStrategy == AllReduceStrategyType::NCCL_SYMMETRIC) { if (mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM || mOp == AllReduceFusionOp::RESIDUAL_RMS_PREPOST_NORM) { NCCLCHECK(ncclAllReduce(inputs[0], outputs[1], size, (*getDtypeMap())[mType], ncclSum, *mNcclComm, stream)); tensorrt_llm::kernels::AllReduceParams params; int fusion_ptr_idx = 0; - if (mStrategy == AllReduceStrategyType::NCCL) + if (mStrategy == AllReduceStrategyType::NCCL || mStrategy == AllReduceStrategyType::NCCL_SYMMETRIC) { fusion_ptr_idx = 1; } diff --git a/cpp/tensorrt_llm/thop/allreduceOp.cpp b/cpp/tensorrt_llm/thop/allreduceOp.cpp index 21018e241da..fbd60d1ec57 100644 --- a/cpp/tensorrt_llm/thop/allreduceOp.cpp +++ b/cpp/tensorrt_llm/thop/allreduceOp.cpp @@ -15,10 +15,12 @@ * limitations under the License. */ +#include "tensorrt_llm/common/cudaDriverWrapper.h" #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/common/customAllReduceUtils.h" #include "tensorrt_llm/common/dataType.h" #include "tensorrt_llm/common/mcastDevMemUtils.h" +#include "tensorrt_llm/common/ncclUtils.h" #include "tensorrt_llm/common/opUtils.h" #include "tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.h" #include "tensorrt_llm/kernels/communicationKernels/customLowPrecisionAllReduceKernels.h" @@ -39,6 +41,7 @@ #if ENABLE_MULTI_DEVICE #include #include +#include #include #include #include @@ -51,6 +54,7 @@ #include #include +#include #include // using namespace nvinfer1; @@ -238,6 +242,9 @@ class AllreduceOp AllreduceOp( std::set group, nvinfer1::DataType type, AllReduceStrategyType strategy, AllReduceFusionOp op, float eps) : mGroup(std::move(group)) + , mIsNVLINKSupported(false) + , mIsP2PSupported(false) + , mIsMNNVLSupported(false) , mType(type) , mStrategy(strategy) , mOp(op) @@ -248,6 +255,9 @@ class AllreduceOp AllreduceOp(std::set group, c10::intrusive_ptr const& process_group_, nvinfer1::DataType type, AllReduceStrategyType strategy, AllReduceFusionOp op, float eps) : mGroup(std::move(group)) + , mIsNVLINKSupported(false) + , mIsP2PSupported(false) + , mIsMNNVLSupported(false) , mType(type) , mStrategy(strategy) , mOp(op) @@ -437,44 +447,117 @@ class AllreduceOp torch::optional const& residual, torch::optional const& norm_weight, torch::optional const& scale, torch::optional const& bias) { + // Handle ProcessGroup path first - cannot extract NCCL comm for window registration + // Use ProcessGroup's allreduce directly and return early + if (mNcclComm.index() == 1) + { + auto torchPg = std::get<1>(mNcclComm); + + torch::Tensor reduceOutput = input.clone(); + std::vector tensors{reduceOutput}; + PGCHECK_THROW(torchPg->allreduce(tensors, {c10d::ReduceOp::SUM})); + + if (mOp == AllReduceFusionOp::NONE) + { + return {reduceOutput}; + } + + // Treat any other patterns as fallback cases. + return fallbackRunSubsequentOps(input, residual, norm_weight, scale, bias, reduceOutput); + } + + // From here on, we have a raw NCCL comm - can proceed with window registration + auto rawComm = std::get<0>(mNcclComm); + ncclComm_t comm = *rawComm; + TLLM_CHECK_WITH_INFO(comm != nullptr, "NCCL communicator is null"); + TLLM_LOG_DEBUG("[runNCCLAllReduceSymmetric] Using raw NCCL comm path (not ProcessGroup)"); + + using tensorrt_llm::common::nccl_util::NCCLWindowAllocator; + using tensorrt_llm::common::nccl_util::createNCCLWindowTensor; auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); int size = input.numel(); - auto& ub_manager = tensorrt_llm::runtime::ub::UserBuffersManager::get_instance(); - auto ub_tensor0 = input; - auto ub_buffer0 = ub_manager.search_buffer(input.data_ptr()); - if (ub_buffer0.invalid()) + size_t bufferSizeBytes = size * input.element_size(); + + // Using unregistered input buffers with NCCL symmetric, requires a memcpy + // This is an overhead introduced with using NCCL_SYMMTRIC over NCCL. + // Both the memcpy and the perf benefit from using NCCL_SYMMETRIC scale linear with the message size. + // But a local memcpy is cheaper than the remote operations, so with larger message sizes the benefit is + // stronger. Additionally, the perf benefit scales with the number of ranks, since multimem enables O(const.) + // versus O(N) complexity. Hence we model this cutoff with a linear model. The numbers below were obtained on + // GB200, scanning different message sizes and ranks. You can determine the regression onset for each number of + // ranks to a single message size. And the following formula was obtained by fitting a linear model to the + // regression onset. It is possible to override this empirical heuristic with the TLLM_NCCL_MIN_REGISTRATION + // environment variable. + double const a = -4986.43478503; + double const b = 156716.52177552; + int nRanks; + NCCLCHECK_THROW(ncclCommCount(comm, &nRanks)); + size_t minRegistrationThreshold = static_cast(std::max(0.0, a * nRanks + b)) * input.element_size(); + // Disable window registration if neither NVLink nor MNNVL is supported + // TODO replace in NCCL 2.29 with comm query + if (!mIsNVLINKSupported && !mIsMNNVLSupported) { - auto [symmetric_input, symmetric_ub_buffer0] - = torch_ext::create_userbuffers_tensor(input.sizes(), input.scalar_type()); - cudaMemcpyAsync(symmetric_ub_buffer0.addr, input.data_ptr(), size * input.element_size(), - cudaMemcpyDeviceToDevice, stream); - ub_buffer0 = symmetric_ub_buffer0; - ub_tensor0 = symmetric_input; + minRegistrationThreshold = std::numeric_limits::max(); + } + char const* envThreshold = std::getenv("TLLM_NCCL_MIN_REGISTRATION"); + if (envThreshold != nullptr) + { + minRegistrationThreshold = static_cast(std::atoi(envThreshold)) * input.element_size(); } - TLLM_CHECK(!ub_buffer0.invalid()); - auto [norm_out, ub_buffer1] = torch_ext::create_userbuffers_tensor(input.sizes(), input.scalar_type()); + // Search for existing buffer + auto& allocator = NCCLWindowAllocator::getInstance(); + auto windowBuffer0 = allocator.searchBuffer(comm, input.data_ptr()); - std::visit(overloaded{[&, norm_out_ = norm_out](std::shared_ptr& rawComm) - { - NCCLCHECK_THROW(ncclAllReduce(ub_buffer0.addr, norm_out_.mutable_data_ptr(), size, - (*getDtypeMap())[mType], ncclSum, *rawComm, stream)); - }, - [&, norm_out_ = norm_out](c10::intrusive_ptr& torchPg) - { - PGCHECK_THROW(PgHelper{torchPg}.allreduce(ub_tensor0, {c10d::ReduceOp::SUM})); - std::ignore = norm_out_.copy_(ub_tensor0, true); - }}, - mNcclComm); + torch::Tensor inputTensor = input; + void* inputPtr = input.data_ptr(); + + // If buffer is not registered, decide whether to register based on size + if (!windowBuffer0.isValid()) + { + if (bufferSizeBytes < minRegistrationThreshold) + { + // Small buffer: use input directly without window registration + TLLM_LOG_DEBUG( + "[runNCCLAllReduceSymmetric] Buffer size %zu bytes < threshold %zu bytes, " + "skipping window registration", + bufferSizeBytes, minRegistrationThreshold); + // inputTensor and inputPtr remain pointing to original input + } + else + { + // Large buffer: create window buffer and copy input (can swap inputTensor reference) + auto [symmetricInput, symmetricBuffer0] + = createNCCLWindowTensor(comm, input.sizes(), input.scalar_type()); + TLLM_CUDA_CHECK(cudaMemcpyAsync( + symmetricBuffer0.ptr, input.data_ptr(), bufferSizeBytes, cudaMemcpyDeviceToDevice, stream)); + windowBuffer0 = symmetricBuffer0; + inputTensor = symmetricInput; // Swap to window-backed tensor + inputPtr = windowBuffer0.ptr; + } + } + else + { + // Buffer already registered - use it directly + inputPtr = windowBuffer0.ptr; + } + + // Use window-backed output buffer + auto [normOut, windowBuffer1] = createNCCLWindowTensor(comm, input.sizes(), input.scalar_type()); + torch::Tensor outputTensor = normOut; + void* outputPtr = windowBuffer1.ptr; + + // Perform allreduce + NCCLCHECK_THROW(ncclAllReduce(inputPtr, outputPtr, size, (*getDtypeMap())[mType], ncclSum, comm, stream)); if (mOp == AllReduceFusionOp::NONE) { - return {norm_out}; + return {outputTensor}; } // Treat any other patterns as fallback cases. - return fallbackRunSubsequentOps(input, residual, norm_weight, scale, bias, norm_out); + return fallbackRunSubsequentOps(input, residual, norm_weight, scale, bias, outputTensor); } std::vector runLowPrecisionAllReduce(torch::Tensor const& input, @@ -799,16 +882,104 @@ class AllreduceOp void initGroupTopology() { - static std::map, std::tuple> cache; + static std::map, std::tuple> cache; if (cache.find(mGroup) != cache.end()) { - auto [is_NVLINK_supported, is_P2P_supported] = cache[mGroup]; + auto [is_NVLINK_supported, is_P2P_supported, is_MNNVL_supported] = cache[mGroup]; mIsNVLINKSupported = is_NVLINK_supported; mIsP2PSupported = is_P2P_supported; + mIsMNNVLSupported = is_MNNVL_supported; return; } setGroupTopology(); - cache[mGroup] = {mIsNVLINKSupported, mIsP2PSupported}; + cache[mGroup] = {mIsNVLINKSupported, mIsP2PSupported, mIsMNNVLSupported}; + } + + bool checkMNNVLSupport(int device_id) + { +#if ENABLE_MULTI_DEVICE + // 1. Check CUDA driver version (needs >= 12.0.10) + int cuda_driver_version = -1; + TLLM_CUDA_CHECK(cudaDriverGetVersion(&cuda_driver_version)); + if (cuda_driver_version < 12010) + { + TLLM_LOG_DEBUG("MNNVL check: CUDA Driver version %d < 12010", cuda_driver_version); + return false; + } + + // 2. Check multicast support + CUdevice cu_device; + TLLM_CU_CHECK(cuDeviceGet(&cu_device, device_id)); + auto cuda_driver = tensorrt_llm::common::CUDADriverWrapper::getInstance(); + + int multicast_supported = 0; + TLLM_CU_CHECK(cuda_driver->cuDeviceGetAttribute( + &multicast_supported, CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, cu_device)); + if (!multicast_supported) + { + TLLM_LOG_DEBUG("MNNVL check: Device %d does not support multicast", device_id); + return false; + } + + // 3. Check fabric handle support + int fabric_handle_supported = 0; + TLLM_CU_CHECK(cuda_driver->cuDeviceGetAttribute( + &fabric_handle_supported, CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED, cu_device)); + if (!fabric_handle_supported) + { + TLLM_LOG_DEBUG("MNNVL check: Device %d does not support fabric handles", device_id); + return false; + } + + // 4. Check NVML GPU Fabric Info + nvmlDevice_t nvml_device; + NVML_CHECK_THROW(nvmlDeviceGetHandleByIndex(device_id, &nvml_device)); + + nvmlGpuFabricInfo_t fabric_info; + NVML_CHECK_THROW(nvmlDeviceGetGpuFabricInfo(nvml_device, &fabric_info)); + + // Check if fabric is fully initialized + if (fabric_info.state != NVML_GPU_FABRIC_STATE_COMPLETED || fabric_info.status != NVML_SUCCESS) + { + TLLM_LOG_DEBUG( + "MNNVL check: Fabric state not complete - state=%u status=%u", fabric_info.state, fabric_info.status); + return false; + } + + // 5. Check NVLink links are active (similar to Python support_nvlink(True)) + unsigned int active_links = 0; + unsigned int available_links = 0; + + for (unsigned int link = 0; link < NVML_NVLINK_MAX_LINKS; link++) + { + unsigned int cap_p2p = 0; + nvmlReturn_t cap_result + = nvmlDeviceGetNvLinkCapability(nvml_device, link, NVML_NVLINK_CAP_P2P_SUPPORTED, &cap_p2p); + if (cap_result == NVML_SUCCESS && cap_p2p) + { + available_links++; + nvmlEnableState_t link_state; + if (nvmlDeviceGetNvLinkState(nvml_device, link, &link_state) == NVML_SUCCESS + && link_state == NVML_FEATURE_ENABLED) + { + active_links++; + } + } + } + + bool all_links_up = (active_links == available_links && available_links > 0); + if (!all_links_up) + { + TLLM_LOG_DEBUG( + "MNNVL check: Not all NVLink links active - active=%u available=%u", active_links, available_links); + return false; + } + + TLLM_LOG_INFO("MNNVL check: Device %d supports MNNVL (fabric_clique=%u)", device_id, fabric_info.cliqueId); + return true; +#else + return false; +#endif } void setGroupTopology() @@ -820,107 +991,189 @@ class AllreduceOp [&](c10::intrusive_ptr& torchPg) { return getLocalGroupTorch(mGroup); }}, mNcclComm); - if (mGroup.size() != local_group.size()) - { - mIsP2PSupported = false; - mIsNVLINKSupported = false; - TLLM_LOG_INFO("Found inter-node TP group for rank %d", rank); - return; - } - TLLM_LOG_INFO("TP group is intra-node for rank %d", rank); + bool is_inter_node = (mGroup.size() != local_group.size()); NvmlManager nvml_manager; mIsP2PSupported = true; mIsNVLINKSupported = true; + mIsMNNVLSupported = false; - // TODO(ytong): Should we provide group topology info instead of querying it here? - // Use cudaDeviceCanAccessPeer to determine whether p2p is supported, - // and use nvml to determine whether there are nvlink links between ranks. - for (int first_device_id : local_group) + // First, check NVLink within local group (intra-node) + if (!local_group.empty()) { - for (int second_device_id : local_group) + for (int first_device_id : local_group) { - if (first_device_id >= second_device_id) - { - continue; - } - - int can_access_peer = 0; - TLLM_CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, first_device_id, second_device_id)); - - if (!can_access_peer) + for (int second_device_id : local_group) { - mIsP2PSupported = false; - mIsNVLINKSupported = false; - - return; - } - - nvmlDevice_t first_device; - NVML_CHECK_THROW(nvmlDeviceGetHandleByIndex(first_device_id, &first_device)); + if (first_device_id >= second_device_id) + { + continue; + } - bool is_NVLINK = false; + int can_access_peer = 0; + TLLM_CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, first_device_id, second_device_id)); - for (unsigned int link = 0; link < NVML_NVLINK_MAX_LINKS; link++) - { - nvmlPciInfo_t remote_pci_info; - if (nvmlDeviceGetNvLinkRemotePciInfo_v2(first_device, link, &remote_pci_info) != NVML_SUCCESS) + if (!can_access_peer) { + mIsP2PSupported = false; + mIsNVLINKSupported = false; + TLLM_LOG_INFO( + "P2P not supported between local devices %d and %d", first_device_id, second_device_id); + // Continue checking other pairs, but mark as not supported continue; } - nvmlDevice_t remote_device; - auto const result = nvmlDeviceGetHandleByPciBusId_v2(remote_pci_info.busId, &remote_device); + nvmlDevice_t first_device; + NVML_CHECK_THROW(nvmlDeviceGetHandleByIndex(first_device_id, &first_device)); - if (result == NVML_SUCCESS) - { - // Two GPUs are connected directly through nvlink - unsigned int remote_device_id; - NVML_CHECK_THROW(nvmlDeviceGetIndex(remote_device, &remote_device_id)); + bool is_NVLINK = false; - if (remote_device_id == static_cast(second_device_id)) + for (unsigned int link = 0; link < NVML_NVLINK_MAX_LINKS; link++) + { + nvmlPciInfo_t remote_pci_info; + if (nvmlDeviceGetNvLinkRemotePciInfo_v2(first_device, link, &remote_pci_info) != NVML_SUCCESS) { - is_NVLINK = true; + continue; } - } - else if (result == NVML_ERROR_NOT_FOUND) - { - // Maybe Two GPUs are connected via nvswitch, - // now remotePciInfo represents the pci information of nvswitch, - // determine whether nvlink is supported by whether two GPUs are connected to the same - // nvswitch. - nvmlDevice_t second_device; - NVML_CHECK_THROW(nvmlDeviceGetHandleByIndex(second_device_id, &second_device)); - - for (unsigned int second_link = 0; second_link < NVML_NVLINK_MAX_LINKS; second_link++) + + nvmlDevice_t remote_device; + auto const result = nvmlDeviceGetHandleByPciBusId_v2(remote_pci_info.busId, &remote_device); + + if (result == NVML_SUCCESS) { - nvmlPciInfo_t second_remote_pci_info; - if (nvmlDeviceGetNvLinkRemotePciInfo_v2(second_device, second_link, &second_remote_pci_info) - != NVML_SUCCESS) - { - continue; - } + // Two GPUs are connected directly through nvlink + unsigned int remote_device_id; + NVML_CHECK_THROW(nvmlDeviceGetIndex(remote_device, &remote_device_id)); - if (strcmp(remote_pci_info.busId, second_remote_pci_info.busId) == 0) + if (remote_device_id == static_cast(second_device_id)) { is_NVLINK = true; - break; } } - } - else - { - NVML_CHECK_THROW(result); - } + else if (result == NVML_ERROR_NOT_FOUND) + { + // Maybe Two GPUs are connected via nvswitch, + // now remotePciInfo represents the pci information of nvswitch, + // determine whether nvlink is supported by whether two GPUs are connected to the same + // nvswitch. + nvmlDevice_t second_device; + NVML_CHECK_THROW(nvmlDeviceGetHandleByIndex(second_device_id, &second_device)); + + for (unsigned int second_link = 0; second_link < NVML_NVLINK_MAX_LINKS; second_link++) + { + nvmlPciInfo_t second_remote_pci_info; + if (nvmlDeviceGetNvLinkRemotePciInfo_v2( + second_device, second_link, &second_remote_pci_info) + != NVML_SUCCESS) + { + continue; + } + + if (strcmp(remote_pci_info.busId, second_remote_pci_info.busId) == 0) + { + is_NVLINK = true; + break; + } + } + } + else + { + NVML_CHECK_THROW(result); + } - if (is_NVLINK) - { - break; + if (is_NVLINK) + { + break; + } } + + mIsNVLINKSupported &= is_NVLINK; } + } + } + + // For inter-node groups, check MNNVL support + if (is_inter_node) + { + TLLM_LOG_INFO("Found inter-node TP group for rank %d, checking MNNVL support", rank); - mIsNVLINKSupported &= is_NVLINK; + // Check MNNVL support on local device(s) + bool local_mnnvl_supported = false; + if (!local_group.empty()) + { + // Check MNNVL on first device in local group (all devices on same node should have same MNNVL status) + int check_device = *local_group.begin(); + local_mnnvl_supported = checkMNNVLSupport(check_device); + } + + // Gather MNNVL status from all ranks in the group + int local_mnnvl_status = local_mnnvl_supported ? 1 : 0; + std::vector all_mnnvl_status(mGroup.size()); + + std::visit(overloaded{[&](std::shared_ptr& comm_ptr) + { + // For NCCL comm, use MPI to gather status + // Use MPI allgather to collect MNNVL status + // Create a sub-communicator for the group + std::vector group_ranks(mGroup.begin(), mGroup.end()); + MPI_Group world_group, new_group; + MPI_Comm group_comm; + MPI_Comm_group(COMM_SESSION, &world_group); + MPI_Group_incl(world_group, group_ranks.size(), group_ranks.data(), &new_group); + MPI_Comm_create_group(COMM_SESSION, new_group, 0, &group_comm); + + if (group_comm != MPI_COMM_NULL) + { + MPI_Allgather(&local_mnnvl_status, 1, MPI_INT, all_mnnvl_status.data(), 1, MPI_INT, + group_comm); + MPI_Comm_free(&group_comm); + } + MPI_Group_free(&new_group); + MPI_Group_free(&world_group); + }, + [&](c10::intrusive_ptr& torchPg) + { + // For ProcessGroup, use allgather directly + // Note: This assumes the ProcessGroup is already set up for the correct group + std::vector input_tensors + = {torch::tensor({local_mnnvl_status}, torch::kInt32)}; + std::vector> output_tensors(1); + output_tensors[0].resize(mGroup.size()); + auto work = torchPg->allgather(output_tensors, input_tensors); + if (work) + { + work->wait(); + for (size_t i = 0; i < mGroup.size(); ++i) + { + all_mnnvl_status[i] = output_tensors[0][i].item(); + } + } + }}, + mNcclComm); + + // Check if all ranks support MNNVL + bool all_ranks_support_mnnvl = true; + for (int status : all_mnnvl_status) + { + if (status == 0) + { + all_ranks_support_mnnvl = false; + break; + } } + + // For inter-node: MNNVL support means all nodes have MNNVL + // Also need local NVLink for optimal performance + mIsMNNVLSupported = mIsNVLINKSupported && all_ranks_support_mnnvl; + mIsP2PSupported = false; // P2P doesn't work across nodes + + TLLM_LOG_INFO("Inter-node topology: local_NVLink=%d, local_MNNVL=%d, all_ranks_MNNVL=%d, final_MNNVL=%d", + mIsNVLINKSupported ? 1 : 0, local_mnnvl_status, all_ranks_support_mnnvl ? 1 : 0, + mIsMNNVLSupported ? 1 : 0); + } + else + { + TLLM_LOG_INFO("TP group is intra-node for rank %d", rank); } } @@ -951,12 +1204,12 @@ class AllreduceOp if (ifFallbackToNCCL(seq_len, message_size_bytes, max_workspace_size)) { - return AllReduceStrategyType::NCCL; + return AllReduceStrategyType::NCCL_SYMMETRIC; } - // This rule based heuristic only chooses between NCCL and MIN_LATENCY strategies. - // From this point, all fusion patterns are supported by all these strategies: NCCL, ONESHOT, TWOSHOT and - // MIN_LATENCY. + // This rule based heuristic only chooses between NCCL_SYMMETRIC and MIN_LATENCY strategies. + // From this point, all fusion patterns are supported by all these strategies: NCCL_SYMMETRIC, ONESHOT, TWOSHOT + // and MIN_LATENCY. if (mStrategy != AllReduceStrategyType::AUTO) { // Check TWOSHOT constraint: seq_len >= tp_size @@ -973,12 +1226,11 @@ class AllreduceOp return tensorrt_llm::utils::customAllReduceUtils::selectStrategyLookUpTable( seq_len, hidden_size, mOp, mGroup.size()); } - return AllReduceStrategyType::NCCL; } bool ifFallbackToNCCL(size_t seq_len, size_t message_size_bytes, size_t max_workspace_size) { - // If messageSize is less than maxWorkspaceSize, use NCCL, regardless of the fusion type. + // If messageSize is greater than maxWorkspaceSize or topology is unsuitable, use NCCL_SYMMETRIC fallback. if (message_size_bytes > max_workspace_size || !mIsP2PSupported || !mIsNVLINKSupported) { return true; @@ -1006,6 +1258,7 @@ class AllreduceOp std::set mGroup; bool mIsNVLINKSupported; bool mIsP2PSupported; + bool mIsMNNVLSupported; nvinfer1::DataType mType; AllReduceStrategyType mStrategy; AllReduceFusionOp mOp; diff --git a/cpp/tests/unit_tests/multi_gpu/CMakeLists.txt b/cpp/tests/unit_tests/multi_gpu/CMakeLists.txt index 5fb79c766cd..44b8e305778 100644 --- a/cpp/tests/unit_tests/multi_gpu/CMakeLists.txt +++ b/cpp/tests/unit_tests/multi_gpu/CMakeLists.txt @@ -20,3 +20,9 @@ target_link_libraries(cacheTransceiverTest PRIVATE ${Python3_LIBRARIES}) add_gtest(mpiUtilsTest mpiUtilsTest.cpp) add_gtest(userBufferTest userBufferTest.cpp) +add_gtest(ncclUtilsTest ncclUtilsTest.cpp) +target_link_libraries(ncclUtilsTest PRIVATE ${Python3_LIBRARIES}) +if(BUILD_PYT) + target_compile_definitions(ncclUtilsTest PUBLIC BUILD_PYT) + target_link_libraries(ncclUtilsTest PUBLIC ${TORCH_LIBRARIES}) +endif() diff --git a/cpp/tests/unit_tests/multi_gpu/ncclUtilsTest.cpp b/cpp/tests/unit_tests/multi_gpu/ncclUtilsTest.cpp new file mode 100644 index 00000000000..bf4ddd21418 --- /dev/null +++ b/cpp/tests/unit_tests/multi_gpu/ncclUtilsTest.cpp @@ -0,0 +1,745 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/common/ncclUtils.h" +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/common/logger.h" +#include "tensorrt_llm/common/opUtils.h" +#include "tensorrt_llm/runtime/utils/mpiUtils.h" + +#include +#include +#include +#include + +#if ENABLE_MULTI_DEVICE && BUILD_PYT +#include +#endif + +#if ENABLE_MULTI_DEVICE + +namespace mpi = tensorrt_llm::mpi; +namespace tr = tensorrt_llm::runtime; +namespace nccl_util = tensorrt_llm::common::nccl_util; + +using ::getComm; + +// Helper function to create a split communicator for testing +// This allows us to test cleanup behavior explicitly by controlling the lifetime +std::shared_ptr createSplitComm(ncclComm_t parentComm, int color, int key) +{ + ncclComm_t newComm; + ncclResult_t result = ncclCommSplit(parentComm, color, key, &newComm, nullptr); + if (result != ncclSuccess) + { + TLLM_THROW("ncclCommSplit failed with error: %d", result); + } + + // Create a shared_ptr with custom deleter that cleans up resources first + return std::shared_ptr(new ncclComm_t(newComm), + [](ncclComm_t* comm) + { + if (comm && *comm) + { + // STEP 1: Clean up all registered resources FIRST + tensorrt_llm::common::nccl_util::NcclCommResourceManager::getInstance().cleanupResources(*comm); + + // STEP 2: Now destroy the NCCL communicator + ncclResult_t result = ncclCommDestroy(*comm); + if (result != ncclSuccess) + { + TLLM_LOG_WARNING("ncclCommDestroy failed with error: %d", result); + } + + // STEP 3: Free the memory + delete comm; + } + }); +} + +//============================================================================== +// NcclCommResourceManager Tests +//============================================================================== + +class NcclCommResourceManagerTest : public ::testing::Test +{ +protected: + void SetUp() override + { + auto& comm = mpi::MpiComm::world(); + mWorldSize = comm.getSize(); + mRank = comm.getRank(); + + if (mWorldSize < 2) + { + GTEST_SKIP() << "Requires at least 2 ranks (got " << mWorldSize << ")"; + } + + // Set CUDA device for this rank (required before NCCL initialization) + int deviceCount = 0; + TLLM_CUDA_CHECK(cudaGetDeviceCount(&deviceCount)); + if (deviceCount > 0) + { + int deviceId = mRank % deviceCount; + TLLM_CUDA_CHECK(cudaSetDevice(deviceId)); + } + + // Create a communicator for testing + std::set group; + for (int i = 0; i < mWorldSize; ++i) + { + group.insert(i); + } + mComm = getComm(group); + } + + void TearDown() override + { + // Communicator cleanup happens automatically via shared_ptr deleter + mComm.reset(); + } + + int mWorldSize; + int mRank; + std::shared_ptr mComm; +}; + +TEST_F(NcclCommResourceManagerTest, ResourceRegistration) +{ + auto& manager = nccl_util::NcclCommResourceManager::getInstance(); + + // Create a separate comm using split for this test + auto testComm = createSplitComm(*mComm, 0, mRank); + + // Register a resource + bool cleanupCalled = false; + manager.registerResource( + *testComm, [&cleanupCalled]() { cleanupCalled = true; }, "TestResource"); + + EXPECT_TRUE(manager.hasResources(*testComm)); + EXPECT_EQ(manager.getResourceCount(*testComm), 1); + EXPECT_FALSE(cleanupCalled); // Cleanup not called yet + + // Store the raw comm value before destruction + ncclComm_t rawComm = *testComm; + + // Cleanup should be called when comm is destroyed + testComm.reset(); + + // Verify cleanup was called + EXPECT_TRUE(cleanupCalled); + + // Verify cleanup: check that the old comm (now destroyed) no longer has resources + // Note: The comm is destroyed, but we can still check the manager's internal state + // The cleanup should have removed all resources for this comm + EXPECT_FALSE(manager.hasResources(rawComm)); + EXPECT_EQ(manager.getResourceCount(rawComm), 0); +} + +TEST_F(NcclCommResourceManagerTest, MultipleResources) +{ + auto& manager = nccl_util::NcclCommResourceManager::getInstance(); + + // Create a separate comm using split for this test + auto testComm = createSplitComm(*mComm, 0, mRank); + + std::vector cleanupOrder; + manager.registerResource( + *testComm, [&cleanupOrder]() { cleanupOrder.push_back(1); }, "Resource1"); + manager.registerResource( + *testComm, [&cleanupOrder]() { cleanupOrder.push_back(2); }, "Resource2"); + manager.registerResource( + *testComm, [&cleanupOrder]() { cleanupOrder.push_back(3); }, "Resource3"); + + EXPECT_EQ(manager.getResourceCount(*testComm), 3); + + // Cleanup order should be preserved - destroy comm and verify order + testComm.reset(); + + // Verify cleanup order was preserved (1, 2, 3) + EXPECT_EQ(cleanupOrder.size(), 3); + EXPECT_EQ(cleanupOrder[0], 1); + EXPECT_EQ(cleanupOrder[1], 2); + EXPECT_EQ(cleanupOrder[2], 3); +} + +TEST_F(NcclCommResourceManagerTest, ResourceCount) +{ + auto& manager = nccl_util::NcclCommResourceManager::getInstance(); + + // Create a separate comm using split for this test + auto testComm = createSplitComm(*mComm, 0, mRank); + + EXPECT_FALSE(manager.hasResources(*testComm)); + EXPECT_EQ(manager.getResourceCount(*testComm), 0); + + manager.registerResource( + *testComm, []() {}, "Test1"); + EXPECT_EQ(manager.getResourceCount(*testComm), 1); + + manager.registerResource( + *testComm, []() {}, "Test2"); + EXPECT_EQ(manager.getResourceCount(*testComm), 2); + + testComm.reset(); +} + +//============================================================================== +// NCCLWindowAllocator Tests +//============================================================================== + +class NCCLWindowAllocatorTest : public ::testing::Test +{ +protected: + void SetUp() override + { + auto& comm = mpi::MpiComm::world(); + mWorldSize = comm.getSize(); + mRank = comm.getRank(); + + if (mWorldSize < 2) + { + GTEST_SKIP() << "Requires at least 2 ranks (got " << mWorldSize << ")"; + } + + // Set CUDA device for this rank (required before NCCL initialization) + int deviceCount = 0; + TLLM_CUDA_CHECK(cudaGetDeviceCount(&deviceCount)); + if (deviceCount > 0) + { + int deviceId = mRank % deviceCount; + TLLM_CUDA_CHECK(cudaSetDevice(deviceId)); + } + + // Check if NCCL symmetric is supported + auto& ncclHelper = nccl_util::NCCLHelper::getInstance(); + if (!ncclHelper.isLoaded()) + { + GTEST_SKIP() << "NCCL library with symmetric memory support is not available"; + } + + std::set group; + for (int i = 0; i < mWorldSize; ++i) + { + group.insert(i); + } + mComm = getComm(group); + } + + void TearDown() override + { + // Cleanup happens automatically + mComm.reset(); + } + + int mWorldSize; + int mRank; + std::shared_ptr mComm; +}; + +TEST_F(NCCLWindowAllocatorTest, BasicAllocation) +{ + auto& allocator = nccl_util::NCCLWindowAllocator::getInstance(); + + const size_t bufferSize = 1024 * 1024; // 1MB + auto buffer = allocator.requestBuffer(*mComm, bufferSize); + + EXPECT_TRUE(buffer.isValid()); + EXPECT_NE(buffer.ptr, nullptr); + EXPECT_NE(buffer.window, nullptr); + EXPECT_EQ(buffer.size, bufferSize); + EXPECT_GE(buffer.handle, 0); + + // Verify we can search for it + auto found = allocator.searchBuffer(*mComm, buffer.ptr); + EXPECT_TRUE(found.isValid()); + EXPECT_EQ(found.ptr, buffer.ptr); + + // Release the buffer + allocator.releaseBuffer(*mComm, buffer.ptr); +} + +TEST_F(NCCLWindowAllocatorTest, BufferReuse) +{ + auto& allocator = nccl_util::NCCLWindowAllocator::getInstance(); + + const size_t bufferSize = 512 * 1024; // 512KB + + // Allocate first buffer + auto buffer1 = allocator.requestBuffer(*mComm, bufferSize); + EXPECT_TRUE(buffer1.isValid()); + void* ptr1 = buffer1.ptr; + + // Release it + allocator.releaseBuffer(*mComm, ptr1); + + // Request another buffer of the same size - should reuse + auto buffer2 = allocator.requestBuffer(*mComm, bufferSize); + EXPECT_TRUE(buffer2.isValid()); + EXPECT_EQ(buffer2.ptr, ptr1); // Should be the same buffer + + allocator.releaseBuffer(*mComm, buffer2.ptr); +} + +TEST_F(NCCLWindowAllocatorTest, BestFitReuse) +{ + auto& allocator = nccl_util::NCCLWindowAllocator::getInstance(); + + // Allocate buffers of different sizes + auto buffer1MB = allocator.requestBuffer(*mComm, 1024 * 1024); + auto buffer2MB = allocator.requestBuffer(*mComm, 2 * 1024 * 1024); + auto buffer512KB = allocator.requestBuffer(*mComm, 512 * 1024); + + void* ptr1MB = buffer1MB.ptr; + void* ptr2MB = buffer2MB.ptr; + void* ptr512KB = buffer512KB.ptr; + + // Release all + allocator.releaseBuffer(*mComm, ptr1MB); + allocator.releaseBuffer(*mComm, ptr2MB); + allocator.releaseBuffer(*mComm, ptr512KB); + + // Request 768KB - should reuse 1MB (best fit, smallest that fits) + auto buffer768KB = allocator.requestBuffer(*mComm, 768 * 1024); + EXPECT_TRUE(buffer768KB.isValid()); + EXPECT_EQ(buffer768KB.ptr, ptr1MB); // Should reuse 1MB buffer + EXPECT_EQ(buffer768KB.size, 1024 * 1024); // Original size + + allocator.releaseBuffer(*mComm, buffer768KB.ptr); +} + +TEST_F(NCCLWindowAllocatorTest, MultipleBuffers) +{ + auto& allocator = nccl_util::NCCLWindowAllocator::getInstance(); + + const size_t bufferSize = 256 * 1024; + std::vector ptrs; + + // Allocate multiple buffers + for (int i = 0; i < 5; ++i) + { + auto buffer = allocator.requestBuffer(*mComm, bufferSize); + EXPECT_TRUE(buffer.isValid()); + ptrs.push_back(buffer.ptr); + } + + EXPECT_EQ(allocator.getBufferCount(*mComm), 5); + EXPECT_EQ(allocator.getBufferInUseCount(*mComm), 5); + + // Release all + for (auto* ptr : ptrs) + { + allocator.releaseBuffer(*mComm, ptr); + } + + EXPECT_EQ(allocator.getBufferInUseCount(*mComm), 0); + EXPECT_EQ(allocator.getBufferCount(*mComm), 5); // Buffers still exist, just not in use +} + +TEST_F(NCCLWindowAllocatorTest, SearchBuffer) +{ + auto& allocator = nccl_util::NCCLWindowAllocator::getInstance(); + + const size_t bufferSize = 128 * 1024; + auto buffer = allocator.requestBuffer(*mComm, bufferSize); + + // Test searchBuffer + auto found = allocator.searchBuffer(*mComm, buffer.ptr); + EXPECT_TRUE(found.isValid()); + EXPECT_EQ(found.ptr, buffer.ptr); + // Compare against actual allocated size (ncclMemAlloc may allocate more than requested) + EXPECT_EQ(found.size, buffer.size); + EXPECT_GE(found.size, bufferSize); // At least the requested size + + // Test search for non-existent buffer + void* fakePtr = reinterpret_cast(0xDEADBEEF); + auto notFound = allocator.searchBuffer(*mComm, fakePtr); + EXPECT_FALSE(notFound.isValid()); + + allocator.releaseBuffer(*mComm, buffer.ptr); +} + +TEST_F(NCCLWindowAllocatorTest, GetWindowAndSize) +{ + auto& allocator = nccl_util::NCCLWindowAllocator::getInstance(); + + const size_t bufferSize = 64 * 1024; + auto buffer = allocator.requestBuffer(*mComm, bufferSize); + + // Test getWindow + auto window = allocator.getWindow(*mComm, buffer.ptr); + EXPECT_NE(window, nullptr); + EXPECT_EQ(window, buffer.window); + + // Test getSize - compare against actual allocated size (ncclMemAlloc may allocate more than requested) + auto size = allocator.getSize(*mComm, buffer.ptr); + EXPECT_EQ(size, buffer.size); + EXPECT_GE(size, bufferSize); // At least the requested size + + // Test with invalid pointer + void* fakePtr = reinterpret_cast(0xDEADBEEF); + EXPECT_EQ(allocator.getWindow(*mComm, fakePtr), nullptr); + EXPECT_EQ(allocator.getSize(*mComm, fakePtr), 0); + + allocator.releaseBuffer(*mComm, buffer.ptr); +} + +TEST_F(NCCLWindowAllocatorTest, GetBufferInfo) +{ + auto& allocator = nccl_util::NCCLWindowAllocator::getInstance(); + + const size_t bufferSize = 32 * 1024; + auto buffer = allocator.requestBuffer(*mComm, bufferSize); + + auto info = allocator.getBufferInfo(*mComm, buffer.ptr); + EXPECT_TRUE(info.isValid()); + EXPECT_EQ(info.ptr, buffer.ptr); + EXPECT_EQ(info.size, buffer.size); + EXPECT_EQ(info.handle, buffer.handle); + EXPECT_EQ(info.window, buffer.window); + + allocator.releaseBuffer(*mComm, buffer.ptr); +} + +TEST_F(NCCLWindowAllocatorTest, ScopedBuffer) +{ + const size_t bufferSize = 16 * 1024; + + { + nccl_util::ScopedNCCLWindowBuffer scopedBuffer(*mComm, bufferSize); + EXPECT_TRUE(scopedBuffer.getBuffer().isValid()); + EXPECT_NE(scopedBuffer.getPtr(), nullptr); + // Compare against actual allocated size (ncclMemAlloc may allocate more than requested) + EXPECT_EQ(scopedBuffer.getSize(), scopedBuffer.getBuffer().size); + EXPECT_GE(scopedBuffer.getSize(), bufferSize); // At least the requested size + EXPECT_NE(scopedBuffer.getWindow(), nullptr); + + // Buffer should be in use + auto& allocator = nccl_util::NCCLWindowAllocator::getInstance(); + EXPECT_EQ(allocator.getBufferInUseCount(*mComm), 1); + } + + // Buffer should be released when scoped buffer goes out of scope + auto& allocator = nccl_util::NCCLWindowAllocator::getInstance(); + EXPECT_EQ(allocator.getBufferInUseCount(*mComm), 0); +} + +TEST_F(NCCLWindowAllocatorTest, CleanupOnCommDestroy) +{ + auto& allocator = nccl_util::NCCLWindowAllocator::getInstance(); + + // Create a separate comm using split for this test + auto testComm = createSplitComm(*mComm, 0, mRank); + + // Store the raw comm value before destruction + ncclComm_t rawComm = *testComm; + + // Allocate some buffers + const size_t bufferSize = 8 * 1024; + auto buffer1 = allocator.requestBuffer(*testComm, bufferSize); + auto buffer2 = allocator.requestBuffer(*testComm, bufferSize * 2); + + EXPECT_EQ(allocator.getBufferCount(*testComm), 2); + EXPECT_EQ(allocator.getBufferInUseCount(*testComm), 2); + + // Verify buffers are valid + EXPECT_TRUE(buffer1.isValid()); + EXPECT_TRUE(buffer2.isValid()); + + // Manually release buffers before cleanup to avoid warnings + allocator.releaseBuffer(*testComm, buffer1.ptr); + allocator.releaseBuffer(*testComm, buffer2.ptr); + + // Verify buffers are released but still exist in pool + EXPECT_EQ(allocator.getBufferInUseCount(*testComm), 0); + EXPECT_EQ(allocator.getBufferCount(*testComm), 2); // Buffers still exist, just not in use + + // Destroy the communicator - buffers should be cleaned up automatically + testComm.reset(); + + // Verify cleanup: check that the old comm (now destroyed) no longer has buffers + // Note: The comm is destroyed, but we can still check the allocator's internal state + // The cleanup should have removed all buffers for this comm + EXPECT_EQ(allocator.getBufferCount(rawComm), 0); + EXPECT_EQ(allocator.getBufferInUseCount(rawComm), 0); + // Note: isCommValid only checks for null, not cleaned-up state, because NCCL can reuse addresses + // The real check is that buffers are gone, which we verify above +} + +TEST_F(NCCLWindowAllocatorTest, CommValidity) +{ + auto& allocator = nccl_util::NCCLWindowAllocator::getInstance(); + + // Valid comm should be valid + EXPECT_TRUE(allocator.isCommValid(*mComm)); + + // Null comm should be invalid + EXPECT_FALSE(allocator.isCommValid(nullptr)); +} + +//============================================================================== +// Integration Tests +//============================================================================== + +TEST_F(NCCLWindowAllocatorTest, MultipleComms) +{ + auto& allocator = nccl_util::NCCLWindowAllocator::getInstance(); + + // Create two different communicators using split (different colors) + auto comm1 = createSplitComm(*mComm, 0, mRank); + auto comm2 = createSplitComm(*mComm, 1, mRank); + + const size_t bufferSize = 4 * 1024; + + // Allocate buffers from both comms + auto buffer1 = allocator.requestBuffer(*comm1, bufferSize); + auto buffer2 = allocator.requestBuffer(*comm2, bufferSize); + + EXPECT_TRUE(buffer1.isValid()); + EXPECT_TRUE(buffer2.isValid()); + + // Buffers should be tracked separately per comm + EXPECT_EQ(allocator.getBufferCount(*comm1), 1); + EXPECT_EQ(allocator.getBufferCount(*comm2), 1); + EXPECT_NE(buffer1.ptr, buffer2.ptr); // Different buffers from different comms + + allocator.releaseBuffer(*comm1, buffer1.ptr); + allocator.releaseBuffer(*comm2, buffer2.ptr); + + // Clean up comms + comm1.reset(); + comm2.reset(); +} + +#if ENABLE_MULTI_DEVICE && BUILD_PYT +//============================================================================== +// createNCCLWindowTensor Tests +//============================================================================== + +class CreateNCCLWindowTensorTest : public ::testing::Test +{ +protected: + void SetUp() override + { + auto& comm = mpi::MpiComm::world(); + mWorldSize = comm.getSize(); + mRank = comm.getRank(); + + if (mWorldSize < 2) + { + GTEST_SKIP() << "Requires at least 2 ranks (got " << mWorldSize << ")"; + } + + // Set CUDA device for this rank (required before NCCL initialization) + int deviceCount = 0; + TLLM_CUDA_CHECK(cudaGetDeviceCount(&deviceCount)); + if (deviceCount > 0) + { + int deviceId = mRank % deviceCount; + TLLM_CUDA_CHECK(cudaSetDevice(deviceId)); + } + + // Check if NCCL symmetric is supported + auto& ncclHelper = nccl_util::NCCLHelper::getInstance(); + if (!ncclHelper.isLoaded()) + { + GTEST_SKIP() << "NCCL library with symmetric memory support is not available"; + } + + std::set group; + for (int i = 0; i < mWorldSize; ++i) + { + group.insert(i); + } + mComm = getComm(group); + } + + void TearDown() override + { + mComm.reset(); + } + + int mWorldSize; + int mRank; + std::shared_ptr mComm; +}; + +TEST_F(CreateNCCLWindowTensorTest, BasicTensorCreation) +{ + using nccl_util::createNCCLWindowTensor; + + // Create a tensor with shape [4, 8] and float32 dtype + std::vector shape = {4, 8}; + auto [tensor, buffer] = createNCCLWindowTensor(*mComm, shape, torch::kFloat32); + + // Verify tensor properties + EXPECT_TRUE(tensor.defined()); + EXPECT_EQ(tensor.dtype(), torch::kFloat32); + EXPECT_EQ(tensor.device().type(), torch::kCUDA); + EXPECT_EQ(tensor.dim(), 2); + EXPECT_EQ(tensor.size(0), 4); + EXPECT_EQ(tensor.size(1), 8); + EXPECT_EQ(tensor.numel(), 4 * 8); + + // Verify buffer properties + EXPECT_TRUE(buffer.isValid()); + EXPECT_NE(buffer.ptr, nullptr); + // ncclMemAlloc may allocate more than requested, so check at least the requested size + EXPECT_GE(buffer.size, 4 * 8 * sizeof(float)); + EXPECT_NE(buffer.window, nullptr); + + // Verify tensor data pointer matches buffer pointer + EXPECT_EQ(tensor.data_ptr(), buffer.ptr); + + // Tensor should be in use + auto& allocator = nccl_util::NCCLWindowAllocator::getInstance(); + EXPECT_EQ(allocator.getBufferInUseCount(*mComm), 1); +} + +TEST_F(CreateNCCLWindowTensorTest, DifferentDtypes) +{ + using nccl_util::createNCCLWindowTensor; + + std::vector shape = {10}; + + // Test float32 + { + auto [tensor, buffer] = createNCCLWindowTensor(*mComm, shape, torch::kFloat32); + EXPECT_EQ(tensor.dtype(), torch::kFloat32); + // ncclMemAlloc may allocate more than requested, so check at least the requested size + EXPECT_GE(buffer.size, 10 * sizeof(float)); + EXPECT_EQ(tensor.data_ptr(), buffer.ptr); + } + + // Test float16 + { + auto [tensor, buffer] = createNCCLWindowTensor(*mComm, shape, torch::kFloat16); + EXPECT_EQ(tensor.dtype(), torch::kFloat16); + // ncclMemAlloc may allocate more than requested, so check at least the requested size + EXPECT_GE(buffer.size, 10 * sizeof(at::Half)); + EXPECT_EQ(tensor.data_ptr(), buffer.ptr); + } + + // Test int32 + { + auto [tensor, buffer] = createNCCLWindowTensor(*mComm, shape, torch::kInt32); + EXPECT_EQ(tensor.dtype(), torch::kInt32); + // ncclMemAlloc may allocate more than requested, so check at least the requested size + EXPECT_GE(buffer.size, 10 * sizeof(int32_t)); + EXPECT_EQ(tensor.data_ptr(), buffer.ptr); + } +} + +TEST_F(CreateNCCLWindowTensorTest, DifferentShapes) +{ + using nccl_util::createNCCLWindowTensor; + + // 1D tensor + { + std::vector shape = {100}; + auto [tensor, buffer] = createNCCLWindowTensor(*mComm, shape, torch::kFloat32); + EXPECT_EQ(tensor.dim(), 1); + EXPECT_EQ(tensor.size(0), 100); + // ncclMemAlloc may allocate more than requested, so check at least the requested size + EXPECT_GE(buffer.size, 100 * sizeof(float)); + } + + // 3D tensor + { + std::vector shape = {2, 3, 4}; + auto [tensor, buffer] = createNCCLWindowTensor(*mComm, shape, torch::kFloat32); + EXPECT_EQ(tensor.dim(), 3); + EXPECT_EQ(tensor.size(0), 2); + EXPECT_EQ(tensor.size(1), 3); + EXPECT_EQ(tensor.size(2), 4); + // ncclMemAlloc may allocate more than requested, so check at least the requested size + EXPECT_GE(buffer.size, 2 * 3 * 4 * sizeof(float)); + } + + // 4D tensor + { + std::vector shape = {1, 2, 3, 4}; + auto [tensor, buffer] = createNCCLWindowTensor(*mComm, shape, torch::kFloat32); + EXPECT_EQ(tensor.dim(), 4); + EXPECT_EQ(tensor.numel(), 1 * 2 * 3 * 4); + // ncclMemAlloc may allocate more than requested, so check at least the requested size + EXPECT_GE(buffer.size, 1 * 2 * 3 * 4 * sizeof(float)); + } +} + +TEST_F(CreateNCCLWindowTensorTest, TensorDeleterReleasesBuffer) +{ + using nccl_util::createNCCLWindowTensor; + + auto& allocator = nccl_util::NCCLWindowAllocator::getInstance(); + + { + std::vector shape = {16, 16}; + auto [tensor, buffer] = createNCCLWindowTensor(*mComm, shape, torch::kFloat32); + + EXPECT_EQ(allocator.getBufferInUseCount(*mComm), 1); + EXPECT_TRUE(buffer.isValid()); + void* bufferPtr = buffer.ptr; + + // Tensor goes out of scope - deleter should release the buffer + } + + // Buffer should be released (not in use anymore) + EXPECT_EQ(allocator.getBufferInUseCount(*mComm), 0); + + // Buffer should still exist in the pool (for reuse) + EXPECT_GE(allocator.getBufferCount(*mComm), 1); +} + +TEST_F(CreateNCCLWindowTensorTest, MultipleTensors) +{ + using nccl_util::createNCCLWindowTensor; + + auto& allocator = nccl_util::NCCLWindowAllocator::getInstance(); + + std::vector shape = {8, 8}; + auto [tensor1, buffer1] = createNCCLWindowTensor(*mComm, shape, torch::kFloat32); + auto [tensor2, buffer2] = createNCCLWindowTensor(*mComm, shape, torch::kFloat32); + auto [tensor3, buffer3] = createNCCLWindowTensor(*mComm, shape, torch::kFloat32); + + EXPECT_EQ(allocator.getBufferInUseCount(*mComm), 3); + EXPECT_NE(buffer1.ptr, buffer2.ptr); + EXPECT_NE(buffer2.ptr, buffer3.ptr); + EXPECT_NE(buffer1.ptr, buffer3.ptr); + + // All tensors should be valid + EXPECT_TRUE(tensor1.defined()); + EXPECT_TRUE(tensor2.defined()); + EXPECT_TRUE(tensor3.defined()); +} + +TEST_F(CreateNCCLWindowTensorTest, TensorStrides) +{ + using nccl_util::createNCCLWindowTensor; + + std::vector shape = {3, 4, 5}; + auto [tensor, buffer] = createNCCLWindowTensor(*mComm, shape, torch::kFloat32); + + // Verify strides are correct (row-major order) + EXPECT_EQ(tensor.stride(0), 4 * 5); // stride for first dimension + EXPECT_EQ(tensor.stride(1), 5); // stride for second dimension + EXPECT_EQ(tensor.stride(2), 1); // stride for third dimension +} + +#endif // ENABLE_MULTI_DEVICE && BUILD_PYT + +#endif // ENABLE_MULTI_DEVICE diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 811f11fce5b..aaac2256c90 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -2844,11 +2844,17 @@ def _init_userbuffers(self, hidden_size): # Disable UB for unsupported platforms if not ub.ub_supported(): return False - use_nccl_symmetric = self.llm_args.allreduce_strategy == "NCCL_SYMMETRIC" - ub.initialize_userbuffers_manager( - self.mapping.tp_size, self.mapping.pp_size, self.mapping.cp_size, - self.mapping.rank, self.mapping.gpus_per_node, - hidden_size * self.max_num_tokens * 2, use_nccl_symmetric) + # NCCL_SYMMETRIC strategy no longer requires UserBuffer allocator initialization. + # It uses NCCLWindowAllocator from ncclUtils directly. + if self.llm_args.allreduce_strategy == "NCCL_SYMMETRIC": + # Skip UB initialization for NCCL_SYMMETRIC - it uses NCCLWindowAllocator directly + return False + ub.initialize_userbuffers_manager(self.mapping.tp_size, + self.mapping.pp_size, + self.mapping.cp_size, + self.mapping.rank, + self.mapping.gpus_per_node, + hidden_size * self.max_num_tokens * 2) return True diff --git a/tensorrt_llm/functional.py b/tensorrt_llm/functional.py index 282febd262e..706917828c2 100755 --- a/tensorrt_llm/functional.py +++ b/tensorrt_llm/functional.py @@ -4022,7 +4022,10 @@ def create_allreduce_plugin( pfc = trt.PluginFieldCollection(pfc) ar_plug = allreduce_plg_creator.create_plugin("allreduce", pfc) plug_inputs = [tensor] - if all_reduce_params.strategy != AllReduceStrategy.NCCL and all_reduce_params.strategy != AllReduceStrategy.UB: + if all_reduce_params.strategy not in { + AllReduceStrategy.NCCL, AllReduceStrategy.UB, + AllReduceStrategy.NCCL_SYMMETRIC + }: plug_inputs.append(workspace) if all_reduce_params.fusion_op != AllReduceFusionOp.NONE: if all_reduce_params.has_bias() == 1: @@ -4094,7 +4097,7 @@ def allreduce( workspace = None if all_reduce_params.strategy != AllReduceStrategy.NCCL and all_reduce_params.strategy != AllReduceStrategy.UB: if current_all_reduce_helper().workspace is None: - all_reduce_params.strategy = AllReduceStrategy.NCCL + all_reduce_params.strategy = AllReduceStrategy.NCCL_SYMMETRIC else: workspace = current_all_reduce_helper().workspace.trt_tensor if all_reduce_params.strategy == AllReduceStrategy.UB: diff --git a/tests/integration/defs/cpp/test_multi_gpu.py b/tests/integration/defs/cpp/test_multi_gpu.py index 3b384dd58e8..7cf92efaadb 100644 --- a/tests/integration/defs/cpp/test_multi_gpu.py +++ b/tests/integration/defs/cpp/test_multi_gpu.py @@ -127,6 +127,24 @@ def run_user_buffer_tests(build_dir: _pl.Path, nprocs=2, timeout=300): timeout=timeout) +def run_nccl_utils_tests(build_dir: _pl.Path, nprocs=2, timeout=300): + tests_dir = build_dir / "tests" / "unit_tests" / "multi_gpu" + mgpu_env = get_multi_gpu_env() + + nccl_utils_test = [ + "mpirun", + "-n", + f"{nprocs}", + "--allow-run-as-root", + "ncclUtilsTest", + ] + + _cpp.run_command(nccl_utils_test, + cwd=tests_dir, + env=mgpu_env, + timeout=timeout) + + def run_llama_executor_leader_tests(build_dir: _pl.Path, timeout=1500): tests_dir = build_dir / "tests" / "e2e_tests" @@ -505,6 +523,15 @@ def test_user_buffer(build_google_tests, nprocs, build_dir): run_user_buffer_tests(build_dir=build_dir, nprocs=nprocs, timeout=300) +@pytest.mark.parametrize("build_google_tests", ["80", "86", "89", "90"], + indirect=True) +@pytest.mark.parametrize("nprocs", [2, 8], ids=["2proc", "8proc"]) +def test_nccl_utils(build_google_tests, nprocs, build_dir): + + if platform.system() != "Windows": + run_nccl_utils_tests(build_dir=build_dir, nprocs=nprocs, timeout=300) + + @pytest.mark.parametrize("build_google_tests", ["80", "86", "89", "90"], indirect=True) @pytest.mark.parametrize("multi_gpu_model", ["t5"], indirect=True) diff --git a/tests/microbenchmarks/all_reduce.py b/tests/microbenchmarks/all_reduce.py index 837b0348129..bd5ceb8826b 100644 --- a/tests/microbenchmarks/all_reduce.py +++ b/tests/microbenchmarks/all_reduce.py @@ -176,6 +176,7 @@ def allreduce_benchmark( ] strategies = [ AllReduceStrategy.NCCL, + AllReduceStrategy.NCCL_SYMMETRIC, AllReduceStrategy.ONESHOT, AllReduceStrategy.TWOSHOT, AllReduceStrategy.AUTO, @@ -242,6 +243,9 @@ def allreduce_benchmark( # print the dataframe if mapping.rank == 0: pd.set_option('display.max_rows', None) + pd.set_option('display.max_columns', None) + pd.set_option('display.width', None) + pd.set_option('display.max_colwidth', None) print(df) # # save the dataframe to a csv file diff --git a/tests/scripts/allreduce_perf/allreduce_heuristic_code_gen.py b/tests/scripts/allreduce_perf/allreduce_heuristic_code_gen.py index 11c114e9cf3..e7aeb994b6e 100644 --- a/tests/scripts/allreduce_perf/allreduce_heuristic_code_gen.py +++ b/tests/scripts/allreduce_perf/allreduce_heuristic_code_gen.py @@ -28,6 +28,7 @@ class Constants: tp_size_list = [2, 4, 8] strategy_name_to_enum = { 'NCCL': 0, + 'NCCL_SYMMETRIC': 8, 'ONESHOT': 4, 'TWOSHOT': 5, } @@ -84,10 +85,10 @@ def generate_heuristic_look_up_table(df: pd.DataFrame) -> str: hidden_size_count = len(Constants.hidden_size_list) num_tokens_count = len(Constants.num_tokens_list) - # Initialize lookup table with default values (NCCL = 0) + # Initialize lookup table with default values (NCCL_SYMMETRIC = 8) strategy_table = np.full( (tp_size_count, fusion_count, hidden_size_count, num_tokens_count), - Constants.strategy_name_to_enum['NCCL'], + Constants.strategy_name_to_enum['NCCL_SYMMETRIC'], dtype=int) # Fill the lookup table with best strategies diff --git a/tests/unittest/_torch/multi_gpu/test_allreduce.py b/tests/unittest/_torch/multi_gpu/test_allreduce.py index c01fe9205ca..5051998c5a6 100644 --- a/tests/unittest/_torch/multi_gpu/test_allreduce.py +++ b/tests/unittest/_torch/multi_gpu/test_allreduce.py @@ -123,7 +123,7 @@ def e2m1_and_ufp8sf_scale_to_float_v2(e2m1_tensor, dtype=dtype, mapping=mapping, tensor_parallel_mode=TensorParallelMode.ROW, - allreduce_strategy=AllReduceStrategy.NCCL, + allreduce_strategy=AllReduceStrategy.NCCL_SYMMETRIC, ).cuda() allreduce = AllReduce(mapping=mapping) norm = RMSNorm(hidden_size=hidden_size, eps=eps, dtype=dtype).cuda() diff --git a/tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py b/tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py index 56cf5a9562e..524fed462e3 100644 --- a/tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py +++ b/tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py @@ -108,7 +108,7 @@ def row_linear_residual_norm_fusion_forward( ub.initialize_userbuffers_manager( tensor_parallel_size, 1, 1, tensor_parallel_rank, torch.cuda.device_count(), - x_list[0].nelement() * x_list[0].element_size(), True) + x_list[0].nelement() * x_list[0].element_size()) elif strategy == AllReduceStrategy.MNNVL: os.environ["TLLM_TEST_MNNVL"] = "1" diff --git a/tests/unittest/_torch/multi_gpu/test_user_buffers.py b/tests/unittest/_torch/multi_gpu/test_user_buffers.py index c547c8a3e89..6de03d19086 100644 --- a/tests/unittest/_torch/multi_gpu/test_user_buffers.py +++ b/tests/unittest/_torch/multi_gpu/test_user_buffers.py @@ -43,8 +43,7 @@ def create_tp_mapping(tp_size, rank): def init_userbuffers_allocator(tp_size, rank, max_ub_size): ub.initialize_userbuffers_manager(tp_size, 1, 1, rank, - torch.cuda.device_count(), max_ub_size, - False) + torch.cuda.device_count(), max_ub_size) def create_userbuffers_tensor(shape, dtype):