Skip to content

Commit 864914e

Browse files
committed
Initial disagg support changes
Signed-off-by: Iman Tabrizian <[email protected]> Fixes for block size Signed-off-by: Iman Tabrizian <[email protected]> don't use unravel index Signed-off-by: Iman Tabrizian <[email protected]> Review commit Signed-off-by: Iman Tabrizian <[email protected]> minor fix Signed-off-by: Iman Tabrizian <[email protected]> Fix compile errors after rebase Signed-off-by: Iman Tabrizian <[email protected]> Bug fixes Signed-off-by: Iman Tabrizian <[email protected]> Remove print Signed-off-by: Iman Tabrizian <[email protected]>
1 parent f0dc746 commit 864914e

File tree

13 files changed

+790
-227
lines changed

13 files changed

+790
-227
lines changed

cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,21 @@ class WindowBlockManager
595595

596596
~WindowBlockManager();
597597

598+
[[nodiscard]] bool isEnableIndexerKCache() const
599+
{
600+
return mEnableIndexerKCache;
601+
}
602+
603+
[[nodiscard]] SizeType32 getIndexerKCacheQuantBlockSize() const
604+
{
605+
return mIndexerKCacheQuantBlockSize;
606+
}
607+
608+
[[nodiscard]] SizeType32 getIndexerKCacheIndexHeadDim() const
609+
{
610+
return mIndexerKCacheIndexHeadDim;
611+
}
612+
598613
void allocatePools(bool useUvm);
599614

600615
void releasePools();
@@ -1014,6 +1029,21 @@ class BlockManager
10141029
std::optional<kvc::BaseAgentConfig> agentConfig = std::nullopt, bool enableIndexerKCache = false,
10151030
SizeType32 indexerKCacheQuantBlockSize = 128, SizeType32 indexerKCacheIndexHeadDim = 0);
10161031

1032+
[[nodiscard]] bool isEnableIndexerKCache() const
1033+
{
1034+
return mWindowBlockManagers.begin()->second.isEnableIndexerKCache();
1035+
}
1036+
1037+
[[nodiscard]] SizeType32 getIndexerKCacheQuantBlockSize() const
1038+
{
1039+
return mWindowBlockManagers.begin()->second.getIndexerKCacheQuantBlockSize();
1040+
}
1041+
1042+
[[nodiscard]] SizeType32 getIndexerKCacheIndexHeadDim() const
1043+
{
1044+
return mWindowBlockManagers.begin()->second.getIndexerKCacheIndexHeadDim();
1045+
}
1046+
10171047
BlockManager(BlockManager const&) = delete;
10181048
BlockManager& operator=(BlockManager const&) = delete;
10191049

@@ -1485,6 +1515,10 @@ class BaseKVCacheManager
14851515

14861516
[[nodiscard]] virtual bool isEnableBlockReuse() const = 0;
14871517

1518+
[[nodiscard]] virtual bool isEnableIndexerKCache() const = 0;
1519+
[[nodiscard]] virtual SizeType32 getIndexerKCacheIndexHeadDim() const = 0;
1520+
[[nodiscard]] virtual SizeType32 getIndexerKCacheQuantBlockSize() const = 0;
1521+
14881522
// void removeToken(SizeType32 seqSlotIdx);
14891523
virtual void rewindKVCache(LlmRequest::RequestIdType requestId, SizeType32 rewindLengths) = 0;
14901524

@@ -1818,6 +1852,21 @@ class KVCacheManager : public BaseKVCacheManager
18181852
return mEnableBlockReuse;
18191853
}
18201854

1855+
[[nodiscard]] bool isEnableIndexerKCache() const override
1856+
{
1857+
return mBlockManager.isEnableIndexerKCache();
1858+
}
1859+
1860+
[[nodiscard]] SizeType32 getIndexerKCacheIndexHeadDim() const override
1861+
{
1862+
return mBlockManager.getIndexerKCacheIndexHeadDim();
1863+
}
1864+
1865+
[[nodiscard]] SizeType32 getIndexerKCacheQuantBlockSize() const override
1866+
{
1867+
return mBlockManager.getIndexerKCacheQuantBlockSize();
1868+
}
1869+
18211870
void removeToken(LlmRequest::RequestIdType requestId);
18221871
void rewindKVCache(LlmRequest::RequestIdType requestId, SizeType32 rewindLengths) override;
18231872

cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ class BlockRange
7373
BaseKVCacheManager& cacheManager, BlockKey const& lastBlockKey, int32_t indexFromEnd)
7474
{
7575

76-
auto poolNum = cacheManager.getNumPools();
76+
auto poolNum = cacheManager.getBlockManager().getNumPools(
77+
/*includeBlockScalePools=*/false, /*includeIndexerKCachePools=*/false);
7778
TLLM_CHECK_WITH_INFO(poolNum == 1, "Reuse tree is not supported for multiple pools or variable window size");
7879

7980
auto windowSize = cacheManager.getBlockManager().getWindowSizesMetadata().begin()->first;
@@ -136,13 +137,21 @@ class BlockRange
136137
return blockHashesPerWindow;
137138
}
138139

139-
BlockRangeForWindow getBlockRangeForWindow(SizeType32 windowSize) const
140+
BlockRangeForWindow getBlockRangeForWindow(SizeType32 windowSize, bool useIndexerKCache = false) const
140141
{
141142
TLLM_CHECK_WITH_INFO(
142143
mPoolsPerWindow.find(windowSize) != mPoolsPerWindow.end(), "Window size %d not found", windowSize);
143144
auto pool = mPoolsPerWindow.at(windowSize).front();
144145
auto blockIds = mBlockIdsPerWindow.at(windowSize);
145-
return BlockRangeForWindow(mManager, windowSize, std::move(blockIds), std::move(pool));
146+
if (useIndexerKCache)
147+
{
148+
TLLM_CHECK(mIndexerKCachePool);
149+
return BlockRangeForWindow(mManager, windowSize, std::move(blockIds), std::move(mIndexerKCachePool));
150+
}
151+
else
152+
{
153+
return BlockRangeForWindow(mManager, windowSize, std::move(blockIds), std::move(pool));
154+
}
146155
}
147156

