Skip to content

Commit 7f44888

Browse files
committed
Fix nixl support for indexer cache transfer
Signed-off-by: Iman Tabrizian <[email protected]>
1 parent a810c12 commit 7f44888

File tree

10 files changed

+216
-117
lines changed

10 files changed

+216
-117
lines changed

cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,8 @@ class CacheTransceiver : public BaseCacheTransceiver
269269
std::unique_ptr<executor::kv_cache::CacheState> mCacheState;
270270
std::unique_ptr<executor::kv_cache::ConnectionManager> mManager;
271271
std::optional<executor::CacheTransceiverConfig> mCacheTransceiverConfig;
272-
std::unique_ptr<kv_cache_manager::CacheTransBufferManager> mCacheTransBufferManager;
272+
std::vector<std::unique_ptr<kv_cache_manager::CacheTransBufferManager>> mCacheTransBufferManagers;
273+
std::vector<kv_cache_manager::CacheTransBufferManager*> mCacheTransBufferManagerPtrs;
273274
// library handle to the communicator related features,
274275
// this is used to defer dependency resolution until needed.
275276
static std::mutex mDllMutex;

cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -972,19 +972,14 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess
972972
}
973973

974974
std::unique_ptr<BaseCacheFormatter> createCacheFormatter(
975-
BaseKVCacheManager* cacheManager, CacheTransBufferManager* cacheTransBufferManager, bool isMLA)
975+
BaseKVCacheManager* cacheManager, std::vector<CacheTransBufferManager*> const& cacheTransBufferManagers, bool isMLA)
976976
{
977+
TLLM_CHECK(!cacheTransBufferManagers.empty());
977978
if (isMLA)
978979
{
979-
std::vector<CacheTransBufferManager*> cacheTransBufferManagers = {cacheTransBufferManager};
980-
auto maxNumTokens = cacheTransBufferManager->getMaxNumTokens();
981-
if (cacheManager->isEnableIndexerKCache())
982-
{
983-
cacheTransBufferManagers.push_back(new CacheTransBufferManager(cacheManager, maxNumTokens, true));
984-
}
985980
return std::make_unique<MLACacheFormatter>(cacheManager, cacheTransBufferManagers);
986981
}
987-
return std::make_unique<CacheFormatter>(cacheManager, cacheTransBufferManager);
982+
return std::make_unique<CacheFormatter>(cacheManager, cacheTransBufferManagers[0]);
988983
}
989984

990985
} // namespace tensorrt_llm::batch_manager::kv_cache_manager

cpp/tensorrt_llm/batch_manager/cacheFormatter.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ class CacheFormatter final : public BaseCacheFormatter
133133
CacheTransBufferManager* mCacheTransBufferManager;
134134
};
135135

136-
std::unique_ptr<BaseCacheFormatter> createCacheFormatter(
137-
BaseKVCacheManager* cacheManager, CacheTransBufferManager* cacheTransBufferManager, bool isMLA = false);
136+
std::unique_ptr<BaseCacheFormatter> createCacheFormatter(BaseKVCacheManager* cacheManager,
137+
std::vector<CacheTransBufferManager*> const& cacheTransBufferManagers, bool isMLA = false);
138138

139139
} // namespace tensorrt_llm::batch_manager::kv_cache_manager

cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,19 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
168168

169169
std::optional<size_t> maxNumTokens = mCacheTransceiverConfig.value().getMaxTokensInBuffer();
170170

171-
mCacheTransBufferManager = std::make_unique<kv_cache_manager::CacheTransBufferManager>(cacheManager, maxNumTokens);
171+
mCacheTransBufferManagers.push_back(
172+
std::make_unique<kv_cache_manager::CacheTransBufferManager>(cacheManager, maxNumTokens));
173+
if (isMLA && cacheManager->isEnableIndexerKCache())
174+
{
175+
mCacheTransBufferManagers.push_back(
176+
std::make_unique<kv_cache_manager::CacheTransBufferManager>(cacheManager, maxNumTokens, true));
177+
}
178+
mCacheTransBufferManagerPtrs.clear();
179+
mCacheTransBufferManagerPtrs.reserve(mCacheTransBufferManagers.size());
180+
for (auto& manager : mCacheTransBufferManagers)
181+
{
182+
mCacheTransBufferManagerPtrs.push_back(manager.get());
183+
}
172184
if (backendType.value() == executor::CacheTransceiverConfig::BackendType::UCX)
173185
{
174186
std::lock_guard<std::mutex> lock(mDllMutex);
@@ -191,7 +203,7 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
191203
else if (backendType.value() == executor::CacheTransceiverConfig::BackendType::NIXL)
192204
{
193205
mManager = std::make_unique<tensorrt_llm::executor::kv_cache::AgentConnectionManager>(
194-
mCacheTransBufferManager.get(), *mCacheState);
206+
mCacheTransBufferManagerPtrs, *mCacheState);
195207
TLLM_LOG_INFO("NIXL Connection Manager created");
196208
}
197209
else if (backendType.value() == executor::CacheTransceiverConfig::BackendType::MPI)
@@ -206,7 +218,7 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
206218
}
207219

