Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
#include <future>
#include <map>
#include <memory>
#include <set>
#include <thread>

using SizeType32 = tensorrt_llm::runtime::SizeType32;

Expand All @@ -43,6 +45,102 @@ class BaseKVCacheManager;
class CacheSender;
class CacheReceiver;

struct UniqueIdSendMessage
{
public:
UniqueIdSendMessage(RequestIdType generationRequestId, std::string const& serverUuid)
: mGenerationRequestId(generationRequestId)
, mServerUuid(serverUuid)
{
}

serializedSize() const
{
return sizeof(RequestIdType) + mServerUuid.size();
}

void serialize(std::ostream& os) const
{
os.write(reinterpret_cast<char const*>(&mGenerationRequestId), sizeof(RequestIdType));
os.write(mServerUuid.c_str(), mServerUuid.size());
}

static UniqueIdSendMessage deserialize(std::istream& is)
{
is.read(reinterpret_cast<char*>(&mGenerationRequestId), sizeof(RequestIdType));
mServerUuid.resize(is.readsome());
is.read(mServerUuid.data(), mServerUuid.size());
return UniqueIdSendMessage(mGenerationRequestId, mServerUuid);
}

RequestIdType mGenerationRequestId;
std::string mServerUuid;
};

class UniqueIdGenerator
{
public:
static int get()
{
std::lock_guard<std::mutex> lock(mMutex);
if (!mReleasedIds.empty())
{
int id = *mReleasedIds.begin();
mReleasedIds.erase(mReleasedIds.begin());
return id;
}
return mNextId++;
}

static void release(int id)
{
std::lock_guard<std::mutex> lock(mMutex);
if (id < mNextId)
{
mReleasedIds.insert(id);
}
}

private:
static std::mutex mMutex;
static int mNextId;
static std::set<int> mReleasedIds;
};

class UniqueIdServer
{
public:
UniqueIdServer()
{
mThread = std::thread(
[this]()
{
int id = UniqueIdGenerator::get();
while (true)
{
int command;
mpi::MpiComm::session().sendRecv(
&id, &command, 1, mpi::MpiType::kINT32, 0, mpi::MpiTag::kUNIQUE_ID_TAG);
if (command != 0)
{
UniqueIdGenerator::release(command);
}
else
{
id = UniqueIdGenerator::get();
}
}
});
}

private:
std::thread mThread;
};

inline std::mutex UniqueIdGenerator::mMutex;
inline int UniqueIdGenerator::mNextId = 1;
inline std::set<int> UniqueIdGenerator::mReleasedIds;

class CacheTransceiverFactory
{
public:
Expand Down Expand Up @@ -132,6 +230,8 @@ class CacheTransceiver : public BaseCacheTransceiver
// this is used to defer dependency resolution until needed.
static std::mutex mDllMutex;
void* mWrapperLibHandle{nullptr};
std::string mUuid;
std::unique_ptr<UniqueIdServer> mUniqueIdServer;
};

} // namespace tensorrt_llm::batch_manager
2 changes: 2 additions & 0 deletions cpp/include/tensorrt_llm/runtime/utils/mpiUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,8 @@ class MpiComm
void sendRawTag(void const* buffer, std::size_t size, MpiType dtype, int dest, int tag) const;
void send(void const* buffer, std::size_t size, MpiType dtype, int dest, MpiTag tag) const;
void send(runtime::IBuffer const& buf, int dest, MpiTag tag) const;
void sendRecv(
void const* sendbuf, void* recvbuf, int sendCount, int recvCount, MpiType dtype, int dest, MpiTag tag) const;

template <typename T>
void sendValue(T const& value, int dest, MpiTag tag) const
Expand Down
23 changes: 21 additions & 2 deletions cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@
#include "tensorrt_llm/executor/serializeUtils.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
#include <algorithm>
#include <boost/uuid/uuid.hpp>
#include <boost/uuid/uuid_generators.hpp>
#include <boost/uuid/uuid_io.hpp>
#include <cstddef>
#include <numeric>
#include <unordered_set>
Expand Down Expand Up @@ -117,6 +120,21 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
: mMpiGroupComm(std::addressof(tensorrt_llm::mpi::MpiComm::session()))
, mCacheTransceiverConfig{cacheTransceiverConfig}
{
// Broadcast rank 0 UUID to all other ranks
if (worldConfig.getRank() == 0)
{
boost::uuids::random_generator uuidGen;
mUuid = boost::uuids::to_string(uuidGen());
std::vector<char> uuidVec(mUuid.begin(), mUuid.end());
mMpiGroupComm->bcast(uuidVec, 0);
mUniqueIdServer = std::make_unique<UniqueIdServer>();
}
else
{
std::vector<char> uuidVec;
mMpiGroupComm->bcast(uuidVec, 0);
mUuid.assign(uuidVec.begin(), uuidVec.end());
}
using tensorrt_llm::batch_manager::kv_cache_manager::CacheFormatter;
if (worldConfig.isTensorParallel())
{
Expand Down Expand Up @@ -199,9 +217,10 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
auto makeFormatter = [cacheManager, isMLA, this]()
{ return createCacheFormatter(cacheManager, mCacheTransBufferManager.get(), isMLA); };

mCacheSender = std::make_unique<CacheSender>(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter());
mCacheSender
= std::make_unique<CacheSender>(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter(), mUuid);
mCacheReceiver
= std::make_unique<CacheReceiver>(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter());
= std::make_unique<CacheReceiver>(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter(), mUuid);

initializeCommState();
}
Expand Down
Loading
Loading