diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu index 3daadf8a..a4495cf5 100644 --- a/apps/nccl/src/nccl.cu +++ b/apps/nccl/src/nccl.cu @@ -718,11 +718,63 @@ NCCL_API ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t return ncclSuccess; } -NCCL_API ncclResult_t ncclReduceScatter(const void*, void*, size_t, ncclDataType_t, ncclRedOp_t, ncclComm_t, - cudaStream_t) { - // TODO: implement this function - WARN("ncclReduceScatter is currently unavailable"); - return ncclInternalError; + +NCCL_API ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff, + size_t recvcount, ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, + cudaStream_t stream) { + size_t bytes = recvcount * ncclTypeSize(datatype); + if (sendbuff == nullptr || recvbuff == nullptr || bytes == 0 || comm == nullptr) { + WARN( + "One or more of the following conditions is met: sendbuff or recvbuff pointer is nullptr, bytes is 0, " + "or comm is nullptr."); + return ncclInvalidArgument; + } + + int rank = comm->comm->bootstrap()->getRank(); + int nRank = comm->comm->bootstrap()->getNranks(); + + std::vector& plans = comm->executionPlans["reducescatter"]; + std::shared_ptr plan; + void* basePtr = (char*)sendbuff + rank * bytes; + bool inPlace = basePtr == recvbuff; + const size_t totalBytes = bytes * nRank; + for (const auto& p : plans) { + if (totalBytes >= p.key.minMessageSize && totalBytes < p.key.maxMessageSize && inPlace == p.key.isInPlace) { + plan = p.plan; + break; + } + } + // TODO: Fallback code for ReduceScatter + if (plan == nullptr){ + WARN( + "No FallBack code for ReduceScatter"); + return ncclInternalError; + } + + switch (datatype) { + case ncclFloat16: + comm->executor->execute(rank, (half*)sendbuff, (half*)recvbuff, totalBytes, bytes, mscclpp::DataType::FLOAT16, + *plan, stream); + break; + case ncclFloat32: + comm->executor->execute(rank, (float*)sendbuff, (float*)recvbuff, totalBytes, bytes, + mscclpp::DataType::FLOAT32, *plan, stream); + break; + case ncclBfloat16: + comm->executor->execute(rank, (__bfloat16*)sendbuff, (__bfloat16*)recvbuff, totalBytes, bytes, + mscclpp::DataType::BFLOAT16, *plan, stream); + break; + case ncclInt32: + case ncclUint32: + comm->executor->execute(rank, (int*)sendbuff, (int*)recvbuff, totalBytes, bytes, mscclpp::DataType::UINT32, + *plan, stream); + break; + default: + WARN("datatype is invalid"); + return ncclInvalidArgument; + } + + return ncclSuccess; } NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcount, ncclDataType_t datatype, diff --git a/python/mscclpp/language/collectives.py b/python/mscclpp/language/collectives.py index 55fe5188..16f1d927 100644 --- a/python/mscclpp/language/collectives.py +++ b/python/mscclpp/language/collectives.py @@ -202,7 +202,8 @@ def init_buffers(self): for i in range(self.num_ranks): for c in range(self.chunk_factor): input_buffer.append(Chunk(r, i * self.chunk_factor + c, i, c)) - buffers = {Buffer.input: input_buffer} + buffers = {Buffer.input: input_buffer, + Buffer.output: input_buffer[r * self.chunk_factor : (r + 1) * self.chunk_factor],} rank_buffers.append(buffers) else: input_buffer = []