Skip to content

Commit b33047c

Browse files
authored
Merge branch 'main' into lkomali/replace-gap-with-aiperf
2 parents 12deeaa + 7ab02ad commit b33047c

File tree

89 files changed

+7133
-591
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

89 files changed

+7133
-591
lines changed

cpp/include/tensorrt_llm/executor/executor.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1465,16 +1465,19 @@ class CacheTransceiverConfig
14651465
NIXL = 3
14661466
};
14671467
explicit CacheTransceiverConfig(std::optional<BackendType> backendType = std::nullopt,
1468-
std::optional<size_t> maxNumTokens = std::nullopt, std::optional<int> kvTransferTimeoutMs = std::nullopt);
1468+
std::optional<size_t> maxNumTokens = std::nullopt, std::optional<int> kvTransferTimeoutMs = std::nullopt,
1469+
std::optional<int> kvTransferSenderFutureTimeoutMs = std::nullopt);
14691470

14701471
bool operator==(CacheTransceiverConfig const& other) const;
14711472
void setBackendType(std::optional<BackendType> backendType);
14721473
void setMaxTokensInBuffer(std::optional<size_t> maxTokensInBuffer);
14731474
void setKvTransferTimeoutMs(std::optional<int> kvTransferTimeoutMs);
1475+
void setKvTransferSenderFutureTimeoutMs(std::optional<int> kvTransferSenderFutureTimeoutMs);
14741476

1475-
[[nodiscard]] std::optional<int> getKvTransferTimeoutMs() const;
14761477
[[nodiscard]] std::optional<size_t> getMaxTokensInBuffer() const;
14771478
[[nodiscard]] std::optional<BackendType> getBackendType() const;
1479+
[[nodiscard]] std::optional<int> getKvTransferTimeoutMs() const;
1480+
[[nodiscard]] std::optional<int> getKvTransferSenderFutureTimeoutMs() const;
14781481

14791482
private:
14801483
std::optional<BackendType> mBackendType;
@@ -1483,6 +1486,9 @@ class CacheTransceiverConfig
14831486
/// transfer may be degraded.
14841487
std::optional<size_t> mMaxTokensInBuffer;
14851488
std::optional<int> mKvTransferTimeoutMs;
1489+
// @brief Timeout in milliseconds to wait for the sender future to be ready when scheduled batch size is 0. This
1490+
// allows the request to be eventually cancelled by the user or because of kv_transfer_timeout_ms
1491+
std::optional<int> mKvTransferSenderFutureTimeoutMs;
14861492
};
14871493

14881494
/// @brief Configuration class for the model executor

