@@ -419,6 +419,13 @@ void updateKVCacheTransferBW(std::shared_ptr<CacheTransceiverComm> const& mComm,
419419void CacheTransceiver::checkContextTransferStatus (std::optional<int > const & atLeastRequestNum)
420420{
421421 bool blockAll = !atLeastRequestNum.has_value ();
422+ std::optional<int > senderFutureTimeoutMs = std::nullopt ;
423+ // If blockAll is true, we want to block and not use a timeout
424+ if (!blockAll && mCacheTransceiverConfig .has_value ())
425+ {
426+ senderFutureTimeoutMs = mCacheTransceiverConfig ->getKvTransferSenderFutureTimeoutMs ();
427+ }
428+
422429 auto syncComm = mCacheState ->getParallelConfig ().mEnableAttentionDP ? mGroupTPInDPComm : mGroupTensorParaComm ;
423430 std::vector<LlmRequest::RequestIdType> contextCompleteRequestIds;
424431 for (auto && [request, future] : mSenderFutures )
@@ -476,16 +483,36 @@ void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLe
476483 {
477484 try
478485 {
479- future.get ();
480- request->setState (LlmRequestState::kDISAGG_CONTEXT_COMPLETE );
486+ // Wait for up to a specified timeout
487+ auto status = future.wait_for (std::chrono::milliseconds (senderFutureTimeoutMs.value_or (0 )));
488+ if (status == std::future_status::ready || !senderFutureTimeoutMs.has_value ())
489+ {
490+ future.get ();
491+ request->setState (LlmRequestState::kDISAGG_CONTEXT_COMPLETE );
492+ it = mSenderFutures .erase (it);
493+ }
494+ else if (status == std::future_status::timeout)
495+ {
496+ TLLM_LOG_WARNING (" Timed out waiting for context transfer for request %ld after %d milliseconds." ,
497+ request->mRequestId , senderFutureTimeoutMs.value ());
498+ ++it;
499+ }
500+ else
501+ {
502+ TLLM_LOG_ERROR (
503+ " Future returned unexpected status for request %ld. Marking as error" , request->mRequestId );
504+
505+ request->setState (LlmRequestState::kDISAGG_TRANS_ERROR );
506+ it = mSenderFutures .erase (it);
507+ }
481508 }
482509 catch (std::exception const & e)
483510 {
484511 TLLM_LOG_ERROR (
485512 " Error occurred during context transfer for request %ld: %s" , request->mRequestId , e.what ());
486513 request->setState (LlmRequestState::kDISAGG_TRANS_ERROR );
514+ it = mSenderFutures .erase (it);
487515 }
488- it = mSenderFutures .erase (it);
489516 }
490517 else
491518 {
0 commit comments