Skip to content

Commit b09d5a6

Browse files
committed
[https://nvbugs/5625990][fix] Fix block copy from GPU to GPU for partial reuse in the KV cache manager
`KVCacheTransferManager::onboard` only covers memory movement between CPU and GPU and not GPU to GPU. Use the `mBufferManager` to copy block content. This fixes the incorrect partial block copy functionality exposed through test case `accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype_vswa_reuse_low_memory_available_partial_reuse` Signed-off-by: eopXD <[email protected]>
1 parent d59e2cb commit b09d5a6

File tree

5 files changed

+74
-10
lines changed

5 files changed

+74
-10
lines changed

cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -995,6 +995,10 @@ class WindowBlockManager
995995
double mTotalInputTokens;
996996
// Whether blocks that are partially matched should be reused.
997997
bool mEnablePartialReuse;
998+
// Number of partial matched blocks reused through a copy
999+
SizeType32 mCopiedReusedPartialBlocks;
1000+
// Number of partial matched blocks reused directly without a copy
1001+
SizeType32 mDirectlyReusedPartialBlocks;
9981002
// Whether partially matched blocks that are already in use should be copied and reused.
9991003
bool mCopyOnPartialReuse;
10001004
// The kv cache connector manager

cpp/include/tensorrt_llm/batch_manager/kvCacheTransferManager.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,25 @@ class KVCacheTransferManager
6262
static tr::ITensor::SharedPtr computeBlockPointer(
6363
BlockPtr const& block, std::vector<KVCacheBlockPool> const& pools, size_t poolIdx);
6464

65+
/*!
66+
* \brief Synchronize pending onboard transfers for the given blocks.
67+
*
68+
* \details For the src block (offloadedBlock), we wait for any pending
69+
* writes before reading from it. For the dst block (block), we
70+
* wait for any pending reads and writes before overwriting it.
71+
* \param offloadedBlock Offloaded block (to be onboarded)
72+
* \param block Block (to be copied content onto)
73+
*/
74+
void syncPendingOnboardTransfers(BlockPtr const& offloadedBlock, BlockPtr const& block);
75+
76+
/*!
77+
* \brief Record pending onboard transfers for the given blocks.
78+
*
79+
* \param offloadedBlock Offloaded block (to be onboarded)
80+
* \param block Block (to be copied content onto)
81+
*/
82+
void recordPendingOnboardTransfers(BlockPtr const& offloadedBlock, BlockPtr const& block);
83+
6584
/*!
6685
* \brief The key method that copies the src block to the dst block.
6786
*
@@ -79,6 +98,15 @@ class KVCacheTransferManager
7998
int numTokensToCopy = 0, executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM,
8099
std::string const& directory = "");
81100

101+
/*!
102+
* \brief Directly copy a block from gpu to gpu.
103+
*
104+
* \param src Source block
105+
* \param dst Destination block
106+
* \param pools Pools describing memory layout for KV blocks
107+
*/
108+
void copyBlockGPUToGPU(BlockPtr const& src, BlockPtr const& dst, std::vector<KVCacheBlockPool> const& pools);
109+
82110
runtime::BufferManager mBufferManager;
83111
runtime::BufferManager mOnboardManager;
84112
runtime::BufferManager mOffloadManager;

cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -672,6 +672,8 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind
672672
, mReusedTokens{0.0}
673673
, mTotalInputTokens{0.0}
674674
, mEnablePartialReuse{enablePartialReuse}
675+
, mCopiedReusedPartialBlocks{0}
676+
, mDirectlyReusedPartialBlocks{0}
675677
, mCopyOnPartialReuse{copyOnPartialReuse}
676678
, mKvCacheConnectorManager{std::move(kvCacheConnectorManager)}
677679
, mEnableIndexerKCache{enableIndexerKCache}
@@ -765,6 +767,9 @@ WindowBlockManager::~WindowBlockManager()
765767
TLLM_LOG_DEBUG("%s - reused tokens: %.0f ", mLogPrefix.c_str(), mReusedTokens);
766768
TLLM_LOG_DEBUG("%s - reused tokens percentage (%%): %.2f ", mLogPrefix.c_str(),
767769
100.0 * mReusedTokens / mTotalInputTokens);
770+
TLLM_LOG_DEBUG("%s - copied reused partial blocks: %lu ", mLogPrefix.c_str(), mCopiedReusedPartialBlocks);
771+
TLLM_LOG_DEBUG(
772+
"%s - directly reused partial blocks: %lu ", mLogPrefix.c_str(), mDirectlyReusedPartialBlocks);
768773
}
769774

770775
bool BlockManager::verifyQueueIntegrity(SizeType32 windowSize)
@@ -1246,7 +1251,7 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector<BlockKey> const&
12461251
// Somebody else is using block or it is not a leaf, copy reusable tokens
12471252
auto newBlock = getFreeBlock(
12481253
sequence, matchingBlock->getPriority(), matchingBlock->getDurationMs(), mode, directory);
1249-
mTransferManager->onboard(matchingBlock, newBlock, mPools, numMatched, mode, directory);
1254+
mTransferManager->copyBlockGPUToGPU(matchingBlock, newBlock, mPools);
12501255
// TODO: (optional) Send out event
12511256
matchingBlock = newBlock;
12521257
if (blockItr != blockKeys.end())
@@ -1257,6 +1262,7 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector<BlockKey> const&
12571262
matchingBlock->setHash();
12581263
TLLM_LOG_DEBUG("%s::loadOrAllocateBlocks - Copied partially filled block %d", mLogPrefix.c_str(),
12591264
matchingBlockId);
1265+
++mCopiedReusedPartialBlocks;
12601266
}
12611267
else
12621268
{
@@ -1266,6 +1272,7 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector<BlockKey> const&
12661272
matchingBlock, perBlockRetentions[bi].retentionPriority, perBlockRetentions[bi].durationMs);
12671273
TLLM_LOG_DEBUG("%s::loadOrAllocateBlocks - Reused partially filled block %d", mLogPrefix.c_str(),
12681274
matchingBlockId);
1275+
++mDirectlyReusedPartialBlocks;
12691276
}
12701277
searchRoot = nullptr; // no matching needed for following blocks
12711278
}

cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,19 @@ tr::ITensor::SharedPtr KVCacheTransferManager::computeBlockPointer(
9797
return blockTensor;
9898
}
9999

100+
// Directly copy a block from gpu to gpu without using the buffer manager.
101+
void KVCacheTransferManager::copyGPUtoGPU(
102+
BlockPtr const& src, BlockPtr const& dst, std::vector<KVCacheBlockPool> const& pools)
103+
{
104+
for (size_t poolIdx = 0; poolIdx < pools.size(); ++poolIdx)
105+
{
106+
auto srcPtr = computeBlockPointer(src, pools, poolIdx);
107+
auto dstPtr = computeBlockPointer(dst, pools, poolIdx);
108+
mBufferManager.copy(*srcPtr, *dstPtr);
109+
}
110+
TLLM_LOG_DEBUG("GPU-to-GPU copy for from block %d to block %d", src->getBlockId(), dst->getBlockId());
111+
}
112+
100113
void KVCacheTransferManager::copyBlock(BlockPtr const& src, BlockPtr const& dst,
101114
std::vector<KVCacheBlockPool> const& pools, bool isOffload, int numTokensToCopy, executor::KvCacheTransferMode mode,
102115
std::string const& directory)
@@ -241,9 +254,7 @@ void KVCacheTransferManager::copyBlock(BlockPtr const& src, BlockPtr const& dst,
241254
// Failing to do so will lead to corrupted blocks eventually.
242255
//
243256

244-
void KVCacheTransferManager::onboard(BlockPtr const& offloadedBlock, BlockPtr const& block,
245-
std::vector<KVCacheBlockPool> const& pools, int numTokensToCopy, executor::KvCacheTransferMode mode,
246-
std::string const& directory)
257+
void KVCacheTransferManager::syncPendingOnboardTransfers(BlockPtr const& offloadedBlock, BlockPtr const& block)
247258
{
248259
// Wait for any pending writes before reading from offloadedBlock
249260
auto offloadedBlockPendingWriteItr = mPendingWrites.find(offloadedBlock->getMemoryPoolBlockIndex());
@@ -266,9 +277,10 @@ void KVCacheTransferManager::onboard(BlockPtr const& offloadedBlock, BlockPtr co
266277
mOnboardManager.getStream().wait(blockPendingWriteItr->second);
267278
mPendingWrites.erase(blockPendingWriteItr);
268279
}
280+
}
269281

270-
copyBlock(offloadedBlock, block, pools, false, numTokensToCopy, mode, directory);
271-
282+
void KVCacheTransferManager::recordPendingOnboardTransfers(BlockPtr const& offloadedBlock, BlockPtr const& block)
283+
{
272284
// Record new pending read from offloadedBlock
273285
mPendingReads[offloadedBlock->getMemoryPoolBlockIndex()] = tr::CudaEvent();
274286
mOnboardManager.getStream().record(mPendingReads[offloadedBlock->getMemoryPoolBlockIndex()]);
@@ -277,6 +289,23 @@ void KVCacheTransferManager::onboard(BlockPtr const& offloadedBlock, BlockPtr co
277289
mOnboardManager.getStream().record(mPendingWrites[block->getMemoryPoolBlockIndex()]);
278290
}
279291

292+
void KVCacheTransferManager::onboard(BlockPtr const& offloadedBlock, BlockPtr const& block,
293+
std::vector<KVCacheBlockPool> const& pools, int numTokensToCopy, executor::KvCacheTransferMode mode,
294+
std::string const& directory)
295+
{
296+
syncPendingOnboardTransfers(offloadedBlock, block);
297+
copyBlock(offloadedBlock, block, pools, false, numTokensToCopy, mode, directory);
298+
recordPendingOnboardTransfers(offloadedBlock, block);
299+
}
300+
301+
void KVCacheTransferManager::copyBlockGPUToGPU(
302+
BlockPtr const& offloadedBlock, BlockPtr const& block, std::vector<KVCacheBlockPool> const& pools)
303+
{
304+
syncPendingOnboardTransfers(offloadedBlock, block);
305+
copyGPUtoGPU(offloadedBlock, block, pools);
306+
recordPendingOnboardTransfers(offloadedBlock, block);
307+
}
308+
280309
void KVCacheTransferManager::offload(BlockPtr const& block, BlockPtr const& offloadBlock,
281310
std::vector<KVCacheBlockPool> const& pools, int numTokensToCopy, executor::KvCacheTransferMode mode,
282311
std::string const& directory)

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1133,10 +1133,6 @@ def test_auto_dtype_vswa_reuse_low_memory_available_no_partial_reuse(self):
11331133
task = MMLU(self.MODEL_NAME)
11341134
task.evaluate(llm)
11351135

1136-
@pytest.mark.skip(
1137-
reason=
1138-
"Currently failing due to accuracy drop, https://nvbugspro.nvidia.com/bug/5625990"
1139-
)
11401136
def test_auto_dtype_vswa_reuse_low_memory_available_partial_reuse(self):
11411137
# NOTE: Test with VSWA kv cache config.
11421138
kv_cache_config = KvCacheConfig(

0 commit comments

Comments
 (0)