cpp/tensorrt_llm/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ set(TRTLLM_LINK_LIBS
189189
fb_gemm_src
190190
gemm_swiglu_sm90_src
191191
cutlass_src
192+
cute_dsl_src
192193
layers_src
193194
runtime_src
194195
testing_src

cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,13 @@ void updateKVCacheTransferBW(std::shared_ptr<CacheTransceiverComm> const& mComm,
419419
void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLeastRequestNum)
420420
{
421421
bool blockAll = !atLeastRequestNum.has_value();
422+
std::optional<int> senderFutureTimeoutMs = std::nullopt;
423+
// If blockAll is true, we want to block and not use a timeout
424+
if (!blockAll && mCacheTransceiverConfig.has_value())
425+
{
426+
senderFutureTimeoutMs = mCacheTransceiverConfig->getKvTransferSenderFutureTimeoutMs();
427+
}
428+
422429
auto syncComm = mCacheState->getParallelConfig().mEnableAttentionDP ? mGroupTPInDPComm : mGroupTensorParaComm;
423430
std::vector<LlmRequest::RequestIdType> contextCompleteRequestIds;
424431
for (auto&& [request, future] : mSenderFutures)
@@ -476,16 +483,36 @@ void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLe
476483
{
477484
try
478485
{
479-
future.get();
480-
request->setState(LlmRequestState::kDISAGG_CONTEXT_COMPLETE);
486+
// Wait for up to a specified timeout
487+
auto status = future.wait_for(std::chrono::milliseconds(senderFutureTimeoutMs.value_or(0)));
488+
if (status == std::future_status::ready || !senderFutureTimeoutMs.has_value())
489+
{
490+
future.get();
491+
request->setState(LlmRequestState::kDISAGG_CONTEXT_COMPLETE);
492+
it = mSenderFutures.erase(it);
493+
}
494+
else if (status == std::future_status::timeout)
495+
{
496+
TLLM_LOG_WARNING("Timed out waiting for context transfer for request %ld after %d milliseconds.",
497+
request->mRequestId, senderFutureTimeoutMs.value());
498+
++it;
499+
}
500+
else
501+
{
502+
TLLM_LOG_ERROR(
503+
"Future returned unexpected status for request %ld. Marking as error", request->mRequestId);
504+
505+
request->setState(LlmRequestState::kDISAGG_TRANS_ERROR);
506+
it = mSenderFutures.erase(it);
507+
}
481508
}
482509
catch (std::exception const& e)
483510
{
484511
TLLM_LOG_ERROR(
485512
"Error occurred during context transfer for request %ld: %s", request->mRequestId, e.what());
486513
request->setState(LlmRequestState::kDISAGG_TRANS_ERROR);
514+
it = mSenderFutures.erase(it);
487515
}
488-
it = mSenderFutures.erase(it);
489516
}
490517
else
491518
{

cpp/tensorrt_llm/executor/cacheTransceiverConfig.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@
2121
namespace tensorrt_llm::executor
2222
{
2323

24-
CacheTransceiverConfig::CacheTransceiverConfig(
25-
std::optional<BackendType> backendType, std::optional<size_t> maxNumTokens, std::optional<int> kvTransferTimeoutMs)
24+
CacheTransceiverConfig::CacheTransceiverConfig(std::optional<BackendType> backendType,
25+
std::optional<size_t> maxNumTokens, std::optional<int> kvTransferTimeoutMs,
26+
std::optional<int> kvTransferSenderFutureTimeoutMs)
2627
: mBackendType(backendType)
2728
, mMaxTokensInBuffer(maxNumTokens)
2829
, mKvTransferTimeoutMs(kvTransferTimeoutMs)
30+
, mKvTransferSenderFutureTimeoutMs(kvTransferSenderFutureTimeoutMs)
2931
{
3032
}
3133

@@ -54,6 +56,15 @@ void CacheTransceiverConfig::setKvTransferTimeoutMs(std::optional<int> kvTransfe
5456
mKvTransferTimeoutMs = kvTransferTimeoutMs;
5557
}
5658

59+
void CacheTransceiverConfig::setKvTransferSenderFutureTimeoutMs(std::optional<int> kvTransferSenderFutureTimeoutMs)
60+
{
61+
if (kvTransferSenderFutureTimeoutMs.has_value() && kvTransferSenderFutureTimeoutMs.value() <= 0)
62+
{
63+
TLLM_THROW("kvTransferSenderFutureTimeoutMs must be positive");
64+
}
65+
mKvTransferSenderFutureTimeoutMs = kvTransferSenderFutureTimeoutMs;
66+
}
67+
5768
std::optional<CacheTransceiverConfig::BackendType> CacheTransceiverConfig::getBackendType() const
5869
{
5970
return mBackendType;
@@ -69,4 +80,8 @@ std::optional<int> CacheTransceiverConfig::getKvTransferTimeoutMs() const
6980
return mKvTransferTimeoutMs;
7081
}
7182

83+
std::optional<int> CacheTransceiverConfig::getKvTransferSenderFutureTimeoutMs() const
84+
{
85+
return mKvTransferSenderFutureTimeoutMs;
86+
}
7287
} // namespace tensorrt_llm::executor

cpp/tensorrt_llm/executor/serialization.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1290,20 +1290,26 @@ CacheTransceiverConfig Serialization::deserializeCacheTransceiverConfig(std::ist
12901290
{
12911291
auto backendType = su::deserialize<std::optional<CacheTransceiverConfig::BackendType>>(is);
12921292
auto maxTokensInBuffer = su::deserialize<std::optional<size_t>>(is);
1293-
return CacheTransceiverConfig{backendType, maxTokensInBuffer};
1293+
auto kvTransferTimeoutMs = su::deserialize<std::optional<int>>(is);
1294+
auto kvTransferSenderFutureTimeoutMs = su::deserialize<std::optional<int>>(is);
1295+
return CacheTransceiverConfig{backendType, maxTokensInBuffer, kvTransferTimeoutMs, kvTransferSenderFutureTimeoutMs};
12941296
}
12951297

12961298
void Serialization::serialize(CacheTransceiverConfig const& cacheTransceiverConfig, std::ostream& os)
12971299
{
12981300
su::serialize(cacheTransceiverConfig.getBackendType(), os);
12991301
su::serialize(cacheTransceiverConfig.getMaxTokensInBuffer(), os);
1302+
su::serialize(cacheTransceiverConfig.getKvTransferTimeoutMs(), os);
1303+
su::serialize(cacheTransceiverConfig.getKvTransferSenderFutureTimeoutMs(), os);
13001304
}
13011305

13021306
size_t Serialization::serializedSize(CacheTransceiverConfig const& cacheTransceiverConfig)
13031307
{
13041308
size_t totalSize = 0;
13051309
totalSize += su::serializedSize(cacheTransceiverConfig.getBackendType());
13061310
totalSize += su::serializedSize(cacheTransceiverConfig.getMaxTokensInBuffer());
1311+
totalSize += su::serializedSize(cacheTransceiverConfig.getKvTransferTimeoutMs());
1312+
totalSize += su::serializedSize(cacheTransceiverConfig.getKvTransferSenderFutureTimeoutMs());
13071313
return totalSize;
13081314
}
13091315

cpp/tensorrt_llm/kernels/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ file(GLOB_RECURSE SRC_CU *.cu)
2222
# selectiveScan trtllmGenKernels folder
2323
list(FILTER SRC_CPP EXCLUDE REGEX "cutlass_kernels/.*")
2424
list(FILTER SRC_CU EXCLUDE REGEX "cutlass_kernels/.*")
25+
list(FILTER SRC_CPP EXCLUDE REGEX "cuteDslKernels/.*")
26+
list(FILTER SRC_CU EXCLUDE REGEX "cuteDslKernels/.*")
2527
list(FILTER SRC_CPP EXCLUDE REGEX "flashMLA/.*")
2628
list(FILTER SRC_CU EXCLUDE REGEX "flashMLA/.*")
2729
list(FILTER SRC_CPP EXCLUDE REGEX "contextFusedMultiHeadAttention/.*")
@@ -75,6 +77,7 @@ target_include_directories(
7577
add_cuda_architectures(kernels_src 89)
7678

7779
add_subdirectory(cutlass_kernels)
80+
add_subdirectory(cuteDslKernels)
7881
add_subdirectory(flashMLA)
7982
add_subdirectory(contextFusedMultiHeadAttention)
8083
add_subdirectory(decoderMaskedMultiheadAttention)

cpp/tensorrt_llm/kernels/IndexerTopK.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@
2424

2525
namespace tensorrt_llm::kernels
2626
{
27-
void invokeIndexerTopKDecode(float const* logits, int const* seqLens, int* outIndices, float* auxLogits,
28-
int* auxIndices, int const splitWorkThreshold, int const numRows, int const numColumns, int const stride0,
29-
int const stride1, int const next_n, int const index_topk = 2048, cudaStream_t const stream = 0);
27+
void invokeIndexerTopKDecode(float const* logits, int const* seqLens, int* indices, float* outLogitsAux,
28+
int* outIndicesAux, int const splitWorkThreshold, int const numRows, int const numColumns, int const stride0,
29+
int const stride1, int const next_n, int const topK = 2048, cudaStream_t const stream = 0);
3030

31-
void invokeIndexerTopKPrefill(float const* logits, int const* rowStarts, int const* rowEnds, int* outIndices,
32-
int const numRows, int const numColumns, int const stride0, int const stride1, int const index_topk = 2048,
31+
void invokeIndexerTopKPrefill(float const* logits, int const* rowStarts, int const* rowEnds, int* indices,
32+
int const numRows, int const numColumns, int const stride0, int const stride1, int const topK = 2048,
3333
cudaStream_t const stream = 0);
3434

3535
} // namespace tensorrt_llm::kernels

cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,12 @@ namespace tensorrt_llm::kernels::mnnvl_throughput
5151
__VA_ARGS__; \
5252
break; \
5353
} \
54+
case 6: \
55+
{ \
56+
constexpr int TOP_K = 6; \
57+
__VA_ARGS__; \
58+
break; \
59+
} \
5460
case 4: \
5561
{ \
5662
constexpr int TOP_K = 4; \
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#
2+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
3+
# All rights reserved. SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
6+
# use this file except in compliance with the License. You may obtain a copy of
7+
# the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14+
# License for the specific language governing permissions and limitations under
15+
# the License.
16+
#
17+
18+
file(GLOB_RECURSE SRC_CPP *.cpp)
19+
file(GLOB_RECURSE SRC_CU *.cu)
20+
21+
add_library(cute_dsl_src OBJECT ${SRC_CPP} ${SRC_CU})
22+
set_property(TARGET cute_dsl_src PROPERTY POSITION_INDEPENDENT_CODE ON)
23+
set_property(TARGET cute_dsl_src PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)

0 commit comments

Comments
 (0)