diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu index a4495cf5..6e6eab92 100644 --- a/apps/nccl/src/nccl.cu +++ b/apps/nccl/src/nccl.cu @@ -718,10 +718,8 @@ NCCL_API ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t return ncclSuccess; } - -NCCL_API ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff, - size_t recvcount, ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, - cudaStream_t stream) { +NCCL_API ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff, size_t recvcount, ncclDataType_t datatype, + ncclRedOp_t op, ncclComm_t comm, cudaStream_t stream) { size_t bytes = recvcount * ncclTypeSize(datatype); if (sendbuff == nullptr || recvbuff == nullptr || bytes == 0 || comm == nullptr) { WARN( @@ -745,9 +743,8 @@ NCCL_API ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff, } } // TODO: Fallback code for ReduceScatter - if (plan == nullptr){ - WARN( - "No FallBack code for ReduceScatter"); + if (plan == nullptr) { + WARN("No FallBack code for ReduceScatter"); return ncclInternalError; } @@ -757,8 +754,8 @@ NCCL_API ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff, *plan, stream); break; case ncclFloat32: - comm->executor->execute(rank, (float*)sendbuff, (float*)recvbuff, totalBytes, bytes, - mscclpp::DataType::FLOAT32, *plan, stream); + comm->executor->execute(rank, (float*)sendbuff, (float*)recvbuff, totalBytes, bytes, mscclpp::DataType::FLOAT32, + *plan, stream); break; case ncclBfloat16: comm->executor->execute(rank, (__bfloat16*)sendbuff, (__bfloat16*)recvbuff, totalBytes, bytes, @@ -766,8 +763,8 @@ NCCL_API ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff, break; case ncclInt32: case ncclUint32: - comm->executor->execute(rank, (int*)sendbuff, (int*)recvbuff, totalBytes, bytes, mscclpp::DataType::UINT32, - *plan, stream); + comm->executor->execute(rank, (int*)sendbuff, (int*)recvbuff, totalBytes, bytes, mscclpp::DataType::UINT32, *plan, + stream); break; default: WARN("datatype is invalid"); diff --git a/python/mscclpp/language/collectives.py b/python/mscclpp/language/collectives.py index 16f1d927..8687ff1d 100644 --- a/python/mscclpp/language/collectives.py +++ b/python/mscclpp/language/collectives.py @@ -202,8 +202,10 @@ def init_buffers(self): for i in range(self.num_ranks): for c in range(self.chunk_factor): input_buffer.append(Chunk(r, i * self.chunk_factor + c, i, c)) - buffers = {Buffer.input: input_buffer, - Buffer.output: input_buffer[r * self.chunk_factor : (r + 1) * self.chunk_factor],} + buffers = { + Buffer.input: input_buffer, + Buffer.output: input_buffer[r * self.chunk_factor : (r + 1) * self.chunk_factor], + } rank_buffers.append(buffers) else: input_buffer = []