Skip to content

Support ReduceScatter in the NCCL interface #460

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 54 additions & 5 deletions apps/nccl/src/nccl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -718,11 +718,60 @@ NCCL_API ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t
return ncclSuccess;
}

NCCL_API ncclResult_t ncclReduceScatter(const void*, void*, size_t, ncclDataType_t, ncclRedOp_t, ncclComm_t,
cudaStream_t) {
// TODO: implement this function
WARN("ncclReduceScatter is currently unavailable");
return ncclInternalError;
NCCL_API ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff, size_t recvcount, ncclDataType_t datatype,
ncclRedOp_t op, ncclComm_t comm, cudaStream_t stream) {
size_t bytes = recvcount * ncclTypeSize(datatype);
if (sendbuff == nullptr || recvbuff == nullptr || bytes == 0 || comm == nullptr) {
WARN(
"One or more of the following conditions is met: sendbuff or recvbuff pointer is nullptr, bytes is 0, "
"or comm is nullptr.");
return ncclInvalidArgument;
}

int rank = comm->comm->bootstrap()->getRank();
int nRank = comm->comm->bootstrap()->getNranks();

std::vector<executionPlanInstance>& plans = comm->executionPlans["reducescatter"];
std::shared_ptr<mscclpp::ExecutionPlan> plan;
void* basePtr = (char*)sendbuff + rank * bytes;
bool inPlace = basePtr == recvbuff;
const size_t totalBytes = bytes * nRank;
for (const auto& p : plans) {
if (totalBytes >= p.key.minMessageSize && totalBytes < p.key.maxMessageSize && inPlace == p.key.isInPlace) {
plan = p.plan;
break;
}
}
// TODO: Fallback code for ReduceScatter
if (plan == nullptr) {
WARN("No FallBack code for ReduceScatter");
return ncclInternalError;
}

switch (datatype) {
case ncclFloat16:
comm->executor->execute(rank, (half*)sendbuff, (half*)recvbuff, totalBytes, bytes, mscclpp::DataType::FLOAT16,
*plan, stream);
break;
case ncclFloat32:
comm->executor->execute(rank, (float*)sendbuff, (float*)recvbuff, totalBytes, bytes, mscclpp::DataType::FLOAT32,
*plan, stream);
break;
case ncclBfloat16:
comm->executor->execute(rank, (__bfloat16*)sendbuff, (__bfloat16*)recvbuff, totalBytes, bytes,
mscclpp::DataType::BFLOAT16, *plan, stream);
break;
case ncclInt32:
case ncclUint32:
comm->executor->execute(rank, (int*)sendbuff, (int*)recvbuff, totalBytes, bytes, mscclpp::DataType::UINT32, *plan,
stream);
break;
default:
WARN("datatype is invalid");
return ncclInvalidArgument;
}

return ncclSuccess;
}

NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcount, ncclDataType_t datatype,
Expand Down
5 changes: 4 additions & 1 deletion python/mscclpp/language/collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,10 @@ def init_buffers(self):
for i in range(self.num_ranks):
for c in range(self.chunk_factor):
input_buffer.append(Chunk(r, i * self.chunk_factor + c, i, c))
buffers = {Buffer.input: input_buffer}
buffers = {
Buffer.input: input_buffer,
Buffer.output: input_buffer[r * self.chunk_factor : (r + 1) * self.chunk_factor],
}
rank_buffers.append(buffers)
else:
input_buffer = []
Expand Down
13 changes: 11 additions & 2 deletions python/test/executor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,17 +145,26 @@ def build_bufs(
nelems_input = nelems if in_place else nelems // num_ranks
else:
nelems_input = nelems
nelems_output = nelems

if "reducescatter" in collective:
assert (nelems % num_ranks) == 0, "nelems %d not multiple of num_ranks %d" % (nelems, num_ranks)
nelems_output = nelems // num_ranks
else:
nelems_output = nelems

result_buf = GpuBuffer(nelems_output, dtype=dtype)
if in_place:
if "allgather" in collective:
input_buf = cp.split(result_buf, num_ranks)[rank]
elif "reducescatter" in collective:
input_buf = GpuBuffer(nelems_input, dtype=dtype)
result_buf = cp.split(input_buf, num_ranks)[rank]
else:
input_buf = result_buf
else:
input_buf = GpuBuffer(nelems_input, dtype=dtype)
test_buf = cp.zeros(nelems_output, dtype=dtype)

test_buf = cp.zeros(nelems, dtype=dtype)

return input_buf, result_buf, test_buf

Expand Down
2 changes: 1 addition & 1 deletion python/test/executor_test_verifier.cu
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ TEST_DATA_ALL_REDUCE(int32, int)
} \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_elems; i += blockDim.x * gridDim.x) { \
if (i >= offset && i < offset + nem_elems_per_rank) { \
assert(abs(float(result_buf[i]) - float(test_buf[i])) < 1e-3 * num_ranks); \
assert(abs(float(result_buf[i - offset]) - float(test_buf[i])) < 1e-3 * num_ranks); \
} \
} \
}
Expand Down
Loading