Skip to content

Conversation

lfr-0531
Copy link
Collaborator

@lfr-0531 lfr-0531 commented Sep 30, 2025

Summary by CodeRabbit

  • New Features

    • Added optional sparse attention (RocketKV) end-to-end: kernel support, KV cache manager, planner wiring, and runtime flags across C++ and Python/Torch backends.
    • New config types (RocketSparseAttentionConfig) and updated attention APIs to accept grouped config params and sparse indices/offsets.
    • Triton/Python helpers for sparse KT-cache operations and Rocket-based sparse attention backends.
  • Documentation

    • LongBench README and example scripts for sparse attention evaluations.
  • Tests

    • New unit tests for sparse kernels and sparse KV cache; integration tests for RocketKV.

Description

Add sparse attention RocketKV support in this PR:

  • Add RocketKV support to vanilla attention backend
  • Add RocketKV support to trtllm attention backend (only for Blackwell)
  • Add longbench tests

Limitation

  • It is still an experimental feature. The performance is not good, and we'll optimize it in a near-future PR.
  • This new feature was only evaluated with llama-3.1-8B model. More models will be supported in the future.
  • There’s no comprehensive accuracy verification. We’ve only done a simple accuracy check.

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@heyuhhh heyuhhh force-pushed the feat/sparse_attention_func branch from 59b51d2 to 9d4b051 Compare September 30, 2025 12:25
@lfr-0531 lfr-0531 marked this pull request as ready for review October 1, 2025 02:21
@lfr-0531 lfr-0531 requested review from a team as code owners October 1, 2025 02:21
@lfr-0531
Copy link
Collaborator Author

lfr-0531 commented Oct 1, 2025

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20437 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20437 [ run ] completed with state DISABLED
L0 testing is limited to prioritized users. User lfr-0531 is not in the prioritized list. L0 testing cannot be triggered.

Copy link
Contributor

coderabbitai bot commented Oct 1, 2025

📝 Walkthrough

Walkthrough

Adds end-to-end sparse (RocketKV/block) attention: new sparse kernels and params, KV-cache sparse update paths, workspace sizing (max_blocks_per_sequence), C++/CUDA kernel additions, Python backend integration, model/config wiring, bindings, examples, and tests; public APIs and some function signatures were expanded to carry sparse parameters.

Changes

