Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
caiomcbr committed Feb 4, 2025
1 parent 3565bfd commit dd80593
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 6 deletions.
62 changes: 57 additions & 5 deletions apps/nccl/src/nccl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -718,11 +718,63 @@ NCCL_API ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t
return ncclSuccess;
}

NCCL_API ncclResult_t ncclReduceScatter(const void*, void*, size_t, ncclDataType_t, ncclRedOp_t, ncclComm_t,
cudaStream_t) {
// TODO: implement this function
WARN("ncclReduceScatter is currently unavailable");
return ncclInternalError;

NCCL_API ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff,
size_t recvcount, ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
cudaStream_t stream) {
size_t bytes = recvcount * ncclTypeSize(datatype);
if (sendbuff == nullptr || recvbuff == nullptr || bytes == 0 || comm == nullptr) {
WARN(
"One or more of the following conditions is met: sendbuff or recvbuff pointer is nullptr, bytes is 0, "
"or comm is nullptr.");
return ncclInvalidArgument;
}

int rank = comm->comm->bootstrap()->getRank();
int nRank = comm->comm->bootstrap()->getNranks();

std::vector<executionPlanInstance>& plans = comm->executionPlans["reducescatter"];
std::shared_ptr<mscclpp::ExecutionPlan> plan;
void* basePtr = (char*)sendbuff + rank * bytes;
bool inPlace = basePtr == recvbuff;
const size_t totalBytes = bytes * nRank;
for (const auto& p : plans) {
if (totalBytes >= p.key.minMessageSize && totalBytes < p.key.maxMessageSize && inPlace == p.key.isInPlace) {
plan = p.plan;
break;
}
}
// TODO: Fallback code for ReduceScatter
if (plan == nullptr){
WARN(
"No FallBack code for ReduceScatter");
return ncclInternalError;
}

switch (datatype) {
case ncclFloat16:
comm->executor->execute(rank, (half*)sendbuff, (half*)recvbuff, totalBytes, bytes, mscclpp::DataType::FLOAT16,
*plan, stream);
break;
case ncclFloat32:
comm->executor->execute(rank, (float*)sendbuff, (float*)recvbuff, totalBytes, bytes,
mscclpp::DataType::FLOAT32, *plan, stream);
break;
case ncclBfloat16:
comm->executor->execute(rank, (__bfloat16*)sendbuff, (__bfloat16*)recvbuff, totalBytes, bytes,
mscclpp::DataType::BFLOAT16, *plan, stream);
break;
case ncclInt32:
case ncclUint32:
comm->executor->execute(rank, (int*)sendbuff, (int*)recvbuff, totalBytes, bytes, mscclpp::DataType::UINT32,
*plan, stream);
break;
default:
WARN("datatype is invalid");
return ncclInvalidArgument;
}

return ncclSuccess;
}

NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcount, ncclDataType_t datatype,
Expand Down
3 changes: 2 additions & 1 deletion python/mscclpp/language/collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ def init_buffers(self):
for i in range(self.num_ranks):
for c in range(self.chunk_factor):
input_buffer.append(Chunk(r, i * self.chunk_factor + c, i, c))
buffers = {Buffer.input: input_buffer}
buffers = {Buffer.input: input_buffer,
Buffer.output: input_buffer[r * self.chunk_factor : (r + 1) * self.chunk_factor],}
rank_buffers.append(buffers)
else:
input_buffer = []
Expand Down

0 comments on commit dd80593

Please sign in to comment.