diff --git a/include/mscclpp/core.hpp b/include/mscclpp/core.hpp index 655f5d0ba..0376c50f4 100644 --- a/include/mscclpp/core.hpp +++ b/include/mscclpp/core.hpp @@ -331,11 +331,6 @@ class RegisteredMemory { /// @return The size of the memory block. size_t size(); - /// Get the pitch of the memory block. - /// - /// @return The pitch of the memory block. - size_t pitch(); - /// Get the rank of the process that owns the memory block. /// /// @return The rank of the process that owns the memory block. @@ -384,12 +379,14 @@ class Connection { /// /// @param dst The destination @ref RegisteredMemory. /// @param dstOffset The offset in bytes from the start of the destination @ref RegisteredMemory. + /// @param dstPitch The pitch of the destination @ref RegisteredMemory in bytes. /// @param src The source @ref RegisteredMemory. /// @param srcOffset The offset in bytes from the start of the source @ref RegisteredMemory. + /// @param srcPitch The pitch of the source @ref RegisteredMemory in bytes. /// @param width The width of the 2D region to write in bytes. /// @param height The height of the 2D region. - virtual void write2D(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, - uint64_t width, uint64_t height) = 0; + virtual void write2D(RegisteredMemory dst, uint64_t dstOffset, uint64_t dstPitch, RegisteredMemory src, + uint64_t srcOffset, uint64_t srcPitch, uint64_t width, uint64_t height) = 0; /// Update a 8-byte value in a destination @ref RegisteredMemory and synchronize the change with the remote process. /// /// @param dst The destination @ref RegisteredMemory. diff --git a/include/mscclpp/proxy_channel.hpp b/include/mscclpp/proxy_channel.hpp index 16a3d3b4d..c8719acab 100644 --- a/include/mscclpp/proxy_channel.hpp +++ b/include/mscclpp/proxy_channel.hpp @@ -8,6 +8,7 @@ #include #include #include +#include namespace mscclpp { @@ -40,6 +41,8 @@ class ProxyService : public BaseProxyService { /// @return The ID of the semaphore. SemaphoreId addSemaphore(std::shared_ptr connection); + void addPitch(SemaphoreId id, std::pair pitch); + /// Register a memory region with the proxy service. /// @param memory The memory region to register. /// @return The ID of the memory region. @@ -65,6 +68,7 @@ class ProxyService : public BaseProxyService { Communicator& communicator_; std::vector> semaphores_; std::vector memories_; + std::unordered_map> pitches_; Proxy proxy_; int deviceNumaNode; diff --git a/src/communicator.cc b/src/communicator.cc index c0ee00031..0480f0231 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -57,12 +57,6 @@ MSCCLPP_API_CPP RegisteredMemory Communicator::registerMemory(void* ptr, size_t std::make_shared(ptr, size, pimpl->bootstrap_->getRank(), transports, *pimpl)); } -MSCCLPP_API_CPP RegisteredMemory Communicator::registerMemory(void* ptr, size_t size, size_t pitchSize, - TransportFlags transports) { - return RegisteredMemory( - std::make_shared(ptr, size, pitchSize, pimpl->bootstrap_->getRank(), transports, *pimpl)); -} - struct MemorySender : public Setuppable { MemorySender(RegisteredMemory memory, int remoteRank, int tag) : memory_(memory), remoteRank_(remoteRank), tag_(tag) {} diff --git a/src/connection.cc b/src/connection.cc index b1d8ea628..4f0025f35 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -58,18 +58,18 @@ void CudaIpcConnection::write(RegisteredMemory dst, uint64_t dstOffset, Register // npkitCollectEntryEvent(conn, NPKIT_EVENT_DMA_SEND_DATA_ENTRY, (uint32_t)size); } -void CudaIpcConnection::write2D(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, - uint64_t width, uint64_t height) { +void CudaIpcConnection::write2D(RegisteredMemory dst, uint64_t dstOffset, uint64_t dstPitch, RegisteredMemory src, + uint64_t srcOffset, uint64_t srcPitch, uint64_t width, uint64_t height) { validateTransport(dst, remoteTransport()); validateTransport(src, transport()); char* dstPtr = (char*)dst.data(); char* srcPtr = (char*)src.data(); - MSCCLPP_CUDATHROW(cudaMemcpy2DAsync(dstPtr + dstOffset, dst.pitch(), srcPtr + srcOffset, src.pitch(), width, height, + MSCCLPP_CUDATHROW(cudaMemcpy2DAsync(dstPtr + dstOffset, dstPitch, srcPtr + srcOffset, srcPitch, width, height, cudaMemcpyDeviceToDevice, stream_)); INFO(MSCCLPP_P2P, "CudaIpcConnection write: from %p to %p, width %lu height %lu dstPitch %lu srcPitch %lu", - srcPtr + srcOffset, dstPtr + dstOffset, width, height, dst.pitch(), src.pitch()); + srcPtr + srcOffset, dstPtr + dstOffset, width, height, dstPitch, srcPitch); } void CudaIpcConnection::updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) { @@ -141,7 +141,8 @@ void IBConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMem // npkitCollectEntryEvent(conn, NPKIT_EVENT_IB_SEND_DATA_ENTRY, (uint32_t)size); } -void IBConnection::write2D(RegisteredMemory, uint64_t, RegisteredMemory, uint64_t, uint64_t, uint64_t) { +void IBConnection::write2D(RegisteredMemory, uint64_t, uint64_t, RegisteredMemory, uint64_t, uint64_t, uint64_t, + uint64_t) { throw Error("write2D is not supported", ErrorCode::InvalidUsage); } diff --git a/src/include/connection.hpp b/src/include/connection.hpp index 17204177d..970e1b1ea 100644 --- a/src/include/connection.hpp +++ b/src/include/connection.hpp @@ -44,8 +44,8 @@ class CudaIpcConnection : public ConnectionBase { void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) override; - void write2D(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t width, - uint64_t height) override; + void write2D(RegisteredMemory dst, uint64_t dstOffset, uint64_t dstPitch, RegisteredMemory src, uint64_t srcOffset, + uint64_t srcPitch, uint64_t width, uint64_t height) override; void updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) override; void flush() override; @@ -69,8 +69,8 @@ class IBConnection : public ConnectionBase { void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) override; - void write2D(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t width, - uint64_t height) override; + void write2D(RegisteredMemory dst, uint64_t dstOffset, uint64_t dstPitch, RegisteredMemory src, uint64_t srcOffset, + uint64_t srcPitch, uint64_t width, uint64_t height) override; void updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) override; void flush() override; diff --git a/src/include/registered_memory.hpp b/src/include/registered_memory.hpp index cf421c533..779cd7965 100644 --- a/src/include/registered_memory.hpp +++ b/src/include/registered_memory.hpp @@ -34,13 +34,11 @@ struct TransportInfo { struct RegisteredMemory::Impl { void* data; size_t size; - size_t pitch; // for 2D int rank; uint64_t hostHash; TransportFlags transports; std::vector transportInfos; - Impl(void* data, size_t size, size_t pitch, int rank, TransportFlags transports, Communicator::Impl& commImpl); Impl(void* data, size_t size, int rank, TransportFlags transports, Communicator::Impl& commImpl); Impl(const std::vector& data); diff --git a/src/proxy_channel.cc b/src/proxy_channel.cc index e3cf2a7cb..56470a0b2 100644 --- a/src/proxy_channel.cc +++ b/src/proxy_channel.cc @@ -29,6 +29,10 @@ MSCCLPP_API_CPP SemaphoreId ProxyService::addSemaphore(std::shared_ptr pitch) { + pitches_[id] = pitch; +} + MSCCLPP_API_CPP MemoryId ProxyService::addMemory(RegisteredMemory memory) { memories_.push_back(memory); return memories_.size() - 1; @@ -63,8 +67,9 @@ ProxyHandlerResult ProxyService::handleTrigger(ProxyTrigger triggerRaw) { RegisteredMemory& dst = memories_[trigger->fields.dstMemoryId]; RegisteredMemory& src = memories_[trigger->fields.srcMemoryId]; if (trigger->fields2D.multiDimensionFlag) { - semaphore->connection()->write2D(dst, trigger->fields.dstOffset, src, trigger->fields.srcOffset, - trigger->fields2D.width, trigger->fields2D.height); + std::pair& pitch = pitches_[trigger->fields.chanId]; + semaphore->connection()->write2D(dst, trigger->fields.dstOffset, pitch.first, src, trigger->fields.srcOffset, + pitch.second, trigger->fields2D.width, trigger->fields2D.height); } else { semaphore->connection()->write(dst, trigger->fields.dstOffset, src, trigger->fields.srcOffset, trigger->fields.size); diff --git a/src/registered_memory.cc b/src/registered_memory.cc index 48712e5cc..bb1ae3563 100644 --- a/src/registered_memory.cc +++ b/src/registered_memory.cc @@ -15,16 +15,7 @@ namespace mscclpp { RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags transports, Communicator::Impl& commImpl) - : Impl(data, size, size, rank, transports, commImpl) {} - -RegisteredMemory::Impl::Impl(void* data, size_t size, size_t pitch, int rank, TransportFlags transports, - Communicator::Impl& commImpl) - : data(data), - size(size), - pitch(pitch), - rank(rank), - hostHash(commImpl.rankToHash_.at(rank)), - transports(transports) { + : data(data), size(size), rank(rank), hostHash(commImpl.rankToHash_.at(rank)), transports(transports) { if (transports.has(Transport::CudaIpc)) { TransportInfo transportInfo; transportInfo.transport = Transport::CudaIpc; @@ -69,8 +60,6 @@ MSCCLPP_API_CPP void* RegisteredMemory::data() { return pimpl->data; } MSCCLPP_API_CPP size_t RegisteredMemory::size() { return pimpl->size; } -MSCCLPP_API_CPP size_t RegisteredMemory::pitch() { return pimpl->pitch; } - MSCCLPP_API_CPP int RegisteredMemory::rank() { return pimpl->rank; } MSCCLPP_API_CPP TransportFlags RegisteredMemory::transports() { return pimpl->transports; } @@ -78,7 +67,6 @@ MSCCLPP_API_CPP TransportFlags RegisteredMemory::transports() { return pimpl->tr MSCCLPP_API_CPP std::vector RegisteredMemory::serialize() { std::vector result; std::copy_n(reinterpret_cast(&pimpl->size), sizeof(pimpl->size), std::back_inserter(result)); - std::copy_n(reinterpret_cast(&pimpl->pitch), sizeof(pimpl->pitch), std::back_inserter(result)); std::copy_n(reinterpret_cast(&pimpl->rank), sizeof(pimpl->rank), std::back_inserter(result)); std::copy_n(reinterpret_cast(&pimpl->hostHash), sizeof(pimpl->hostHash), std::back_inserter(result)); std::copy_n(reinterpret_cast(&pimpl->transports), sizeof(pimpl->transports), std::back_inserter(result)); @@ -111,8 +99,6 @@ RegisteredMemory::Impl::Impl(const std::vector& serialization) { auto it = serialization.begin(); std::copy_n(it, sizeof(this->size), reinterpret_cast(&this->size)); it += sizeof(this->size); - std::copy_n(it, sizeof(this->pitch), reinterpret_cast(&this->pitch)); - it += sizeof(this->pitch); std::copy_n(it, sizeof(this->rank), reinterpret_cast(&this->rank)); it += sizeof(this->rank); std::copy_n(it, sizeof(this->hostHash), reinterpret_cast(&this->hostHash)); diff --git a/test/mp_unit/communicator_tests.cu b/test/mp_unit/communicator_tests.cu index c7ca523ef..4ce14cc8a 100644 --- a/test/mp_unit/communicator_tests.cu +++ b/test/mp_unit/communicator_tests.cu @@ -61,16 +61,7 @@ void CommunicatorTestBase::registerMemoryPairs(void* buff, size_t buffSize, mscc const std::vector& remoteRanks, mscclpp::RegisteredMemory& localMemory, std::unordered_map& remoteMemories) { - registerMemoryPairs(buff, buffSize, buffSize, transport, tag, remoteRanks, localMemory, remoteMemories); -} - -// Register a local memory with pitch and receive corresponding remote memories -void CommunicatorTestBase::registerMemoryPairs(void* buff, size_t buffSize, size_t pitchSize, - mscclpp::TransportFlags transport, int tag, - const std::vector& remoteRanks, - mscclpp::RegisteredMemory& localMemory, - std::unordered_map& remoteMemories) { - localMemory = communicator->registerMemory(buff, buffSize, pitchSize, transport); + localMemory = communicator->registerMemory(buff, buffSize, transport); std::unordered_map> futureRemoteMemories; for (int remoteRank : remoteRanks) { if (remoteRank != communicator->bootstrap()->getRank()) { @@ -105,9 +96,7 @@ void CommunicatorTest::SetUp() { devicePtr.resize(numBuffers); localMemory.resize(numBuffers); - local2DMemory.resize(numBuffers); remoteMemory.resize(numBuffers); - remote2DMemory.resize(numBuffers); std::vector remoteRanks; for (int i = 0; i < gEnv->worldSize; i++) { @@ -121,18 +110,11 @@ void CommunicatorTest::SetUp() { registerMemoryPairs(devicePtr[n].get(), deviceBufferSize, mscclpp::Transport::CudaIpc | ibTransport, 0, remoteRanks, localMemory[n], remoteMemory[n]); } - - for (size_t n = 0; n < numBuffers; n++) { - registerMemoryPairs(devicePtr[n].get(), deviceBufferSize, deviceBufferPitchSize, mscclpp::Transport::CudaIpc, 0, - remoteRanks, local2DMemory[n], remote2DMemory[n]); - } } void CommunicatorTest::TearDown() { remoteMemory.clear(); - remote2DMemory.clear(); localMemory.clear(); - local2DMemory.clear(); devicePtr.clear(); CommunicatorTestBase::TearDown(); } @@ -168,8 +150,9 @@ void CommunicatorTest::writeTileToRemote(size_t rowIndex, size_t colIndex, size_ for (int i = 0; i < gEnv->worldSize; i++) { if (i != gEnv->rank) { auto& conn = connections.at(i); - auto& peerMemory = remote2DMemory[n].at(i); - conn->write2D(peerMemory, offset, local2DMemory[n], offset, width * sizeof(int), height); + auto& peerMemory = remoteMemory[n].at(i); + conn->write2D(peerMemory, offset, deviceBufferPitchSize, localMemory[n], offset, deviceBufferPitchSize, + width * sizeof(int), height); conn->flush(); } } diff --git a/test/mp_unit/mp_unit_tests.hpp b/test/mp_unit/mp_unit_tests.hpp index cbad17cf0..11d72e3da 100644 --- a/test/mp_unit/mp_unit_tests.hpp +++ b/test/mp_unit/mp_unit_tests.hpp @@ -99,10 +99,6 @@ class CommunicatorTestBase : public MultiProcessTest { void registerMemoryPairs(void* buff, size_t buffSize, mscclpp::TransportFlags transport, int tag, const std::vector& remoteRanks, mscclpp::RegisteredMemory& localMemory, std::unordered_map& remoteMemories); - // Register a local memory with pitch and receive corresponding remote memories - void registerMemoryPairs(void* buff, size_t buffSize, size_t pitch, mscclpp::TransportFlags transport, int tag, - const std::vector& remoteRanks, mscclpp::RegisteredMemory& localMemory, - std::unordered_map& remoteMemories); // Register a local memory an receive one corresponding remote memory void registerMemoryPair(void* buff, size_t buffSize, mscclpp::TransportFlags transport, int tag, int remoteRank, mscclpp::RegisteredMemory& localMemory, mscclpp::RegisteredMemory& remoteMemory); @@ -128,9 +124,7 @@ class CommunicatorTest : public CommunicatorTestBase { const int deviceBufferPitchSize = 512; std::vector> devicePtr; std::vector localMemory; - std::vector local2DMemory; std::vector> remoteMemory; - std::vector> remote2DMemory; }; class ProxyChannelOneToOneTest : public CommunicatorTestBase { diff --git a/test/mp_unit/proxy_channel_tests.cu b/test/mp_unit/proxy_channel_tests.cu index db59da3e8..cd660210c 100644 --- a/test/mp_unit/proxy_channel_tests.cu +++ b/test/mp_unit/proxy_channel_tests.cu @@ -25,16 +25,16 @@ void ProxyChannelOneToOneTest::setupMeshConnections( void ProxyChannelOneToOneTest::setupMeshConnections( std::vector>& proxyChannels, bool useIbOnly, void* sendBuff, - size_t sendBuffBytes, size_t pitchSize, void* recvBuff, size_t recvBuffBytes) { + size_t sendBuffBytes, size_t pitch, void* recvBuff, size_t recvBuffBytes) { const int rank = communicator->bootstrap()->getRank(); const int worldSize = communicator->bootstrap()->getNranks(); const bool isInPlace = (recvBuff == nullptr); mscclpp::TransportFlags transport = (useIbOnly) ? ibTransport : (mscclpp::Transport::CudaIpc | ibTransport); - mscclpp::RegisteredMemory sendBufRegMem = communicator->registerMemory(sendBuff, sendBuffBytes, pitchSize, transport); + mscclpp::RegisteredMemory sendBufRegMem = communicator->registerMemory(sendBuff, sendBuffBytes, transport); mscclpp::RegisteredMemory recvBufRegMem; if (!isInPlace) { - recvBufRegMem = communicator->registerMemory(recvBuff, recvBuffBytes, pitchSize, transport); + recvBufRegMem = communicator->registerMemory(recvBuff, recvBuffBytes, transport); } for (int r = 0; r < worldSize; r++) { @@ -59,6 +59,7 @@ void ProxyChannelOneToOneTest::setupMeshConnections( communicator->setup(); mscclpp::SemaphoreId cid = channelService->addSemaphore(conn); + channelService->addPitch(cid, std::pair(pitch, pitch)); communicator->setup(); proxyChannels.emplace_back(mscclpp::deviceHandle(