148157
std::vector<SizeType32> getWindowSizes() const
@@ -167,9 +176,8 @@ class BlockRange
167176
, mRequestId(requestId)
168177
, mBlockIdsPerWindow(std::move(blockIdsPerWindow))
169178
{
170-
171-
// cacheManager.getBlockManager.getPrimaryPool(0);
172-
auto poolNum = mManager->getNumPools();
179+
auto poolNum = mManager->getBlockManager().getNumPools(
180+
/*includeBlockScalePools=*/false, /*includeIndexerKCachePools=*/false);
173181
for (SizeType32 poolIdx = 0; poolIdx < poolNum; ++poolIdx)
174182
{
175183
auto windowSize = cacheManager.getBlockManager().getPoolWindowSize(poolIdx);
@@ -181,21 +189,27 @@ class BlockRange
181189
: mManager(&cacheManager)
182190
, mRequestId(requestId)
183191
{
184-
auto poolNum = mManager->getNumPools();
192+
auto poolNum = mManager->getBlockManager().getNumPools(
193+
/*includeBlockScalePools=*/false, /*includeIndexerKCachePools=*/false);
185194
for (SizeType32 poolIdx = 0; poolIdx < poolNum; ++poolIdx)
186195
{
187196
auto windowSize = cacheManager.getBlockManager().getPoolWindowSize(poolIdx);
188197
mPoolsPerWindow[windowSize].push_back(cacheManager.getBlockManager().getPrimaryPool(poolIdx));
189198
mBlockIdsPerWindow[windowSize]
190199
= cacheManager.getSequence(mRequestId).getCacheBlockIds(windowSize).at(kFIRST_AND_ONLY_BEAM);
191200
}
201+
if (cacheManager.isEnableIndexerKCache())
202+
{
203+
mIndexerKCachePool = cacheManager.getIndexerKCachePool();
204+
}
192205
}
193206

194207
private:
195208
BaseKVCacheManager const* mManager;
196209
LlmRequest::RequestIdType const mRequestId;
197210
std::unordered_map<SizeType32, std::vector<SizeType32>> mBlockIdsPerWindow;
198211
std::unordered_map<SizeType32, std::vector<runtime::ITensor::SharedPtr>> mPoolsPerWindow;
212+
runtime::ITensor::SharedPtr mIndexerKCachePool;
199213

200214
static constexpr SizeType32 kFIRST_AND_ONLY_BEAM = 0;
201215
static constexpr SizeType32 kFIRST_POOL_INDEX = 0;

cpp/include/tensorrt_llm/executor/dataTransceiverState.h

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ class CacheState final
5050

5151
CacheState(ModelConfig modelConfig, runtime::WorldConfig const& worldConfig,
5252
std::vector<SizeType32> const& attentionLayerNumPerPP, nvinfer1::DataType dataType,
53-
AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2, bool enableBlockReuse = false)
53+
AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2, bool enableBlockReuse = false,
54+
bool hasIndexerKCache = false, SizeType32 indexerDimPerHead = 0, SizeType32 indexerKCacheQuantBlockSize = 128)
5455
: mModelConfig(std::move(modelConfig))
5556
, mParallelConfig{worldConfig.getTensorParallelism(), worldConfig.getPipelineParallelism(),
5657
worldConfig.getContextParallelism(), worldConfig.enableAttentionDP(), worldConfig.getTensorParallelRank(),
@@ -59,34 +60,45 @@ class CacheState final
5960
, mAttentionConfig(attentionType, kvFactor)
6061
{
6162
mEnableBlockReuse = enableBlockReuse;
63+
mHasIndexerKCache = hasIndexerKCache;
64+
mIndexerDimPerHead = indexerDimPerHead;
65+
mIndexerKCacheQuantBlockSize = indexerKCacheQuantBlockSize;
6266
}
6367

6468
CacheState(std::vector<SizeType32> nbKvHeadPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
6569
SizeType32 tensorParallelism, SizeType32 pipelineParallelism, SizeType32 contextParallelism,
6670
std::vector<SizeType32> const& attentionLayerNumPerPP, nvinfer1::DataType dataType,
6771
AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2, bool enableAttentionDP = false,
68-
int DPrank = 0, int DPsize = 0, bool enableBlockReuse = false)
72+
int DPrank = 0, int DPsize = 0, bool enableBlockReuse = false, bool hasIndexerKCache = false,
73+
SizeType32 indexerDimPerHead = 0, SizeType32 indexerKCacheQuantBlockSize = 128)
6974
: mModelConfig{std::move(nbKvHeadPerLayer), sizePerHead, tokensPerBlock}
7075
, mParallelConfig{tensorParallelism, pipelineParallelism, contextParallelism, enableAttentionDP, DPrank, DPsize,
7176
attentionLayerNumPerPP}
7277
, mDataType{dataType}
7378
, mAttentionConfig(attentionType, kvFactor)
7479
{
7580
mEnableBlockReuse = enableBlockReuse;
81+
mHasIndexerKCache = hasIndexerKCache;
82+
mIndexerDimPerHead = indexerDimPerHead;
83+
mIndexerKCacheQuantBlockSize = indexerKCacheQuantBlockSize;
7684
}
7785

7886
CacheState(SizeType32 nbAttentionLayers, SizeType32 nbKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
7987
SizeType32 tensorParallelism, SizeType32 pipelineParallelism, SizeType32 contextParallelism,
8088
std::vector<SizeType32> const& attentionLayerNumPerPP, nvinfer1::DataType dataType,
8189
AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2, bool enableAttentionDP = false,
82-
int DPrank = 0, int DPsize = 0, bool enableBlockReuse = false)
90+
int DPrank = 0, int DPsize = 0, bool enableBlockReuse = false, bool hasIndexerKCache = false,
91+
SizeType32 indexerDimPerHead = 0, SizeType32 indexerKCacheQuantBlockSize = 128)
8392
: mModelConfig{std::vector(nbAttentionLayers, nbKvHeads), sizePerHead, tokensPerBlock}
8493
, mParallelConfig{tensorParallelism, pipelineParallelism, contextParallelism, enableAttentionDP, DPrank, DPsize,
8594
attentionLayerNumPerPP}
8695
, mDataType{dataType}
8796
, mAttentionConfig(attentionType, kvFactor)
8897
{
8998
mEnableBlockReuse = enableBlockReuse;
99+
mHasIndexerKCache = hasIndexerKCache;
100+
mIndexerDimPerHead = indexerDimPerHead;
101+
mIndexerKCacheQuantBlockSize = indexerKCacheQuantBlockSize;
90102
}
91103

92104
[[nodiscard]] bool operator==(kv_cache::CacheState const& other) const noexcept
@@ -174,6 +186,21 @@ class CacheState final
174186
return mEnableBlockReuse;
175187
}
176188

189+
[[nodiscard]] bool getHasIndexerKCache() const
190+
{
191+
return mHasIndexerKCache;
192+
}
193+
194+
[[nodiscard]] SizeType32 getIndexerDimPerHead() const
195+
{
196+
return mIndexerDimPerHead;
197+
}
198+
199+
[[nodiscard]] SizeType32 getIndexerKCacheQuantBlockSize() const
200+
{
201+
return mIndexerKCacheQuantBlockSize;
202+
}
203+
177204
[[nodiscard]] std::string toString() const
178205
{
179206
std::stringstream sstring;
@@ -194,6 +221,9 @@ class CacheState final
194221
sstring << "dpRank:" << mParallelConfig.mDPrank << "\n";
195222
sstring << "dpSize:" << mParallelConfig.mDPsize << "\n";
196223
sstring << "enableBlockReuse:" << mEnableBlockReuse << "\n";
224+
sstring << "hasIndexerKCache:" << mHasIndexerKCache << "\n";
225+
sstring << "indexerDimPerHead:" << mIndexerDimPerHead << "\n";
226+
sstring << "indexerKCacheQuantBlockSize:" << mIndexerKCacheQuantBlockSize << "\n";
197227
return sstring.str();
198228
}
199229

