-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[TRTLLM-8536][feat] Add the sparse attention framework and one use case--RocketKV support #8086
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
59b51d2
to
9d4b051
Compare
/bot run |
PR_Github #20437 [ run ] triggered by Bot |
PR_Github #20437 [ run ] completed with state |
📝 WalkthroughWalkthroughAdds end-to-end sparse (RocketKV/block) attention: new sparse kernels and params, KV-cache sparse update paths, workspace sizing (max_blocks_per_sequence), C++/CUDA kernel additions, Python backend integration, model/config wiring, bindings, examples, and tests; public APIs and some function signatures were expanded to carry sparse parameters. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor User
participant PyLLM as LLM (Python)
participant AttnMod as Attention module (Py)
participant Backend as AttentionBackend (Py)
participant KVMgr as KVCacheManager (Py)
participant Wrapper as TrtllmAttentionWrapper (Py)
participant THOP as THOP attention (C++)
participant AttentionOp as AttentionOp (C++)
participant Dispatcher as XQA/FMHA Dispatcher (C++)
participant Kernels as CUDA Kernels
User->>PyLLM: generate(prompts, sparse_attention_config)
PyLLM->>AttnMod: create attention (sparse config)
AttnMod->>Backend: get_attention_backend(..., sparse_attn_config)
Backend->>KVMgr: prepare KV resources (sparse)
Backend->>Wrapper: plan(..., sparse_kv_indices/offsets, sparse_attn_indices/offsets)
Wrapper->>THOP: call attention(..., sparse params, attention_config_params)
THOP->>AttentionOp: enqueue/generate(..., sparse params, max_blocks_per_sequence)
AttentionOp->>Dispatcher: build XQAParams (use_sparse_attention=true)
Dispatcher->>Kernels: invokeGatherKvPageOffsets(...) (if block-sparse)
Dispatcher->>Kernels: FMHA / XQA kernel launch
Kernels-->>Dispatcher: outputs
Dispatcher->>Kernels: invokeUpdateSparseKvCacheAfterFmha (post-FMHA sparse KV write)
Dispatcher-->>AttentionOp: return
AttentionOp-->>THOP: output
THOP-->>Wrapper: return
Wrapper-->>Backend: return tokens
sequenceDiagram
autonumber
participant THOP as THOP (C++)
participant AO as AttentionOp (C++)
participant Kern as Kernels (CUDA)
rect rgba(220,240,255,0.25)
note right of THOP: Context / Prefill
THOP->>AO: preprocess (sparse_kv_indices, sparse_kv_offsets, is_last_chunk)
AO->>Kern: FMHA
AO->>Kern: invokeUpdateSparseKvCacheAfterFmha (sparse path)
end
rect rgba(220,255,220,0.25)
note right of THOP: Generation
THOP->>AO: preprocess (sparse_attn_indices, sparse_attn_offsets, max_blocks_per_sequence)
AO->>Kern: gatherKvPageOffsets (block offsets -> seq lengths)
AO->>Kern: FMHA/XQA (block-sparse)
end
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 40
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (11)
tensorrt_llm/_utils.py (1)
1-1
: Update copyright year to 2025.Per coding guidelines, the copyright header should include the current year. The file header shows 2022-2024, but this PR is from 2025.
Apply this diff:
-# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/kernelParams.h (1)
2-2
: Update copyright year to 2025.The copyright header shows 2020-2024, but this PR is from 2025. Per coding guidelines, update to include the current year.
Apply this diff:
- * Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved.tensorrt_llm/_torch/attention_backend/__init__.py (1)
1-16
: Add NVIDIA Apache-2.0 copyright header.According to the coding guidelines for Python files, prepend the NVIDIA Apache-2.0 copyright header with the current year (2025) to the top of this source file.
Example format:
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # ...As per coding guidelines.
cpp/tensorrt_llm/kernels/xqaDispatcher.cpp (1)
271-296
: Workspace budget must include the new sparse RocketKV buffersThe new
sparse_kv_block_offsets
/sparse_seq_lengths
carvekv_block_offsets_size + seq_lengths_size
bytes out ofparams.workspaces
, but neitherXqaDispatcher::getWorkspaceSize()
nor the upstream workspace sizing has been expanded to reserve them. In common generation settings (e.g., batch 4, beam 1,num_kv_heads
16,max_blocks_per_sequence
512) this extra demand exceeds 250 KB; with the current budget derived only frommax_num_tokens
,buildXQALaunchParams
will write past the provided buffer and corrupt memory. Please bump the workspace calculation (here and in any caller that pre-allocatesparams.workspaces
) by those sparse buffers or an equivalent worst-case bound tied tomax_blocks_per_sequence
.tensorrt_llm/_torch/pyexecutor/_util.py (2)
1-1
: Add NVIDIA Apache-2.0 header (2025).Please prepend the standard header to comply with licensing guidelines.
499-518
: Avoid passing unsupported keyword to KVCacheManager constructors.sparse_attn_config is always passed, even when the selected class is the vanilla KVCacheManager, which likely doesn’t accept it → TypeError at runtime. Gate the kwarg.
- kv_cache_manager = self._kv_cache_manager_cls( + extra_kwargs = {} + if sparse_attn_config is not None: + extra_kwargs["sparse_attn_config"] = sparse_attn_config + kv_cache_manager = self._kv_cache_manager_cls( self._kv_cache_config, tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF, @@ - kv_connector_manager=self._kv_connector_manager - if not estimating_kv_cache else None, - sparse_attn_config=sparse_attn_config, + kv_connector_manager=self._kv_connector_manager if not estimating_kv_cache else None, + **extra_kwargs, )tensorrt_llm/_torch/attention_backend/utils.py (1)
1-1
: Add NVIDIA Apache-2.0 header to utils.py
Prepend the standard NVIDIA Apache-2.0 copyright header with the current year (2025) at the top of tensorrt_llm/_torch/attention_backend/utils.py.cpp/tensorrt_llm/kernels/sparseAttentionKernels.h (1)
58-60
: Close include guard.} // namespace kernels } // namespace tensorrt_llm + +#endif // TRTLLM_SPARSEATTENTIONKERNELS_HAs per coding guidelines.
tensorrt_llm/_torch/attention_backend/sparse/kernel.py (1)
210-309
: Generation-path shape bug: documented k is (total_seq_len, …) but gen kernel treats k as (batch_size, …). Fix API or indexing.Current code computes
k_base = k_ptr + batch_idx*hidden_size + head_idx*dim_size
, i.e., assumesk.shape == (batch_size, num_kv_heads, head_dim)
. Either (A) change gen-path to address the last token in the flattenedtotal_seq_len
layout, or (B) document and enforcek
to be(batch_size, num_kv_heads, head_dim)
in generation. Option A keeps a single input contract.Option A (index last token per batch in flattened k):
- grid = (batch_size, num_kv_heads) - _update_kt_cache_gen_kernel[grid](k, + # Expect k as (total_seq_len, num_kv_heads, head_dim) + grid = (batch_size, num_kv_heads) + # Compute per-batch base offsets to the last token + cum_seq_lens = torch.cumsum(torch.cat([torch.zeros(1, device=k.device, dtype=torch.long), + seq_lens.to(torch.long)]), dim=0) + last_token_offsets = (cum_seq_lens[:-1] + (seq_lens.to(torch.long) - 1)).contiguous() + _update_kt_cache_gen_kernel[grid](k, kt_cache_tensor, kt_cache_block_offsets, - seq_lens, + last_token_offsets, # pass absolute last-token indices num_kv_heads, head_dim, kt_page_size, tokens_per_block, max_kt_blocks_per_seq, BLOCK_SIZE=1024)And update
_update_kt_cache_gen_kernel
to consume absolute token indices:- past_key_value_length = tl.load(seq_lens_ptr + batch_idx) - 1 - kt_token_idx = past_key_value_length // kt_page_size - kt_token_idx_in_page = past_key_value_length % kt_page_size + last_tok = tl.load(seq_lens_ptr + batch_idx) # absolute index in flattened K + kt_token_idx = last_tok // kt_page_size + kt_token_idx_in_page = last_tok % kt_page_size - k_base = k_ptr + batch_idx * hidden_size + head_idx * dim_size + k_base = k_ptr + last_tok * hidden_size + head_idx * dim_sizeIf you prefer Option B, update the function docstring and assert
k.shape[0] == batch_size
, then keep current indexing. Please confirm which contract you want. Based on learnings.tensorrt_llm/_torch/pyexecutor/model_engine.py (1)
140-153
: ImportSparseAttentionConfig
so lint stops failingRuff is currently raising F821 on the new type hint because
SparseAttentionConfig
is never imported. The string annotation is OK at runtime, but the lint failure will break the build. Pull the symbol in underTYPE_CHECKING
to satisfy the checker without affecting execution:-from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING @@ -from ..attention_backend.interface import (AttentionMetadata, - AttentionRuntimeFeatures) +from ..attention_backend.interface import (AttentionMetadata, + AttentionRuntimeFeatures) + +if TYPE_CHECKING: + from ..attention_backend.config import SparseAttentionConfigAny equivalent import location is fine as long as
SparseAttentionConfig
resolves during linting.tensorrt_llm/_torch/attention_backend/vanilla.py (1)
1-18
: Add NVIDIA Apache-2.0 header.All source files must carry the NVIDIA Apache-2.0 header.
+# +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# import math from typing import OptionalAs per coding guidelines.
🧹 Nitpick comments (36)
cpp/tensorrt_llm/nanobind/common/customCasters.h (1)
29-29
: Remove unused ArrayRef include
The header<c10/util/ArrayRef.h>
isn’t referenced in this file or any Nanobind bindings—delete it.tensorrt_llm/_utils.py (1)
216-224
: Consider enhancing the docstring.The function is correct, but the docstring could be more descriptive following Google-style format with parameter and return value documentation.
Apply this diff:
def next_power_of_two(x): """ - get next power of two + Get the next power of two greater than or equal to x. + + Args: + x: Input integer value. + + Returns: + The next power of two. Returns 1 for x <= 0. + If x is already a power of two, returns x unchanged. """examples/llm-api/llm_sparse_attention.py (3)
17-78
: Add docstring and consider removing hardcoded default paths.The function needs a docstring per coding guidelines. Additionally, the hardcoded default paths (lines 23, 28) may not be appropriate for all users.
Apply this diff:
def parse_arguments(): + """ + Parse command-line arguments for sparse attention evaluation. + + Returns: + Parsed arguments namespace. + """ parser = argparse.ArgumentParser() parser.add_argument( '--model_path', type=str, - default= - "/home/scratch.trt_llm_data/llm-models/llama-3.1-model/Llama-3.1-8B-Instruct" + required=True, + help="Path to the model directory" ) parser.add_argument( '--input_file', type=str, - default="tests/unittest/_torch/multi_gpu/test_star_attention_input.jsonl" + required=True, + help="Path to the input JSONL file" )
81-129
: Add docstring for run_RocketKV function.Per coding guidelines, all functions should have Google-style docstrings.
Apply this diff:
def run_RocketKV(args): + """ + Run RocketKV sparse attention evaluation. + + Args: + args: Parsed command-line arguments containing model path, input file, + and configuration parameters. + """ data = read_input(args.input_file)
132-137
: Add docstring and improve exception message.Per coding guidelines and static analysis, add a docstring and consider using a custom exception class for better error handling.
Apply this diff:
def main(): + """ + Main entry point for the sparse attention example script. + """ args = parse_arguments() if args.algo == 'ROCKETKV': run_RocketKV(args) else: - raise ValueError(f"Invalid algorithm: {args.algo}") + raise ValueError(f"Unsupported algorithm: {args.algo}")tensorrt_llm/_torch/pyexecutor/resource_manager.py (4)
166-167
: Avoid silently swallowing unknown kwargs.init accepts **kwargs but ignores them. Either document supported keys and validate, or drop the parameter to surface misconfig early.
Apply minimal safety now:
- **kwargs, + **_unused_kwargs,Optionally log unexpected keys in debug builds. Would you like me to wire a strict validator?
603-619
: Parity with RocketKV KT cache factor.get_cache_bytes_per_token omits KT cache overhead used by RocketKV (2 * kt_tokens_per_block / tokens_per_block). If this Python path ever sizes RocketKV, it will undercount memory.
Option: thread a kt_tokens_per_block and tokens_per_block (when sparse enabled) and fold the factor as in sparse. If this method is strictly for dense KV, please add a comment stating KT is excluded to avoid misuse.
620-659
: Clarify units and guard divisions in calculate_max_num_blocks.
- max_tokens is float; later ceil(max_tokens / tokens_per_block) is fine, but consider explicit floor/ceil commentary for readability.
- Guard cache_size_bytes_per_token > 0 (defensive).
- For multi-device: calling mpi_comm/MPI only when ENABLE_MULTI_DEVICE is compiled; if mapping.world_size > 1 without that build flag will raise. Add a feature guard or assert ENABLE_MULTI_DEVICE when world_size > 1.
I can add a small helper to centralize memory-to-blocks math with asserts if you want.
563-602
: Rename unused kwargs and align per-token head count with runtime allocation
- Rename
**kwargs
to**_kwargs
(it’s never used).- If
num_key_value_heads
is an Iterable, usesum(num_key_value_heads)
instead of averaging to mirrorget_cache_bytes_per_token
.Optionally, delegate to
get_cache_bytes_per_token
for full consistency.cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h (4)
1581-1583
: Debug logs in hot path: gate or lower frequency.TLLM_LOG_DEBUG in per-dispatch paths can be noisy. Consider gating by an env flag already read via getEnvEnablePDL() or use trace-level.
Also applies to: 1606-1610
1714-1796
: Sparse KV post-FMHA kernel: validate index layout and grid utilization.
- sparse_kv_indices indexing uses kv_head_idx * num_sparse_kv_tokens + global_sparse_idx; this assumes a [kv_heads, sum_B_tokens] flattening. Please confirm producer matches this layout. A mismatch will corrupt KV.
- grid.x is 1 and the kernel loops over tokens. This is correct but can underutilize SMs for long sparse lists. Consider setting grid.x = ceil_div(num_sparse_tokens, tokens_per_block) and letting the loop stride by gridDim.x to improve occupancy.
If shapes are stable, I can propose a host-side grid.x computation based on params.sparse_kv_offsets for better scaling.
1734-1765
: Shared memory bounds and bank use: OK, add static checks for large Dh.smem size = 2 * block.y * VECS_PER_HEAD * 16B. With Dh=256 and TCache=fp16, this is 32KB; safe. Add a compile-time/assert check to ensure size <= device limit (e.g., 96KB/SM) to prevent launch failures on configs with larger Dh.
1800-1811
: kernelSparseDispatchHeadSize: dynamic grid.x and smem calc.
- grid.x fixed at 1; consider dynamic token tiling as above.
- smem size computation is correct; add TLLM_CHECK for smem size <= device attr sharedMemPerBlockOptin when PDL is enabled to fail fast.
cpp/tensorrt_llm/nanobind/thop/bindings.cpp (1)
44-55
: Binding surface changed: document and validate new params.
- New required/optional args (attention_config_params, rotary_embedding_int_params, sparse_attention_params). Please update Python docs/examples and add runtime validation for incompatible combos (e.g., sparse params with cross-attn if unsupported).
I can draft a minimal docstring and a unit smoke test ensuring nanobind signature matches torch_ext::attention signature.
cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp (1)
575-576
: Consider usingnbDims - 1
instead of hard-coded dimension index.Line 576 hard-codes dimension index
[3]
to extractmax_blocks_per_sequence
, but line 867 in the same file useskvCacheBlockOffsetsShape.d[kvCacheBlockOffsetsShape.nbDims - 1]
for the same purpose. The latter approach is more robust and consistent.Apply this diff to improve consistency and robustness:
- int const max_blocks_per_sequence - = (useKVCache() && mPagedKVCache) ? inputs[getIdx(IdxEntry::KV_CACHE_BLOCK_OFFSETS)].dims.d[3] : 0; + auto const& kvCacheBlockOffsetsShape = inputs[getIdx(IdxEntry::KV_CACHE_BLOCK_OFFSETS)].dims; + int const max_blocks_per_sequence = (useKVCache() && mPagedKVCache) + ? kvCacheBlockOffsetsShape.d[kvCacheBlockOffsetsShape.nbDims - 1] + : 0;cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunnerParams.h (1)
2-2
: Update copyright year to include 2025.Since this file is being modified in 2025, update the copyright header to reflect the current year:
-* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. +* Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved.cpp/tests/unit_tests/kernels/sparseAttentionKernelsTest.cpp (1)
15-15
: Use CamelCase for test class name.The test class name
sparseAttentionKernelsTest
should use CamelCase starting with an uppercase letter per C++ coding guidelines. Consider renaming toSparseAttentionKernelsTest
.-class sparseAttentionKernelsTest : public ::testing::Test +class SparseAttentionKernelsTest : public ::testing::TestAnd update the TEST_F macro accordingly:
-TEST_F(sparseAttentionKernelsTest, GatherKvPageOffsetsKernelTest) +TEST_F(SparseAttentionKernelsTest, GatherKvPageOffsetsKernelTest)As per coding guidelines.
tensorrt_llm/_torch/pyexecutor/_util.py (2)
9-24
: Remove duplicate ModelConfig import to avoid ambiguity.Two ModelConfig imports (lines 9 and 23) are redundant and can confuse type checkers. Keep one.
-from tensorrt_llm._torch.model_config import ModelConfig @@ -from ..model_config import ModelConfig
206-212
: Vanilla backend disables capacity estimation; confirm expected behavior with sparse VANILLA.If VANILLA + sparse is used, estimation is disabled unconditionally. If that’s intentional, add a TODO linking to planned support; otherwise, guard on sparse config as well.
cpp/tensorrt_llm/common/attentionOp.h (1)
160-173
: Debug dump of context/sequence lengths: keep cost low.The toString path wraps host pointers into ITensor; ensure this is compile-time only (debug) or guarded to avoid perf overhead in production logs.
cpp/tensorrt_llm/kernels/sparseAttentionKernels.h (2)
11-32
: Make sparse parameter pointers const; fixdata()
const‑correctness and document the public struct.Indices/offsets are read-only; expose them as
int32_t const*
. Return a tuple of const pointers. Add brief Doxygen.-struct SparseAttentionParams +//! Parameters for sparse attention indices/offsets on device. +struct SparseAttentionParams { - int32_t* sparse_kv_indices{nullptr}; // [num_kv_heads, num_sparse_kv_indices] - int32_t* sparse_attn_indices{nullptr}; // [num_kv_heads, num_sparse_attn_indices] - int32_t* sparse_kv_offsets{nullptr}; // [num_contexts + 1] - int32_t* sparse_attn_offsets{nullptr}; // [num_generations + 1] + //!< [num_kv_heads, num_sparse_kv_indices] + int32_t const* sparse_kv_indices{nullptr}; + //!< [num_kv_heads, num_sparse_attn_indices] + int32_t const* sparse_attn_indices{nullptr}; + //!< [num_contexts + 1] + int32_t const* sparse_kv_offsets{nullptr}; + //!< [num_generations + 1] + int32_t const* sparse_attn_offsets{nullptr}; @@ - auto data() const + std::tuple<int32_t const*, int32_t const*, int32_t const*, int32_t const*> data() const { return std::make_tuple(sparse_kv_indices, sparse_attn_indices, sparse_kv_offsets, sparse_attn_offsets); }As per coding guidelines.
51-56
: Document the API with Doxygen; align parameter naming and constness.Add brief Doxygen and keep parameter naming consistent (
num_kv_heads
vsnum_head_kv
). Keep as declared, but consider renaming for consistency in future PRs. No behavior change needed.-void invokeGatherKvPageOffsets(int32_t* output_kv_page_offsets, // [num_head_kv, batch_size, 2, max_num_pages_per_seq] +//! Gathers KV page offsets per head/batch using sparse params. +//! \param output_kv_page_offsets [num_head_kv, batch_size, 2, max_num_pages_per_seq] +//! \param output_seq_lengths [num_head_kv, batch_size] +//! \param kv_page_offsets [batch_size, 2, max_num_pages_per_seq] +//! \param seq_lengths [batch_size] +void invokeGatherKvPageOffsets(int32_t* output_kv_page_offsets, // [num_head_kv, batch_size, 2, max_num_pages_per_seq]Based on coding guidelines.
cpp/tests/unit_tests/kernels/sparseKvCacheTest.cu (1)
430-492
: Verification helper assumes a specific KVBlockArray pool layout; keep but flag brittleness.The linearization math is test-specific and may diverge from production layout. If the internal layout changes, this will silently miscompare. Consider adding static assertions or reading block strides/capacity from KVBlockArray if available, or limit comparisons to the first block for robustness.
tensorrt_llm/_torch/attention_backend/sparse/kernel.py (2)
210-309
: Add minimal docstrings and input validation to public wrappers.Provide clear contracts (shapes/dtypes), return types, and error messages. Also avoid shadowing the
tokens_per_block
arg.-def triton_update_kt_cache(k, +def triton_update_kt_cache(k, kt_cache_tensor, kt_cache_block_offsets, seq_lens, kt_page_size, tokens_per_block, max_kt_blocks_per_seq, update=True): - # inputs: - # k: (total_seq_len, num_kv_heads, head_dim) - # kt_cache_tensor: (num_blocks, tokens_per_block, num_kv_heads, 2 * head_dim) - # kt_cache_block_offsets: (max_batch_size, max_kt_blocks_per_seq) - # seq_lens: (batch_size) - # kt_page_size: int - # update: bool - - # outputs: - # kt_states: (total_kt_tokens, num_kv_heads, 2 * head_dim) + """ + Update or load KT cache. + Args: + k: (total_seq_len, num_kv_heads, head_dim) in context; see gen-path note above. + kt_cache_tensor: (num_blocks, tokens_per_block, num_kv_heads, 2 * head_dim), CUDA tensor. + kt_cache_block_offsets: (max_batch_size, max_kt_blocks_per_seq), int32/int64 CUDA tensor. + seq_lens: (batch_size), lengths per batch, CUDA tensor. + kt_page_size: int. + tokens_per_block: int. + max_kt_blocks_per_seq: int. + update: bool. If False, context path; if True, generation path. + Returns: + None (context) or kt_states: (total_kt_tokens, num_kv_heads, 2 * head_dim). + """ + assert k.is_cuda and kt_cache_tensor.is_cuda and seq_lens.is_cuda, "All inputs must be on CUDA." + assert kt_cache_tensor.size(1) == tokens_per_block, "tokens_per_block must match cache tensor."As per coding guidelines.
73-207
: Kernels: minor safety/clarity nits.
- Use
tl.full
with explicit dtype fromk_ptr
to avoid fp16→fp32→fp16 churn unless intended.- Consider early-return masking when
kv_end_idx <= kv_start_idx
in context (empty range).- Avoid recomputing
hidden_size
if you can hoist constants in wrappers.These are optional and can wait.
tensorrt_llm/_torch/attention_backend/sparse/utils.py (1)
5-12
: Unify error handling and add concise docstrings; silence Ruff TRY003.Factor repeated messages and document public helpers. Keeps messages short and consistent.
-def get_sparse_attn_kv_cache_manager( - sparse_attn_config: "SparseAttentionConfig"): - if sparse_attn_config.algorithm == "rocket": - return RocketKVCacheManager - else: - raise ValueError( - f"Unsupported sparse attention algorithm: {sparse_attn_config.algorithm}" - ) +def _unsupported(algo: str, where: str) -> ValueError: + return ValueError(f"Unsupported sparse attention algorithm in {where}: {algo}") + +def get_sparse_attn_kv_cache_manager(sparse_attn_config: SparseAttentionConfig): + """Return KV cache manager class for the given sparse algorithm.""" + if sparse_attn_config.algorithm == "rocket": + return RocketKVCacheManager + raise _unsupported(sparse_attn_config.algorithm, "kv_cache_manager") @@ -def get_vanilla_sparse_attn_attention_backend( - sparse_attn_config: "SparseAttentionConfig"): - if sparse_attn_config.algorithm == "rocket": - return RocketVanillaAttention - else: - raise ValueError( - f"Unsupported sparse attention algorithm in vanilla attention backend: {sparse_attn_config.algorithm}" - ) +def get_vanilla_sparse_attn_attention_backend(sparse_attn_config: SparseAttentionConfig): + """Return VanillaAttention backend class for the given sparse algorithm.""" + if sparse_attn_config.algorithm == "rocket": + return RocketVanillaAttention + raise _unsupported(sparse_attn_config.algorithm, "vanilla") @@ -def get_trtllm_sparse_attn_attention_backend( - sparse_attn_config: "SparseAttentionConfig"): - if sparse_attn_config.algorithm == "rocket": - return RocketTrtllmAttention - else: - raise ValueError( - f"Unsupported sparse attention algorithm in trtllm attention backend: {sparse_attn_config.algorithm}" - ) +def get_trtllm_sparse_attn_attention_backend(sparse_attn_config: SparseAttentionConfig): + """Return TrtllmAttention backend class for the given sparse algorithm.""" + if sparse_attn_config.algorithm == "rocket": + return RocketTrtllmAttention + raise _unsupported(sparse_attn_config.algorithm, "trtllm") @@ -def get_flashinfer_sparse_attn_attention_backend( - sparse_attn_config: "SparseAttentionConfig"): - raise ValueError( - f"Unsupported sparse attention algorithm in flashinfer attention backend: {sparse_attn_config.algorithm}" - ) +def get_flashinfer_sparse_attn_attention_backend(sparse_attn_config: SparseAttentionConfig): + """FlashInfer sparse attention not supported yet.""" + raise _unsupported(sparse_attn_config.algorithm, "flashinfer")Based on static analysis hints.
Also applies to: 15-22, 25-32, 35-39
tensorrt_llm/_torch/attention_backend/trtllm.py (1)
1259-1265
: Avoid unnecessary predictions when a phase is absent.Compute per-phase sparse indices only when needed to save GPU time on mixed batches.
- sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets = None, None, None, None - if self.sparse_attention_config is not None: - sparse_kv_indices, sparse_kv_offsets = self.sparse_kv_predict( - q, k, metadata) - sparse_attn_indices, sparse_attn_offsets = self.sparse_attn_predict( - q, k, metadata) + sparse_kv_indices = sparse_kv_offsets = None + sparse_attn_indices = sparse_attn_offsets = None + if self.sparse_attention_config is not None: + if metadata.num_contexts > 0: + sparse_kv_indices, sparse_kv_offsets = self.sparse_kv_predict(q, k, metadata) + if metadata.num_generations > 0: + sparse_attn_indices, sparse_attn_offsets = self.sparse_attn_predict(q, k, metadata)examples/longbench/eval_longbench_v2.py (2)
507-515
: Ensure summary reflects actual generation length; fix extraneous f-strings.
- You override
max_new_tokens
based on--cot
but the summary writesargs.max_new_tokens
. Sync them.- Remove
f
from f-strings without placeholders flagged by Ruff F541.- max_new_tokens = 1024 if args.cot else 256 + max_new_tokens = 1024 if args.cot else 256 + args.max_new_tokens = max_new_tokens @@ - 'max_new_tokens': args.max_new_tokens + 'max_new_tokens': args.max_new_tokensAlso drop unnecessary
f
prefixes at Lines 412, 457, 500, 739, 746 (and any similar cases) to satisfy Ruff F541. Example:- logger.info(f"Loading LongBench v2 dataset...") + logger.info("Loading LongBench v2 dataset...")Also applies to: 676-691
218-239
: Remove unusedtokenizer
parameter from build_chat and its call sites.The
tokenizer
argument is unused (Ruff ARG001). Simplify the signature and calls.-def build_chat(tokenizer, prompt, chat_template): +def build_chat(prompt, chat_template): @@ - formatted_prompt = build_chat(tokenizer, formatted_prompt, - chat_template) + formatted_prompt = build_chat(formatted_prompt, chat_template) @@ - cot_ans_prompt = build_chat(tokenizer, cot_ans_prompt, - chat_template) + cot_ans_prompt = build_chat(cot_ans_prompt, chat_template)Also applies to: 488-492, 539-542
tensorrt_llm/_torch/attention_backend/vanilla.py (3)
190-197
: Docstring/return order mismatch.The docstring says
(is_causal, attn_mask)
but the function returns(attn_mask, is_causal)
. Align the docstring to the implementation to avoid confusion.- Returns: - Tuple of (is_causal, attn_mask) + Returns: + Tuple of (attn_mask, is_causal)Also applies to: 222-223
141-156
: KV cache update path: clarify k/v types when quantized.When
has_fp8_kv_cache
,k
/v
are cast tofloat8_e4m3fn
butindex_copy_
writes viaview(dtype=access_type)
as ints; later, concatenation with past usesk_out
/v_out
slices of cache (pre-cast). You fix dtype in_single_request_attn_forward
byto(q.dtype)
. This is fine, but please add a brief comment here to explain the intended dtype transitions to avoid regressions.
16-16
: Safety: gather indices shape assumptions.
triton_index_gather
requires[row, token, head, dim]
inputs and[row, token, head]
indices. Ensuresparse_kv_indices
/sparse_indices
follow this exactly; add asserts before gathering for clearer failures.- if sparse_kv_indices is not None: + if sparse_kv_indices is not None: + assert k.dim() == 4 and v.dim() == 4 and sparse_kv_indices.dim() == 3, \ + "Expect k/v [B,T,H,D] and indices [B,T,H] for sparse KV selection" k_selected = triton_index_gather(k, sparse_kv_indices) v_selected = triton_index_gather(v, sparse_kv_indices) @@ - if sparse_indices is not None: + if sparse_indices is not None: + assert key_states.dim() == 4 and value_states.dim() == 4 and sparse_indices.dim() == 3, \ + "Expect kv [B,T,H,D] and indices [B,T,H] for sparse attention selection" key_states = triton_index_gather(key_states, sparse_indices) value_states = triton_index_gather(value_states, sparse_indices)Also applies to: 224-238
examples/longbench/eval_longbench_v1.py (4)
214-235
: Remove unusedtokenizer
parameter from build_chat and call site.Same as v2; satisfy Ruff ARG001.
-def build_chat(tokenizer, prompt, chat_template): +def build_chat(prompt, chat_template): @@ - prompt = build_chat(tokenizer, prompt, chat_template) + prompt = build_chat(prompt, chat_template)Also applies to: 390-394
542-549
: Cleaner metric key retrieval.Prefer
next(iter(metrics))
to avoid building a list and silence Ruff RUF015.- if metrics: - metric_key = list(metrics.keys())[0] + if metrics: + metric_key = next(iter(metrics)) val = metrics[metric_key]
697-707
: Drop extraneous f-strings; minor logging cleanup.Remove
f
where no placeholders exist (Ruff F541), and keep messages concise.Examples:
- logger.info( - "=========== LongBench Evaluation with TensorRT-LLM ===========") + logger.info("=========== LongBench Evaluation with TensorRT-LLM ===========") @@ - logger.info(f"Running evaluation on full LongBench datasets") + logger.info("Running evaluation on full LongBench datasets") @@ - logger.info(f"FINAL RESULTS:") + logger.info("FINAL RESULTS:")Also applies to: 718-722, 785-791
184-194
: Path validation: improve error message and doc.If
LongBench/
is vendored under--longbench_path
, the current error is fine. Consider hinting expected layout in the exception to aid users.- raise FileNotFoundError( - f"LongBench directory not found: {longbench_dir}") + raise FileNotFoundError( + f"LongBench directory not found: {longbench_dir}. Expected {longbench_dir}/config and dataset files from THUDM/LongBench." + )
cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/cubin/kernelMetaInfo.h
Outdated
Show resolved
Hide resolved
cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/cubin/kernelMetaInfo.h
Outdated
Show resolved
Hide resolved
/bot run |
PR_Github #20469 [ run ] triggered by Bot |
PR_Github #20469 [ run ] completed with state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 14
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/pyexecutor/model_engine.py (1)
11-15
: Fix F821: forward-ref type without import.Add a guarded import for SparseAttentionConfig to satisfy static analyzers without runtime deps.
from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from ..model_config import SparseAttentionConfig
♻️ Duplicate comments (42)
cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/cubin/kernelMetaInfo.h (3)
3068-3081
: Set abLayout to kPersistentSwapsAbForGen (2) for persistent and kSwapsAbForGen (1) for static swap variants.These BF16 swap-variant initializers (lines 3068–3081) reference
*SwapsAbForGen*
cubins but still have abLayout set to0
(kKeepsAbForGen) or1
(kSwapsAbForGen) when they should be2
(kPersistentSwapsAbForGen) for persistent variants and1
(kSwapsAbForGen) for static variants. The dispatcher will pass incorrectly-laid-out buffers to these kernels, causing incorrect results.Apply the diff from the previous review comment to update the abLayout values for lines 3068–3081.
4170-4183
: Set abLayout to kPersistentSwapsAbForGen (2) for persistent and kSwapsAbForGen (1) for static FP16 swap variants.These FP16 swap-variant initializers (lines 4170–4183) mirror the BF16 issue: abLayout is set to
0
or1
when it should be2
for persistent swap variants and1
for static swap variants.Apply the diff from the previous review comment to update the abLayout values for lines 4170–4183.
3072-3075
: Set abLayout to kSwapsAbForGen (1) for these P64 static swap variants.These four BF16 P64 StaticSwapsAbForGen initializers (lines 3072–3075) also have abLayout incorrectly set to
0
(kKeepsAbForGen) instead of1
(kSwapsAbForGen), causing the same buffer layout mismatch issue.Apply this diff:
-{ DATA_TYPE_BF16, DATA_TYPE_BF16, DATA_TYPE_BF16, 16, 128, 16, 256, 128, 128, 128, kSM_100f, FmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvCgaVarSeqQ16Kv128StaticSwapsAbForGen_cubin, FmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvCgaVarSeqQ16Kv128StaticSwapsAbForGen_cubin_len, "fmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvCgaVarSeqQ16Kv128StaticSwapsAbForGen", 182480, 512, 2, 64, 0, 2, 16, 0, 3, true, false, false, "9197439559d66629cf845bd96f28abf11060ed38633468fe7ad291e6c315f20d"}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, DATA_TYPE_BF16, 16, 128, 16, 256, 128, 128, 128, kSM_100f, FmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvCgaVarSeqQ16Kv128StaticSwapsAbForGen_cubin, FmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvCgaVarSeqQ16Kv128StaticSwapsAbForGen_cubin_len, "fmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvCgaVarSeqQ16Kv128StaticSwapsAbForGen", 182480, 512, 2, 64, 0, 2, 16, 1, 3, true, false, false, "9197439559d66629cf845bd96f28abf11060ed38633468fe7ad291e6c315f20d"}, -{ DATA_TYPE_BF16, DATA_TYPE_BF16, DATA_TYPE_BF16, 8, 128, 8, 256, 128, 128, 128, kSM_100f, FmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvCgaVarSeqQ8Kv128StaticSwapsAbForGen_cubin, FmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvCgaVarSeqQ8Kv128StaticSwapsAbForGen_cubin_len, "fmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvCgaVarSeqQ8Kv128StaticSwapsAbForGen", 175824, 512, 2, 64, 0, 2, 8, 0, 3, true, false, false, "70a2117d2aa8d255040f97cfa19c4bf189ed485ba4a953d4b99997735c706265"}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, DATA_TYPE_BF16, 8, 128, 8, 256, 128, 128, 128, kSM_100f, FmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvCgaVarSeqQ8Kv128StaticSwapsAbForGen_cubin, FmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvCgaVarSeqQ8Kv128StaticSwapsAbForGen_cubin_len, "fmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvCgaVarSeqQ8Kv128StaticSwapsAbForGen", 175824, 512, 2, 64, 0, 2, 8, 1, 3, true, false, false, "70a2117d2aa8d255040f97cfa19c4bf189ed485ba4a953d4b99997735c706265"}, -{ DATA_TYPE_BF16, DATA_TYPE_BF16, DATA_TYPE_BF16, 16, 128, 16, 256, 128, 128, 128, kSM_100f, FmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvVarSeqQ16Kv128StaticSwapsAbForGen_cubin, FmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvVarSeqQ16Kv128StaticSwapsAbForGen_cubin_len, "fmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvVarSeqQ16Kv128StaticSwapsAbForGen", 148624, 512, 2, 64, 0, 2, 16, 0, 1, true, false, false, "876da47e60e0448fc87875d23318807afa489931112f01e2f455123e512ebcce"}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, DATA_TYPE_BF16, 16, 128, 16, 256, 128, 128, 128, kSM_100f, FmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvVarSeqQ16Kv128StaticSwapsAbForGen_cubin, FmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvVarSeqQ16Kv128StaticSwapsAbForGen_cubin_len, "fmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvVarSeqQ16Kv128StaticSwapsAbForGen", 148624, 512, 2, 64, 0, 2, 16, 1, 1, true, false, false, "876da47e60e0448fc87875d23318807afa489931112f01e2f455123e512ebcce"}, -{ DATA_TYPE_BF16, DATA_TYPE_BF16, DATA_TYPE_BF16, 8, 128, 8, 256, 128, 128, 128, kSM_100f, FmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvVarSeqQ8Kv128StaticSwapsAbForGen_cubin, FmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvVarSeqQ8Kv128StaticSwapsAbForGen_cubin_len, "fmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvVarSeqQ8Kv128StaticSwapsAbForGen", 141968, 512, 2, 64, 0, 2, 8, 0, 1, true, false, false, "dacac32324a94091467aab4e09474420408a7eaaa58933178e329205e1e026b4"}, +{ DATA_TYPE_BF16, DATA_TYPE_BF16, DATA_TYPE_BF16, 8, 128, 8, 256, 128, 128, 128, kSM_100f, FmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvVarSeqQ8Kv128StaticSwapsAbForGen_cubin, FmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvVarSeqQ8Kv128StaticSwapsAbForGen_cubin_len, "fmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvVarSeqQ8Kv128StaticSwapsAbForGen", 141968, 512, 2, 64, 0, 2, 8, 1, 1, true, false, false, "dacac32324a94091467aab4e09474420408a7eaaa58933178e329205e1e026b4"},tensorrt_llm/_torch/attention_backend/sparse/kernel.py (2)
1-4
: Add required NVIDIA Apache-2.0 header (current year).The file is missing the mandated copyright header.
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import torch import triton import triton.language as tlAs per coding guidelines.
42-71
: Validate index bounds to prevent OOB GPU loads.The function does not validate that
indices
values are within the valid range[0, input.shape[1])
, which can cause undefined behavior on GPU.def triton_index_gather(input, indices): assert input.ndim == 4, "Input must be a 4D tensor, [row, token, head, dim]" assert indices.ndim == 3, "Indices must be a 3D tensor, [row, token, head]" + # Bounds check + min_idx = indices.min() + max_idx = indices.max() + assert min_idx >= 0 and max_idx < input.shape[1], \ + f"indices out of range: [{int(min_idx)}..{int(max_idx)}] vs tokens={input.shape[1]}"As per coding guidelines.
tensorrt_llm/_torch/pyexecutor/_util.py (1)
108-116
: Compute KV size per token using the correct manager class per model.Using
self._kv_cache_manager_cls
for both main and draft models is incorrect when their configurations differ (e.g., main uses RocketKV sparse, draft does not). Each model should use its own manager class.- kv_size_per_token = self._kv_cache_manager_cls.get_cache_size_per_token( - model_config, mapping, tokens_per_block=self._tokens_per_block) + main_cls = get_kv_cache_manager_cls(model_config) + kv_size_per_token = main_cls.get_cache_size_per_token( + model_config, mapping, tokens_per_block=self._tokens_per_block) if self._draft_model_engine is not None: draft_model_config = self._draft_model_engine.model.model_config - kv_size_per_token += self._kv_cache_manager_cls.get_cache_size_per_token( - draft_model_config, - mapping, - tokens_per_block=self._tokens_per_block) + draft_cls = get_kv_cache_manager_cls(draft_model_config) + kv_size_per_token += draft_cls.get_cache_size_per_token( + draft_model_config, mapping, tokens_per_block=self._tokens_per_block)cpp/tensorrt_llm/pybind/thop/bindings.cpp (1)
54-55
: Fix default for sparse_attention_params to maintain backward compatibility.Defaulting
sparse_attention_params
tostd::nullopt
will cause a TypeError when callers omit the argument, as the C++ implementation expects a vector. Provide a real default vector of four null optionals.- py::arg("spec_decoding_tensor_params"), py::arg("sparse_attention_params") = std::nullopt, + py::arg("spec_decoding_tensor_params"), + py::arg("sparse_attention_params") + = std::vector<std::optional<torch::Tensor>>{std::nullopt, std::nullopt, std::nullopt, std::nullopt}, "Multi-head attention operation", py::call_guard<py::gil_scoped_release>());tensorrt_llm/_torch/attention_backend/sparse/rocket.py (5)
1-26
: Add the required NVIDIA Apache-2.0 header.This new module is missing the mandated NVIDIA Apache-2.0 copyright header for 2025.
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import mathAs per coding guidelines.
114-114
: Fix page-index conversion (wrong divisor).
convert_token_to_page_sparse_indices
divides token indices bymetadata.tokens_per_block
, which is the paged-KV block length, not the RocketKV page size. Use the RocketKV page size from the cache manager.- page_size = metadata.tokens_per_block + page_size = metadata.kv_cache_manager.page_size
135-165
: Eliminate padded-1
page ids.Deduplicating per head and padding shorter heads with
-1
leaves sentinel values innew_page_indices
. After the final transpose, the kernel sees-1
as a real page id, causing incorrect gather operations. Rebuild the routine to only emit valid (non-negative) page indices.
959-968
: Correct KT rewind page count math.
math.ceil(num_tokens - rewind_len / self.page_size)
mixes units—rewind_len / self.page_size
is fractional pages subtracted from a raw token count. Compute remaining tokens first, clamp at zero, then convert to page units.- updated_kt_token_num = math.ceil(num_tokens - - rewind_len / self.page_size) + remaining = max(num_tokens - rewind_len, 0) + updated_kt_token_num = math.ceil(remaining / self.page_size)
1000-1001
: Fix ceiling division incompute_page_count
.
(token_count + tokens_per_page) // tokens_per_page
over-allocates whentoken_count
is an exact multiple and returns1
whentoken_count == 0
. Use the standard ceiling formula.- return (token_count + tokens_per_page) // tokens_per_page + if token_count <= 0: + return 0 + return (token_count + tokens_per_page - 1) // tokens_per_pageexamples/longbench/requirements.txt (1)
1-3
: Replace GPL-onlyfuzzywuzzy
with a permissive fork.
fuzzywuzzy
is GPL‑2.0 and archived, which conflicts with TensorRT-LLM’s Apache-2.0 licensing. Please switch to a compatible fork such asthefuzz
(MIT) and adjust any imports/usages accordingly.
[suggested change]-jieba -fuzzywuzzy -rouge +jieba +thefuzz +rougeexamples/longbench/eval_longbench_v1.py (2)
1-10
: Add NVIDIA Apache-2.0 header and resolve shebang executability.Prepend the standard NVIDIA Apache-2.0 header. Keep the shebang only if the file is executable; otherwise remove it to satisfy EXE001.
#!/usr/bin/env python3 +# +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# """ LongBench v1 evaluation script with TensorRT-LLM and sparse attention.As per coding guidelines.
292-336
: Wire KV cache options, validate RocketKV backend, and preserve traceback.
- Use
--kv_cache_dtype
and--kv_cache_fraction
.- Fail fast if backend unsupported for RocketKV.
- Drop extraneous f-string and re-raise with plain
raise
.- kv_cache_config = KvCacheConfig( - enable_block_reuse=False, # RocketKV doesn't support KV cache reuse - ) + kv_cache_config = KvCacheConfig( + enable_block_reuse=False, # RocketKV doesn't support KV cache reuse + dtype=args.kv_cache_dtype, + free_gpu_memory_fraction=args.kv_cache_fraction, + ) @@ - if args.rocket_sparse: + if args.rocket_sparse: + if args.backend != "pytorch": + raise ValueError( + "RocketKV sparse attention currently supports backend='pytorch'. " + "Use --backend pytorch or disable --rocket_sparse." + ) # Configure RocketKV sparse attention @@ - logger.info(f"Using RocketKV sparse attention") + logger.info("Using RocketKV sparse attention") @@ - except Exception as e: - logger.error(f"Failed to initialize LLM: {e}") - raise e + except Exception as e: + logger.error(f"Failed to initialize LLM: {e}") + raiseexamples/longbench/eval_longbench_v2.py (4)
1-16
: Add NVIDIA Apache-2.0 header and resolve shebang executability.Same as v1: add header; retain shebang only if executable.
#!/usr/bin/env python3 +# +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# """ LongBench v2 evaluation script with TensorRT-LLM and sparse attention.As per coding guidelines.
298-310
: Preserve special tokens during truncation.Decoding with
skip_special_tokens=True
can drop chat tokens like <|eot_id|>.- prompt = tokenizer.decode(truncated_ids, skip_special_tokens=True) + prompt = tokenizer.decode(truncated_ids, skip_special_tokens=False)
346-389
: Wire KV cache options, validate RocketKV backend, remove redundant f-string, and preserve traceback.Same concerns as v1; add dtype/fraction, backend guard, plain
raise
, and drop superfluousf
.- kv_cache_config = KvCacheConfig( - enable_block_reuse=False, # RocketKV doesn't support KV cache reuse - ) + kv_cache_config = KvCacheConfig( + enable_block_reuse=False, + dtype=args.kv_cache_dtype, + free_gpu_memory_fraction=args.kv_cache_fraction, + ) @@ - if args.rocket_sparse: + if args.rocket_sparse: + if args.backend != "pytorch": + raise ValueError( + "RocketKV sparse attention currently supports backend='pytorch'. " + "Use --backend pytorch or disable --rocket_sparse." + ) # Configure RocketKV sparse attention @@ - logger.info(f"Using RocketKV sparse attention") + logger.info("Using RocketKV sparse attention") @@ - except Exception as e: - logger.error(f"Failed to initialize LLM: {e}") - raise e + except Exception as e: + logger.error(f"Failed to initialize LLM: {e}") + raise
468-482
: Avoid KeyError for templates and IndexError for stop-token IDs.Use
.get()
for templates and guard tokenizer encodes.- template = config['templates'][template_key] + template = config['templates'].get(template_key) + if template is None: + raise KeyError(f"Missing template '{template_key}' under {args.longbench_path}/prompts") @@ - if chat_template == "llama3": - eot_id = tokenizer.encode("<|eot_id|>", add_special_tokens=False)[0] - extra_end_token_ids.append(eot_id) + if chat_template == "llama3": + eot_ids = tokenizer.encode("<|eot_id|>", add_special_tokens=False) + if eot_ids: + extra_end_token_ids.append(eot_ids[0]) @@ - if chat_template == "qwen": - im_end_id = tokenizer.encode("<|im_end|>", add_special_tokens=False)[0] - extra_end_token_ids.append(im_end_id) + if chat_template == "qwen": + im_end_ids = tokenizer.encode("<|im_end|>", add_special_tokens=False) + if im_end_ids: + extra_end_token_ids.append(im_end_ids[0])examples/longbench/README.md (3)
27-39
: Add language specifier to fenced code block.Use language hint for proper rendering.
-``` +```text sparse_attention/ @@ -``` +```
148-160
: Add language specifier to fenced code block.Same as above.
-``` +```text results/v1_experiment/ @@ -``` +```
164-170
: Add language specifier to fenced code block.Same as above.
-``` +```text results/v2_experiment/ @@ -``` +```tensorrt_llm/_torch/attention_backend/interface.py (1)
143-143
: ImportSparseAttentionConfig
for type checking.Static analysis correctly flags
SparseAttentionConfig
as undefined. Add a conditional import underTYPE_CHECKING
to resolve the type checker warning and improve IDE support.+from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from tensorrt_llm.llmapi.llm_args import SparseAttentionConfig🪛 Ruff (0.13.2)
143-143: Undefined name
SparseAttentionConfig
(F821)
cpp/tensorrt_llm/thop/attentionOp.cpp (1)
230-244
: Reset all sparse pointers before conditional assignment.With cached ops, leaving the non-updated branch's pointers unchanged can carry stale addresses into future runs when the phase switches. Clear all four sparse pointer fields to
nullptr
before theif (is_context)
branch to prevent stale pointer reuse.// Prepare sparse attention parameters +op.mRuntimeSparseAttentionParams.sparse_kv_indices = nullptr; +op.mRuntimeSparseAttentionParams.sparse_kv_offsets = nullptr; +op.mRuntimeSparseAttentionParams.sparse_attn_indices = nullptr; +op.mRuntimeSparseAttentionParams.sparse_attn_offsets = nullptr; if (is_context) { op.mRuntimeSparseAttentionParams.sparse_kv_indices = sparse_kv_indices.has_value() ? sparse_kv_indices.value().data_ptr<int32_t>() : nullptr; op.mRuntimeSparseAttentionParams.sparse_kv_offsets = sparse_kv_offsets.has_value() ? sparse_kv_offsets.value().data_ptr<int32_t>() : nullptr; } else { op.mRuntimeSparseAttentionParams.sparse_attn_indices = sparse_attn_indices.has_value() ? sparse_attn_indices.value().data_ptr<int32_t>() : nullptr; op.mRuntimeSparseAttentionParams.sparse_attn_offsets = sparse_attn_offsets.has_value() ? sparse_attn_offsets.value().data_ptr<int32_t>() : nullptr; }cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h (1)
1821-1833
: Add missing head sizes to sparse dispatch.The sparse dispatch switch only covers 16, 32, 64, 128, 256 but the dense V2 kernel (lines 1587-1603) supports many more, including common sizes 80, 96, and 144. Models using those head dimensions with sparse attention will error out unexpectedly.
Add cases for the missing head sizes:
switch (params.size_per_head) { case 16: kernelSparseDispatchHeadSize<16, T, TCache, KVCacheBuffer>(params, stream); break; case 32: kernelSparseDispatchHeadSize<32, T, TCache, KVCacheBuffer>(params, stream); break; case 64: kernelSparseDispatchHeadSize<64, T, TCache, KVCacheBuffer>(params, stream); break; +case 80: kernelSparseDispatchHeadSize<80, T, TCache, KVCacheBuffer>(params, stream); break; +case 96: kernelSparseDispatchHeadSize<96, T, TCache, KVCacheBuffer>(params, stream); break; case 128: kernelSparseDispatchHeadSize<128, T, TCache, KVCacheBuffer>(params, stream); break; +case 144: kernelSparseDispatchHeadSize<144, T, TCache, KVCacheBuffer>(params, stream); break; case 256: kernelSparseDispatchHeadSize<256, T, TCache, KVCacheBuffer>(params, stream); break; default: TLLM_CHECK_WITH_INFO( false, "updateSparseKvCacheAfterFmha kernel doesn't support head size = %d", params.size_per_head); break; }examples/llm-api/llm_sparse_attention.py (2)
8-14
: read_input: add docstring and robust I/O/JSON handling.Handle missing file/invalid JSON; document behavior.
def read_input(input_file): - results = [] - with open(input_file, 'r') as f: - for line in f: - ret = json.loads(line) - results.append(ret) - return results + """ + Read JSONL and return a list of parsed objects. + """ + results = [] + try: + with open(input_file, 'r') as f: + for line in f: + line = line.strip() + if not line: + continue + results.append(json.loads(line)) + except FileNotFoundError: + raise FileNotFoundError(f"Input file not found: {input_file}") + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON in input file: {e}") from e + return results
1-1
: Add NVIDIA Apache-2.0 copyright header.Required by repo guidelines; add the SPDX block at file top.
Apply:
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.tensorrt_llm/_torch/pyexecutor/model_loader.py (1)
163-163
: Missing import forSparseAttentionConfig
.The type annotation on line 163 references
SparseAttentionConfig
, but this type is not imported. While the quoted annotation defers evaluation, the import is still required for runtime type checking and IDE support.Add the import at the top of the file:
+from tensorrt_llm.llmapi.llm_args import SparseAttentionConfig
Or use the appropriate import path based on where
SparseAttentionConfig
is defined in the codebase.cpp/tests/unit_tests/kernels/sparseAttentionKernelsTest.cpp (1)
1-1
: Add NVIDIA Apache-2.0 copyright header.According to coding guidelines, all C++ source files must have the NVIDIA Apache-2.0 copyright header prepended with the current year (2025).
Prepend the following header at the top of the file:
/* * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */As per coding guidelines.
tensorrt_llm/_torch/attention_backend/utils.py (2)
1-10
: Add TYPE_CHECKING import for SparseAttentionConfig type hint.Ruff F821 flags
SparseAttentionConfig
at line 54 as undefined. Add a forward declaration under TYPE_CHECKING to satisfy static analysis without causing runtime import cycles.Apply this diff:
-from typing import Optional, Type +from typing import TYPE_CHECKING, Optional, Type from ...models.modeling_utils import QuantConfig from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE from .interface import AttentionBackend, MLAParams, PositionalEmbeddingParams from .sparse import (get_flashinfer_sparse_attn_attention_backend, get_trtllm_sparse_attn_attention_backend, get_vanilla_sparse_attn_attention_backend) + +if TYPE_CHECKING: + from tensorrt_llm.llmapi.llm_args import SparseAttentionConfig
13-33
: Normalize backend_name to handle case variations.Current implementation is case-sensitive and will fail for inputs like
"vanilla"
or"flashinfer"
, silently falling back to TRTLLM. Add normalization at the function entry.Apply this diff:
def get_attention_backend(backend_name: str, - sparse_attn_config=None) -> Type[AttentionBackend]: + sparse_attn_config: Optional["SparseAttentionConfig"] = None) -> Type[AttentionBackend]: + backend_name = backend_name.upper() if backend_name == "VANILLA":tests/unittest/_torch/attention/sparse/test_rocketkv.py (1)
55-63
: Test remains non-deterministic due to stochastic sampling.As flagged in the previous review, using
temperature=0.8
andtop_p=0.95
without seeding makes this test flaky. Switch to greedy decoding (temperature=0.0, top_p=1.0) or set random seeds before generation to ensure reproducible results across CI runs.Apply this diff to make generation deterministic:
sampling_params=SamplingParams(add_special_tokens=False, max_tokens=max_output_tokens, - temperature=0.8, - top_p=0.95), + temperature=0.0, + top_p=1.0),tensorrt_llm/_torch/attention_backend/sparse/__init__.py (1)
1-11
: Add the required NVIDIA Apache-2.0 copyright header.All source files must begin with the NVIDIA Apache-2.0 copyright header with the current year (2025). Please prepend the standard header to the top of this file before any imports or code.
As per coding guidelines.
tensorrt_llm/llmapi/llm_args.py (1)
176-187
: Fix thefrom_dict
dispatch bug.Line 183 queries
config_classes.get("algorithm")
which always returnsNone
because the key"algorithm"
doesn't exist in the dict (it should be"Rocket"
). Additionally, even if fixed, the key casing mismatch ("Rocket"
vs default"rocket"
) would cause lookup failures.Apply this diff to fix the dispatch logic:
@classmethod def from_dict(cls, data: dict): # dispatch to the correct sparse attention config config_classes = { - "Rocket": RocketSparseAttentionConfig, + "rocket": RocketSparseAttentionConfig, } - config_class = config_classes.get("algorithm") + algorithm = data.get("algorithm", cls().algorithm) + config_class = config_classes.get(algorithm.lower()) if config_class is None: - raise ValueError(f"Invalid algorithm") + raise ValueError(f"Invalid algorithm: {algorithm}") return config_class(**data)cpp/tensorrt_llm/kernels/sparseAttentionKernels.h (3)
1-1
: Replace pragma once with include guards and add NVIDIA Apache-2.0 header.Required by project guidelines; use TRTLLM_SPARSEATTENTIONKERNELS_H and the 2025 banner.
As per coding guidelines.
40-49
: Guard device qualifier to keep header includable from host-only TUs.Public headers should not force device-only symbols on .cpp users.
As per coding guidelines.
3-5
: Make header self-contained: add missing standard includes.std::string, std::tuple, and int32_t are used but their headers aren’t included here.
#include <cuda_runtime.h> -#include <sstream> +#include <cstdint> +#include <string> +#include <tuple> +#include <sstream>As per coding guidelines.
tensorrt_llm/_torch/attention_backend/sparse/utils.py (1)
1-2
: Add NVIDIA license header and future annotations import.Required by guidelines; also prevents forward-ref evaluation at runtime.
+// SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. +// SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotationsAs per coding guidelines.
cpp/tests/unit_tests/kernels/sparseKvCacheTest.cu (1)
1-15
: Copyright year already flagged in previous review.A past review comment has identified that the copyright year should be 2025 for new files in this PR. This is a duplicate of that finding.
cpp/tensorrt_llm/thop/attentionOp.h (1)
37-61
: Header/impl signature mismatch will break linking; use std::optional consistentlyattentionOp.cpp defines std::optional/ std::vector<std::optional<...>>; this header declares torch::optional in multiple params. Align to std::optional to match the definition.
Apply:
-void attention(torch::Tensor q, torch::optional<torch::Tensor> k, torch::optional<torch::Tensor> v, - torch::Tensor& output, torch::optional<torch::Tensor> output_sf, std::optional<torch::ScalarType> out_dtype, - torch::optional<torch::Tensor> workspace_, torch::Tensor sequence_length, torch::Tensor host_past_key_value_lengths, +void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<torch::Tensor> v, + torch::Tensor& output, std::optional<torch::Tensor> output_sf, std::optional<torch::ScalarType> out_dtype, + std::optional<torch::Tensor> workspace_, torch::Tensor sequence_length, torch::Tensor host_past_key_value_lengths, @@ - torch::Tensor host_request_types, torch::optional<torch::Tensor> kv_cache_block_offsets, - torch::optional<torch::Tensor> host_kv_cache_block_offsets, - torch::optional<torch::Tensor> host_kv_cache_pool_pointers, - torch::optional<torch::Tensor> host_kv_cache_pool_mapping, torch::optional<torch::Tensor> cache_indirection, - torch::optional<torch::Tensor> kv_scale_orig_quant, torch::optional<torch::Tensor> kv_scale_quant_orig, - torch::optional<torch::Tensor> out_scale, torch::optional<torch::Tensor> rotary_inv_freq, - torch::optional<torch::Tensor> rotary_cos_sin, torch::optional<torch::Tensor> latent_cache, - torch::optional<torch::Tensor> q_pe, torch::optional<torch::Tensor> block_ids_per_seq, - torch::optional<torch::Tensor> attention_sinks, bool const is_fused_qkv, bool const update_kv_cache, + torch::Tensor host_request_types, std::optional<torch::Tensor> kv_cache_block_offsets, + std::optional<torch::Tensor> host_kv_cache_block_offsets, + std::optional<torch::Tensor> host_kv_cache_pool_pointers, + std::optional<torch::Tensor> host_kv_cache_pool_mapping, std::optional<torch::Tensor> cache_indirection, + std::optional<torch::Tensor> kv_scale_orig_quant, std::optional<torch::Tensor> kv_scale_quant_orig, + std::optional<torch::Tensor> out_scale, std::optional<torch::Tensor> rotary_inv_freq, + std::optional<torch::Tensor> rotary_cos_sin, std::optional<torch::Tensor> latent_cache, + std::optional<torch::Tensor> q_pe, std::optional<torch::Tensor> block_ids_per_seq, + std::optional<torch::Tensor> attention_sinks, bool const is_fused_qkv, bool const update_kv_cache, @@ - std::optional<torch::Tensor> mrope_rotary_cos_sin, std::optional<torch::Tensor> mrope_position_deltas, + std::optional<torch::Tensor> mrope_rotary_cos_sin, std::optional<torch::Tensor> mrope_position_deltas, @@ - std::vector<std::optional<torch::Tensor>> spec_decoding_tensor_params, - std::vector<torch::optional<torch::Tensor>> sparse_attention_params); + std::vector<std::optional<torch::Tensor>> spec_decoding_tensor_params, + std::vector<std::optional<torch::Tensor>> sparse_attention_params);After patching, rebuild and confirm only a single attention(...) symbol exists:
#!/bin/bash rg -n 'void\s+attention\(' cpp/tensorrt_llm/thop -C2tensorrt_llm/_torch/attention_backend/vanilla.py (2)
98-115
: Sparse hooks must not raise; adjust types and provide safe defaultsRaising NotImplementedError will crash when sparse_attention_config is present. Return a no‑op (None, kv_len) and fix type hints to include the kv_len.
Apply:
- def _single_request_sparse_attn_predict(self, q: torch.Tensor, + def _single_request_sparse_attn_predict(self, q: torch.Tensor, k: Optional[torch.Tensor], v: Optional[torch.Tensor], kv_cache_tensor: torch.Tensor, metadata: AttentionMetadata, past_seen_token: int, sample_idx: int, - **kwargs) -> Optional[torch.Tensor]: - raise NotImplementedError + **kwargs) -> tuple[Optional[torch.Tensor], int]: + kv_len = k.size(0) if k is not None else 0 + return None, kv_len @@ - def _single_request_sparse_kv_predict(self, q: Optional[torch.Tensor], + def _single_request_sparse_kv_predict(self, q: Optional[torch.Tensor], k: Optional[torch.Tensor], v: Optional[torch.Tensor], metadata: AttentionMetadata, past_seen_token: int, sample_idx: int, - **kwargs) -> Optional[torch.Tensor]: - raise NotImplementedError + **kwargs) -> tuple[Optional[torch.Tensor], int]: + kv_len = k.size(0) if k is not None else 0 + return None, kv_len
431-434
: Pass attention_window_size through to per‑request pathWithout forwarding, sliding‑window masking is never applied.
Apply:
- attn_output = self._single_request_forward( - single_q, single_k, single_v, attention_mask, kv_cache_tensor, - past_seen_token, cache_idx, sample_idx, metadata, **kwargs) + attn_output = self._single_request_forward( + single_q, single_k, single_v, attention_mask, kv_cache_tensor, + past_seen_token, cache_idx, sample_idx, metadata, + attention_window_size=attention_window_size, **kwargs)
🧹 Nitpick comments (25)
tensorrt_llm/_utils.py (1)
216-224
: Add type hints and improve docstring format.The function logic is correct, but it lacks type hints and the docstring does not follow Google-style format.
Apply this diff to add type hints and improve the docstring:
-def next_power_of_two(x): +def next_power_of_two(x: int) -> int: """ - get next power of two + Returns the next power of two greater than or equal to x. + + Args: + x: An integer value. + + Returns: + The smallest power of two >= x, or 1 if x <= 0. """ if x <= 0: return 1 if (x & (x - 1)) == 0: return x return 1 << x.bit_length()As per coding guidelines (Python code should include type hints and use Google-style docstrings).
examples/longbench/eval_longbench_v1.py (1)
214-234
: Silence unused parameter in build_chat.
tokenizer
is unused; rename to_tokenizer
to satisfy linters without changing call sites.-def build_chat(tokenizer, prompt, chat_template): +def build_chat(_tokenizer, prompt, chat_template): """Build chat prompt following LongBench's approach."""examples/longbench/eval_longbench_v2.py (4)
219-238
: Silence unused parameter in build_chat.Rename
tokenizer
to_tokenizer
.-def build_chat(tokenizer, prompt, chat_template): +def build_chat(_tokenizer, prompt, chat_template): """Build chat prompt following LongBench's approach."""
523-592
: Minor cleanups: rename unused loop var; robust prompt_token_ids usage.
- Rename
i
to_
(B007).- Handle method vs attribute for
prompt_token_ids
.- for i, (sample, output) in enumerate(zip(filtered_data, outputs)): + for _, (sample, output) in enumerate(zip(filtered_data, outputs)): @@ - 'prompt_length': len(output.prompt_token_ids), + 'prompt_length': len(output.prompt_token_ids() if callable(getattr(output, "prompt_token_ids", None)) else output.prompt_token_ids),
597-639
: Expose sample-count breakdowns used later in logs.
main()
logs reference{easy,length}_samples
keys that are not produced; add counts to metrics to avoid always printing 0.metrics = { 'overall_accuracy': round(overall_accuracy * 100, 2), 'total_samples': total_samples, - 'correct_samples': correct_samples + 'correct_samples': correct_samples } @@ - for difficulty in difficulties: + for difficulty in difficulties: diff_results = [r for r in results if r['difficulty'] == difficulty] if diff_results: diff_correct = sum(1 for r in diff_results if r['is_correct']) metrics[f'{difficulty}_accuracy'] = round( (diff_correct / len(diff_results)) * 100, 2) + metrics[f'{difficulty}_samples'] = len(diff_results) @@ - for length in lengths: + for length in lengths: len_results = [r for r in results if r['length'] == length] if len_results: len_correct = sum(1 for r in len_results if r['is_correct']) metrics[f'{length}_accuracy'] = round( (len_correct / len(len_results)) * 100, 2) + metrics[f'{length}_samples'] = len(len_results)
724-768
: Remove redundant f-strings without placeholders.Clean up minor F541 instances.
- logger.info(f"Starting LongBench v2 evaluation...") + logger.info("Starting LongBench v2 evaluation...") @@ - logger.info(f"{'-'*80}") + logger.info('-' * 80) @@ - logger.info( - f"Overall accuracy: {metrics.get('overall_accuracy', 'N/A')}%") + logger.info("Overall accuracy: %s%%", metrics.get('overall_accuracy', 'N/A'))examples/llm-api/llm_sparse_attention.py (3)
20-24
: Avoid hard-coded, user-specific default model path.Make model_path required or read from env to improve portability.
- parser.add_argument( - '--model_path', - type=str, - default= - "/home/scratch.trt_llm_data/llm-models/llama-3.1-model/Llama-3.1-8B-Instruct" - ) + parser.add_argument('--model_path', type=str, required=True)
81-112
: Function naming style and CLI cohesion.Follow snake_case; also consider exposing '--backend' to exercise TRTLLM path too.
-def run_RocketKV(args): +def run_rocketkv(args): @@ - if args.algo == 'ROCKETKV': - run_RocketKV(args) + if args.algo == 'ROCKETKV': + run_rocketkv(args)Optionally add:
- parser.add_argument('--tensor_parallel_size', type=int, default=1) + parser.add_argument('--tensor_parallel_size', type=int, default=1) + parser.add_argument('--backend', type=str, default='pytorch', choices=['pytorch','trtllm'])and pass backend=args.backend to LLM().
120-129
: Guard output indexing to avoid IndexError on empty candidates.Use a safe fallback if no outputs returned.
- for idx, output in enumerate(outputs): - print( - f'Generated text: {output.outputs[0].text!r}, ref: {reference[idx]}' - ) + for idx, output in enumerate(outputs): + text = output.outputs[0].text if output.outputs else "" + print(f'Generated text: {text!r}, ref: {reference[idx]}')cpp/tensorrt_llm/nanobind/common/customCasters.h (1)
29-29
: Remove unused ArrayRef include
The headercpp/tensorrt_llm/nanobind/common/customCasters.h
contains no references toc10::ArrayRef
; drop the#include <c10/util/ArrayRef.h>
to avoid an unnecessary dependency.cpp/tests/unit_tests/kernels/sparseAttentionKernelsTest.cpp (1)
15-15
: Use CamelCase for class name.The test fixture class name
sparseAttentionKernelsTest
uses lowerCamelCase, but coding guidelines require CamelCase for type names (classes).Apply this diff:
-class sparseAttentionKernelsTest : public ::testing::Test +class SparseAttentionKernelsTest : public ::testing::TestAnd update the test macro on line 31:
-TEST_F(sparseAttentionKernelsTest, GatherKvPageOffsetsKernelTest) +TEST_F(SparseAttentionKernelsTest, GatherKvPageOffsetsKernelTest)As per coding guidelines.
tests/unittest/_torch/attention/sparse/test_rocketkv.py (1)
41-54
: Consider using Path for cleaner path construction.The nested
os.path.dirname
calls work but are less readable. Also, reusingtest_star_attention_input.jsonl
for RocketKV tests may cause confusion—consider clarifying or renaming if these are specifically RocketKV test inputs.Apply this diff for cleaner path handling:
- current_file = os.path.abspath(__file__) - current_dir = os.path.dirname(os.path.dirname( - os.path.dirname(current_file))) - input_file = f'{current_dir}/multi_gpu/test_star_attention_input.jsonl' + from pathlib import Path + current_dir = Path(__file__).resolve().parent.parent.parent + input_file = current_dir / 'multi_gpu' / 'test_star_attention_input.jsonl'tensorrt_llm/_torch/attention_backend/sparse/__init__.py (1)
6-11
: Consider sorting__all__
for consistency.The static analyzer suggests sorting the
__all__
list. While this is a minor style issue, alphabetically sorting public exports improves maintainability and follows common Python conventions.Apply this diff:
__all__ = [ + "get_flashinfer_sparse_attn_attention_backend", "get_sparse_attn_kv_cache_manager", - "get_vanilla_sparse_attn_attention_backend", "get_trtllm_sparse_attn_attention_backend", - "get_flashinfer_sparse_attn_attention_backend", + "get_vanilla_sparse_attn_attention_backend", ]tensorrt_llm/llmapi/llm_args.py (1)
192-197
: Remove unused method parameterbackend
.The
supports_backend
method inSparseAttentionBaseConfig
always returnsTrue
and doesn't use thebackend
parameter. Either implement backend validation or remove the parameter if all backends are universally supported.If backend validation is not needed in the base class:
- def supports_backend(self, backend: str) -> bool: + def supports_backend(self, backend: str = None) -> bool: """ Override if the speculation algorithm does not support a subset of the possible backends. """ return TrueOr implement actual validation if certain backends should be restricted.
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (1)
256-257
: Document and safeguard VANILLA backend’s tokens_per_block override
Theif attn_backend == "VANILLA": tokens_per_block = max_num_tokens
override is intentional (estimation is disabled in_util.py
lines 216–221), but should be explicitly documented atpy_executor_creator.py:256–257
and supplemented with an OOM safeguard or configurable warning whenmax_num_tokens
is large.tensorrt_llm/_torch/pyexecutor/model_engine.py (1)
268-271
: Use the stored config for consistency.Pass self.sparse_attention_config, not the local arg name.
- self.attn_backend = get_attention_backend( - pytorch_backend_config.attn_backend, - sparse_attn_config=sparse_attention_config) + self.attn_backend = get_attention_backend( + pytorch_backend_config.attn_backend, + sparse_attn_config=self.sparse_attention_config)cpp/tensorrt_llm/kernels/sparseAttentionKernels.h (1)
11-32
: Prefer const-correct pointers in parameter carrier.These arrays are read-only at call sites and in kernels; mark as const to express intent and enable more compiler checks. Adjust downstream uses if needed.
- int32_t* sparse_kv_indices{nullptr}; - int32_t* sparse_attn_indices{nullptr}; - int32_t* sparse_kv_offsets{nullptr}; - int32_t* sparse_attn_offsets{nullptr}; + int32_t const* sparse_kv_indices{nullptr}; + int32_t const* sparse_attn_indices{nullptr}; + int32_t const* sparse_kv_offsets{nullptr}; + int32_t const* sparse_attn_offsets{nullptr};cpp/tensorrt_llm/kernels/sparseAttentionKernels.cu (1)
105-114
: Avoid allocating unused dynamic shared memory.You already use typed shared TempStorage; setting smem_size > 0 reduces occupancy with no benefit.
- // Shared memory size. - size_t smem_size = sizeof(Pair) * 256; + // No dynamic shared memory needed; BlockReduce uses statically-declared storage. + size_t smem_size = 0; // Launch the kernel. gatherKvPageOffsetsKernel<256><<<grid, block, smem_size, stream>>>(output_kv_page_offsets, output_seq_lengths, kv_page_offsets, seq_lengths, sparse_params, batch_size, tokens_per_page, max_num_pages_per_seq);cpp/tests/unit_tests/kernels/sparseKvCacheTest.cu (3)
27-27
: Avoid file-scope using directives.File-scope
using namespace
can cause name collisions. Prefer explicit qualification or limit scope.Apply this diff to qualify usage explicitly or move into narrower scope:
-using namespace tensorrt_llm::kernels; +// Remove and qualify types explicitly, e.g., tensorrt_llm::kernels::KVBlockArray
168-184
: Consider checking CUDA errors in cleanup.The
cleanup()
function does not check return values fromcudaFree
. While cleanup in tests may silently ignore errors, consider usingTLLM_CUDA_CHECK
for consistency and to detect issues during development.Apply this diff to add error checking:
void cleanup() { if (mSparseKvIndicesDevice) - cudaFree(mSparseKvIndicesDevice); + TLLM_CUDA_CHECK(cudaFree(mSparseKvIndicesDevice)); if (mSparseKvOffsetsDevice) - cudaFree(mSparseKvOffsetsDevice); + TLLM_CUDA_CHECK(cudaFree(mSparseKvOffsetsDevice)); // ... (repeat for all cudaFree calls) }
366-426
: Remove magic number in sparse index calculation.Line 400 hardcodes
8
(total sparse tokens). This should be calculated dynamically or use a named constant to avoid breaking if test parameters change.Apply this diff:
+ int const total_sparse_tokens = hostSparseOffsets[mBatchSize]; // ... - int const sparse_idx_offset = head * 8 + global_sparse_idx; // 8 is total sparse tokens + int const sparse_idx_offset = head * total_sparse_tokens + global_sparse_idx;cpp/tensorrt_llm/common/attentionOp.cpp (3)
916-938
: XQA sparse workspace sizing: make arithmetic size_t‑safe and document layoutCurrent formula packs two buffers into one size. Cast early to size_t to avoid 32‑bit overflow and clarify intent in code comments.
Apply:
- int const XQA_NUM_BUFFERS = 8; + int const XQA_NUM_BUFFERS = 8; @@ - // Two workspaces for sparse attention. One for the sequence lengths, and one for kv block offsets. - size_t const sparse_attn_cache_size = useTllmGenSparseAttention() - ? sizeof(int) * (batch_beam + batch_beam * 2 * max_blocks_per_sequence) * mNumKVHeads - : 0; + // Two workspaces for sparse attention packed contiguously: + // [seq_lens_kv_per_head (batch_beam)] + [kv_page_offsets_per_head (batch_beam * 2 * max_blocks_per_sequence)] + size_t const sparse_attn_cache_size = useTllmGenSparseAttention() + ? static_cast<size_t>(sizeof(int)) + * static_cast<size_t>(mNumKVHeads) + * (static_cast<size_t>(batch_beam) + + static_cast<size_t>(batch_beam) * 2u * static_cast<size_t>(max_blocks_per_sequence)) + : 0;
1692-1694
: is_last_chunk condition likely wrong for chunked prefillComparing input_seq_length to max_past_kv_length doesn’t detect “last chunk” when prefill is split; it flags true only when there’s a single chunk. Consider a condition based on cumulative processed length (or an explicit flag from the caller).
1860-1865
: Postprocess should be gated on KV cacheinvokeKvCachePostprocessing() should be a no‑op when KV cache is disabled; guard explicitly to avoid unnecessary work.
Apply:
- if (!mIsMLAEnabled) // Only for non-MLA attention - { - invokeKvCachePostprocessing(preprocessingParams, stream); - sync_check_cuda_error(stream); - } + if (!mIsMLAEnabled && useKVCache()) + { + invokeKvCachePostprocessing(preprocessingParams, stream); + sync_check_cuda_error(stream); + }tensorrt_llm/_torch/attention_backend/vanilla.py (1)
193-195
: Docstring/return order mismatchDocstring says (is_causal, attn_mask) but function returns (attn_mask, is_causal). Align the docstring to avoid confusion.
Apply:
- Returns: - Tuple of (is_causal, attn_mask) + Returns: + Tuple of (attn_mask, is_causal)Also applies to: 222-222
cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/cubin/kernelMetaInfo.h
Outdated
Show resolved
Hide resolved
cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/cubin/kernelMetaInfo.h
Outdated
Show resolved
Hide resolved
cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/cubin/kernelMetaInfo.h
Outdated
Show resolved
Hide resolved
a9917b0
to
4489018
Compare
/bot run --disable-fail-fast |
PR_Github #20651 [ run ] triggered by Bot |
PR_Github #20651 [ run ] completed with state |
/bot run --disable-fail-fast |
PR_Github #20738 [ run ] triggered by Bot |
/bot kill |
/bot run --disable-fail-fast |
PR_Github #21197 [ run ] completed with state |
add abstract level for sparse attention minor code adjustment change RocketKVCacheManager to one that inherts from KVCacheManager. fix format. Signed-off-by: Fanrong Li <[email protected]> refactor sparse attention backend. Signed-off-by: Fanrong Li <[email protected]> fix format. update code design and details fix some accuracy bugs minor test Signed-off-by: yuhangh <[email protected]> fix bugs when seq_len<budget Signed-off-by: yuhangh <[email protected]> add longbench evaluation Signed-off-by: yuhangh <[email protected]> Fix accuracy issues and abstract the forward logic Signed-off-by: yuhangh <[email protected]> fuse fanrong's updates: added gather kernel Signed-off-by: yuhangh <[email protected]> rm some unrelated changes. Signed-off-by: Fanrong Li <[email protected]> fix vanilla dense and sparse attention. Signed-off-by: Fanrong Li <[email protected]> fix rocket. Signed-off-by: Fanrong Li <[email protected]> fix. Signed-off-by: Fanrong Li <[email protected]> remove sparse attention base class and refactor vanilla. Signed-off-by: Fanrong Li <[email protected]> fix seq_len. Signed-off-by: Fanrong Li <[email protected]> refactor get_cache_size_per_token for kvCacheManager to support new added sparse manager. Signed-off-by: Fanrong Li <[email protected]> Add longbench evaluation scripts Signed-off-by: yuhangh <[email protected]> disable estimate kv cache for vanilla attention backend. Signed-off-by: Fanrong Li <[email protected]> fix longbench README. Signed-off-by: Fanrong Li <[email protected]> fix longbench_v1. Signed-off-by: Fanrong Li <[email protected]>
Signed-off-by: yuhangh <[email protected]> Update sparse attention parameters passing logic Signed-off-by: yuhangh <[email protected]> fix rebase breaks Signed-off-by: yuhangh <[email protected]> Signed-off-by: Fanrong Li <[email protected]> [None][feat] add gatherKvPageOffsetsKernel (#32) * add gatherKvPageOffsetsKernel. Signed-off-by: Fanrong Li <[email protected]> * fix. Signed-off-by: Fanrong Li <[email protected]> * fix. Signed-off-by: Fanrong Li <[email protected]> --------- Signed-off-by: Fanrong Li <[email protected]> Add sparse kv indices write kernel & fix several bugs Signed-off-by: yuhangh <[email protected]> fix for rebase Signed-off-by: yuhangh <[email protected]> Signed-off-by: Fanrong Li <[email protected]> [None][feat] integrate block sparse attention kernels (#33) * integrate block sparse attention kernels. Signed-off-by: Fanrong Li <[email protected]> * fix. Signed-off-by: Fanrong Li <[email protected]> * Support num_kv_heads in seq_len & fix several workspace size bugs Signed-off-by: yuhangh <[email protected]> * update block sparse attention kernel to support per-head kv_len. Signed-off-by: Fanrong Li <[email protected]> * minor fix Signed-off-by: yuhangh <[email protected]> * update kernel meta info. * add more block sparse kernels. * disable rope_fusion for sparse attention. Signed-off-by: Fanrong Li <[email protected]> * fix block sparse attention kernels. * update block sparse attention kernel. Signed-off-by: Fanrong Li <[email protected]> * fix workspace issue. Signed-off-by: Fanrong Li <[email protected]> * minor fix Signed-off-by: yuhangh <[email protected]> * fix gatherKvPageOffsetsKernel. Signed-off-by: Fanrong Li <[email protected]> * remove cuda stream sync. Signed-off-by: Fanrong Li <[email protected]> --------- Signed-off-by: Fanrong Li <[email protected]> Signed-off-by: yuhangh <[email protected]> Co-authored-by: yuhangh <[email protected]> Signed-off-by: Fanrong Li <[email protected]> [None][feat] change the sparse indices format and update the gatherKvPageOffsetsKe… (#34) * change the sparse indices format and update the gatherKvPageOffsetsKernel. Signed-off-by: Fanrong Li <[email protected]> * update kv write & optimize logic of using tllmgen kernels Signed-off-by: yuhangh <[email protected]> --------- Signed-off-by: Fanrong Li <[email protected]> Signed-off-by: yuhangh <[email protected]> Co-authored-by: yuhangh <[email protected]> Signed-off-by: Fanrong Li <[email protected]> add paged kt cache (1st commit). Signed-off-by: Fanrong Li <[email protected]> minnor fix. Signed-off-by: Fanrong Li <[email protected]> fix _single_request_update_kt_cache for vanilla RocketKV. Signed-off-by: Fanrong Li <[email protected]> add paged kt cache to rocketkv trtllm. Signed-off-by: Fanrong Li <[email protected]> fix _single_request_update_kt_cache for trtllm RocketKV. Signed-off-by: Fanrong Li <[email protected]> fix k_snap length. Signed-off-by: Fanrong Li <[email protected]> fix memory issue when using paged kt cache. Signed-off-by: Fanrong Li <[email protected]> fix rebase breaks Signed-off-by: yuhangh <[email protected]> fix rebase bug. Signed-off-by: Fanrong Li <[email protected]> fix rebase bug. Signed-off-by: Fanrong Li <[email protected]> update block sparse attention kernel. Signed-off-by: Fanrong Li <[email protected]> fix params issue Signed-off-by: yuhangh <[email protected]> Signed-off-by: Fanrong Li <[email protected]> [None][feat] Do sparse attention functional clean (#43) * fix several bugs & adjust some code Signed-off-by: yuhangh <[email protected]> * minor code clean Signed-off-by: yuhangh <[email protected]> * Add simple unittest for rocketkv Signed-off-by: yuhangh <[email protected]> * Adjustment for sparse attention params and example Signed-off-by: yuhangh <[email protected]> * fix bugs introduced by last commit Signed-off-by: yuhangh <[email protected]> * Optimize Xqa_params and num_sparse_kv_tokens Signed-off-by: yuhangh <[email protected]> * Fix gather kernel & minor adjustment Signed-off-by: yuhangh <[email protected]> * Rename sparse_attention_params in xqa_params Signed-off-by: yuhangh <[email protected]> * minor Signed-off-by: yuhangh <[email protected]> --------- Signed-off-by: yuhangh <[email protected]> Signed-off-by: Fanrong Li <[email protected]> [None][feat] Update trtllm-gen fmha kernels and remove block sparse cubins (#44) * rm sparse kernels. Signed-off-by: Fanrong Li <[email protected]> * update new kernel. Signed-off-by: Fanrong Li <[email protected]> * update trtllm-gen fmha. Signed-off-by: Fanrong Li <[email protected]> --------- Signed-off-by: Fanrong Li <[email protected]> fix rebase conflicts Signed-off-by: yuhangh <[email protected]> minor fix Signed-off-by: yuhangh <[email protected]> pre-commit fix Signed-off-by: yuhangh <[email protected]> Signed-off-by: Fanrong Li <[email protected]> [None][fix] update trtllm sparse attention interface (#45) * update trtllm sparse attention interface. Signed-off-by: Fanrong Li <[email protected]> * fix interface. Signed-off-by: Fanrong Li <[email protected]> --------- Signed-off-by: Fanrong Li <[email protected]> fix rocketkv interface. (#47) Signed-off-by: Fanrong Li <[email protected]>
Signed-off-by: yuhangh <[email protected]>
Signed-off-by: yuhangh <[email protected]>
Signed-off-by: yuhangh <[email protected]>
Signed-off-by: yuhangh <[email protected]>
Signed-off-by: yuhangh <[email protected]>
Signed-off-by: yuhangh <[email protected]>
Signed-off-by: yuhangh <[email protected]>
Signed-off-by: yuhangh <[email protected]>
Signed-off-by: Fanrong Li <[email protected]>
03bfa99
to
d6a1c17
Compare
/bot run --disable-fail-fast |
PR_Github #21238 [ run ] triggered by Bot |
PR_Github #21238 [ run ] completed with state |
Signed-off-by: Fanrong Li <[email protected]>
/bot run --disable-fail-fast |
PR_Github #21300 [ run ] triggered by Bot |
PR_Github #21300 [ run ] completed with state |
/bot run |
PR_Github #21336 [ run ] triggered by Bot |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR_Github #21336 [ run ] completed with state |
Summary by CodeRabbit
New Features
Documentation
Tests
Description
Add sparse attention RocketKV support in this PR:
Limitation
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...
Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]
to print this help message.See details below for each supported subcommand.
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]
Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id
(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test
(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast
(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test
(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"
(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"
(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"
(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test
(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test
(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test
(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge
(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"
(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log
(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug
(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-list
parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.md
and the
scripts/test_to_stage_mapping.py
helper.kill
kill
Kill all running builds associated with pull request.
skip
skip --comment COMMENT
Skip testing for latest commit on pull request.
--comment "Reason for skipping build/test"
is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipeline
Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.