Skip to content

Commit

Permalink
optimize allreduce5
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang committed Dec 6, 2023
1 parent 801536e commit a6fded6
Showing 1 changed file with 100 additions and 0 deletions.
100 changes: 100 additions & 0 deletions test/mscclpp-test/allreduce_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,67 @@ __device__ void localReduceScatterSm2(int* buff, int rank, int nRanksPerNode, si
}
}

__device__ void localReduceScatterSm3(int* buff, int rank, int nRanksPerNode, size_t chunkSize, size_t nelems,
int nBlocks) {
if (nRanksPerNode == 1) return;
if ((int)blockIdx.x >= nBlocks) return;
const int nPeer = nRanksPerNode - 1;
DeviceHandle<mscclpp::SmChannel>* smChans = constSmOutOfPlaceGetChans;

const size_t localRankIndexInNode = rank % nRanksPerNode;
const size_t indexOffset = localRankIndexInNode * chunkSize;
const size_t indexOffset4 = indexOffset / 4;

int4* buff4 = (int4*)buff;

const int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid < nPeer) {
smChans[tid].signal();
}
const int waitStart = nBlocks * blockDim.x - nPeer;
if (tid >= waitStart && tid < (int)(nBlocks * blockDim.x)) {
smChans[tid - waitStart].wait();
}
reduceScatterDeviceSyncer.sync(nBlocks);

const size_t nInt4 = nelems / 4;

size_t base = 0;
const size_t unitNInt4 = blockDim.x * nBlocks;
for (; base + unitNInt4 < nInt4; base += unitNInt4) {
for (int index = 0; index < nPeer; ++index) {
int4 val;
int peerIdx = (index + localRankIndexInNode) % nPeer;
for (size_t idx = base + threadIdx.x + blockIdx.x * blockDim.x; idx < base + unitNInt4; idx += blockDim.x * nBlocks) {
val = smChans[peerIdx].read<int4>(indexOffset4 + idx);
buff4[indexOffset4 + idx].w += val.w;
buff4[indexOffset4 + idx].x += val.x;
buff4[indexOffset4 + idx].y += val.y;
buff4[indexOffset4 + idx].z += val.z;
}
}
}
for (int index = 0; index < nPeer; ++index) {
int4 val;
int peerIdx = (index + localRankIndexInNode) % nPeer;
for (size_t idx = base + threadIdx.x + blockIdx.x * blockDim.x; idx < nInt4; idx += blockDim.x * nBlocks) {
val = smChans[peerIdx].read<int4>(indexOffset4 + idx);
buff4[indexOffset4 + idx].w += val.w;
buff4[indexOffset4 + idx].x += val.x;
buff4[indexOffset4 + idx].y += val.y;
buff4[indexOffset4 + idx].z += val.z;
}
}

const size_t nLastInts = nelems % 4;
for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) {
for (size_t idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nLastInts; idx += blockDim.x * nBlocks) {
int val = smChans[(localRankIndexInNode + peerIdx) % nPeer].read<int>(indexOffset + nInt4 * 4 + idx);
buff[indexOffset + nInt4 * 4 + idx] += val;
}
}
}

__device__ void reduceScatterSm(int* buff, int* scratch, int rank, int nRanksPerNode, int worldSize,
size_t nelems // must be divisible by 3
) {
Expand Down Expand Up @@ -520,6 +581,39 @@ __device__ void localRingAllGatherSm(int rank, int nRanksPerNode, uint64_t size,
}
}

__device__ void localRingAllGatherSm2(size_t rank, size_t nRanksPerNode, size_t size, size_t nBlocks) {
if (nRanksPerNode == 1) return;
if (blockIdx.x >= nBlocks) return;

size_t tid = threadIdx.x + blockIdx.x * blockDim.x;
const size_t nPeer = nRanksPerNode - 1;

if (tid < nPeer) {
constSmInPlaceChans[tid].signal();
}
size_t waitStart = nBlocks * blockDim.x - nPeer;
if (tid >= waitStart && tid < nBlocks * blockDim.x) {
constSmInPlaceChans[tid - waitStart].wait();
}
allGatherDeviceSyncer.sync(nBlocks);
const size_t unitSize = 16 * blockDim.x * nBlocks;
size_t base = 0;
for (; base + unitSize < size; base += unitSize) {
for (size_t i = 0; i < nPeer; ++i) {
size_t peerIdx = (i + rank) % nPeer;
const size_t remoteRankLocalIndex = (peerIdx < rank ? peerIdx : peerIdx + 1);
size_t offset = size * remoteRankLocalIndex + base;
constSmInPlaceChans[peerIdx].get(offset, unitSize, tid, blockDim.x * nBlocks);
}
}
for (size_t i = 0; i < nPeer; ++i) {
size_t peerIdx = (i + rank) % nPeer;
const size_t remoteRankLocalIndex = (peerIdx < rank ? peerIdx : peerIdx + 1);
size_t offset = size * remoteRankLocalIndex + base;
constSmInPlaceChans[peerIdx].get(offset, size - base, tid, blockDim.x * nBlocks);
}
}

// This is an allgather4 equivalent
__device__ void allGatherSm(int rank, int worldSize, int nRanksPerNode, size_t nelemsPerGPU) {
// this allgather is a pipelined and hierarchical one and only works for two nodes
Expand Down Expand Up @@ -861,9 +955,15 @@ __global__ void allreduce4(int* buff, int* scratch, int rank, int nRanksPerNode,
}

__global__ void allreduce5(int* buff, int rank, int nRanksPerNode, int worldSize, size_t nelems) {
#if defined(__HIP_PLATFORM_AMD__) && (__HIP_PLATFORM_AMD__ == 1)
localReduceScatterSm3(buff, rank, nRanksPerNode, nelems / worldSize, nelems / worldSize, gridDim.x);
deviceSyncer.sync(gridDim.x);
localRingAllGatherSm2(rank, nRanksPerNode, nelems / worldSize * sizeof(int), gridDim.x);
#else
localReduceScatterSm2(buff, rank, nRanksPerNode, nelems / worldSize, nelems / worldSize, gridDim.x);
deviceSyncer.sync(gridDim.x);
localRingAllGatherSm(rank, nRanksPerNode, nelems / worldSize * sizeof(int), gridDim.x);
#endif
}

__global__ void allreduce6(int* buff, int* scratch, void* resultBuff, int rank, int nRanksPerNode, int worldSize,
Expand Down

0 comments on commit a6fded6

Please sign in to comment.