Skip to content

Commit

Permalink
adjusting formatation
Browse files Browse the repository at this point in the history
  • Loading branch information
root committed Feb 6, 2025
1 parent 1a70c7b commit 2cca534
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 13 deletions.
19 changes: 8 additions & 11 deletions apps/nccl/src/nccl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -718,10 +718,8 @@ NCCL_API ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t
return ncclSuccess;
}


NCCL_API ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff,
size_t recvcount, ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
cudaStream_t stream) {
NCCL_API ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff, size_t recvcount, ncclDataType_t datatype,
ncclRedOp_t op, ncclComm_t comm, cudaStream_t stream) {
size_t bytes = recvcount * ncclTypeSize(datatype);
if (sendbuff == nullptr || recvbuff == nullptr || bytes == 0 || comm == nullptr) {
WARN(
Expand All @@ -745,9 +743,8 @@ NCCL_API ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff,
}
}
// TODO: Fallback code for ReduceScatter
if (plan == nullptr){
WARN(
"No FallBack code for ReduceScatter");
if (plan == nullptr) {
WARN("No FallBack code for ReduceScatter");
return ncclInternalError;
}

Expand All @@ -757,17 +754,17 @@ NCCL_API ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff,
*plan, stream);
break;
case ncclFloat32:
comm->executor->execute(rank, (float*)sendbuff, (float*)recvbuff, totalBytes, bytes,
mscclpp::DataType::FLOAT32, *plan, stream);
comm->executor->execute(rank, (float*)sendbuff, (float*)recvbuff, totalBytes, bytes, mscclpp::DataType::FLOAT32,
*plan, stream);
break;
case ncclBfloat16:
comm->executor->execute(rank, (__bfloat16*)sendbuff, (__bfloat16*)recvbuff, totalBytes, bytes,
mscclpp::DataType::BFLOAT16, *plan, stream);
break;
case ncclInt32:
case ncclUint32:
comm->executor->execute(rank, (int*)sendbuff, (int*)recvbuff, totalBytes, bytes, mscclpp::DataType::UINT32,
*plan, stream);
comm->executor->execute(rank, (int*)sendbuff, (int*)recvbuff, totalBytes, bytes, mscclpp::DataType::UINT32, *plan,
stream);
break;
default:
WARN("datatype is invalid");
Expand Down
6 changes: 4 additions & 2 deletions python/mscclpp/language/collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,10 @@ def init_buffers(self):
for i in range(self.num_ranks):
for c in range(self.chunk_factor):
input_buffer.append(Chunk(r, i * self.chunk_factor + c, i, c))
buffers = {Buffer.input: input_buffer,
Buffer.output: input_buffer[r * self.chunk_factor : (r + 1) * self.chunk_factor],}
buffers = {
Buffer.input: input_buffer,
Buffer.output: input_buffer[r * self.chunk_factor : (r + 1) * self.chunk_factor],
}
rank_buffers.append(buffers)
else:
input_buffer = []
Expand Down

0 comments on commit 2cca534

Please sign in to comment.