208220
auto makeFormatter = [cacheManager, isMLA, this]()
209-
{ return createCacheFormatter(cacheManager, mCacheTransBufferManager.get(), isMLA); };
221+
{ return createCacheFormatter(cacheManager, mCacheTransBufferManagerPtrs, isMLA); };
210222

211223
mCacheSender = std::make_unique<CacheSender>(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter());
212224
mCacheReceiver

cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -834,12 +834,14 @@ class CacheReceiver::Impl
834834
}
835835

836836
auto* agentConnectionManager = dynamic_cast<executor::kv_cache::AgentConnectionManager*>(mManager);
837-
std::optional<size_t> cacheBufferId = std::nullopt;
837+
std::vector<std::optional<size_t>> cacheBufferIds;
838838
if (agentConnectionManager)
839839
{
840-
cacheBufferId = agentConnectionManager->getCacheTransBufferManager()->assignBufferIndexForRecv();
841-
TLLM_CHECK(cacheBufferId.has_value());
842-
// memory Desp , validSegmentIdx send
840+
for (auto& cacheTransBufferManager : agentConnectionManager->getCacheTransBufferManagers())
841+
{
842+
cacheBufferIds.push_back(cacheTransBufferManager->assignBufferIndexForRecv());
843+
}
844+
TLLM_CHECK(!cacheBufferIds.empty());
843845
}
844846
auto counterParts = mFormatter->getCounterparts(
845847
mSelfState.getCacheState().value(), mSelfState.getCommState().value().getSelfIdx(), destCacheState);
@@ -864,9 +866,9 @@ class CacheReceiver::Impl
864866
auto validConnectionIdx = std::find(pickUpIdx.begin(), pickUpIdx.end(), i) - pickUpIdx.begin();
865867
auto* agentConnection = dynamic_cast<executor::kv_cache::AgentConnection const*>(connection);
866868
TLLM_CHECK(agentConnection != nullptr);
867-
TLLM_CHECK(cacheBufferId.has_value());
869+
TLLM_CHECK(!cacheBufferIds.empty());
868870
const_cast<executor::kv_cache::AgentConnection*>(agentConnection)
869-
->sendRequestAndBufferInfo(requestInfo, cacheBufferId, validConnectionIdx);
871+
->sendRequestAndBufferInfo(requestInfo, cacheBufferIds, validConnectionIdx);
870872
}
871873
else
872874
{

cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,16 @@ void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& ses
131131

132132
for (auto transferIndexerKCache : transferringIndexerKCache)
133133
{
134+
auto activeBufferIdx = transferIndexerKCache ? 1UL : 0UL;
135+
for (auto const* connection : connections)
136+
{
137+
if (auto const* agentConnection = dynamic_cast<executor::kv_cache::AgentConnection const*>(connection))
138+
{
139+
TLLM_CHECK(agentConnection->getSenderBufferCount() > activeBufferIdx);
140+
const_cast<executor::kv_cache::AgentConnection*>(agentConnection)
141+
->setActiveSenderBufferIdx(activeBufferIdx);
142+
}
143+
}
134144
int blockNum = 0;
135145
std::vector<runtime::ITensor::SharedPtr> inputKvCacheBlocks;
136146
if (!transferIndexerKCache)
@@ -417,9 +427,10 @@ void MLACacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& s
417427
else
418428
{
419429
auto* agentConnnecion = dynamic_cast<executor::kv_cache::AgentConnection const*>(connections[0]);
430+
size_t activeBufferIdx = transferIndexerKCache ? 1 : 0;
420431
if (agentConnnecion != nullptr)
421432
{
422-
cacheBufferId = agentConnnecion->getCacheBufferId();
433+
cacheBufferId = agentConnnecion->getCacheBufferId(activeBufferIdx);
423434
TLLM_CHECK(cacheBufferId.has_value());
424435
}
425436
else

0 commit comments

Comments
 (0)