Cohort / File(s) Summary
Core C++ attention & XQA wiring
cpp/tensorrt_llm/common/attentionOp.h, cpp/tensorrt_llm/common/attentionOp.cpp, cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/xqaParams.h, cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h, cpp/tensorrt_llm/kernels/xqaDispatcher.cpp, cpp/tensorrt_llm/kernels/unfusedAttentionKernels.h, cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h
Introduces SparseAttentionParams and use_sparse_attention flags into XQA paths; adds preprocessingParams sparse fields and is_last_chunk; extends AttentionOp getWorkspaceSizeForGeneration signatures (adds max_blocks_per_sequence), increases XQA buffer count; wires sparse KV indices/offsets through preprocessing and post-FMHA sparse KV update paths.
Sparse attention kernels (CUDA)
cpp/tensorrt_llm/kernels/sparseAttentionKernels.h, cpp/tensorrt_llm/kernels/sparseAttentionKernels.cu
Adds SparseAttentionParams struct and new CUDA kernel+launcher gatherKvPageOffsets plus related device-side Pair/PairReduceOp utilities and public invokeGatherKvPageOffsets declaration.
FMHA kernel metadata & runner params
cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/cubin/*_cubin.cpp (LFS pointer updates), cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/kernelParams.h, cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunnerParams.h, .../kernelMetaInfo.h
Updates cubin LFS pointers; switches kernel meta entries to Swaps/PersistentSwaps variants; adds mUseBlockSparseAttention flag into fmha runner/kernel params.
THOP attention public API & op wiring
cpp/tensorrt_llm/thop/attentionOp.h, cpp/tensorrt_llm/thop/attentionOp.cpp, cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp
Reworks attention(...) API to accept attention_config_params, rotary params, and sparse_attention_params (optional tensors); threads sparse params into Runner/RunnerBase run and workspace computations; passes max_blocks_per_sequence into generation workspace sizing.
Bindings and caster includes
cpp/tensorrt_llm/pybind/thop/bindings.cpp, cpp/tensorrt_llm/nanobind/thop/bindings.cpp, cpp/tensorrt_llm/pybind/common/customCasters.h, cpp/tensorrt_llm/nanobind/common/customCasters.h
Reworks Python/nanobind attention bindings to accept config vectors and sparse_attention_params; adds c10::ArrayRef include in custom casters.
Unfused / XQA launch params
cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h
Adds launch params for sparse (sparse_kv_block_offsets, sparse_seq_lengths) and reserves workspace space when use_sparse_attention is enabled.
Unfused sparse KV update kernels
cpp/tensorrt_llm/kernels/unfusedAttentionKernels.h, .../unfusedAttentionKernels_2_template.h
Declares and implements updateSparseKvCacheAfterFmha dispatch and kernels; adds invokeUpdateSparseKvCacheAfterFmha and debug helper; integrates sparse path into KvCachePostprocessing.
Python Torch backends & sparse package
tensorrt_llm/_torch/attention_backend/{__init__.py,interface.py,utils.py,trtllm.py,vanilla.py}, tensorrt_llm/_torch/attention_backend/sparse/{__init__.py,utils.py,kernel.py,rocket.py}
Adds sparse_attention_config plumbing across backends; new sparse package with Rocket implementations, KV cache manager, Triton-backed kernels (triton_index_gather, triton_update_kt_cache), triton helpers, and selection utils; updates TrtllmAttentionWrapper.plan to accept sparse indices and adds abstract sparse predict hooks.
Model, modules, executor, resource manager
tensorrt_llm/_torch/model_config.py, tensorrt_llm/_torch/modules/attention.py, tensorrt_llm/_torch/pyexecutor/{model_engine.py,model_loader.py,py_executor_creator.py,_util.py,resource_manager.py}
Adds sparse_attention_config to ModelConfig and propagates it through attention creation and PyExecutor stack; selects appropriate KV cache manager; refactors KV cache sizing APIs and resource manager interfaces to support cache-per-token calculations and **kwargs for manager constructors.
Python sparse API & utils
tensorrt_llm/llmapi/llm_args.py, tensorrt_llm/llmapi/__init__.py, tensorrt_llm/_torch/attention_backend/__init__.py, tensorrt_llm/_utils.py
Introduces RocketSparseAttentionConfig and SparseAttentionConfig type, exposes RocketSparseAttentionConfig in llmapi, re-exports sparse KV cache manager, and adds next_power_of_two utility.
Examples & LongBench
examples/llm-api/llm_sparse_attention.py, examples/longbench/{README.md,eval_longbench_v1.py,eval_longbench_v2.py,requirements.txt}
Adds example/demo for sparse attention and two LongBench evaluation scripts (v1/v2) with docs and dependencies.
Tests
cpp/tests/unit_tests/kernels/{sparseAttentionKernelsTest.cpp,sparseKvCacheTest.cu}, cpp/tests/unit_tests/kernels/CMakeLists.txt, tests/unittest/_torch/attention/sparse/test_rocketkv.py
Adds CUDA/GTest unit tests for gatherKvPageOffsets and sparse KV cache update; adds Python unit test for RocketKV attention correctness; updates CMake test targets.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor User
  participant PyLLM as LLM (Python)
  participant AttnMod as Attention module (Py)
  participant Backend as AttentionBackend (Py)
  participant KVMgr as KVCacheManager (Py)
  participant Wrapper as TrtllmAttentionWrapper (Py)
  participant THOP as THOP attention (C++)
  participant AttentionOp as AttentionOp (C++)
  participant Dispatcher as XQA/FMHA Dispatcher (C++)
  participant Kernels as CUDA Kernels

  User->>PyLLM: generate(prompts, sparse_attention_config)
  PyLLM->>AttnMod: create attention (sparse config)
  AttnMod->>Backend: get_attention_backend(..., sparse_attn_config)
  Backend->>KVMgr: prepare KV resources (sparse)
  Backend->>Wrapper: plan(..., sparse_kv_indices/offsets, sparse_attn_indices/offsets)
  Wrapper->>THOP: call attention(..., sparse params, attention_config_params)
  THOP->>AttentionOp: enqueue/generate(..., sparse params, max_blocks_per_sequence)
  AttentionOp->>Dispatcher: build XQAParams (use_sparse_attention=true)
  Dispatcher->>Kernels: invokeGatherKvPageOffsets(...) (if block-sparse)
  Dispatcher->>Kernels: FMHA / XQA kernel launch
  Kernels-->>Dispatcher: outputs
  Dispatcher->>Kernels: invokeUpdateSparseKvCacheAfterFmha (post-FMHA sparse KV write)
  Dispatcher-->>AttentionOp: return
  AttentionOp-->>THOP: output
  THOP-->>Wrapper: return
  Wrapper-->>Backend: return tokens
Loading
sequenceDiagram
  autonumber
  participant THOP as THOP (C++)
  participant AO as AttentionOp (C++)
  participant Kern as Kernels (CUDA)

  rect rgba(220,240,255,0.25)
    note right of THOP: Context / Prefill
    THOP->>AO: preprocess (sparse_kv_indices, sparse_kv_offsets, is_last_chunk)
    AO->>Kern: FMHA
    AO->>Kern: invokeUpdateSparseKvCacheAfterFmha (sparse path)
  end

  rect rgba(220,255,220,0.25)
    note right of THOP: Generation
    THOP->>AO: preprocess (sparse_attn_indices, sparse_attn_offsets, max_blocks_per_sequence)
    AO->>Kern: gatherKvPageOffsets (block offsets -> seq lengths)
    AO->>Kern: FMHA/XQA (block-sparse)
  end
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly related PRs

Suggested reviewers

  • QiJune

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Description Check ⚠️ Warning The PR description lacks the required summary section (e.g., the @coderabbitai summary placeholder), leaves the “## Test Coverage” section empty, and still contains instructional comments at the top instead of actual content, so it does not conform to the repository’s description template. Please remove the instructional comment block, add a brief summary under the designated summary placeholder, and populate the “## Test Coverage” section with a list of relevant tests that exercise the new RocketKV and sparse attention paths.
Docstring Coverage ⚠️ Warning Docstring coverage is 36.76% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (1 passed)
Check name Status Explanation
Title Check ✅ Passed The title clearly and succinctly describes the primary enhancement introduced by this pull request: the addition of a sparse attention framework with a concrete RocketKV support use case, aligning directly with the PR’s main objective of integrating sparse attention functionality. It avoids extraneous details while conveying the core feature and its experimental context. The inclusion of the issue identifier and feature tag is standard practice and does not detract from clarity. Overall, the title enables a quick understanding of the PR’s purpose for reviewers scanning the commit history.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 40

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (11)
tensorrt_llm/_utils.py (1)

1-1: Update copyright year to 2025.

Per coding guidelines, the copyright header should include the current year. The file header shows 2022-2024, but this PR is from 2025.

Apply this diff:

-# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/kernelParams.h (1)

2-2: Update copyright year to 2025.

The copyright header shows 2020-2024, but this PR is from 2025. Per coding guidelines, update to include the current year.

Apply this diff:

- * Copyright (c) 2020-2024, NVIDIA CORPORATION.  All rights reserved.
+ * Copyright (c) 2020-2025, NVIDIA CORPORATION.  All rights reserved.
tensorrt_llm/_torch/attention_backend/__init__.py (1)

1-16: Add NVIDIA Apache-2.0 copyright header.

According to the coding guidelines for Python files, prepend the NVIDIA Apache-2.0 copyright header with the current year (2025) to the top of this source file.

Example format:

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# ...

As per coding guidelines.

cpp/tensorrt_llm/kernels/xqaDispatcher.cpp (1)

271-296: Workspace budget must include the new sparse RocketKV buffers

The new sparse_kv_block_offsets/sparse_seq_lengths carve kv_block_offsets_size + seq_lengths_size bytes out of params.workspaces, but neither XqaDispatcher::getWorkspaceSize() nor the upstream workspace sizing has been expanded to reserve them. In common generation settings (e.g., batch 4, beam 1, num_kv_heads 16, max_blocks_per_sequence 512) this extra demand exceeds 250 KB; with the current budget derived only from max_num_tokens, buildXQALaunchParams will write past the provided buffer and corrupt memory. Please bump the workspace calculation (here and in any caller that pre-allocates params.workspaces) by those sparse buffers or an equivalent worst-case bound tied to max_blocks_per_sequence.

tensorrt_llm/_torch/pyexecutor/_util.py (2)

1-1: Add NVIDIA Apache-2.0 header (2025).

Please prepend the standard header to comply with licensing guidelines.


499-518: Avoid passing unsupported keyword to KVCacheManager constructors.

sparse_attn_config is always passed, even when the selected class is the vanilla KVCacheManager, which likely doesn’t accept it → TypeError at runtime. Gate the kwarg.

-            kv_cache_manager = self._kv_cache_manager_cls(
+            extra_kwargs = {}
+            if sparse_attn_config is not None:
+                extra_kwargs["sparse_attn_config"] = sparse_attn_config
+            kv_cache_manager = self._kv_cache_manager_cls(
                 self._kv_cache_config,
                 tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
@@
-                kv_connector_manager=self._kv_connector_manager
-                if not estimating_kv_cache else None,
-                sparse_attn_config=sparse_attn_config,
+                kv_connector_manager=self._kv_connector_manager if not estimating_kv_cache else None,
+                **extra_kwargs,
             )
tensorrt_llm/_torch/attention_backend/utils.py (1)

1-1: Add NVIDIA Apache-2.0 header to utils.py
Prepend the standard NVIDIA Apache-2.0 copyright header with the current year (2025) at the top of tensorrt_llm/_torch/attention_backend/utils.py.

cpp/tensorrt_llm/kernels/sparseAttentionKernels.h (1)

58-60: Close include guard.

 } // namespace kernels
 } // namespace tensorrt_llm
+
+#endif // TRTLLM_SPARSEATTENTIONKERNELS_H

As per coding guidelines.

tensorrt_llm/_torch/attention_backend/sparse/kernel.py (1)

210-309: Generation-path shape bug: documented k is (total_seq_len, …) but gen kernel treats k as (batch_size, …). Fix API or indexing.

Current code computes k_base = k_ptr + batch_idx*hidden_size + head_idx*dim_size, i.e., assumes k.shape == (batch_size, num_kv_heads, head_dim). Either (A) change gen-path to address the last token in the flattened total_seq_len layout, or (B) document and enforce k to be (batch_size, num_kv_heads, head_dim) in generation. Option A keeps a single input contract.

Option A (index last token per batch in flattened k):

-        grid = (batch_size, num_kv_heads)
-        _update_kt_cache_gen_kernel[grid](k,
+        # Expect k as (total_seq_len, num_kv_heads, head_dim)
+        grid = (batch_size, num_kv_heads)
+        # Compute per-batch base offsets to the last token
+        cum_seq_lens = torch.cumsum(torch.cat([torch.zeros(1, device=k.device, dtype=torch.long),
+                                               seq_lens.to(torch.long)]), dim=0)
+        last_token_offsets = (cum_seq_lens[:-1] + (seq_lens.to(torch.long) - 1)).contiguous()
+        _update_kt_cache_gen_kernel[grid](k,
                                           kt_cache_tensor,
                                           kt_cache_block_offsets,
-                                          seq_lens,
+                                          last_token_offsets,  # pass absolute last-token indices
                                           num_kv_heads,
                                           head_dim,
                                           kt_page_size,
                                           tokens_per_block,
                                           max_kt_blocks_per_seq,
                                           BLOCK_SIZE=1024)

And update _update_kt_cache_gen_kernel to consume absolute token indices:

-    past_key_value_length = tl.load(seq_lens_ptr + batch_idx) - 1
-    kt_token_idx = past_key_value_length // kt_page_size
-    kt_token_idx_in_page = past_key_value_length % kt_page_size
+    last_tok = tl.load(seq_lens_ptr + batch_idx)  # absolute index in flattened K
+    kt_token_idx = last_tok // kt_page_size
+    kt_token_idx_in_page = last_tok % kt_page_size
-    k_base = k_ptr + batch_idx * hidden_size + head_idx * dim_size
+    k_base = k_ptr + last_tok * hidden_size + head_idx * dim_size

If you prefer Option B, update the function docstring and assert k.shape[0] == batch_size, then keep current indexing. Please confirm which contract you want. Based on learnings.

tensorrt_llm/_torch/pyexecutor/model_engine.py (1)

140-153: Import SparseAttentionConfig so lint stops failing

Ruff is currently raising F821 on the new type hint because SparseAttentionConfig is never imported. The string annotation is OK at runtime, but the lint failure will break the build. Pull the symbol in under TYPE_CHECKING to satisfy the checker without affecting execution:

-from typing import Any, Callable, Dict, List, Optional, Tuple
+from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING
@@
-from ..attention_backend.interface import (AttentionMetadata,
-                                           AttentionRuntimeFeatures)
+from ..attention_backend.interface import (AttentionMetadata,
+                                           AttentionRuntimeFeatures)
+
+if TYPE_CHECKING:
+    from ..attention_backend.config import SparseAttentionConfig

Any equivalent import location is fine as long as SparseAttentionConfig resolves during linting.

tensorrt_llm/_torch/attention_backend/vanilla.py (1)

1-18: Add NVIDIA Apache-2.0 header.

All source files must carry the NVIDIA Apache-2.0 header.

+#
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
 import math
 from typing import Optional

As per coding guidelines.

🧹 Nitpick comments (36)
cpp/tensorrt_llm/nanobind/common/customCasters.h (1)

29-29: Remove unused ArrayRef include
The header <c10/util/ArrayRef.h> isn’t referenced in this file or any Nanobind bindings—delete it.

tensorrt_llm/_utils.py (1)

216-224: Consider enhancing the docstring.

The function is correct, but the docstring could be more descriptive following Google-style format with parameter and return value documentation.

Apply this diff:

 def next_power_of_two(x):
     """
-    get next power of two
+    Get the next power of two greater than or equal to x.
+    
+    Args:
+        x: Input integer value.
+        
+    Returns:
+        The next power of two. Returns 1 for x <= 0.
+        If x is already a power of two, returns x unchanged.
     """
examples/llm-api/llm_sparse_attention.py (3)

17-78: Add docstring and consider removing hardcoded default paths.

The function needs a docstring per coding guidelines. Additionally, the hardcoded default paths (lines 23, 28) may not be appropriate for all users.

Apply this diff:

 def parse_arguments():
+    """
+    Parse command-line arguments for sparse attention evaluation.
+    
+    Returns:
+        Parsed arguments namespace.
+    """
     parser = argparse.ArgumentParser()
     parser.add_argument(
         '--model_path',
         type=str,
-        default=
-        "/home/scratch.trt_llm_data/llm-models/llama-3.1-model/Llama-3.1-8B-Instruct"
+        required=True,
+        help="Path to the model directory"
     )
     parser.add_argument(
         '--input_file',
         type=str,
-        default="tests/unittest/_torch/multi_gpu/test_star_attention_input.jsonl"
+        required=True,
+        help="Path to the input JSONL file"
     )

81-129: Add docstring for run_RocketKV function.

Per coding guidelines, all functions should have Google-style docstrings.

Apply this diff:

 def run_RocketKV(args):
+    """
+    Run RocketKV sparse attention evaluation.
+    
+    Args:
+        args: Parsed command-line arguments containing model path, input file,
+              and configuration parameters.
+    """
     data = read_input(args.input_file)

132-137: Add docstring and improve exception message.

Per coding guidelines and static analysis, add a docstring and consider using a custom exception class for better error handling.

Apply this diff:

 def main():
+    """
+    Main entry point for the sparse attention example script.
+    """
     args = parse_arguments()
     if args.algo == 'ROCKETKV':
         run_RocketKV(args)
     else:
-        raise ValueError(f"Invalid algorithm: {args.algo}")
+        raise ValueError(f"Unsupported algorithm: {args.algo}")
tensorrt_llm/_torch/pyexecutor/resource_manager.py (4)

166-167: Avoid silently swallowing unknown kwargs.

init accepts **kwargs but ignores them. Either document supported keys and validate, or drop the parameter to surface misconfig early.

Apply minimal safety now:

-        **kwargs,
+        **_unused_kwargs,

Optionally log unexpected keys in debug builds. Would you like me to wire a strict validator?


603-619: Parity with RocketKV KT cache factor.

get_cache_bytes_per_token omits KT cache overhead used by RocketKV (2 * kt_tokens_per_block / tokens_per_block). If this Python path ever sizes RocketKV, it will undercount memory.

Option: thread a kt_tokens_per_block and tokens_per_block (when sparse enabled) and fold the factor as in sparse. If this method is strictly for dense KV, please add a comment stating KT is excluded to avoid misuse.


620-659: Clarify units and guard divisions in calculate_max_num_blocks.

  • max_tokens is float; later ceil(max_tokens / tokens_per_block) is fine, but consider explicit floor/ceil commentary for readability.
  • Guard cache_size_bytes_per_token > 0 (defensive).
  • For multi-device: calling mpi_comm/MPI only when ENABLE_MULTI_DEVICE is compiled; if mapping.world_size > 1 without that build flag will raise. Add a feature guard or assert ENABLE_MULTI_DEVICE when world_size > 1.

I can add a small helper to centralize memory-to-blocks math with asserts if you want.


563-602: Rename unused kwargs and align per-token head count with runtime allocation

  • Rename **kwargs to **_kwargs (it’s never used).
  • If num_key_value_heads is an Iterable, use sum(num_key_value_heads) instead of averaging to mirror get_cache_bytes_per_token.

Optionally, delegate to get_cache_bytes_per_token for full consistency.

cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h (4)

1581-1583: Debug logs in hot path: gate or lower frequency.

TLLM_LOG_DEBUG in per-dispatch paths can be noisy. Consider gating by an env flag already read via getEnvEnablePDL() or use trace-level.

Also applies to: 1606-1610


1714-1796: Sparse KV post-FMHA kernel: validate index layout and grid utilization.

  • sparse_kv_indices indexing uses kv_head_idx * num_sparse_kv_tokens + global_sparse_idx; this assumes a [kv_heads, sum_B_tokens] flattening. Please confirm producer matches this layout. A mismatch will corrupt KV.
  • grid.x is 1 and the kernel loops over tokens. This is correct but can underutilize SMs for long sparse lists. Consider setting grid.x = ceil_div(num_sparse_tokens, tokens_per_block) and letting the loop stride by gridDim.x to improve occupancy.

If shapes are stable, I can propose a host-side grid.x computation based on params.sparse_kv_offsets for better scaling.


1734-1765: Shared memory bounds and bank use: OK, add static checks for large Dh.

smem size = 2 * block.y * VECS_PER_HEAD * 16B. With Dh=256 and TCache=fp16, this is 32KB; safe. Add a compile-time/assert check to ensure size <= device limit (e.g., 96KB/SM) to prevent launch failures on configs with larger Dh.


1800-1811: kernelSparseDispatchHeadSize: dynamic grid.x and smem calc.

  • grid.x fixed at 1; consider dynamic token tiling as above.
  • smem size computation is correct; add TLLM_CHECK for smem size <= device attr sharedMemPerBlockOptin when PDL is enabled to fail fast.
cpp/tensorrt_llm/nanobind/thop/bindings.cpp (1)

44-55: Binding surface changed: document and validate new params.

  • New required/optional args (attention_config_params, rotary_embedding_int_params, sparse_attention_params). Please update Python docs/examples and add runtime validation for incompatible combos (e.g., sparse params with cross-attn if unsupported).

I can draft a minimal docstring and a unit smoke test ensuring nanobind signature matches torch_ext::attention signature.

cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp (1)

575-576: Consider using nbDims - 1 instead of hard-coded dimension index.

Line 576 hard-codes dimension index [3] to extract max_blocks_per_sequence, but line 867 in the same file uses kvCacheBlockOffsetsShape.d[kvCacheBlockOffsetsShape.nbDims - 1] for the same purpose. The latter approach is more robust and consistent.

Apply this diff to improve consistency and robustness:

-    int const max_blocks_per_sequence
-        = (useKVCache() && mPagedKVCache) ? inputs[getIdx(IdxEntry::KV_CACHE_BLOCK_OFFSETS)].dims.d[3] : 0;
+    auto const& kvCacheBlockOffsetsShape = inputs[getIdx(IdxEntry::KV_CACHE_BLOCK_OFFSETS)].dims;
+    int const max_blocks_per_sequence = (useKVCache() && mPagedKVCache) 
+        ? kvCacheBlockOffsetsShape.d[kvCacheBlockOffsetsShape.nbDims - 1] 
+        : 0;
cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunnerParams.h (1)

2-2: Update copyright year to include 2025.

Since this file is being modified in 2025, update the copyright header to reflect the current year:

-* Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
+* Copyright (c) 2020-2025, NVIDIA CORPORATION.  All rights reserved.
cpp/tests/unit_tests/kernels/sparseAttentionKernelsTest.cpp (1)

15-15: Use CamelCase for test class name.

The test class name sparseAttentionKernelsTest should use CamelCase starting with an uppercase letter per C++ coding guidelines. Consider renaming to SparseAttentionKernelsTest.

-class sparseAttentionKernelsTest : public ::testing::Test
+class SparseAttentionKernelsTest : public ::testing::Test

And update the TEST_F macro accordingly:

-TEST_F(sparseAttentionKernelsTest, GatherKvPageOffsetsKernelTest)
+TEST_F(SparseAttentionKernelsTest, GatherKvPageOffsetsKernelTest)

As per coding guidelines.

tensorrt_llm/_torch/pyexecutor/_util.py (2)

9-24: Remove duplicate ModelConfig import to avoid ambiguity.

Two ModelConfig imports (lines 9 and 23) are redundant and can confuse type checkers. Keep one.

-from tensorrt_llm._torch.model_config import ModelConfig
@@
-from ..model_config import ModelConfig

206-212: Vanilla backend disables capacity estimation; confirm expected behavior with sparse VANILLA.

If VANILLA + sparse is used, estimation is disabled unconditionally. If that’s intentional, add a TODO linking to planned support; otherwise, guard on sparse config as well.

cpp/tensorrt_llm/common/attentionOp.h (1)

160-173: Debug dump of context/sequence lengths: keep cost low.

The toString path wraps host pointers into ITensor; ensure this is compile-time only (debug) or guarded to avoid perf overhead in production logs.

cpp/tensorrt_llm/kernels/sparseAttentionKernels.h (2)

11-32: Make sparse parameter pointers const; fix data() const‑correctness and document the public struct.

Indices/offsets are read-only; expose them as int32_t const*. Return a tuple of const pointers. Add brief Doxygen.

-struct SparseAttentionParams
+//! Parameters for sparse attention indices/offsets on device.
+struct SparseAttentionParams
 {
-    int32_t* sparse_kv_indices{nullptr};   // [num_kv_heads, num_sparse_kv_indices]
-    int32_t* sparse_attn_indices{nullptr}; // [num_kv_heads, num_sparse_attn_indices]
-    int32_t* sparse_kv_offsets{nullptr};   // [num_contexts + 1]
-    int32_t* sparse_attn_offsets{nullptr}; // [num_generations + 1]
+    //!< [num_kv_heads, num_sparse_kv_indices]
+    int32_t const* sparse_kv_indices{nullptr};
+    //!< [num_kv_heads, num_sparse_attn_indices]
+    int32_t const* sparse_attn_indices{nullptr};
+    //!< [num_contexts + 1]
+    int32_t const* sparse_kv_offsets{nullptr};
+    //!< [num_generations + 1]
+    int32_t const* sparse_attn_offsets{nullptr};
@@
-    auto data() const
+    std::tuple<int32_t const*, int32_t const*, int32_t const*, int32_t const*> data() const
     {
         return std::make_tuple(sparse_kv_indices, sparse_attn_indices, sparse_kv_offsets, sparse_attn_offsets);
     }

As per coding guidelines.


51-56: Document the API with Doxygen; align parameter naming and constness.

Add brief Doxygen and keep parameter naming consistent (num_kv_heads vs num_head_kv). Keep as declared, but consider renaming for consistency in future PRs. No behavior change needed.

-void invokeGatherKvPageOffsets(int32_t* output_kv_page_offsets, // [num_head_kv, batch_size, 2, max_num_pages_per_seq]
+//! Gathers KV page offsets per head/batch using sparse params.
+//! \param output_kv_page_offsets [num_head_kv, batch_size, 2, max_num_pages_per_seq]
+//! \param output_seq_lengths     [num_head_kv, batch_size]
+//! \param kv_page_offsets        [batch_size, 2, max_num_pages_per_seq]
+//! \param seq_lengths            [batch_size]
+void invokeGatherKvPageOffsets(int32_t* output_kv_page_offsets, // [num_head_kv, batch_size, 2, max_num_pages_per_seq]

Based on coding guidelines.

cpp/tests/unit_tests/kernels/sparseKvCacheTest.cu (1)

430-492: Verification helper assumes a specific KVBlockArray pool layout; keep but flag brittleness.

The linearization math is test-specific and may diverge from production layout. If the internal layout changes, this will silently miscompare. Consider adding static assertions or reading block strides/capacity from KVBlockArray if available, or limit comparisons to the first block for robustness.

tensorrt_llm/_torch/attention_backend/sparse/kernel.py (2)

210-309: Add minimal docstrings and input validation to public wrappers.

Provide clear contracts (shapes/dtypes), return types, and error messages. Also avoid shadowing the tokens_per_block arg.

-def triton_update_kt_cache(k,
+def triton_update_kt_cache(k,
                            kt_cache_tensor,
                            kt_cache_block_offsets,
                            seq_lens,
                            kt_page_size,
                            tokens_per_block,
                            max_kt_blocks_per_seq,
                            update=True):
-    # inputs:
-    # k: (total_seq_len, num_kv_heads, head_dim)
-    # kt_cache_tensor: (num_blocks, tokens_per_block, num_kv_heads, 2 * head_dim)
-    # kt_cache_block_offsets: (max_batch_size, max_kt_blocks_per_seq)
-    # seq_lens: (batch_size)
-    # kt_page_size: int
-    # update: bool
-
-    # outputs:
-    # kt_states: (total_kt_tokens, num_kv_heads, 2 * head_dim)
+    """
+    Update or load KT cache.
+    Args:
+        k: (total_seq_len, num_kv_heads, head_dim) in context; see gen-path note above.
+        kt_cache_tensor: (num_blocks, tokens_per_block, num_kv_heads, 2 * head_dim), CUDA tensor.
+        kt_cache_block_offsets: (max_batch_size, max_kt_blocks_per_seq), int32/int64 CUDA tensor.
+        seq_lens: (batch_size), lengths per batch, CUDA tensor.
+        kt_page_size: int.
+        tokens_per_block: int.
+        max_kt_blocks_per_seq: int.
+        update: bool. If False, context path; if True, generation path.
+    Returns:
+        None (context) or kt_states: (total_kt_tokens, num_kv_heads, 2 * head_dim).
+    """
+    assert k.is_cuda and kt_cache_tensor.is_cuda and seq_lens.is_cuda, "All inputs must be on CUDA."
+    assert kt_cache_tensor.size(1) == tokens_per_block, "tokens_per_block must match cache tensor."

As per coding guidelines.


73-207: Kernels: minor safety/clarity nits.

  • Use tl.full with explicit dtype from k_ptr to avoid fp16→fp32→fp16 churn unless intended.
  • Consider early-return masking when kv_end_idx <= kv_start_idx in context (empty range).
  • Avoid recomputing hidden_size if you can hoist constants in wrappers.

These are optional and can wait.

tensorrt_llm/_torch/attention_backend/sparse/utils.py (1)

5-12: Unify error handling and add concise docstrings; silence Ruff TRY003.

Factor repeated messages and document public helpers. Keeps messages short and consistent.

-def get_sparse_attn_kv_cache_manager(
-        sparse_attn_config: "SparseAttentionConfig"):
-    if sparse_attn_config.algorithm == "rocket":
-        return RocketKVCacheManager
-    else:
-        raise ValueError(
-            f"Unsupported sparse attention algorithm: {sparse_attn_config.algorithm}"
-        )
+def _unsupported(algo: str, where: str) -> ValueError:
+    return ValueError(f"Unsupported sparse attention algorithm in {where}: {algo}")
+
+def get_sparse_attn_kv_cache_manager(sparse_attn_config: SparseAttentionConfig):
+    """Return KV cache manager class for the given sparse algorithm."""
+    if sparse_attn_config.algorithm == "rocket":
+        return RocketKVCacheManager
+    raise _unsupported(sparse_attn_config.algorithm, "kv_cache_manager")
@@
-def get_vanilla_sparse_attn_attention_backend(
-        sparse_attn_config: "SparseAttentionConfig"):
-    if sparse_attn_config.algorithm == "rocket":
-        return RocketVanillaAttention
-    else:
-        raise ValueError(
-            f"Unsupported sparse attention algorithm in vanilla attention backend: {sparse_attn_config.algorithm}"
-        )
+def get_vanilla_sparse_attn_attention_backend(sparse_attn_config: SparseAttentionConfig):
+    """Return VanillaAttention backend class for the given sparse algorithm."""
+    if sparse_attn_config.algorithm == "rocket":
+        return RocketVanillaAttention
+    raise _unsupported(sparse_attn_config.algorithm, "vanilla")
@@
-def get_trtllm_sparse_attn_attention_backend(
-        sparse_attn_config: "SparseAttentionConfig"):
-    if sparse_attn_config.algorithm == "rocket":
-        return RocketTrtllmAttention
-    else:
-        raise ValueError(
-            f"Unsupported sparse attention algorithm in trtllm attention backend: {sparse_attn_config.algorithm}"
-        )
+def get_trtllm_sparse_attn_attention_backend(sparse_attn_config: SparseAttentionConfig):
+    """Return TrtllmAttention backend class for the given sparse algorithm."""
+    if sparse_attn_config.algorithm == "rocket":
+        return RocketTrtllmAttention
+    raise _unsupported(sparse_attn_config.algorithm, "trtllm")
@@
-def get_flashinfer_sparse_attn_attention_backend(
-        sparse_attn_config: "SparseAttentionConfig"):
-    raise ValueError(
-        f"Unsupported sparse attention algorithm in flashinfer attention backend: {sparse_attn_config.algorithm}"
-    )
+def get_flashinfer_sparse_attn_attention_backend(sparse_attn_config: SparseAttentionConfig):
+    """FlashInfer sparse attention not supported yet."""
+    raise _unsupported(sparse_attn_config.algorithm, "flashinfer")

Based on static analysis hints.

Also applies to: 15-22, 25-32, 35-39

tensorrt_llm/_torch/attention_backend/trtllm.py (1)

1259-1265: Avoid unnecessary predictions when a phase is absent.

Compute per-phase sparse indices only when needed to save GPU time on mixed batches.

-        sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets = None, None, None, None
-        if self.sparse_attention_config is not None:
-            sparse_kv_indices, sparse_kv_offsets = self.sparse_kv_predict(
-                q, k, metadata)
-            sparse_attn_indices, sparse_attn_offsets = self.sparse_attn_predict(
-                q, k, metadata)
+        sparse_kv_indices = sparse_kv_offsets = None
+        sparse_attn_indices = sparse_attn_offsets = None
+        if self.sparse_attention_config is not None:
+            if metadata.num_contexts > 0:
+                sparse_kv_indices, sparse_kv_offsets = self.sparse_kv_predict(q, k, metadata)
+            if metadata.num_generations > 0:
+                sparse_attn_indices, sparse_attn_offsets = self.sparse_attn_predict(q, k, metadata)
examples/longbench/eval_longbench_v2.py (2)

507-515: Ensure summary reflects actual generation length; fix extraneous f-strings.

  • You override max_new_tokens based on --cot but the summary writes args.max_new_tokens. Sync them.
  • Remove f from f-strings without placeholders flagged by Ruff F541.
-    max_new_tokens = 1024 if args.cot else 256
+    max_new_tokens = 1024 if args.cot else 256
+    args.max_new_tokens = max_new_tokens
@@
-            'max_new_tokens': args.max_new_tokens
+            'max_new_tokens': args.max_new_tokens

Also drop unnecessary f prefixes at Lines 412, 457, 500, 739, 746 (and any similar cases) to satisfy Ruff F541. Example:

-    logger.info(f"Loading LongBench v2 dataset...")
+    logger.info("Loading LongBench v2 dataset...")

Also applies to: 676-691


218-239: Remove unused tokenizer parameter from build_chat and its call sites.

The tokenizer argument is unused (Ruff ARG001). Simplify the signature and calls.

-def build_chat(tokenizer, prompt, chat_template):
+def build_chat(prompt, chat_template):
@@
-            formatted_prompt = build_chat(tokenizer, formatted_prompt,
-                                          chat_template)
+            formatted_prompt = build_chat(formatted_prompt, chat_template)
@@
-                cot_ans_prompt = build_chat(tokenizer, cot_ans_prompt,
-                                            chat_template)
+                cot_ans_prompt = build_chat(cot_ans_prompt, chat_template)

Also applies to: 488-492, 539-542

tensorrt_llm/_torch/attention_backend/vanilla.py (3)

190-197: Docstring/return order mismatch.

The docstring says (is_causal, attn_mask) but the function returns (attn_mask, is_causal). Align the docstring to the implementation to avoid confusion.

-        Returns:
-            Tuple of (is_causal, attn_mask)
+        Returns:
+            Tuple of (attn_mask, is_causal)

Also applies to: 222-223


141-156: KV cache update path: clarify k/v types when quantized.

When has_fp8_kv_cache, k/v are cast to float8_e4m3fn but index_copy_ writes via view(dtype=access_type) as ints; later, concatenation with past uses k_out/v_out slices of cache (pre-cast). You fix dtype in _single_request_attn_forward by to(q.dtype). This is fine, but please add a brief comment here to explain the intended dtype transitions to avoid regressions.


16-16: Safety: gather indices shape assumptions.

triton_index_gather requires [row, token, head, dim] inputs and [row, token, head] indices. Ensure sparse_kv_indices/sparse_indices follow this exactly; add asserts before gathering for clearer failures.

-        if sparse_kv_indices is not None:
+        if sparse_kv_indices is not None:
+            assert k.dim() == 4 and v.dim() == 4 and sparse_kv_indices.dim() == 3, \
+                "Expect k/v [B,T,H,D] and indices [B,T,H] for sparse KV selection"
             k_selected = triton_index_gather(k, sparse_kv_indices)
             v_selected = triton_index_gather(v, sparse_kv_indices)
@@
-        if sparse_indices is not None:
+        if sparse_indices is not None:
+            assert key_states.dim() == 4 and value_states.dim() == 4 and sparse_indices.dim() == 3, \
+                "Expect kv [B,T,H,D] and indices [B,T,H] for sparse attention selection"
             key_states = triton_index_gather(key_states, sparse_indices)
             value_states = triton_index_gather(value_states, sparse_indices)

Also applies to: 224-238

examples/longbench/eval_longbench_v1.py (4)

214-235: Remove unused tokenizer parameter from build_chat and call site.

Same as v2; satisfy Ruff ARG001.

-def build_chat(tokenizer, prompt, chat_template):
+def build_chat(prompt, chat_template):
@@
-        prompt = build_chat(tokenizer, prompt, chat_template)
+        prompt = build_chat(prompt, chat_template)

Also applies to: 390-394


542-549: Cleaner metric key retrieval.

Prefer next(iter(metrics)) to avoid building a list and silence Ruff RUF015.

-            if metrics:
-                metric_key = list(metrics.keys())[0]
+            if metrics:
+                metric_key = next(iter(metrics))
                 val = metrics[metric_key]

697-707: Drop extraneous f-strings; minor logging cleanup.

Remove f where no placeholders exist (Ruff F541), and keep messages concise.

Examples:

-    logger.info(
-        "=========== LongBench Evaluation with TensorRT-LLM ===========")
+    logger.info("=========== LongBench Evaluation with TensorRT-LLM ===========")
@@
-        logger.info(f"Running evaluation on full LongBench datasets")
+        logger.info("Running evaluation on full LongBench datasets")
@@
-        logger.info(f"FINAL RESULTS:")
+        logger.info("FINAL RESULTS:")

Also applies to: 718-722, 785-791


184-194: Path validation: improve error message and doc.

If LongBench/ is vendored under --longbench_path, the current error is fine. Consider hinting expected layout in the exception to aid users.

-        raise FileNotFoundError(
-            f"LongBench directory not found: {longbench_dir}")
+        raise FileNotFoundError(
+            f"LongBench directory not found: {longbench_dir}. Expected {longbench_dir}/config and dataset files from THUDM/LongBench."
+        )

@lfr-0531
Copy link
Collaborator Author

lfr-0531 commented Oct 1, 2025

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20469 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20469 [ run ] completed with state FAILURE

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 14

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/pyexecutor/model_engine.py (1)

11-15: Fix F821: forward-ref type without import.

Add a guarded import for SparseAttentionConfig to satisfy static analyzers without runtime deps.

 from typing import Any, Callable, Dict, List, Optional, Tuple
 
+from typing import TYPE_CHECKING
+if TYPE_CHECKING:
+    from ..model_config import SparseAttentionConfig
♻️ Duplicate comments (42)
cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/cubin/kernelMetaInfo.h (3)

3068-3081: Set abLayout to kPersistentSwapsAbForGen (2) for persistent and kSwapsAbForGen (1) for static swap variants.

These BF16 swap-variant initializers (lines 3068–3081) reference *SwapsAbForGen* cubins but still have abLayout set to 0 (kKeepsAbForGen) or 1 (kSwapsAbForGen) when they should be 2 (kPersistentSwapsAbForGen) for persistent variants and 1 (kSwapsAbForGen) for static variants. The dispatcher will pass incorrectly-laid-out buffers to these kernels, causing incorrect results.

Apply the diff from the previous review comment to update the abLayout values for lines 3068–3081.


4170-4183: Set abLayout to kPersistentSwapsAbForGen (2) for persistent and kSwapsAbForGen (1) for static FP16 swap variants.

These FP16 swap-variant initializers (lines 4170–4183) mirror the BF16 issue: abLayout is set to 0 or 1 when it should be 2 for persistent swap variants and 1 for static swap variants.

Apply the diff from the previous review comment to update the abLayout values for lines 4170–4183.


3072-3075: Set abLayout to kSwapsAbForGen (1) for these P64 static swap variants.

These four BF16 P64 StaticSwapsAbForGen initializers (lines 3072–3075) also have abLayout incorrectly set to 0 (kKeepsAbForGen) instead of 1 (kSwapsAbForGen), causing the same buffer layout mismatch issue.

Apply this diff:

-{ DATA_TYPE_BF16, DATA_TYPE_BF16, DATA_TYPE_BF16, 16, 128, 16, 256, 128, 128, 128, kSM_100f, FmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvCgaVarSeqQ16Kv128StaticSwapsAbForGen_cubin, FmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvCgaVarSeqQ16Kv128StaticSwapsAbForGen_cubin_len, "fmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvCgaVarSeqQ16Kv128StaticSwapsAbForGen", 182480, 512, 2, 64, 0, 2, 16, 0, 3, true, false, false, "9197439559d66629cf845bd96f28abf11060ed38633468fe7ad291e6c315f20d"},
+{ DATA_TYPE_BF16, DATA_TYPE_BF16, DATA_TYPE_BF16, 16, 128, 16, 256, 128, 128, 128, kSM_100f, FmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvCgaVarSeqQ16Kv128StaticSwapsAbForGen_cubin, FmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvCgaVarSeqQ16Kv128StaticSwapsAbForGen_cubin_len, "fmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvCgaVarSeqQ16Kv128StaticSwapsAbForGen", 182480, 512, 2, 64, 0, 2, 16, 1, 3, true, false, false, "9197439559d66629cf845bd96f28abf11060ed38633468fe7ad291e6c315f20d"},
-{ DATA_TYPE_BF16, DATA_TYPE_BF16, DATA_TYPE_BF16, 8, 128, 8, 256, 128, 128, 128, kSM_100f, FmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvCgaVarSeqQ8Kv128StaticSwapsAbForGen_cubin, FmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvCgaVarSeqQ8Kv128StaticSwapsAbForGen_cubin_len, "fmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvCgaVarSeqQ8Kv128StaticSwapsAbForGen", 175824, 512, 2, 64, 0, 2, 8, 0, 3, true, false, false, "70a2117d2aa8d255040f97cfa19c4bf189ed485ba4a953d4b99997735c706265"},
+{ DATA_TYPE_BF16, DATA_TYPE_BF16, DATA_TYPE_BF16, 8, 128, 8, 256, 128, 128, 128, kSM_100f, FmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvCgaVarSeqQ8Kv128StaticSwapsAbForGen_cubin, FmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvCgaVarSeqQ8Kv128StaticSwapsAbForGen_cubin_len, "fmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvCgaVarSeqQ8Kv128StaticSwapsAbForGen", 175824, 512, 2, 64, 0, 2, 8, 1, 3, true, false, false, "70a2117d2aa8d255040f97cfa19c4bf189ed485ba4a953d4b99997735c706265"},
-{ DATA_TYPE_BF16, DATA_TYPE_BF16, DATA_TYPE_BF16, 16, 128, 16, 256, 128, 128, 128, kSM_100f, FmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvVarSeqQ16Kv128StaticSwapsAbForGen_cubin, FmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvVarSeqQ16Kv128StaticSwapsAbForGen_cubin_len, "fmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvVarSeqQ16Kv128StaticSwapsAbForGen", 148624, 512, 2, 64, 0, 2, 16, 0, 1, true, false, false, "876da47e60e0448fc87875d23318807afa489931112f01e2f455123e512ebcce"},
+{ DATA_TYPE_BF16, DATA_TYPE_BF16, DATA_TYPE_BF16, 16, 128, 16, 256, 128, 128, 128, kSM_100f, FmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvVarSeqQ16Kv128StaticSwapsAbForGen_cubin, FmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvVarSeqQ16Kv128StaticSwapsAbForGen_cubin_len, "fmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvVarSeqQ16Kv128StaticSwapsAbForGen", 148624, 512, 2, 64, 0, 2, 16, 1, 1, true, false, false, "876da47e60e0448fc87875d23318807afa489931112f01e2f455123e512ebcce"},
-{ DATA_TYPE_BF16, DATA_TYPE_BF16, DATA_TYPE_BF16, 8, 128, 8, 256, 128, 128, 128, kSM_100f, FmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvVarSeqQ8Kv128StaticSwapsAbForGen_cubin, FmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvVarSeqQ8Kv128StaticSwapsAbForGen_cubin_len, "fmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvVarSeqQ8Kv128StaticSwapsAbForGen", 141968, 512, 2, 64, 0, 2, 8, 0, 1, true, false, false, "dacac32324a94091467aab4e09474420408a7eaaa58933178e329205e1e026b4"},
+{ DATA_TYPE_BF16, DATA_TYPE_BF16, DATA_TYPE_BF16, 8, 128, 8, 256, 128, 128, 128, kSM_100f, FmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvVarSeqQ8Kv128StaticSwapsAbForGen_cubin, FmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvVarSeqQ8Kv128StaticSwapsAbForGen_cubin_len, "fmhaSm100fKernel_QkvBfloat16OBfloat16H128PagedKvDenseP64MultiCtasKvVarSeqQ8Kv128StaticSwapsAbForGen", 141968, 512, 2, 64, 0, 2, 8, 1, 1, true, false, false, "dacac32324a94091467aab4e09474420408a7eaaa58933178e329205e1e026b4"},
tensorrt_llm/_torch/attention_backend/sparse/kernel.py (2)

1-4: Add required NVIDIA Apache-2.0 header (current year).

The file is missing the mandated copyright header.

+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 import torch
 import triton
 import triton.language as tl

As per coding guidelines.


42-71: Validate index bounds to prevent OOB GPU loads.

The function does not validate that indices values are within the valid range [0, input.shape[1]), which can cause undefined behavior on GPU.

 def triton_index_gather(input, indices):
     assert input.ndim == 4, "Input must be a 4D tensor, [row, token, head, dim]"
     assert indices.ndim == 3, "Indices must be a 3D tensor, [row, token, head]"
+    # Bounds check
+    min_idx = indices.min()
+    max_idx = indices.max()
+    assert min_idx >= 0 and max_idx < input.shape[1], \
+        f"indices out of range: [{int(min_idx)}..{int(max_idx)}] vs tokens={input.shape[1]}"

As per coding guidelines.

tensorrt_llm/_torch/pyexecutor/_util.py (1)

108-116: Compute KV size per token using the correct manager class per model.

Using self._kv_cache_manager_cls for both main and draft models is incorrect when their configurations differ (e.g., main uses RocketKV sparse, draft does not). Each model should use its own manager class.

-        kv_size_per_token = self._kv_cache_manager_cls.get_cache_size_per_token(
-            model_config, mapping, tokens_per_block=self._tokens_per_block)
+        main_cls = get_kv_cache_manager_cls(model_config)
+        kv_size_per_token = main_cls.get_cache_size_per_token(
+            model_config, mapping, tokens_per_block=self._tokens_per_block)
         if self._draft_model_engine is not None:
             draft_model_config = self._draft_model_engine.model.model_config
-            kv_size_per_token += self._kv_cache_manager_cls.get_cache_size_per_token(
-                draft_model_config,
-                mapping,
-                tokens_per_block=self._tokens_per_block)
+            draft_cls = get_kv_cache_manager_cls(draft_model_config)
+            kv_size_per_token += draft_cls.get_cache_size_per_token(
+                draft_model_config, mapping, tokens_per_block=self._tokens_per_block)
cpp/tensorrt_llm/pybind/thop/bindings.cpp (1)

54-55: Fix default for sparse_attention_params to maintain backward compatibility.

Defaulting sparse_attention_params to std::nullopt will cause a TypeError when callers omit the argument, as the C++ implementation expects a vector. Provide a real default vector of four null optionals.

-        py::arg("spec_decoding_tensor_params"), py::arg("sparse_attention_params") = std::nullopt,
+        py::arg("spec_decoding_tensor_params"),
+        py::arg("sparse_attention_params")
+            = std::vector<std::optional<torch::Tensor>>{std::nullopt, std::nullopt, std::nullopt, std::nullopt},
         "Multi-head attention operation", py::call_guard<py::gil_scoped_release>());
tensorrt_llm/_torch/attention_backend/sparse/rocket.py (5)

1-26: Add the required NVIDIA Apache-2.0 header.

This new module is missing the mandated NVIDIA Apache-2.0 copyright header for 2025.

+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 import math

As per coding guidelines.


114-114: Fix page-index conversion (wrong divisor).

convert_token_to_page_sparse_indices divides token indices by metadata.tokens_per_block, which is the paged-KV block length, not the RocketKV page size. Use the RocketKV page size from the cache manager.

-    page_size = metadata.tokens_per_block
+    page_size = metadata.kv_cache_manager.page_size

135-165: Eliminate padded -1 page ids.

Deduplicating per head and padding shorter heads with -1 leaves sentinel values in new_page_indices. After the final transpose, the kernel sees -1 as a real page id, causing incorrect gather operations. Rebuild the routine to only emit valid (non-negative) page indices.


959-968: Correct KT rewind page count math.

math.ceil(num_tokens - rewind_len / self.page_size) mixes units—rewind_len / self.page_size is fractional pages subtracted from a raw token count. Compute remaining tokens first, clamp at zero, then convert to page units.

-        updated_kt_token_num = math.ceil(num_tokens -
-                                         rewind_len / self.page_size)
+        remaining = max(num_tokens - rewind_len, 0)
+        updated_kt_token_num = math.ceil(remaining / self.page_size)

1000-1001: Fix ceiling division in compute_page_count.

(token_count + tokens_per_page) // tokens_per_page over-allocates when token_count is an exact multiple and returns 1 when token_count == 0. Use the standard ceiling formula.

-        return (token_count + tokens_per_page) // tokens_per_page
+        if token_count <= 0:
+            return 0
+        return (token_count + tokens_per_page - 1) // tokens_per_page
examples/longbench/requirements.txt (1)

1-3: Replace GPL-only fuzzywuzzy with a permissive fork.

fuzzywuzzy is GPL‑2.0 and archived, which conflicts with TensorRT-LLM’s Apache-2.0 licensing. Please switch to a compatible fork such as thefuzz (MIT) and adjust any imports/usages accordingly.
[suggested change]

-jieba
-fuzzywuzzy
-rouge
+jieba
+thefuzz
+rouge
examples/longbench/eval_longbench_v1.py (2)

1-10: Add NVIDIA Apache-2.0 header and resolve shebang executability.

Prepend the standard NVIDIA Apache-2.0 header. Keep the shebang only if the file is executable; otherwise remove it to satisfy EXE001.

 #!/usr/bin/env python3
+#
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
 """
 LongBench v1 evaluation script with TensorRT-LLM and sparse attention.

As per coding guidelines.


292-336: Wire KV cache options, validate RocketKV backend, and preserve traceback.

  • Use --kv_cache_dtype and --kv_cache_fraction.
  • Fail fast if backend unsupported for RocketKV.
  • Drop extraneous f-string and re-raise with plain raise.
-        kv_cache_config = KvCacheConfig(
-            enable_block_reuse=False,  # RocketKV doesn't support KV cache reuse
-        )
+        kv_cache_config = KvCacheConfig(
+            enable_block_reuse=False,  # RocketKV doesn't support KV cache reuse
+            dtype=args.kv_cache_dtype,
+            free_gpu_memory_fraction=args.kv_cache_fraction,
+        )
@@
-        if args.rocket_sparse:
+        if args.rocket_sparse:
+            if args.backend != "pytorch":
+                raise ValueError(
+                    "RocketKV sparse attention currently supports backend='pytorch'. "
+                    "Use --backend pytorch or disable --rocket_sparse."
+                )
             # Configure RocketKV sparse attention
@@
-            logger.info(f"Using RocketKV sparse attention")
+            logger.info("Using RocketKV sparse attention")
@@
-    except Exception as e:
-        logger.error(f"Failed to initialize LLM: {e}")
-        raise e
+    except Exception as e:
+        logger.error(f"Failed to initialize LLM: {e}")
+        raise
examples/longbench/eval_longbench_v2.py (4)

1-16: Add NVIDIA Apache-2.0 header and resolve shebang executability.

Same as v1: add header; retain shebang only if executable.

 #!/usr/bin/env python3
+#
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
 """
 LongBench v2 evaluation script with TensorRT-LLM and sparse attention.

As per coding guidelines.


298-310: Preserve special tokens during truncation.

Decoding with skip_special_tokens=True can drop chat tokens like <|eot_id|>.

-        prompt = tokenizer.decode(truncated_ids, skip_special_tokens=True)
+        prompt = tokenizer.decode(truncated_ids, skip_special_tokens=False)

346-389: Wire KV cache options, validate RocketKV backend, remove redundant f-string, and preserve traceback.

Same concerns as v1; add dtype/fraction, backend guard, plain raise, and drop superfluous f.

-        kv_cache_config = KvCacheConfig(
-            enable_block_reuse=False,  # RocketKV doesn't support KV cache reuse
-        )
+        kv_cache_config = KvCacheConfig(
+            enable_block_reuse=False,
+            dtype=args.kv_cache_dtype,
+            free_gpu_memory_fraction=args.kv_cache_fraction,
+        )
@@
-        if args.rocket_sparse:
+        if args.rocket_sparse:
+            if args.backend != "pytorch":
+                raise ValueError(
+                    "RocketKV sparse attention currently supports backend='pytorch'. "
+                    "Use --backend pytorch or disable --rocket_sparse."
+                )
             # Configure RocketKV sparse attention
@@
-            logger.info(f"Using RocketKV sparse attention")
+            logger.info("Using RocketKV sparse attention")
@@
-    except Exception as e:
-        logger.error(f"Failed to initialize LLM: {e}")
-        raise e
+    except Exception as e:
+        logger.error(f"Failed to initialize LLM: {e}")
+        raise

468-482: Avoid KeyError for templates and IndexError for stop-token IDs.

Use .get() for templates and guard tokenizer encodes.

-    template = config['templates'][template_key]
+    template = config['templates'].get(template_key)
+    if template is None:
+        raise KeyError(f"Missing template '{template_key}' under {args.longbench_path}/prompts")
@@
-    if chat_template == "llama3":
-        eot_id = tokenizer.encode("<|eot_id|>", add_special_tokens=False)[0]
-        extra_end_token_ids.append(eot_id)
+    if chat_template == "llama3":
+        eot_ids = tokenizer.encode("<|eot_id|>", add_special_tokens=False)
+        if eot_ids:
+            extra_end_token_ids.append(eot_ids[0])
@@
-    if chat_template == "qwen":
-        im_end_id = tokenizer.encode("<|im_end|>", add_special_tokens=False)[0]
-        extra_end_token_ids.append(im_end_id)
+    if chat_template == "qwen":
+        im_end_ids = tokenizer.encode("<|im_end|>", add_special_tokens=False)
+        if im_end_ids:
+            extra_end_token_ids.append(im_end_ids[0])
examples/longbench/README.md (3)

27-39: Add language specifier to fenced code block.

Use language hint for proper rendering.

-```
+```text
 sparse_attention/
@@
-```
+```

148-160: Add language specifier to fenced code block.

Same as above.

-```
+```text
 results/v1_experiment/
@@
-```
+```

164-170: Add language specifier to fenced code block.

Same as above.

-```
+```text
 results/v2_experiment/
@@
-```
+```
tensorrt_llm/_torch/attention_backend/interface.py (1)

143-143: Import SparseAttentionConfig for type checking.

Static analysis correctly flags SparseAttentionConfig as undefined. Add a conditional import under TYPE_CHECKING to resolve the type checker warning and improve IDE support.

+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+    from tensorrt_llm.llmapi.llm_args import SparseAttentionConfig
🪛 Ruff (0.13.2)

143-143: Undefined name SparseAttentionConfig

(F821)

cpp/tensorrt_llm/thop/attentionOp.cpp (1)

230-244: Reset all sparse pointers before conditional assignment.

With cached ops, leaving the non-updated branch's pointers unchanged can carry stale addresses into future runs when the phase switches. Clear all four sparse pointer fields to nullptr before the if (is_context) branch to prevent stale pointer reuse.

 // Prepare sparse attention parameters
+op.mRuntimeSparseAttentionParams.sparse_kv_indices = nullptr;
+op.mRuntimeSparseAttentionParams.sparse_kv_offsets = nullptr;
+op.mRuntimeSparseAttentionParams.sparse_attn_indices = nullptr;
+op.mRuntimeSparseAttentionParams.sparse_attn_offsets = nullptr;
 if (is_context)
 {
     op.mRuntimeSparseAttentionParams.sparse_kv_indices
         = sparse_kv_indices.has_value() ? sparse_kv_indices.value().data_ptr<int32_t>() : nullptr;
     op.mRuntimeSparseAttentionParams.sparse_kv_offsets
         = sparse_kv_offsets.has_value() ? sparse_kv_offsets.value().data_ptr<int32_t>() : nullptr;
 }
 else
 {
     op.mRuntimeSparseAttentionParams.sparse_attn_indices
         = sparse_attn_indices.has_value() ? sparse_attn_indices.value().data_ptr<int32_t>() : nullptr;
     op.mRuntimeSparseAttentionParams.sparse_attn_offsets
         = sparse_attn_offsets.has_value() ? sparse_attn_offsets.value().data_ptr<int32_t>() : nullptr;
 }
cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h (1)

1821-1833: Add missing head sizes to sparse dispatch.

The sparse dispatch switch only covers 16, 32, 64, 128, 256 but the dense V2 kernel (lines 1587-1603) supports many more, including common sizes 80, 96, and 144. Models using those head dimensions with sparse attention will error out unexpectedly.

Add cases for the missing head sizes:

 switch (params.size_per_head)
 {
 case 16: kernelSparseDispatchHeadSize<16, T, TCache, KVCacheBuffer>(params, stream); break;
 case 32: kernelSparseDispatchHeadSize<32, T, TCache, KVCacheBuffer>(params, stream); break;
 case 64: kernelSparseDispatchHeadSize<64, T, TCache, KVCacheBuffer>(params, stream); break;
+case 80: kernelSparseDispatchHeadSize<80, T, TCache, KVCacheBuffer>(params, stream); break;
+case 96: kernelSparseDispatchHeadSize<96, T, TCache, KVCacheBuffer>(params, stream); break;
 case 128: kernelSparseDispatchHeadSize<128, T, TCache, KVCacheBuffer>(params, stream); break;
+case 144: kernelSparseDispatchHeadSize<144, T, TCache, KVCacheBuffer>(params, stream); break;
 case 256: kernelSparseDispatchHeadSize<256, T, TCache, KVCacheBuffer>(params, stream); break;
 default:
     TLLM_CHECK_WITH_INFO(
         false, "updateSparseKvCacheAfterFmha kernel doesn't support head size = %d", params.size_per_head);
     break;
 }
examples/llm-api/llm_sparse_attention.py (2)

8-14: read_input: add docstring and robust I/O/JSON handling.

Handle missing file/invalid JSON; document behavior.

 def read_input(input_file):
-    results = []
-    with open(input_file, 'r') as f:
-        for line in f:
-            ret = json.loads(line)
-            results.append(ret)
-    return results
+    """
+    Read JSONL and return a list of parsed objects.
+    """
+    results = []
+    try:
+        with open(input_file, 'r') as f:
+            for line in f:
+                line = line.strip()
+                if not line:
+                    continue
+                results.append(json.loads(line))
+    except FileNotFoundError:
+        raise FileNotFoundError(f"Input file not found: {input_file}")
+    except json.JSONDecodeError as e:
+        raise ValueError(f"Invalid JSON in input file: {e}") from e
+    return results

1-1: Add NVIDIA Apache-2.0 copyright header.

Required by repo guidelines; add the SPDX block at file top.

Apply:

+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
tensorrt_llm/_torch/pyexecutor/model_loader.py (1)

163-163: Missing import for SparseAttentionConfig.

The type annotation on line 163 references SparseAttentionConfig, but this type is not imported. While the quoted annotation defers evaluation, the import is still required for runtime type checking and IDE support.

Add the import at the top of the file:

+from tensorrt_llm.llmapi.llm_args import SparseAttentionConfig

Or use the appropriate import path based on where SparseAttentionConfig is defined in the codebase.

cpp/tests/unit_tests/kernels/sparseAttentionKernelsTest.cpp (1)

1-1: Add NVIDIA Apache-2.0 copyright header.

According to coding guidelines, all C++ source files must have the NVIDIA Apache-2.0 copyright header prepended with the current year (2025).

Prepend the following header at the top of the file:

/*
 * Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

As per coding guidelines.

tensorrt_llm/_torch/attention_backend/utils.py (2)

1-10: Add TYPE_CHECKING import for SparseAttentionConfig type hint.

Ruff F821 flags SparseAttentionConfig at line 54 as undefined. Add a forward declaration under TYPE_CHECKING to satisfy static analysis without causing runtime import cycles.

Apply this diff:

-from typing import Optional, Type
+from typing import TYPE_CHECKING, Optional, Type

 from ...models.modeling_utils import QuantConfig
 from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE
 from .interface import AttentionBackend, MLAParams, PositionalEmbeddingParams
 from .sparse import (get_flashinfer_sparse_attn_attention_backend,
                      get_trtllm_sparse_attn_attention_backend,
                      get_vanilla_sparse_attn_attention_backend)
+
+if TYPE_CHECKING:
+    from tensorrt_llm.llmapi.llm_args import SparseAttentionConfig

13-33: Normalize backend_name to handle case variations.

Current implementation is case-sensitive and will fail for inputs like "vanilla" or "flashinfer", silently falling back to TRTLLM. Add normalization at the function entry.

Apply this diff:

 def get_attention_backend(backend_name: str,
-                          sparse_attn_config=None) -> Type[AttentionBackend]:
+                          sparse_attn_config: Optional["SparseAttentionConfig"] = None) -> Type[AttentionBackend]:
+    backend_name = backend_name.upper()
     if backend_name == "VANILLA":
tests/unittest/_torch/attention/sparse/test_rocketkv.py (1)

55-63: Test remains non-deterministic due to stochastic sampling.

As flagged in the previous review, using temperature=0.8 and top_p=0.95 without seeding makes this test flaky. Switch to greedy decoding (temperature=0.0, top_p=1.0) or set random seeds before generation to ensure reproducible results across CI runs.

Apply this diff to make generation deterministic:

             sampling_params=SamplingParams(add_special_tokens=False,
                                            max_tokens=max_output_tokens,
-                                           temperature=0.8,
-                                           top_p=0.95),
+                                           temperature=0.0,
+                                           top_p=1.0),
tensorrt_llm/_torch/attention_backend/sparse/__init__.py (1)

1-11: Add the required NVIDIA Apache-2.0 copyright header.

All source files must begin with the NVIDIA Apache-2.0 copyright header with the current year (2025). Please prepend the standard header to the top of this file before any imports or code.

As per coding guidelines.

tensorrt_llm/llmapi/llm_args.py (1)

176-187: Fix the from_dict dispatch bug.

Line 183 queries config_classes.get("algorithm") which always returns None because the key "algorithm" doesn't exist in the dict (it should be "Rocket"). Additionally, even if fixed, the key casing mismatch ("Rocket" vs default "rocket") would cause lookup failures.

Apply this diff to fix the dispatch logic:

     @classmethod
     def from_dict(cls, data: dict):
         # dispatch to the correct sparse attention config
         config_classes = {
-            "Rocket": RocketSparseAttentionConfig,
+            "rocket": RocketSparseAttentionConfig,
         }
 
-        config_class = config_classes.get("algorithm")
+        algorithm = data.get("algorithm", cls().algorithm)
+        config_class = config_classes.get(algorithm.lower())
         if config_class is None:
-            raise ValueError(f"Invalid algorithm")
+            raise ValueError(f"Invalid algorithm: {algorithm}")
 
         return config_class(**data)
cpp/tensorrt_llm/kernels/sparseAttentionKernels.h (3)

1-1: Replace pragma once with include guards and add NVIDIA Apache-2.0 header.

Required by project guidelines; use TRTLLM_SPARSEATTENTIONKERNELS_H and the 2025 banner.

As per coding guidelines.


40-49: Guard device qualifier to keep header includable from host-only TUs.

Public headers should not force device-only symbols on .cpp users.

As per coding guidelines.


3-5: Make header self-contained: add missing standard includes.

std::string, std::tuple, and int32_t are used but their headers aren’t included here.

 #include <cuda_runtime.h>
-#include <sstream>
+#include <cstdint>
+#include <string>
+#include <tuple>
+#include <sstream>

As per coding guidelines.

tensorrt_llm/_torch/attention_backend/sparse/utils.py (1)

1-2: Add NVIDIA license header and future annotations import.

Required by guidelines; also prevents forward-ref evaluation at runtime.

+// SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES.
+// SPDX-License-Identifier: Apache-2.0
+
+from __future__ import annotations

As per coding guidelines.

cpp/tests/unit_tests/kernels/sparseKvCacheTest.cu (1)

1-15: Copyright year already flagged in previous review.

A past review comment has identified that the copyright year should be 2025 for new files in this PR. This is a duplicate of that finding.

cpp/tensorrt_llm/thop/attentionOp.h (1)

37-61: Header/impl signature mismatch will break linking; use std::optional consistently

attentionOp.cpp defines std::optional/ std::vector<std::optional<...>>; this header declares torch::optional in multiple params. Align to std::optional to match the definition.

Apply:

-void attention(torch::Tensor q, torch::optional<torch::Tensor> k, torch::optional<torch::Tensor> v,
-    torch::Tensor& output, torch::optional<torch::Tensor> output_sf, std::optional<torch::ScalarType> out_dtype,
-    torch::optional<torch::Tensor> workspace_, torch::Tensor sequence_length, torch::Tensor host_past_key_value_lengths,
+void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<torch::Tensor> v,
+    torch::Tensor& output, std::optional<torch::Tensor> output_sf, std::optional<torch::ScalarType> out_dtype,
+    std::optional<torch::Tensor> workspace_, torch::Tensor sequence_length, torch::Tensor host_past_key_value_lengths,
@@
-    torch::Tensor host_request_types, torch::optional<torch::Tensor> kv_cache_block_offsets,
-    torch::optional<torch::Tensor> host_kv_cache_block_offsets,
-    torch::optional<torch::Tensor> host_kv_cache_pool_pointers,
-    torch::optional<torch::Tensor> host_kv_cache_pool_mapping, torch::optional<torch::Tensor> cache_indirection,
-    torch::optional<torch::Tensor> kv_scale_orig_quant, torch::optional<torch::Tensor> kv_scale_quant_orig,
-    torch::optional<torch::Tensor> out_scale, torch::optional<torch::Tensor> rotary_inv_freq,
-    torch::optional<torch::Tensor> rotary_cos_sin, torch::optional<torch::Tensor> latent_cache,
-    torch::optional<torch::Tensor> q_pe, torch::optional<torch::Tensor> block_ids_per_seq,
-    torch::optional<torch::Tensor> attention_sinks, bool const is_fused_qkv, bool const update_kv_cache,
+    torch::Tensor host_request_types, std::optional<torch::Tensor> kv_cache_block_offsets,
+    std::optional<torch::Tensor> host_kv_cache_block_offsets,
+    std::optional<torch::Tensor> host_kv_cache_pool_pointers,
+    std::optional<torch::Tensor> host_kv_cache_pool_mapping, std::optional<torch::Tensor> cache_indirection,
+    std::optional<torch::Tensor> kv_scale_orig_quant, std::optional<torch::Tensor> kv_scale_quant_orig,
+    std::optional<torch::Tensor> out_scale, std::optional<torch::Tensor> rotary_inv_freq,
+    std::optional<torch::Tensor> rotary_cos_sin, std::optional<torch::Tensor> latent_cache,
+    std::optional<torch::Tensor> q_pe, std::optional<torch::Tensor> block_ids_per_seq,
+    std::optional<torch::Tensor> attention_sinks, bool const is_fused_qkv, bool const update_kv_cache,
@@
-    std::optional<torch::Tensor> mrope_rotary_cos_sin, std::optional<torch::Tensor> mrope_position_deltas,
+    std::optional<torch::Tensor> mrope_rotary_cos_sin, std::optional<torch::Tensor> mrope_position_deltas,
@@
-    std::vector<std::optional<torch::Tensor>> spec_decoding_tensor_params,
-    std::vector<torch::optional<torch::Tensor>> sparse_attention_params);
+    std::vector<std::optional<torch::Tensor>> spec_decoding_tensor_params,
+    std::vector<std::optional<torch::Tensor>> sparse_attention_params);

After patching, rebuild and confirm only a single attention(...) symbol exists:

#!/bin/bash
rg -n 'void\s+attention\(' cpp/tensorrt_llm/thop -C2
tensorrt_llm/_torch/attention_backend/vanilla.py (2)

98-115: Sparse hooks must not raise; adjust types and provide safe defaults

Raising NotImplementedError will crash when sparse_attention_config is present. Return a no‑op (None, kv_len) and fix type hints to include the kv_len.

Apply:

-    def _single_request_sparse_attn_predict(self, q: torch.Tensor,
+    def _single_request_sparse_attn_predict(self, q: torch.Tensor,
                                             k: Optional[torch.Tensor],
                                             v: Optional[torch.Tensor],
                                             kv_cache_tensor: torch.Tensor,
                                             metadata: AttentionMetadata,
                                             past_seen_token: int,
                                             sample_idx: int,
-                                            **kwargs) -> Optional[torch.Tensor]:
-        raise NotImplementedError
+                                            **kwargs) -> tuple[Optional[torch.Tensor], int]:
+        kv_len = k.size(0) if k is not None else 0
+        return None, kv_len
@@
-    def _single_request_sparse_kv_predict(self, q: Optional[torch.Tensor],
+    def _single_request_sparse_kv_predict(self, q: Optional[torch.Tensor],
                                           k: Optional[torch.Tensor],
                                           v: Optional[torch.Tensor],
                                           metadata: AttentionMetadata,
                                           past_seen_token: int, sample_idx: int,
-                                          **kwargs) -> Optional[torch.Tensor]:
-        raise NotImplementedError
+                                          **kwargs) -> tuple[Optional[torch.Tensor], int]:
+        kv_len = k.size(0) if k is not None else 0
+        return None, kv_len

431-434: Pass attention_window_size through to per‑request path

Without forwarding, sliding‑window masking is never applied.

Apply:

-            attn_output = self._single_request_forward(
-                single_q, single_k, single_v, attention_mask, kv_cache_tensor,
-                past_seen_token, cache_idx, sample_idx, metadata, **kwargs)
+            attn_output = self._single_request_forward(
+                single_q, single_k, single_v, attention_mask, kv_cache_tensor,
+                past_seen_token, cache_idx, sample_idx, metadata,
+                attention_window_size=attention_window_size, **kwargs)
🧹 Nitpick comments (25)
tensorrt_llm/_utils.py (1)

216-224: Add type hints and improve docstring format.

The function logic is correct, but it lacks type hints and the docstring does not follow Google-style format.

Apply this diff to add type hints and improve the docstring:

-def next_power_of_two(x):
+def next_power_of_two(x: int) -> int:
     """
-    get next power of two
+    Returns the next power of two greater than or equal to x.
+    
+    Args:
+        x: An integer value.
+    
+    Returns:
+        The smallest power of two >= x, or 1 if x <= 0.
     """
     if x <= 0:
         return 1
     if (x & (x - 1)) == 0:
         return x
     return 1 << x.bit_length()

As per coding guidelines (Python code should include type hints and use Google-style docstrings).

examples/longbench/eval_longbench_v1.py (1)

214-234: Silence unused parameter in build_chat.

tokenizer is unused; rename to _tokenizer to satisfy linters without changing call sites.

-def build_chat(tokenizer, prompt, chat_template):
+def build_chat(_tokenizer, prompt, chat_template):
     """Build chat prompt following LongBench's approach."""
examples/longbench/eval_longbench_v2.py (4)

219-238: Silence unused parameter in build_chat.

Rename tokenizer to _tokenizer.

-def build_chat(tokenizer, prompt, chat_template):
+def build_chat(_tokenizer, prompt, chat_template):
     """Build chat prompt following LongBench's approach."""

523-592: Minor cleanups: rename unused loop var; robust prompt_token_ids usage.

  • Rename i to _ (B007).
  • Handle method vs attribute for prompt_token_ids.
-    for i, (sample, output) in enumerate(zip(filtered_data, outputs)):
+    for _, (sample, output) in enumerate(zip(filtered_data, outputs)):
@@
-            'prompt_length': len(output.prompt_token_ids),
+            'prompt_length': len(output.prompt_token_ids() if callable(getattr(output, "prompt_token_ids", None)) else output.prompt_token_ids),

597-639: Expose sample-count breakdowns used later in logs.

main() logs reference {easy,length}_samples keys that are not produced; add counts to metrics to avoid always printing 0.

     metrics = {
         'overall_accuracy': round(overall_accuracy * 100, 2),
         'total_samples': total_samples,
-        'correct_samples': correct_samples
+        'correct_samples': correct_samples
     }
@@
-    for difficulty in difficulties:
+    for difficulty in difficulties:
         diff_results = [r for r in results if r['difficulty'] == difficulty]
         if diff_results:
             diff_correct = sum(1 for r in diff_results if r['is_correct'])
             metrics[f'{difficulty}_accuracy'] = round(
                 (diff_correct / len(diff_results)) * 100, 2)
+            metrics[f'{difficulty}_samples'] = len(diff_results)
@@
-    for length in lengths:
+    for length in lengths:
         len_results = [r for r in results if r['length'] == length]
         if len_results:
             len_correct = sum(1 for r in len_results if r['is_correct'])
             metrics[f'{length}_accuracy'] = round(
                 (len_correct / len(len_results)) * 100, 2)
+            metrics[f'{length}_samples'] = len(len_results)

724-768: Remove redundant f-strings without placeholders.

Clean up minor F541 instances.

-    logger.info(f"Starting LongBench v2 evaluation...")
+    logger.info("Starting LongBench v2 evaluation...")
@@
-    logger.info(f"{'-'*80}")
+    logger.info('-' * 80)
@@
-        logger.info(
-            f"Overall accuracy: {metrics.get('overall_accuracy', 'N/A')}%")
+        logger.info("Overall accuracy: %s%%", metrics.get('overall_accuracy', 'N/A'))
examples/llm-api/llm_sparse_attention.py (3)

20-24: Avoid hard-coded, user-specific default model path.

Make model_path required or read from env to improve portability.

-    parser.add_argument(
-        '--model_path',
-        type=str,
-        default=
-        "/home/scratch.trt_llm_data/llm-models/llama-3.1-model/Llama-3.1-8B-Instruct"
-    )
+    parser.add_argument('--model_path', type=str, required=True)

81-112: Function naming style and CLI cohesion.

Follow snake_case; also consider exposing '--backend' to exercise TRTLLM path too.

-def run_RocketKV(args):
+def run_rocketkv(args):
@@
-    if args.algo == 'ROCKETKV':
-        run_RocketKV(args)
+    if args.algo == 'ROCKETKV':
+        run_rocketkv(args)

Optionally add:

-    parser.add_argument('--tensor_parallel_size', type=int, default=1)
+    parser.add_argument('--tensor_parallel_size', type=int, default=1)
+    parser.add_argument('--backend', type=str, default='pytorch', choices=['pytorch','trtllm'])

and pass backend=args.backend to LLM().


120-129: Guard output indexing to avoid IndexError on empty candidates.

Use a safe fallback if no outputs returned.

-    for idx, output in enumerate(outputs):
-        print(
-            f'Generated text: {output.outputs[0].text!r}, ref: {reference[idx]}'
-        )
+    for idx, output in enumerate(outputs):
+        text = output.outputs[0].text if output.outputs else ""
+        print(f'Generated text: {text!r}, ref: {reference[idx]}')
cpp/tensorrt_llm/nanobind/common/customCasters.h (1)

29-29: Remove unused ArrayRef include
The header cpp/tensorrt_llm/nanobind/common/customCasters.h contains no references to c10::ArrayRef; drop the #include <c10/util/ArrayRef.h> to avoid an unnecessary dependency.

cpp/tests/unit_tests/kernels/sparseAttentionKernelsTest.cpp (1)

15-15: Use CamelCase for class name.

The test fixture class name sparseAttentionKernelsTest uses lowerCamelCase, but coding guidelines require CamelCase for type names (classes).

Apply this diff:

-class sparseAttentionKernelsTest : public ::testing::Test
+class SparseAttentionKernelsTest : public ::testing::Test

And update the test macro on line 31:

-TEST_F(sparseAttentionKernelsTest, GatherKvPageOffsetsKernelTest)
+TEST_F(SparseAttentionKernelsTest, GatherKvPageOffsetsKernelTest)

As per coding guidelines.

tests/unittest/_torch/attention/sparse/test_rocketkv.py (1)

41-54: Consider using Path for cleaner path construction.

The nested os.path.dirname calls work but are less readable. Also, reusing test_star_attention_input.jsonl for RocketKV tests may cause confusion—consider clarifying or renaming if these are specifically RocketKV test inputs.

Apply this diff for cleaner path handling:

-    current_file = os.path.abspath(__file__)
-    current_dir = os.path.dirname(os.path.dirname(
-        os.path.dirname(current_file)))
-    input_file = f'{current_dir}/multi_gpu/test_star_attention_input.jsonl'
+    from pathlib import Path
+    current_dir = Path(__file__).resolve().parent.parent.parent
+    input_file = current_dir / 'multi_gpu' / 'test_star_attention_input.jsonl'
tensorrt_llm/_torch/attention_backend/sparse/__init__.py (1)

6-11: Consider sorting __all__ for consistency.

The static analyzer suggests sorting the __all__ list. While this is a minor style issue, alphabetically sorting public exports improves maintainability and follows common Python conventions.

Apply this diff:

 __all__ = [
+    "get_flashinfer_sparse_attn_attention_backend",
     "get_sparse_attn_kv_cache_manager",
-    "get_vanilla_sparse_attn_attention_backend",
     "get_trtllm_sparse_attn_attention_backend",
-    "get_flashinfer_sparse_attn_attention_backend",
+    "get_vanilla_sparse_attn_attention_backend",
 ]
tensorrt_llm/llmapi/llm_args.py (1)

192-197: Remove unused method parameter backend.

The supports_backend method in SparseAttentionBaseConfig always returns True and doesn't use the backend parameter. Either implement backend validation or remove the parameter if all backends are universally supported.

If backend validation is not needed in the base class:

-    def supports_backend(self, backend: str) -> bool:
+    def supports_backend(self, backend: str = None) -> bool:
         """
         Override if the speculation algorithm does not support
         a subset of the possible backends.
         """
         return True

Or implement actual validation if certain backends should be restricted.

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (1)

256-257: Document and safeguard VANILLA backend’s tokens_per_block override
The if attn_backend == "VANILLA": tokens_per_block = max_num_tokens override is intentional (estimation is disabled in _util.py lines 216–221), but should be explicitly documented at py_executor_creator.py:256–257 and supplemented with an OOM safeguard or configurable warning when max_num_tokens is large.

tensorrt_llm/_torch/pyexecutor/model_engine.py (1)

268-271: Use the stored config for consistency.

Pass self.sparse_attention_config, not the local arg name.

-        self.attn_backend = get_attention_backend(
-            pytorch_backend_config.attn_backend,
-            sparse_attn_config=sparse_attention_config)
+        self.attn_backend = get_attention_backend(
+            pytorch_backend_config.attn_backend,
+            sparse_attn_config=self.sparse_attention_config)
cpp/tensorrt_llm/kernels/sparseAttentionKernels.h (1)

11-32: Prefer const-correct pointers in parameter carrier.

These arrays are read-only at call sites and in kernels; mark as const to express intent and enable more compiler checks. Adjust downstream uses if needed.

-    int32_t* sparse_kv_indices{nullptr};
-    int32_t* sparse_attn_indices{nullptr};
-    int32_t* sparse_kv_offsets{nullptr};
-    int32_t* sparse_attn_offsets{nullptr};
+    int32_t const* sparse_kv_indices{nullptr};
+    int32_t const* sparse_attn_indices{nullptr};
+    int32_t const* sparse_kv_offsets{nullptr};
+    int32_t const* sparse_attn_offsets{nullptr};
cpp/tensorrt_llm/kernels/sparseAttentionKernels.cu (1)

105-114: Avoid allocating unused dynamic shared memory.

You already use typed shared TempStorage; setting smem_size > 0 reduces occupancy with no benefit.

-    // Shared memory size.
-    size_t smem_size = sizeof(Pair) * 256;
+    // No dynamic shared memory needed; BlockReduce uses statically-declared storage.
+    size_t smem_size = 0;
 
     // Launch the kernel.
     gatherKvPageOffsetsKernel<256><<<grid, block, smem_size, stream>>>(output_kv_page_offsets, output_seq_lengths,
         kv_page_offsets, seq_lengths, sparse_params, batch_size, tokens_per_page, max_num_pages_per_seq);
cpp/tests/unit_tests/kernels/sparseKvCacheTest.cu (3)

27-27: Avoid file-scope using directives.

File-scope using namespace can cause name collisions. Prefer explicit qualification or limit scope.

Apply this diff to qualify usage explicitly or move into narrower scope:

-using namespace tensorrt_llm::kernels;
+// Remove and qualify types explicitly, e.g., tensorrt_llm::kernels::KVBlockArray

168-184: Consider checking CUDA errors in cleanup.

The cleanup() function does not check return values from cudaFree. While cleanup in tests may silently ignore errors, consider using TLLM_CUDA_CHECK for consistency and to detect issues during development.

Apply this diff to add error checking:

 void cleanup()
 {
     if (mSparseKvIndicesDevice)
-        cudaFree(mSparseKvIndicesDevice);
+        TLLM_CUDA_CHECK(cudaFree(mSparseKvIndicesDevice));
     if (mSparseKvOffsetsDevice)
-        cudaFree(mSparseKvOffsetsDevice);
+        TLLM_CUDA_CHECK(cudaFree(mSparseKvOffsetsDevice));
     // ... (repeat for all cudaFree calls)
 }

366-426: Remove magic number in sparse index calculation.

Line 400 hardcodes 8 (total sparse tokens). This should be calculated dynamically or use a named constant to avoid breaking if test parameters change.

Apply this diff:

+    int const total_sparse_tokens = hostSparseOffsets[mBatchSize];
     // ...
-    int const sparse_idx_offset = head * 8 + global_sparse_idx; // 8 is total sparse tokens
+    int const sparse_idx_offset = head * total_sparse_tokens + global_sparse_idx;
cpp/tensorrt_llm/common/attentionOp.cpp (3)

916-938: XQA sparse workspace sizing: make arithmetic size_t‑safe and document layout

Current formula packs two buffers into one size. Cast early to size_t to avoid 32‑bit overflow and clarify intent in code comments.

Apply:

-        int const XQA_NUM_BUFFERS = 8;
+        int const XQA_NUM_BUFFERS = 8;
@@
-        // Two workspaces for sparse attention. One for the sequence lengths, and one for kv block offsets.
-        size_t const sparse_attn_cache_size = useTllmGenSparseAttention()
-            ? sizeof(int) * (batch_beam + batch_beam * 2 * max_blocks_per_sequence) * mNumKVHeads
-            : 0;
+        // Two workspaces for sparse attention packed contiguously:
+        // [seq_lens_kv_per_head (batch_beam)] + [kv_page_offsets_per_head (batch_beam * 2 * max_blocks_per_sequence)]
+        size_t const sparse_attn_cache_size = useTllmGenSparseAttention()
+            ? static_cast<size_t>(sizeof(int))
+                * static_cast<size_t>(mNumKVHeads)
+                * (static_cast<size_t>(batch_beam)
+                   + static_cast<size_t>(batch_beam) * 2u * static_cast<size_t>(max_blocks_per_sequence))
+            : 0;

1692-1694: is_last_chunk condition likely wrong for chunked prefill

Comparing input_seq_length to max_past_kv_length doesn’t detect “last chunk” when prefill is split; it flags true only when there’s a single chunk. Consider a condition based on cumulative processed length (or an explicit flag from the caller).


1860-1865: Postprocess should be gated on KV cache

invokeKvCachePostprocessing() should be a no‑op when KV cache is disabled; guard explicitly to avoid unnecessary work.

Apply:

-        if (!mIsMLAEnabled) // Only for non-MLA attention
-        {
-            invokeKvCachePostprocessing(preprocessingParams, stream);
-            sync_check_cuda_error(stream);
-        }
+        if (!mIsMLAEnabled && useKVCache())
+        {
+            invokeKvCachePostprocessing(preprocessingParams, stream);
+            sync_check_cuda_error(stream);
+        }
tensorrt_llm/_torch/attention_backend/vanilla.py (1)

193-195: Docstring/return order mismatch

Docstring says (is_causal, attn_mask) but function returns (attn_mask, is_causal). Align the docstring to avoid confusion.

Apply:

-        Returns:
-            Tuple of (is_causal, attn_mask)
+        Returns:
+            Tuple of (attn_mask, is_causal)

Also applies to: 222-222

@lfr-0531 lfr-0531 force-pushed the feat/sparse_attention_func branch 2 times, most recently from a9917b0 to 4489018 Compare October 5, 2025 15:14
@lfr-0531
Copy link
Collaborator Author

lfr-0531 commented Oct 5, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20651 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20651 [ run ] completed with state FAILURE

@juney-nvidia juney-nvidia changed the title [None][feat] add RocketKV support (experimental) [None][feat] Add the sparse attention framework and one use case--RocketKV support Oct 7, 2025
@lfr-0531
Copy link
Collaborator Author

lfr-0531 commented Oct 7, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20738 [ run ] triggered by Bot

@heyuhhh
Copy link
Collaborator

heyuhhh commented Oct 13, 2025

/bot kill

@heyuhhh
Copy link
Collaborator

heyuhhh commented Oct 13, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #21197 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #16001 completed with status: 'FAILURE'

heyuhhh and others added 11 commits October 13, 2025 08:26
add abstract level for sparse attention

minor code adjustment

change RocketKVCacheManager to one that inherts from KVCacheManager.

fix format.

Signed-off-by: Fanrong Li <[email protected]>

refactor sparse attention backend.

Signed-off-by: Fanrong Li <[email protected]>

fix format.

update code design and details

fix some accuracy bugs

minor test

Signed-off-by: yuhangh <[email protected]>

fix bugs when seq_len<budget

Signed-off-by: yuhangh <[email protected]>

add longbench evaluation

Signed-off-by: yuhangh <[email protected]>

Fix accuracy issues and abstract the forward logic

Signed-off-by: yuhangh <[email protected]>

fuse fanrong's updates: added gather kernel

Signed-off-by: yuhangh <[email protected]>

rm some unrelated changes.

Signed-off-by: Fanrong Li <[email protected]>

fix vanilla dense and sparse attention.

Signed-off-by: Fanrong Li <[email protected]>

fix rocket.

Signed-off-by: Fanrong Li <[email protected]>

fix.

Signed-off-by: Fanrong Li <[email protected]>

remove sparse attention base class and refactor vanilla.

Signed-off-by: Fanrong Li <[email protected]>

fix seq_len.

Signed-off-by: Fanrong Li <[email protected]>

refactor get_cache_size_per_token for kvCacheManager to support new added sparse manager.

Signed-off-by: Fanrong Li <[email protected]>

Add longbench evaluation scripts

Signed-off-by: yuhangh <[email protected]>

disable estimate kv cache for vanilla attention backend.

Signed-off-by: Fanrong Li <[email protected]>

fix longbench README.

Signed-off-by: Fanrong Li <[email protected]>

fix longbench_v1.

Signed-off-by: Fanrong Li <[email protected]>
Signed-off-by: yuhangh <[email protected]>

Update sparse attention parameters passing logic

Signed-off-by: yuhangh <[email protected]>

fix rebase breaks

Signed-off-by: yuhangh <[email protected]>
Signed-off-by: Fanrong Li <[email protected]>

[None][feat] add gatherKvPageOffsetsKernel (#32)

* add gatherKvPageOffsetsKernel.

Signed-off-by: Fanrong Li <[email protected]>

* fix.

Signed-off-by: Fanrong Li <[email protected]>

* fix.

Signed-off-by: Fanrong Li <[email protected]>

---------

Signed-off-by: Fanrong Li <[email protected]>

Add sparse kv indices write kernel & fix several bugs

Signed-off-by: yuhangh <[email protected]>

fix for rebase

Signed-off-by: yuhangh <[email protected]>
Signed-off-by: Fanrong Li <[email protected]>

[None][feat] integrate block sparse attention kernels (#33)

* integrate block sparse attention kernels.

Signed-off-by: Fanrong Li <[email protected]>

* fix.

Signed-off-by: Fanrong Li <[email protected]>

* Support num_kv_heads in seq_len & fix several workspace size bugs

Signed-off-by: yuhangh <[email protected]>

* update block sparse attention kernel to support per-head kv_len.

Signed-off-by: Fanrong Li <[email protected]>

* minor fix

Signed-off-by: yuhangh <[email protected]>

* update kernel meta info.

* add more block sparse kernels.

* disable rope_fusion for sparse attention.

Signed-off-by: Fanrong Li <[email protected]>

* fix block sparse attention kernels.

* update block sparse attention kernel.

Signed-off-by: Fanrong Li <[email protected]>

* fix workspace issue.

Signed-off-by: Fanrong Li <[email protected]>

* minor fix

Signed-off-by: yuhangh <[email protected]>

* fix gatherKvPageOffsetsKernel.

Signed-off-by: Fanrong Li <[email protected]>

* remove cuda stream sync.

Signed-off-by: Fanrong Li <[email protected]>

---------

Signed-off-by: Fanrong Li <[email protected]>
Signed-off-by: yuhangh <[email protected]>
Co-authored-by: yuhangh <[email protected]>
Signed-off-by: Fanrong Li <[email protected]>

[None][feat] change the sparse indices format and update the gatherKvPageOffsetsKe… (#34)

* change the sparse indices format and update the gatherKvPageOffsetsKernel.

Signed-off-by: Fanrong Li <[email protected]>

* update kv write & optimize logic of using tllmgen kernels

Signed-off-by: yuhangh <[email protected]>

---------

Signed-off-by: Fanrong Li <[email protected]>
Signed-off-by: yuhangh <[email protected]>
Co-authored-by: yuhangh <[email protected]>
Signed-off-by: Fanrong Li <[email protected]>

add paged kt cache (1st commit).

Signed-off-by: Fanrong Li <[email protected]>

minnor fix.

Signed-off-by: Fanrong Li <[email protected]>

fix _single_request_update_kt_cache for vanilla RocketKV.

Signed-off-by: Fanrong Li <[email protected]>

add paged kt cache to rocketkv trtllm.

Signed-off-by: Fanrong Li <[email protected]>

fix _single_request_update_kt_cache for trtllm RocketKV.

Signed-off-by: Fanrong Li <[email protected]>

fix k_snap length.

Signed-off-by: Fanrong Li <[email protected]>

fix memory issue when using paged kt cache.

Signed-off-by: Fanrong Li <[email protected]>

fix rebase breaks

Signed-off-by: yuhangh <[email protected]>

fix rebase bug.

Signed-off-by: Fanrong Li <[email protected]>

fix rebase bug.

Signed-off-by: Fanrong Li <[email protected]>

update block sparse attention kernel.

Signed-off-by: Fanrong Li <[email protected]>

fix params issue

Signed-off-by: yuhangh <[email protected]>
Signed-off-by: Fanrong Li <[email protected]>

[None][feat] Do sparse attention functional clean (#43)

* fix several bugs & adjust some code

Signed-off-by: yuhangh <[email protected]>

* minor code clean

Signed-off-by: yuhangh <[email protected]>

* Add simple unittest for rocketkv

Signed-off-by: yuhangh <[email protected]>

* Adjustment for sparse attention  params and example

Signed-off-by: yuhangh <[email protected]>

* fix bugs introduced by last commit

Signed-off-by: yuhangh <[email protected]>

* Optimize Xqa_params and num_sparse_kv_tokens

Signed-off-by: yuhangh <[email protected]>

* Fix gather kernel & minor adjustment

Signed-off-by: yuhangh <[email protected]>

* Rename sparse_attention_params in xqa_params

Signed-off-by: yuhangh <[email protected]>

* minor

Signed-off-by: yuhangh <[email protected]>

---------

Signed-off-by: yuhangh <[email protected]>
Signed-off-by: Fanrong Li <[email protected]>

[None][feat] Update trtllm-gen fmha kernels and remove block sparse cubins (#44)

* rm sparse kernels.

Signed-off-by: Fanrong Li <[email protected]>

* update new kernel.

Signed-off-by: Fanrong Li <[email protected]>

* update trtllm-gen fmha.

Signed-off-by: Fanrong Li <[email protected]>

---------

Signed-off-by: Fanrong Li <[email protected]>

fix rebase conflicts

Signed-off-by: yuhangh <[email protected]>

minor fix

Signed-off-by: yuhangh <[email protected]>

pre-commit fix

Signed-off-by: yuhangh <[email protected]>
Signed-off-by: Fanrong Li <[email protected]>

[None][fix] update trtllm sparse attention interface (#45)

* update trtllm sparse attention interface.

Signed-off-by: Fanrong Li <[email protected]>

* fix interface.

Signed-off-by: Fanrong Li <[email protected]>

---------

Signed-off-by: Fanrong Li <[email protected]>

fix rocketkv interface. (#47)

Signed-off-by: Fanrong Li <[email protected]>
Signed-off-by: yuhangh <[email protected]>
Signed-off-by: yuhangh <[email protected]>
Signed-off-by: yuhangh <[email protected]>
Signed-off-by: yuhangh <[email protected]>
Signed-off-by: yuhangh <[email protected]>
Signed-off-by: yuhangh <[email protected]>
Signed-off-by: Fanrong Li <[email protected]>
@lfr-0531 lfr-0531 force-pushed the feat/sparse_attention_func branch from 03bfa99 to d6a1c17 Compare October 13, 2025 16:01
@lfr-0531
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #21238 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #21238 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #16030 completed with status: 'FAILURE'

@lfr-0531
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #21300 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #21300 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #16080 completed with status: 'FAILURE'

@lfr-0531
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #21336 [ run ] triggered by Bot

Copy link
Collaborator

@QiJune QiJune left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@lfr-0531 lfr-0531 enabled auto-merge (squash) October 14, 2025 12:39
@tensorrt-cicd
Copy link
Collaborator

PR_Github #21336 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #16105 completed with status: 'SUCCESS'

@lfr-0531 lfr-0531 merged commit 0d20a8f into NVIDIA:main Oct 14, 2025
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants