Skip to content

Commit fea5bfb

Browse files
authored
[None][feat] add detailed KV cache transfer time breakdown (#8521)
Signed-off-by: zhengd-nv <[email protected]>
1 parent f444fe2 commit fea5bfb

File tree

12 files changed

+129
-104
lines changed

12 files changed

+129
-104
lines changed

cpp/include/tensorrt_llm/batch_manager/llmRequest.h

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1691,22 +1691,22 @@ class GenericLlmRequest
16911691
mDecodingIter = iter;
16921692
}
16931693

1694-
void setKvCacheTransferStart(TimePoint const& time)
1694+
void setKvCacheTransferStart(TimePoint time) const
16951695
{
16961696
mPerfMetrics.timingMetrics.kvCacheTransferStart = maybeToGlobalSteadyClock(time);
16971697
}
16981698

1699-
void setKvCacheTransferEnd(TimePoint const& time)
1699+
void setKvCacheTransferEnd(TimePoint time) const
17001700
{
17011701
mPerfMetrics.timingMetrics.kvCacheTransferEnd = maybeToGlobalSteadyClock(time);
17021702
}
17031703

1704-
TimePoint getKvCacheTransferStart()
1704+
TimePoint getKvCacheTransferStart() const
17051705
{
17061706
return mPerfMetrics.timingMetrics.kvCacheTransferStart;
17071707
}
17081708

1709-
TimePoint getKvCacheTransferEnd()
1709+
TimePoint getKvCacheTransferEnd() const
17101710
{
17111711
return mPerfMetrics.timingMetrics.kvCacheTransferEnd;
17121712
}
@@ -1865,13 +1865,11 @@ class GenericLlmRequest
18651865
return mUseDraftModel;
18661866
}
18671867

1868-
// If mGlobalSteadyClockOffset is set, return a global steady clock time point, otherwise return local steady clock
1868+
// If sGlobalSteadyClockOffset is set, return a global steady clock time point, otherwise return local steady clock
18691869
// time point
1870-
[[nodiscard]] TimePoint getSteadyClockNow() const
1870+
[[nodiscard]] static TimePoint getSteadyClockNow()
18711871
{
1872-
const TimePoint time_point = std::chrono::steady_clock::now();
1873-
1874-
return maybeToGlobalSteadyClock(time_point);
1872+
return maybeToGlobalSteadyClock(std::chrono::steady_clock::now());
18751873
}
18761874

18771875
RequestIdType mRequestId;
@@ -1894,7 +1892,7 @@ class GenericLlmRequest
18941892
SizeType32 mPtableCurrentPosition{0};
18951893

18961894
// The offset between local steady clock and global steady clock (at rank 0)
1897-
inline static std::optional<Duration> mGlobalSteadyClockOffset{std::nullopt};
1895+
inline static std::optional<Duration> sGlobalSteadyClockOffset{std::nullopt};
18981896

18991897
protected:
19001898
bool mIsStreaming;
@@ -2028,9 +2026,9 @@ class GenericLlmRequest
20282026

20292027
std::optional<TensorPtr> mSkipCrossAttnBlocks{std::nullopt};
20302028

2031-
// Performance metrics.
2029+
// Performance metrics. Should be updatable even from a const LlmRequest reference.
20322030
bool mReturnPerfMetrics{false};
2033-
executor::RequestPerfMetrics mPerfMetrics;
2031+
mutable executor::RequestPerfMetrics mPerfMetrics;
20342032

20352033
// Guided decoding params.
20362034
std::optional<executor::GuidedDecodingParams> mGuidedDecodingParams{std::nullopt};
@@ -2183,16 +2181,13 @@ class GenericLlmRequest
21832181
return tensor;
21842182
}
21852183

2186-
TimePoint maybeToGlobalSteadyClock(TimePoint const& time_point) const
2184+
static TimePoint maybeToGlobalSteadyClock(TimePoint const& time_point)
21872185
{
2188-
if (mGlobalSteadyClockOffset.has_value())
2189-
{
2190-
return time_point + *mGlobalSteadyClockOffset;
2191-
}
2192-
else
2186+
if (sGlobalSteadyClockOffset.has_value())
21932187
{
2194-
return time_point;
2188+
return time_point + *sGlobalSteadyClockOffset;
21952189
}
2190+
return time_point;
21962191
}
21972192
};
21982193

cpp/include/tensorrt_llm/executor/types.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,7 @@ struct RequestPerfMetrics
451451
/// @brief End time of the KV cache transfer for disaggregated serving
452452
TimePoint kvCacheTransferEnd;
453453
/// @brief KV Cache size transfer for disaggregated serving
454-
mutable size_t kvCacheSize = 0;
454+
size_t kvCacheSize = 0;
455455
};
456456

457457
struct KvCacheMetrics

cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ std::vector<size_t> CacheFormatter::pickRecvConnections(
227227
void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& session)
228228
{
229229
NVTX3_SCOPED_RANGE(CacheFormatter_format);
230+
session.setTime(TransferSession::kTimeFormatter);
230231
auto const& llmRequest = session.getLlmRequest();
231232
TLLM_LOG_DEBUG(
232233
mpi::MpiComm::world().getRank(), "Start sending KV cache for request ID: %ld.", llmRequest.mRequestId);
@@ -249,9 +250,6 @@ void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& sessio
249250
auto const numPools = blockManager.getNumPools();
250251
// TODO(oargov): are we sure the other side has the same number of pools? this might not hold for pp_size>1...
251252

252-
auto lastTokenTime = llmRequest.getPerfMetrics().timingMetrics.lastTokenTime;
253-
bool recordDelay = lastTokenTime != std::chrono::steady_clock::time_point();
254-
255253
bool layerWise = common::getEnvDisaggLayerwise() && numPools == 1;
256254
if (layerWise)
257255
{
@@ -420,6 +418,7 @@ void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& sessio
420418
inputKvCacheBlocksPerWindow, outputSplitCaches, destConfig, selfConfig, selfIdx, bufferManager);
421419

422420
bufferManager.getStream().synchronize();
421+
session.setTime(TransferSession::kTimePreprocess);
423422

424423
auto preAllocSendBuffer = mCacheTransBufferManager->getSendBuffer(cacheBufferId);
425424
if (preAllocSendBuffer != nullptr)
@@ -434,7 +433,7 @@ void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& sessio
434433
TLLM_CUDA_CHECK(cudaSetDevice(deviceId));
435434
TLLM_CHECK(connections.size() > (processIdx / peerDuplicateHeadFactor));
436435
TLLM_CHECK(outputSplitCaches.size() > (processIdx / peerDuplicateHeadFactor));
437-
auto startTime = llmRequest.getSteadyClockNow();
436+
auto startTime = LlmRequest::getSteadyClockNow();
438437

439438
size_t ppDomainSize = targetInfo.mDomainPPSize;
440439
size_t bufferTpRank = (processIdx / ppDomainSize) / peerDuplicateHeadFactor;
@@ -481,15 +480,8 @@ void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& sessio
481480
}
482481
}
483482

484-
auto endTime = llmRequest.getSteadyClockNow();
485-
double delay = 0.0;
486-
if (recordDelay)
487-
{
488-
delay = std::chrono::duration<double, std::milli>(startTime - lastTokenTime).count();
489-
}
490-
double cacheTransferTime
491-
= std::max(0.0, std::chrono::duration<double, std::milli>(endTime - startTime).count());
492-
session.appendMeasure(delay, cacheTransferTime, size);
483+
auto endTime = LlmRequest::getSteadyClockNow();
484+
session.appendMeasure(startTime, endTime, size);
493485
};
494486

495487
if (connections.size() > 1)
@@ -534,8 +526,10 @@ void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& sessio
534526
{
535527
sendBufferFun(deviceId, 0);
536528
}
529+
session.setTime(TransferSession::kTimeTransmissions);
537530