@@ -204,6 +234,9 @@ class CacheState final
204234
nvinfer1::DataType mDataType;
205235
AttentionConfig mAttentionConfig;
206236
bool mEnableBlockReuse{false};
237+
bool mHasIndexerKCache{false};
238+
SizeType32 mIndexerDimPerHead{0};
239+
SizeType32 mIndexerKCacheQuantBlockSize{128};
207240
};
208241

209242
struct MpiState

cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ namespace tensorrt_llm::batch_manager::kv_cache_manager
4545
BlockRange getBlockRangeForSending(
4646
BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest, BlockKey const& lastBlockKey, int32_t indexFromEnd)
4747
{
48-
auto poolNum = cacheManager->getBlockManager().getNumPools();
48+
auto poolNum = cacheManager->getBlockManager().getNumPools(
49+
/*includeBlockScalePools=*/false, /*includeIndexerKCachePools=*/false);
4950
if (poolNum > 1 || !cacheManager->isEnableBlockReuse() || lastBlockKey.uniqueTokens.size() == 0)
5051
{
5152
// disable reuse path, and vwsa don't support reuse.
@@ -87,7 +88,8 @@ BlockRange getBlockRangeForSending(
8788
BlockRange getBlockRangeForReceiving(
8889
BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest, bool srcEnableBlockReuse)
8990
{
90-
auto poolNum = cacheManager->getBlockManager().getNumPools();
91+
auto poolNum = cacheManager->getBlockManager().getNumPools(
92+
/*includeBlockScalePools=*/false, /*includeIndexerKCachePools=*/false);
9193
if (poolNum == 1 && srcEnableBlockReuse)
9294
{
9395
// Build from all block ids, then slice off the reused blocks so we only transfer newly allocated ones.
@@ -170,7 +172,8 @@ void checkAlternateWindow(BaseKVCacheManager* cacheManager, BaseCacheFormatter::
170172
// if gen PP and context PP are different, cache formatter only support alternative window like gpt-oss.
171173
// which is one layer is WSA, and another layer is Full attention.
172174

173-
auto numPools = cacheManager->getBlockManager().getNumPools();
175+
auto numPools = cacheManager->getBlockManager().getNumPools(
176+
/*includeBlockScalePools=*/false, /*includeIndexerKCachePools=*/false);
174177
auto layerNum = cacheManager->getBlockManager().getNumLayers();
175178

176179
auto selfPPNum = selfConfig.getParallelConfig().mPipelineParallelism;
@@ -247,7 +250,8 @@ void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& sessio
247250
auto& blockManager = mCacheManager->getBlockManager();
248251
auto const& lastBlockKey = session.getLastBlockKey();
249252
auto blockRange = getBlockRangeForSending(mCacheManager, llmRequest, lastBlockKey, indexFromEnd);
250-
auto const numPools = blockManager.getNumPools();
253+
auto const numPools
254+
= blockManager.getNumPools(/*includeBlockScalePools=*/false, /*includeIndexerKCachePools=*/false);
251255
// TODO(oargov): are we sure the other side has the same number of pools? this might not hold for pp_size>1...
252256

253257
bool layerWise = common::getEnvDisaggLayerwise() && numPools == 1;
@@ -555,7 +559,8 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess
555559
TLLM_LOG_DEBUG("pickUpConnections size: %d connections size: %d", pickUpConnections.size(), connections.size());
556560
std::vector<runtime::ITensor::SharedPtr> recvBufferTmps;
557561
std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>> outputBuffersPerWindow;
558-
auto const numPools = mCacheManager->getBlockManager().getNumPools();
562+
auto const numPools = mCacheManager->getBlockManager().getNumPools(
563+
/*includeBlockScalePools=*/false, /*includeIndexerKCachePools=*/false);
559564
// TODO(oargov): are we sure the other side has the same number of pools? this might not hold for pp_size>1...
560565
size_t blockNum = 0;
561566
size_t cacheBlockSizeSum = 0;
@@ -969,7 +974,13 @@ std::unique_ptr<BaseCacheFormatter> createCacheFormatter(
969974
{
970975
if (isMLA)
971976
{
972-
return std::make_unique<MLACacheFormatter>(cacheManager, cacheTransBufferManager);
977+
std::vector<CacheTransBufferManager*> cacheTransBufferManagers = {cacheTransBufferManager};
978+
auto maxNumTokens = cacheTransBufferManager->getMaxNumTokens();
979+
if (cacheManager->isEnableIndexerKCache())
980+
{
981+
cacheTransBufferManagers.push_back(new CacheTransBufferManager(cacheManager, maxNumTokens, true));
982+
}
983+
return std::make_unique<MLACacheFormatter>(cacheManager, cacheTransBufferManagers);
973984
}
974985
return std::make_unique<CacheFormatter>(cacheManager, cacheTransBufferManager);
975986
}

cpp/tensorrt_llm/batch_manager/cacheTransBuffer.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,14 +189,22 @@ bool FabricMemory::supportFbaricMemory()
189189
}
190190

191191
CacheTransBufferManager::CacheTransBufferManager(
192-
KVCacheManager::BaseKVCacheManager* cacheManager, std::optional<size_t> maxNumTokens)
192+
KVCacheManager::BaseKVCacheManager* cacheManager, std::optional<size_t> maxNumTokens, bool transferIndexerKCache)
193193
: mCacheManager{cacheManager}
194194
, mBufferManager{std::make_shared<runtime::CudaStream>()}
195+
, mTransferIndexerKCache{transferIndexerKCache}
196+
, mMaxNumTokens{maxNumTokens}
195197
{
196-
197198
// TODO: FP4 dataSize
198199
TLLM_CHECK(mCacheManager);
199-
mDataType = mCacheManager->getPrimaryPool(0)->getDataType();
200+
if (transferIndexerKCache)
201+
{
202+
mDataType = mCacheManager->getIndexerKCachePool()->getDataType();
203+
}
204+
else
205+
{
206+
mDataType = mCacheManager->getPrimaryPool(0)->getDataType();
207+
}
200208

201209
auto tokensPerBlock = mCacheManager->getBlockManager().getTokensPerBlock();
202210
size_t bufferSizeFromMaxNumToken = 0;

cpp/tensorrt_llm/batch_manager/cacheTransBuffer.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ class FabricMemory
5757
class CacheTransBufferManager
5858
{
5959
public:
60-
CacheTransBufferManager(
61-
KVCacheManager::BaseKVCacheManager* cacheManager, std::optional<size_t> maxNumTokens = std::nullopt);
60+
CacheTransBufferManager(KVCacheManager::BaseKVCacheManager* cacheManager,
61+
std::optional<size_t> maxNumTokens = std::nullopt, bool transferIndexerKCache = false);
6262

6363
static size_t preAllocBufferSize(std::map<SizeType32, SizeType32> const& cacheSizeBytesPerTokenPerWindow,
6464
SizeType32 tokensPerBlock,
@@ -82,6 +82,11 @@ class CacheTransBufferManager
8282
size_t getRecvBufferCount();
8383
size_t getSendBufferCount();
8484

85+
std::optional<size_t> getMaxNumTokens()
86+
{
87+
return mMaxNumTokens;
88+
}
89+
8590
private:
8691
struct ConcurrenceResource
8792
{
@@ -114,6 +119,8 @@ class CacheTransBufferManager
114119
KVCacheManager::BaseKVCacheManager* mCacheManager;
115120
runtime::BufferManager mBufferManager;
116121
std::vector<std::unique_ptr<FabricMemory>> mFabricMemory;
122+
bool mTransferIndexerKCache;
123+
std::optional<size_t> mMaxNumTokens;
117124
};
118125

119126
} // namespace tensorrt_llm::batch_manager::kv_cache_manager

cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,9 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
137137
kvFactor = 1;
138138
}
139139
mCacheState = std::make_unique<executor::kv_cache::CacheState>(cacheStateModelCfg, worldConfig,
140-
attentionLayerNumPerPP, dataType, attentionType, kvFactor, cacheManager->isEnableBlockReuse());
140+
attentionLayerNumPerPP, dataType, attentionType, kvFactor, cacheManager->isEnableBlockReuse(),
141+
cacheManager->isEnableIndexerKCache(), cacheManager->getIndexerKCacheIndexHeadDim(),
142+
cacheManager->getIndexerKCacheQuantBlockSize());
141143

142144
if (mCacheState->getParallelConfig().mEnableAttentionDP)
143145
{

0 commit comments

Comments
 (0)