diff --git a/src/executor/execution_kernel.cu b/src/executor/execution_kernel.cu index a60317c5..759ab8d1 100644 --- a/src/executor/execution_kernel.cu +++ b/src/executor/execution_kernel.cu @@ -8,12 +8,12 @@ namespace mscclpp { template void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, void* src, void* dst, void* scratch, - size_t scratchSize, DataType dataType, DeviceExecutionPlan* plan, - size_t sharedMemSize, cudaStream_t stream, uint32_t flag) { + DataType dataType, DeviceExecutionPlan* plan, size_t sharedMemSize, + cudaStream_t stream, uint32_t flag) { switch (dataType) { case DataType::INT32: executionKernel<<>>( - rank, (int32_t*)src, (int32_t*)dst, (int32_t*)scratch, scratchSize, plan, flag + rank, (int32_t*)src, (int32_t*)dst, (int32_t*)scratch, plan, flag #if defined(ENABLE_NPKIT) , NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp()); @@ -23,7 +23,7 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo break; case DataType::UINT32: executionKernel<<>>( - rank, (uint32_t*)src, (uint32_t*)dst, (uint32_t*)scratch, scratchSize, plan, flag + rank, (uint32_t*)src, (uint32_t*)dst, (uint32_t*)scratch, plan, flag #if defined(ENABLE_NPKIT) , NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp()); @@ -33,7 +33,7 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo break; case DataType::FLOAT16: executionKernel<<>>( - rank, (half*)src, (half*)dst, (half*)scratch, scratchSize, plan, flag + rank, (half*)src, (half*)dst, (half*)scratch, plan, flag #if defined(ENABLE_NPKIT) , NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp()); @@ -43,7 +43,7 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo break; case DataType::FLOAT32: executionKernel<<>>( - rank, (float*)src, (float*)dst, (float*)scratch, scratchSize, plan, flag + rank, (float*)src, (float*)dst, (float*)scratch, plan, flag #if defined(ENABLE_NPKIT) , NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp()); @@ -53,7 +53,7 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo break; case DataType::BFLOAT16: executionKernel<__bfloat16, PacketType><<>>( - rank, (__bfloat16*)src, (__bfloat16*)dst, (__bfloat16*)scratch, scratchSize, plan, flag + rank, (__bfloat16*)src, (__bfloat16*)dst, (__bfloat16*)scratch, plan, flag #if defined(ENABLE_NPKIT) , NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp()); @@ -65,12 +65,10 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo } template void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, void* src, void* dst, - void* scratch, size_t scratchSize, DataType dataType, - DeviceExecutionPlan* plan, size_t sharedMemSize, - cudaStream_t stream, uint32_t flag); + void* scratch, DataType dataType, DeviceExecutionPlan* plan, + size_t sharedMemSize, cudaStream_t stream, uint32_t flag); template void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, void* src, void* dst, - void* scratch, size_t scratchSize, DataType dataType, - DeviceExecutionPlan* plan, size_t sharedMemSize, - cudaStream_t stream, uint32_t flag); + void* scratch, DataType dataType, DeviceExecutionPlan* plan, + size_t sharedMemSize, cudaStream_t stream, uint32_t flag); } // namespace mscclpp #endif diff --git a/src/executor/execution_plan.cc b/src/executor/execution_plan.cc index 56c881bd..bf5abc66 100644 --- a/src/executor/execution_plan.cc +++ b/src/executor/execution_plan.cc @@ -36,7 +36,7 @@ auto getOpType = [](const std::string& str) { return mscclpp::OperationType::WAIT; } else if (str == "flush") { return mscclpp::OperationType::FLUSH; - } else if (str == "re") { + } else if (str == "reduce") { return mscclpp::OperationType::REDUCE; } else if (str == "rs") { return mscclpp::OperationType::REDUCE_SEND; @@ -176,37 +176,27 @@ std::vector ExecutionPlan::Impl::getConnectedBufferTypes(int rank) c return std::vector(bufferTypes.begin(), bufferTypes.end()); } -size_t ExecutionPlan::Impl::getScratchBufferSize(int rank, size_t inputSize, size_t outputSize) const { +void ExecutionPlan::Impl::calcScratchBufferSizeAndOffset(int rank, size_t inputSize, size_t outputSize, int flag) { size_t sizePerRank = 0; if (this->inputChunks.at(rank) != 0) - sizePerRank = inputSize / this->inputChunks.at(rank); + sizePerRank = std::min(inputSize, this->maxMessageSize) / this->inputChunks.at(rank); else if (this->outputChunks.at(rank) != 0) - sizePerRank = outputSize / this->outputChunks.at(rank); + sizePerRank = std::min(outputSize, this->maxMessageSize) / this->outputChunks.at(rank); else throw mscclpp::Error("Output or Input chunks must be greater than 0", mscclpp::ErrorCode::ExecutorError); + this->scratchBufferSize = sizePerRank * this->scratchChunks.at(rank); + this->scratchBufferOffset = (this->isUsingDoubleScratchBuffer && (flag % 2) == 0) ? this->scratchBufferSize : 0; if (this->isUsingPacket) { - return sizePerRank * this->scratchChunks.at(rank) * 2 /* data + flag*/ * 2 /*double buffer*/; + this->scratchBufferSize *= 2; /* data + flag */ } - return sizePerRank * this->scratchChunks.at(rank); -} - -size_t ExecutionPlan::Impl::getMaxScratchBufferSize(int rank) const { - if (this->maxMessageSize == std::numeric_limits::max()) { - return std::numeric_limits::max(); + if (this->isUsingDoubleScratchBuffer) { + this->scratchBufferSize *= 2; /* double buffer */ } - size_t sizePerChunk = 0; - if (this->inputChunks.at(rank) != 0) - sizePerChunk = maxMessageSize / this->inputChunks.at(rank); - else if (this->outputChunks.at(rank) != 0) - sizePerChunk = maxMessageSize / this->outputChunks.at(rank); - else - throw mscclpp::Error("Output or Input chunks must be greater than 0", mscclpp::ErrorCode::ExecutorError); - - return this->getScratchBufferSize(rank, sizePerChunk * this->inputChunks.at(rank), - sizePerChunk * this->outputChunks.at(rank)); } +size_t ExecutionPlan::Impl::getScratchBufferSize() const { return this->scratchBufferSize; } + std::vector ExecutionPlan::Impl::getOperations(int rank, int threadblock) const { return this->operations.at(rank)[threadblock]; } @@ -215,8 +205,9 @@ int ExecutionPlan::Impl::getThreadblockCount(int rank) const { return this->oper int ExecutionPlan::Impl::getNThreadsPerBlock() const { return this->nThreadsPerBlock; } -void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize, size_t outputSize, size_t contsSrcOffset, - size_t constDstOffset) { +void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize, size_t outputSize, size_t constSrcOffset, + size_t constDstOffset, int selfRank, size_t inputBufferSize, + size_t outputBufferSize, int flag) { std::ifstream file(this->planPath); json obj = json::parse(file); if (this->name != obj["name"]) { @@ -230,6 +221,7 @@ void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize, size_t outputSize, this->inputSize = inputSize; this->outputSize = outputSize; this->nThreadsPerBlock = obj.value("num_threads_per_block", 1024); + this->isUsingDoubleScratchBuffer = obj["use_double_scratch_buffer"]; this->minMessageSize = obj.value("min_message_size", 0); this->maxMessageSize = obj.value("max_message_size", std::numeric_limits::max()); this->isInPlace = obj["inplace"]; @@ -243,11 +235,13 @@ void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize, size_t outputSize, this->chunkGroups[rank] = gpu["chunkGroups"]; } this->setupChannels(gpus); - this->setupOperations(gpus, contsSrcOffset, constDstOffset); + this->calcScratchBufferSizeAndOffset(selfRank, inputBufferSize, outputBufferSize, flag); + this->setupOperations(gpus, constSrcOffset, constDstOffset); } -void ExecutionPlan::Impl::lightLoadExecutionPlan(size_t inputSize, size_t outputSize, size_t contsSrcOffset, - size_t constDstOffset) { +void ExecutionPlan::Impl::lightLoadExecutionPlan(size_t inputSize, size_t outputSize, size_t constSrcOffset, + size_t constDstOffset, int selfRank, size_t inputBufferSize, + size_t outputBufferSize, int flag) { std::ifstream file(this->planPath); json obj = json::parse(file); if (this->name != obj["name"]) { @@ -257,6 +251,7 @@ void ExecutionPlan::Impl::lightLoadExecutionPlan(size_t inputSize, size_t output if (protocol == "LL") { this->isUsingPacket = true; } + this->isUsingDoubleScratchBuffer = obj["use_double_scratch_buffer"]; const auto& gpus = obj["gpus"]; for (const auto& gpu : gpus) { @@ -269,7 +264,8 @@ void ExecutionPlan::Impl::lightLoadExecutionPlan(size_t inputSize, size_t output this->inputSize = inputSize; this->outputSize = outputSize; - this->setupOperations(gpus, contsSrcOffset, constDstOffset); + this->calcScratchBufferSizeAndOffset(selfRank, inputBufferSize, outputBufferSize, flag); + this->setupOperations(gpus, constSrcOffset, constDstOffset); } void ExecutionPlan::Impl::parseChannels( @@ -373,6 +369,15 @@ void ExecutionPlan::Impl::setupChannels(const json& gpus) { } } +void ExecutionPlan::Impl::checkChannelsPerOperation(int channels) { + if (channels > MAX_CHANNEL_PER_OPERATION) { + throw Error("Executor plan has " + std::to_string(channels) + + " channels per operation, exceeding executor support (" + + std::to_string(MAX_CHANNEL_PER_OPERATION) + ")", + ErrorCode::ExecutorError); + } +} + void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t constSrcOffset, size_t constDstOffset) { auto getConstOffset = [&](BufferType type) -> size_t { switch (type) { @@ -381,7 +386,7 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t constSrcOffse case BufferType::OUTPUT: return constDstOffset; case BufferType::SCRATCH: - return 0; + return this->scratchBufferOffset; default: throw Error("Invalid buffer type", ErrorCode::ExecutorError); } @@ -424,6 +429,7 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t constSrcOffse chunkIndexes.push_back((uint32_t)op["srcoff"]); } else { operation.nInputs = op["i_cids"].size(); + checkChannelsPerOperation(operation.nInputs); for (int i = 0; i < operation.nInputs; i++) { BufferType srcBufferType = convertToBufferType(op["i_buff"]["src"]); BufferType dstBufferType = convertToBufferType(op["i_buff"]["dst"]); @@ -440,6 +446,7 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t constSrcOffse // will have either srcs or i_cids if (op.contains("srcs")) { operation.nInputs = op["srcs"].size(); + checkChannelsPerOperation(operation.nInputs); operation.inputBufferType = convertToBufferType(op["srcs"][0]["buff"]); for (int i = 0; i < operation.nInputs; i++) { operation.inputOffsets[i] = @@ -450,6 +457,7 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t constSrcOffse } if (op.contains("o_cids")) { operation.nOutputs = op["o_cids"].size(); + checkChannelsPerOperation(operation.nOutputs); for (int i = 0; i < operation.nOutputs; i++) { if (operation.channelType == mscclpp::ChannelType::NVLS) { BufferType dstBufferType = convertToBufferType(op["dstbuff"]); @@ -471,6 +479,7 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t constSrcOffse // will have either dsts or o_cids if (op.contains("dsts")) { operation.nOutputs = op["dsts"].size(); + checkChannelsPerOperation(operation.nOutputs); operation.outputBufferType = convertToBufferType(op["dsts"][0]["buff"]); for (int i = 0; i < operation.nOutputs; i++) { operation.outputOffsets[i] = @@ -484,6 +493,9 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t constSrcOffse } if (op.contains("srcoff")) { operation.srcOffset = this->getOffset(rank, this->inputSize, this->outputSize, (uint32_t)op["srcoff"]); + if (operation.srcBufferType == BufferType::SCRATCH) { + operation.srcOffset += this->scratchBufferOffset; + } chunkIndexes.push_back((uint32_t)op["srcoff"]); } if (op.contains("dstbuff")) { @@ -491,6 +503,9 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t constSrcOffse } if (op.contains("dstoff")) { operation.dstOffset = this->getOffset(rank, this->inputSize, this->outputSize, (uint32_t)op["dstoff"]); + if (operation.dstBufferType == BufferType::SCRATCH) { + operation.dstOffset += this->scratchBufferOffset; + } chunkIndexes.push_back((uint32_t)op["dstoff"]); } if (op.contains("cnt")) { diff --git a/src/executor/executor.cc b/src/executor/executor.cc index d2e5ac7e..bf0a2bd3 100644 --- a/src/executor/executor.cc +++ b/src/executor/executor.cc @@ -140,7 +140,8 @@ struct Executor::Impl { ExecutionContext setupExecutionContext(int rank, void* sendbuff, void* recvbuff, size_t inputMessageSize, size_t outputMessageSize, size_t constSrcOffset, size_t constDstOffset, - size_t sendMemRange, size_t recvMemRange, const ExecutionPlan& plan) { + size_t sendMemRange, size_t recvMemRange, const ExecutionPlan& plan, + int flag) { ExecutionContextKey key = {sendbuff, recvbuff, sendMemRange, recvMemRange, plan.impl_->name}; DeviceExecutionPlanKey devicePlanKey = {inputMessageSize, outputMessageSize, constSrcOffset, constDstOffset}; if (this->contexts.find(key) != this->contexts.end()) { @@ -152,7 +153,8 @@ struct Executor::Impl { return this->contexts[key]; } plan.impl_->operationsReset(); - plan.impl_->lightLoadExecutionPlan(inputMessageSize, outputMessageSize, constSrcOffset, constDstOffset); + plan.impl_->lightLoadExecutionPlan(inputMessageSize, outputMessageSize, constSrcOffset, constDstOffset, rank, + sendMemRange, recvMemRange, flag); this->setupDeviceExecutionPlan(this->contexts[key], devicePlanKey, rank, plan); this->contexts[key].deviceExecutionPlansBuffers[devicePlanKey] = allocExtSharedCuda(devicePlans[devicePlanKey].size() * sizeof(DeviceExecutionPlan)); @@ -164,12 +166,11 @@ struct Executor::Impl { } plan.impl_->reset(); - plan.impl_->loadExecutionPlan(inputMessageSize, outputMessageSize, constSrcOffset, constDstOffset); + plan.impl_->loadExecutionPlan(inputMessageSize, outputMessageSize, constSrcOffset, constDstOffset, rank, + sendMemRange, recvMemRange, flag); ExecutionContext context; - size_t maxScratchBufferSize = plan.impl_->getMaxScratchBufferSize(rank); - size_t scratchBufferSize = - std::min(plan.impl_->getScratchBufferSize(rank, sendMemRange, recvMemRange), maxScratchBufferSize); + size_t scratchBufferSize = plan.impl_->getScratchBufferSize(); std::shared_ptr scratchBuffer; if (isNvlsSupported()) { scratchBuffer = allocSharedPhysicalCuda(scratchBufferSize); @@ -372,8 +373,19 @@ struct Executor::Impl { DeviceExecutionPlan deviceExecutionPlan = {}; std::vector ops = plan.impl_->getOperations(rank, threadblock); deviceExecutionPlan.nOperations = ops.size(); + if (deviceExecutionPlan.nOperations > MAX_OPERATION) { + throw Error("Executor plan has " + std::to_string(deviceExecutionPlan.nOperations) + + " operations, exceeding executor support (" + std::to_string(MAX_OPERATION) + ")", + ErrorCode::ExecutorError); + } deviceExecutionPlan.nSmChannels = plan.impl_->threadblockSMChannelMap.at(rank).at(threadblock).size(); deviceExecutionPlan.nProxyChannels = plan.impl_->threadblockProxyChannelMap.at(rank).at(threadblock).size(); + if (deviceExecutionPlan.nSmChannels > MAX_CHANNEL || deviceExecutionPlan.nProxyChannels > MAX_CHANNEL) { + throw Error("Executor plan has " + + std::to_string(std::max(deviceExecutionPlan.nSmChannels, deviceExecutionPlan.nProxyChannels)) + + " channels, exceeding executor support (" + std::to_string(MAX_CHANNEL) + ")", + ErrorCode::ExecutorError); + } int chanIndex = 0; for (const auto& [index, _] : plan.impl_->threadblockSMChannelMap.at(rank).at(threadblock)) { deviceExecutionPlan.channels.smChannels[chanIndex++] = mscclpp::deviceHandle(context.smChannels[index]); @@ -400,8 +412,7 @@ struct Executor::Impl { } void launchKernel(ExecutionContext& context, int rank, void* sendbuff, void* recvbuff, DataType dataType, - cudaStream_t stream, PacketType packetType) { - static uint32_t flag = 0; + cudaStream_t stream, PacketType packetType, uint32_t flag) { DeviceExecutionPlanKey key = context.currentDevicePlan; int nthreadblocks = context.deviceExecutionPlans[key].size(); #if defined(ENABLE_NPKIT) @@ -419,16 +430,16 @@ struct Executor::Impl { #endif switch (packetType) { case PacketType::LL16: - ExecutionKernel::launchKernel( - rank, nthreadblocks, context.nthreadsPerBlock, sendbuff, recvbuff, (void*)context.scratchBuffer.get(), - context.scratchBufferSize, dataType, (DeviceExecutionPlan*)context.deviceExecutionPlansBuffers[key].get(), - sharedMemSize, stream, ++flag); + ExecutionKernel::launchKernel(rank, nthreadblocks, context.nthreadsPerBlock, sendbuff, recvbuff, + (void*)context.scratchBuffer.get(), dataType, + (DeviceExecutionPlan*)context.deviceExecutionPlansBuffers[key].get(), + sharedMemSize, stream, flag); break; case PacketType::LL8: - ExecutionKernel::launchKernel( - rank, nthreadblocks, context.nthreadsPerBlock, sendbuff, recvbuff, (void*)context.scratchBuffer.get(), - context.scratchBufferSize, dataType, (DeviceExecutionPlan*)context.deviceExecutionPlansBuffers[key].get(), - sharedMemSize, stream, ++flag); + ExecutionKernel::launchKernel(rank, nthreadblocks, context.nthreadsPerBlock, sendbuff, recvbuff, + (void*)context.scratchBuffer.get(), dataType, + (DeviceExecutionPlan*)context.deviceExecutionPlansBuffers[key].get(), + sharedMemSize, stream, flag); break; default: throw Error("Invalid packet type", ErrorCode::ExecutorError); @@ -441,17 +452,18 @@ Executor::Executor(std::shared_ptr comm) : impl_(std::make_unique< void Executor::execute(int rank, void* sendbuff, void* recvbuff, size_t sendBuffSize, [[maybe_unused]] size_t recvBuffSize, DataType dataType, const ExecutionPlan& plan, cudaStream_t stream, PacketType packetType) { + static uint32_t flag = 1; size_t sendMemRange, recvMemRange; CUdeviceptr sendBasePtr, recvBasePtr; MSCCLPP_CUTHROW(cuMemGetAddressRange(&sendBasePtr, &sendMemRange, (CUdeviceptr)sendbuff)); MSCCLPP_CUTHROW(cuMemGetAddressRange(&recvBasePtr, &recvMemRange, (CUdeviceptr)recvbuff)); size_t offsetIn = (char*)sendbuff - (char*)sendBasePtr; size_t offsetOut = (char*)recvbuff - (char*)recvBasePtr; - ExecutionContext context = this->impl_->setupExecutionContext(rank, (void*)sendBasePtr, (void*)recvBasePtr, sendBuffSize, recvBuffSize, - offsetIn, offsetOut, sendMemRange, recvMemRange, plan); - this->impl_->launchKernel(context, rank, sendbuff, recvbuff, dataType, stream, packetType); + offsetIn, offsetOut, sendMemRange, recvMemRange, plan, flag); + this->impl_->launchKernel(context, rank, sendbuff, recvbuff, dataType, stream, packetType, flag); + flag++; } Executor::~Executor() = default; diff --git a/src/include/execution_kernel.hpp b/src/include/execution_kernel.hpp index 98bed37e..68e27d2b 100644 --- a/src/include/execution_kernel.hpp +++ b/src/include/execution_kernel.hpp @@ -307,15 +307,14 @@ MSCCLPP_DEVICE_INLINE void handleReadReduceCopySend(T* output, uint32_t outputOf } template -MSCCLPP_DEVICE_INLINE void handlePutPacket(size_t scratchSize, DeviceHandle* smChannels, +MSCCLPP_DEVICE_INLINE void handlePutPacket(DeviceHandle* smChannels, DeviceHandle* proxyChannels, uint8_t* dstChannelIndexes, uint32_t* dstOffsets, uint32_t* srcOffsets, int nDstChannels, uint32_t size, ChannelType chType, uint32_t flag) { - const size_t scratchBaseOffset = flag & 0x1 ? 0 : scratchSize >> 1; if (chType == ChannelType::SM) { for (int index = 0; index < nDstChannels; ++index) { - smChannels[dstChannelIndexes[index]].putPackets( - scratchBaseOffset + dstOffsets[index] * 2, srcOffsets[index], size, threadIdx.x, blockDim.x, flag); + smChannels[dstChannelIndexes[index]].putPackets(dstOffsets[index] * 2, srcOffsets[index], size, + threadIdx.x, blockDim.x, flag); } } if (chType == ChannelType::PROXY) { @@ -324,20 +323,19 @@ MSCCLPP_DEVICE_INLINE void handlePutPacket(size_t scratchSize, DeviceHandle MSCCLPP_DEVICE_INLINE void handleReduceSendPacket(T* dst, uint32_t dstOffsetByBytes, T* src, uint32_t srcOffsetByBytes, - T* inputBuff, size_t inputBuffSize, uint32_t* inputOffsets, int nSrcs, + T* inputBuff, uint32_t* inputOffsets, int nSrcs, DeviceHandle* smChannels, uint8_t* outputChannelIndexes, uint32_t* outputOffsets, int nDstChannels, size_t size, uint32_t flag) { size_t nPackets = size * 2 / sizeof(PacketType); - const size_t intputBaseOffset = flag & 0x1 ? 0 : inputBuffSize >> 1; const uint32_t srcOffset = srcOffsetByBytes / sizeof(PacketPayload); const uint32_t dstOffset = dstOffsetByBytes / sizeof(PacketPayload); PacketPayload* srcPacketPayload = (PacketPayload*)src + srcOffset; @@ -345,7 +343,7 @@ MSCCLPP_DEVICE_INLINE void handleReduceSendPacket(T* dst, uint32_t dstOffsetByBy for (size_t idx = threadIdx.x; idx < nPackets; idx += blockDim.x) { PacketPayload data = {}; for (int index = 0; index < nSrcs; ++index) { - PacketType* pkt = (PacketType*)((char*)inputBuff + intputBaseOffset + 2 * inputOffsets[index]); + PacketType* pkt = (PacketType*)((char*)inputBuff + 2 * inputOffsets[index]); PacketPayload val = pkt[idx].read(flag); data = add_vectors(data, val); } @@ -355,7 +353,7 @@ MSCCLPP_DEVICE_INLINE void handleReduceSendPacket(T* dst, uint32_t dstOffsetByBy if (SendToRemote) { PacketType pkt(data, flag); for (int index = 0; index < nDstChannels; ++index) { - size_t offset = (intputBaseOffset + outputOffsets[index] * 2) / sizeof(PacketType); + size_t offset = outputOffsets[index] * 2 / sizeof(PacketType); smChannels[outputChannelIndexes[index]].write(offset + idx, pkt); } } @@ -363,10 +361,9 @@ MSCCLPP_DEVICE_INLINE void handleReduceSendPacket(T* dst, uint32_t dstOffsetByBy } template -MSCCLPP_DEVICE_INLINE void handleCopyPacket(void* dst, void* src, size_t srcSize, uint32_t dstOffset, - uint32_t srcOffset, size_t size, uint32_t flag) { - const size_t inputScratchBaseOffset = flag & 0x1 ? 0 : srcSize >> 1; - PacketType* srcPackets = (PacketType*)((char*)src + inputScratchBaseOffset + 2 * srcOffset); +MSCCLPP_DEVICE_INLINE void handleCopyPacket(void* dst, void* src, uint32_t dstOffset, uint32_t srcOffset, size_t size, + uint32_t flag) { + PacketType* srcPackets = (PacketType*)((char*)src + 2 * srcOffset); PacketPayload* result = (PacketPayload*)((char*)dst + dstOffset); size_t nPackets = size * 2 / sizeof(PacketType); for (size_t idx = threadIdx.x; idx < nPackets; idx += blockDim.x) { @@ -376,18 +373,17 @@ MSCCLPP_DEVICE_INLINE void handleCopyPacket(void* dst, void* src, size_t srcSize } template -MSCCLPP_DEVICE_INLINE void handleTransformToPacket(void* dst, void* src, size_t dstSize, uint32_t dstOffset, - uint32_t srcOffset, size_t size, uint32_t flag) { - const size_t outputScratchBaseOffset = flag & 0x1 ? 0 : dstSize >> 1; - dstOffset = dstOffset * 2 + outputScratchBaseOffset; +MSCCLPP_DEVICE_INLINE void handleTransformToPacket(void* dst, void* src, uint32_t dstOffset, uint32_t srcOffset, + size_t size, uint32_t flag) { + dstOffset = dstOffset * 2; mscclpp::putPackets(dst, dstOffset, src, srcOffset, size, threadIdx.x, blockDim.x, flag); } -template +template MSCCLPP_DEVICE_INLINE void handleReduceSend(T* dst, uint32_t dstOffsetByBytes, T* src, uint32_t srcOffsetByBytes, - T* input, uint32_t* inputOffsets, DeviceHandle* smChannels, - uint8_t* outputChannelIndexes, uint32_t* outputOffsets, int nOutChannels, - uint32_t size) { + T* input, uint32_t* inputOffsets, int nInputs, + DeviceHandle* smChannels, uint8_t* outputChannelIndexes, + uint32_t* outputOffsets, int nOutChannels, uint32_t size) { const size_t nInt4 = size / sizeof(int4); const size_t srcOffset4 = srcOffsetByBytes / sizeof(int4); const size_t dstOffset4 = dstOffsetByBytes / sizeof(int4); @@ -396,15 +392,17 @@ MSCCLPP_DEVICE_INLINE void handleReduceSend(T* dst, uint32_t dstOffsetByBytes, T int4* input4 = (int4*)input; for (size_t idx = threadIdx.x; idx < nInt4; idx += blockDim.x) { int4 tmp = src4[srcOffset4 + idx]; - for (int index = 0; index < nOutChannels; ++index) { + for (int index = 0; index < nInputs; ++index) { size_t offset = inputOffsets[index] / sizeof(int4); int4 val = input4[offset + idx]; tmp = add_vectors(tmp, val); } dst4[dstOffset4 + idx] = tmp; - for (int index = 0; index < nOutChannels; ++index) { - size_t offset = outputOffsets[index] / sizeof(int4); - smChannels[outputChannelIndexes[index]].write(offset + idx, tmp); + if constexpr (SendToRemote) { + for (int index = 0; index < nOutChannels; ++index) { + size_t offset = outputOffsets[index] / sizeof(int4); + smChannels[outputChannelIndexes[index]].write(offset + idx, tmp); + } } } // handle rest of data @@ -413,14 +411,16 @@ MSCCLPP_DEVICE_INLINE void handleReduceSend(T* dst, uint32_t dstOffsetByBytes, T const size_t endIdx = (srcOffsetByBytes + size) / sizeof(T); for (size_t idx = threadIdx.x + startIdx; idx < endIdx; idx += blockDim.x) { T tmp = src[idx]; - for (int index = 0; index < nOutChannels; ++index) { + for (int index = 0; index < nInputs; ++index) { size_t offset = inputOffsets[index] / sizeof(T); tmp = add_elements(tmp, input[offset + idx]); } dst[idx] = tmp; - for (int index = 0; index < nOutChannels; ++index) { - size_t offset = outputOffsets[index] / sizeof(T); - smChannels[outputChannelIndexes[index]].write(offset + idx, tmp); + if constexpr (SendToRemote) { + for (int index = 0; index < nOutChannels; ++index) { + size_t offset = outputOffsets[index] / sizeof(T); + smChannels[outputChannelIndexes[index]].write(offset + idx, tmp); + } } } } @@ -464,7 +464,7 @@ MSCCLPP_DEVICE_INLINE void handleMultiLoadReduceStore(T* dst, T* src, uint32_t d template __global__ void executionKernel([[maybe_unused]] int rank /*for debug*/, T* input, T* output, T* scratch, - size_t scratchSize, DeviceExecutionPlan* plan, uint32_t flag + DeviceExecutionPlan* plan, uint32_t flag #if defined(ENABLE_NPKIT) , NpKitEventCollectContext* npKitEventCollectContexts, uint64_t* cpuTimestamp) { @@ -568,34 +568,40 @@ __global__ void executionKernel([[maybe_unused]] int rank /*for debug*/, T* inpu op.inputChannelIndexes, op.outputOffsets, op.inputOffsets, op.nOutputs, op.nInputs, op.size, false); } else if (op.type == OperationType::PUT_PACKET) { - handlePutPacket(scratchSize, smChannels, proxyChannels, op.outputChannelIndexes, op.outputOffsets, - op.inputOffsets, op.nOutputs, op.size, op.channelType, flag); + handlePutPacket(smChannels, proxyChannels, op.outputChannelIndexes, op.outputOffsets, op.inputOffsets, + op.nOutputs, op.size, op.channelType, flag); } else if (op.type == OperationType::REDUCE_SEND_PACKET) { T* dst = getBuffer(input, output, scratch, op.dstBufferType); T* src = getBuffer(input, output, scratch, op.srcBufferType); - handleReduceSendPacket(dst, op.dstOffset, src, op.srcOffset, scratch, scratchSize, op.inputOffsets, - op.nInputs, smChannels, op.outputChannelIndexes, op.outputOffsets, - op.nOutputs, op.size, flag); + handleReduceSendPacket(dst, op.dstOffset, src, op.srcOffset, scratch, op.inputOffsets, op.nInputs, + smChannels, op.outputChannelIndexes, op.outputOffsets, op.nOutputs, op.size, + flag); } else if (op.type == OperationType::REDUCE_PACKET) { T* dst = getBuffer(input, output, scratch, op.dstBufferType); T* src = getBuffer(input, output, scratch, op.srcBufferType); - handleReduceSendPacket(dst, op.dstOffset, src, op.srcOffset, scratch, scratchSize, - op.inputOffsets, op.nInputs, smChannels, op.outputChannelIndexes, - op.outputOffsets, op.nOutputs, op.size, flag); + handleReduceSendPacket(dst, op.dstOffset, src, op.srcOffset, scratch, op.inputOffsets, + op.nInputs, smChannels, op.outputChannelIndexes, op.outputOffsets, + op.nOutputs, op.size, flag); } else if (op.type == OperationType::COPY_PACKET) { T* dst = getBuffer(input, output, scratch, op.dstBufferType); T* src = getBuffer(input, output, scratch, op.srcBufferType); - handleCopyPacket(dst, src, scratchSize, op.dstOffset, op.srcOffset, op.size, flag); + handleCopyPacket(dst, src, op.dstOffset, op.srcOffset, op.size, flag); } else if (op.type == OperationType::TRANSFORM_TO_PACKET) { T* dst = getBuffer(input, output, scratch, op.dstBufferType); T* src = getBuffer(input, output, scratch, op.srcBufferType); - handleTransformToPacket(dst, src, scratchSize, op.dstOffset, op.srcOffset, op.size, flag); + handleTransformToPacket(dst, src, op.dstOffset, op.srcOffset, op.size, flag); } else if (op.type == OperationType::REDUCE_SEND) { T* dst = getBuffer(input, output, scratch, op.dstBufferType); T* src = getBuffer(input, output, scratch, op.srcBufferType); T* tmp = getBuffer(input, output, scratch, op.inputBufferType); - handleReduceSend(dst, op.dstOffset, src, op.srcOffset, tmp, op.inputOffsets, smChannels, op.outputChannelIndexes, - op.outputOffsets, op.nOutputs, op.size); + handleReduceSend(dst, op.dstOffset, src, op.srcOffset, tmp, op.inputOffsets, op.nInputs, smChannels, + op.outputChannelIndexes, op.outputOffsets, op.nOutputs, op.size); + } else if (op.type == OperationType::REDUCE) { + T* dst = getBuffer(input, output, scratch, op.dstBufferType); + T* src = getBuffer(input, output, scratch, op.srcBufferType); + T* tmp = getBuffer(input, output, scratch, op.inputBufferType); + handleReduceSend(dst, op.dstOffset, src, op.srcOffset, tmp, op.inputOffsets, op.nInputs, smChannels, + op.outputChannelIndexes, op.outputOffsets, op.nOutputs, op.size); } #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 else if (op.type == OperationType::MULTI_LOAD_REDUCE_STORE) { @@ -622,12 +628,12 @@ class ExecutionKernel { #if defined(MSCCLPP_DEVICE_HIP) template static void launchKernel(int rank, int nthreadblocks, int nthreads, void* src, void* dst, void* scratch, - size_t scratchSize, DataType dataType, DeviceExecutionPlan* plan, size_t sharedMemSize, - cudaStream_t stream, uint32_t flag = 0) { + DataType dataType, DeviceExecutionPlan* plan, size_t sharedMemSize, cudaStream_t stream, + uint32_t flag) { switch (dataType) { case DataType::INT32: executionKernel<<>>( - rank, (int32_t*)src, (int32_t*)dst, (int32_t*)scratch, scratchSize, plan, flag + rank, (int32_t*)src, (int32_t*)dst, (int32_t*)scratch, plan, flag #if defined(ENABLE_NPKIT) , NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp()); @@ -637,7 +643,7 @@ class ExecutionKernel { break; case DataType::UINT32: executionKernel<<>>( - rank, (uint32_t*)src, (uint32_t*)dst, (uint32_t*)scratch, scratchSize, plan, flag + rank, (uint32_t*)src, (uint32_t*)dst, (uint32_t*)scratch, plan, flag #if defined(ENABLE_NPKIT) , NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp()); @@ -647,7 +653,7 @@ class ExecutionKernel { break; case DataType::FLOAT16: executionKernel<<>>( - rank, (half*)src, (half*)dst, (half*)scratch, scratchSize, plan, flag + rank, (half*)src, (half*)dst, (half*)scratch, plan, flag #if defined(ENABLE_NPKIT) , NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp()); @@ -657,7 +663,7 @@ class ExecutionKernel { break; case DataType::FLOAT32: executionKernel<<>>( - rank, (float*)src, (float*)dst, (float*)scratch, scratchSize, plan, flag + rank, (float*)src, (float*)dst, (float*)scratch, plan, flag #if defined(ENABLE_NPKIT) , NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp()); @@ -667,7 +673,7 @@ class ExecutionKernel { break; case DataType::BFLOAT16: executionKernel<__bfloat16, PacketType><<>>( - rank, (__bfloat16*)src, (__bfloat16*)dst, (__bfloat16*)scratch, scratchSize, plan, flag + rank, (__bfloat16*)src, (__bfloat16*)dst, (__bfloat16*)scratch, plan, flag #if defined(ENABLE_NPKIT) , NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp()); @@ -680,8 +686,8 @@ class ExecutionKernel { #else // !defined(MSCCLPP_DEVICE_HIP) template static void launchKernel(int rank, int nthreadblocks, int nthreads, void* src, void* dst, void* scratch, - size_t scratchSize, DataType dataType, DeviceExecutionPlan* plan, size_t sharedMemSize, - cudaStream_t stream, uint32_t flag = 0); + DataType dataType, DeviceExecutionPlan* plan, size_t sharedMemSize, cudaStream_t stream, + uint32_t flag); #endif // !defined(MSCCLPP_DEVICE_HIP) }; } // namespace mscclpp diff --git a/src/include/execution_plan.hpp b/src/include/execution_plan.hpp index 080a7688..0f4170a4 100644 --- a/src/include/execution_plan.hpp +++ b/src/include/execution_plan.hpp @@ -72,14 +72,15 @@ struct ExecutionPlan::Impl { std::vector getNvlsInfos(int rank, size_t sendBuffserSize = 0, size_t recvBufferSize = 0) const; std::vector getConnectedPeers(int rank) const; std::vector getConnectedBufferTypes(int rank) const; - size_t getScratchBufferSize(int rank, size_t inputSize, size_t outputSize) const; - size_t getMaxScratchBufferSize(int rank) const; + size_t getScratchBufferSize() const; std::vector getOperations(int rank, int threadblock) const; int getThreadblockCount(int rank) const; int getNThreadsPerBlock() const; - void loadExecutionPlan(size_t inputSize, size_t outputSize, size_t contsSrcOffset, size_t constDstOffset); - void lightLoadExecutionPlan(size_t inputSize, size_t outputSize, size_t contsSrcOffset, size_t constDstOffset); + void loadExecutionPlan(size_t inputSize, size_t outputSize, size_t contsSrcOffset, size_t constDstOffset, int rank, + size_t inputBufferSize, size_t outputBufferSize, int flag); + void lightLoadExecutionPlan(size_t inputSize, size_t outputSize, size_t contsSrcOffset, size_t constDstOffset, + int rank, size_t inputBufferSize, size_t outputBufferSize, int flag); void setupChannels(const nlohmann::json& gpus); void setupOperations(const nlohmann::json& gpus, size_t contsSrcOffset, size_t constDstOffset); @@ -108,6 +109,9 @@ struct ExecutionPlan::Impl { size_t inputSize; size_t outputSize; int nThreadsPerBlock; + bool isUsingDoubleScratchBuffer; + size_t scratchBufferSize; + size_t scratchBufferOffset; size_t minMessageSize; size_t maxMessageSize; bool isInPlace; @@ -117,6 +121,8 @@ struct ExecutionPlan::Impl { size_t getOffset(int rank, size_t inputSize, size_t outputSize, uint32_t chunkIndex, uint32_t alignment = 16) const; size_t getNChunkSize(int rank, size_t inputSize, size_t outputSize, uint32_t nChunks, const std::vector offsets) const; + void calcScratchBufferSizeAndOffset(int rank, size_t inputSize, size_t outputSize, int flag); + void checkChannelsPerOperation(int channels); size_t getUpperBoundChunkSize(int rank, size_t inputSize, size_t outputSize) const; // helper functions to setup the channels