538531
mCacheTransBufferManager->freeBufferIndexForSend(cacheBufferId);
532+
session.setTime(TransferSession::kTimePostprocess);
539533
}
540534
TLLM_LOG_DEBUG(
541535
mpi::MpiComm::world().getRank(), "End the sending of KV cache for the request ID:%ld ", llmRequest.mRequestId);
@@ -544,6 +538,7 @@ void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& sessio
544538
void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& session)
545539
{
546540
NVTX3_SCOPED_RANGE(CacheFormatter_unformat);
541+
session.setTime(TransferSession::kTimeFormatter);
547542
auto const& llmRequest = session.getLlmRequest();
548543
auto const ctxReqId = llmRequest.getContextPhaseParams().value().getReqId();
549544
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
@@ -555,9 +550,6 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess
555550
auto& bufferManager = session.getBufferManager();
556551
auto blockRange = getBlockRangeForReceiving(mCacheManager, llmRequest, destConfig.getEnableBlockReuse());
557552

558-
auto arrivalTime = llmRequest.getPerfMetrics().timingMetrics.arrivalTime;
559-
bool recordDelay = arrivalTime != std::chrono::steady_clock::time_point();
560-
561553
auto pickUpConnections = pickRecvConnections(connections.size(), selfConfig, selfIdx, destConfig);
562554

563555
TLLM_LOG_DEBUG("pickUpConnections size: %d connections size: %d", pickUpConnections.size(), connections.size());
@@ -779,6 +771,7 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess
779771
// sync to alloc buffer
780772
bufferManager.getStream().synchronize();
781773
}
774+
session.setTime(TransferSession::kTimePreprocess);
782775

783776
runtime::ITensor::SharedPtr preAllocRecvBuffer = nullptr;
784777
if (cacheBufferId.has_value())
@@ -794,7 +787,7 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess
794787
TLLM_CUDA_CHECK(cudaSetDevice(deviceId));
795788
TLLM_CHECK(pickUpConnections.size() > processIdx);
796789
TLLM_CHECK(recvSplitCaches.size() > processIdx);
797-
auto startTime = llmRequest.getSteadyClockNow();
790+
auto startTime = LlmRequest::getSteadyClockNow();
798791
size_t size = 0;
799792

800793
if (processIdx >= remainNoCoverTargetNum)
@@ -835,15 +828,8 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess
835828
}
836829
}
837830

838-
auto endTime = llmRequest.getSteadyClockNow();
839-
double delay = 0.0;
840-
if (recordDelay)
841-
{
842-
delay = std::chrono::duration<double, std::milli>(startTime - arrivalTime).count();
843-
}
844-
double cacheTransferTime
845-
= std::max(0.0, std::chrono::duration<double, std::milli>(endTime - startTime).count());
846-
session.appendMeasure(delay, cacheTransferTime, size);
831+
auto endTime = LlmRequest::getSteadyClockNow();
832+
session.appendMeasure(startTime, endTime, size);
847833
};
848834
if (pickUpConnections.size() > 1)
849835
{
@@ -891,6 +877,7 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess
891877
{
892878
recvBufferFun(deviceId, 0);
893879
}
880+
session.setTime(TransferSession::kTimeTransmissions);
894881

895882
{
896883
NVTX3_SCOPED_RANGE(formatInputConcatenate);
@@ -904,6 +891,7 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess
904891
mCacheTransBufferManager->freeBufferIndexForRecv(cacheBufferId);
905892
}
906893
}
894+
session.setTime(TransferSession::kTimePostprocess);
907895
}
908896
}
909897

cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,7 @@ void CacheTransceiver::checkGenTransferStatus(std::optional<int> const& atLeastR
603603
it->first->setState(LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE);
604604

605605
// Gather the kv cache transfer time from all workers and update to leader rank
606-
if (!common::getEnvKVCacheTransferOutputPath().empty())
606+
if (!common::getEnvKVCacheTimeOutputPath().empty())
607607
{
608608
auto syncComm = mCacheState->getParallelConfig().mEnableAttentionDP ? mGroupDataComm : mGroupComm;
609609
updateKVCacheTransferBW(syncComm, it->first);

cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "tensorrt_llm/executor/cache_transmission/agent_utils/connection.h"
2929
#include "tensorrt_llm/runtime/common.h"
3030
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
31+
#include <chrono>
3132
#include <future>
3233
#include <map>
3334
#include <memory>
@@ -105,39 +106,65 @@ void TransferSession::setLlmRequest(LlmRequest const& llmRequest)
105106
mRequest = &llmRequest;
106107
}
107108

108-
void TransferSession::appendMeasure(double delay, double duration, size_t size)
109+
void TransferSession::setTime(TimeNames name)
109110
{
110-
if (!mRecordMeasure)
111+
if (mTimes)
111112
{
112-
return;
113+
mTimes->times.at(name) = LlmRequest::getSteadyClockNow();
114+
}
115+
}
116+
117+
void TransferSession::appendMeasure(LlmRequest::TimePoint start, LlmRequest::TimePoint end, size_t size)
118+
{
119+
if (mTimes)
120+
{
121+
mTimes->measures.emplace_back(Measure{start, end, size});
113122
}
114-
auto bandwidth = size * 8 / (duration / 1000) / 1e9; // byte, ms => Gbps
115-
mMeasures.emplace_back(Measure{delay, duration, bandwidth});
116123
}
117124

118125
void TransferSession::exportMeasure(std::ofstream& outFile, bool isContext) const
119126
{
120-
if (mMeasures.empty())
127+
if (!mTimes || mTimes->measures.empty())
121128
{
122129
return;
123130
}
124131
// write header if not exist
125132
if (outFile.tellp() == 0)
126133
{
127-
outFile << "RequestID";
128-
for (size_t i = 0; i < mMeasures.size(); i++)
134+
outFile << "RequestID,RequestInfo,Preparation,Preprocess,Transmissions,Postprocess";
135+
for (size_t i = 0; i < mTimes->measures.size(); i++)
129136
{
130-
outFile << ",Delay(ms),Duration(ms),Bandwidth(Gbps)";
137+
outFile << ",Delay,Duration,Bandwidth(Gbps)";
131138
}
132139
outFile << '\n';
133140
}
134-
// write measures
141+
auto transferStart = mRequest->getPerfMetrics().timingMetrics.kvCacheTransferStart;
142+
using Milliseconds = std::chrono::duration<double, std::milli>;
143+
144+
// write measures, time is in milliseconds
135145
TLLM_CHECK(isContext || mRequest->getContextPhaseParams().has_value());
136146
auto reqId = isContext ? mRequest->mRequestId : mRequest->getContextPhaseParams().value().getReqId();
137147
outFile << reqId;
138-
for (auto const& measure : mMeasures)
148+
auto previousTime = transferStart;
149+
for (auto time : mTimes->times)
150+
{
151+
if (time == LlmRequest::TimePoint())
152+
{
153+
// timepoint is unset, skip
154+
outFile << ",0.0";
155+
continue;
156+
}
157+
double delay = Milliseconds(time - previousTime).count();
158+
previousTime = time;
159+
outFile << "," << delay;
160+
}
161+
previousTime = mTimes->times[kTimePreprocess];
162+
for (auto const& measure : mTimes->measures)
139163
{
140-
outFile << "," << measure.delay << "," << measure.duration << "," << measure.bandwidth;
164+
double delay = Milliseconds(measure.start - previousTime).count();
165+
double duration = Milliseconds(measure.end - measure.start).count();
166+
double bandwidth = static_cast<double>(measure.size) * 8.0 / duration / 1e6; // byte, ms => Gbps
167+
outFile << "," << delay << "," << duration << "," << bandwidth;
141168
}
142169
outFile << '\n' << std::flush;
143170
}
@@ -158,7 +185,7 @@ int32_t tagFromRequestId(LlmRequest::RequestIdType requestId)
158185
std::filesystem::path getTransferOutputPath(char const* tag)
159186
{
160187
namespace fs = std::filesystem;
161-
auto outputPath = common::getEnvKVCacheTransferOutputPath();
188+
auto outputPath = common::getEnvKVCacheTimeOutputPath();
162189
if (!outputPath.empty())
163190
{
164191
auto rank = mpi::MpiComm::world().getRank();
@@ -273,6 +300,7 @@ class CacheSender::Impl
273300
{
274301
std::promise<void> promise;
275302
auto future = promise.get_future();
303+
llmRequest.setKvCacheTransferStart(LlmRequest::getSteadyClockNow());
276304
{
277305
{
278306
std::scoped_lock lkResp(mSenderMutex);
@@ -309,7 +337,7 @@ class CacheSender::Impl
309337
std::unique_lock<std::mutex> lk(mMtxForMap);
310338
auto it = mRequestToSession.find(requestId);
311339
TLLM_CHECK(it != mRequestToSession.end());
312-
if (!common::getEnvKVCacheTransferOutputPath().empty())
340+
if (!common::getEnvKVCacheTimeOutputPath().empty())
313341
{
314342
if (!mMeasuresFile.is_open())
315343
{
@@ -363,7 +391,8 @@ class CacheSender::Impl
363391
auto session = TransferSession(std::vector<Connection const*>(peerRelativeRanks.size(), nullptr),
364392
DataContext{tagFromRequestId(requestId)}, mSelfState, info.getTransState(), mBufferManager,
365393
info.getIndexFromEnd(), info.getLastBlockKey(), nullptr,
366-
!common::getEnvKVCacheTransferOutputPath().empty());
394+
!common::getEnvKVCacheTimeOutputPath().empty());
395+
session.setTime(TransferSession::kTimeRequestInfo);
367396
it = mRequestToSession.emplace(requestId, std::move(session)).first;
368397
}
369398
it->second.setConnection(peerIdx, connection);
@@ -382,6 +411,7 @@ class CacheSender::Impl
382411
}
383412
session->setLlmRequest(llmRequest);
384413
mFormatter->format(*session);
414+
llmRequest.setKvCacheTransferEnd(LlmRequest::getSteadyClockNow());
385415
}
386416

387417
bool cancelRequest(LlmRequest const& llmRequest)
@@ -751,7 +781,7 @@ class CacheReceiver::Impl
751781
void receiveSync(TransferSession& session)
752782
{
753783
mFormatter->unformat(session);
754-
if (!common::getEnvKVCacheTransferOutputPath().empty())
784+
if (!common::getEnvKVCacheTimeOutputPath().empty())
755785
{
756786
std::unique_lock<std::mutex> lock(mMeasuresFileMutex);
757787
if (!mMeasuresFile.is_open())
@@ -846,7 +876,7 @@ class CacheReceiver::Impl
846876
auto const& resource = getReceiveCacheResource(llmRequest);
847877
return TransferSession(std::move(counterPartConnections), DataContext{tagFromRequestId(requestId)}, mSelfState,
848878
contextState, resource->mBufferManager, requestInfo.getIndexFromEnd(), requestInfo.getLastBlockKey(),
849-
&llmRequest, !common::getEnvKVCacheTransferOutputPath().empty());
879+
&llmRequest, !common::getEnvKVCacheTimeOutputPath().empty());
850880
}
851881

852882
std::unique_ptr<ReceiveCacheResource> const& getReceiveCacheResource(LlmRequest const& llmRequest)
@@ -957,6 +987,7 @@ class CacheReceiver::Impl
957987
llmRequest.setKvCacheTransferStart(std::chrono::steady_clock::now());
958988
TLLM_CUDA_CHECK(cudaSetDevice(mDeviceId));
959989
auto session = sendRequestInfo(llmRequest);
990+
session.setTime(TransferSession::kTimeRequestInfo);
960991
bool isReady = receiveReadySignal(session);
961992
if (!isReady)
962993
{

0 commit comments

Comments
 (0)