Skip to content

Conversation

@bkryu
Copy link
Collaborator

@bkryu bkryu commented Oct 10, 2025

📌 Description

Current PR fixes the test and benchmark codes IMAs when running trtllm-gen paged & ragged prefill with batch size 1 -- the issue was described in #1898

Root cause of the issue: flashinfer.prefill.trtllm_ragged_attention_deepseek and flashinfer.prefill.trtllm_batch_context_with_kv_cache both require max_q_len to match the length of the query when batch size is 1.

Updated PR:
Issue has been addressed from the kernel-side so that the "max_q_len to match the length of the query when batch size is 1" is no longer required.

Current PR updates trtllm-gen FMHA cubins to latest and brings minor updates to kernel metadata.

Unit test results after PR:

$ pytest tests/attention/test_trtllm_gen_attention.py 
...
platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0
rootdir: /flashinfer
configfile: pytest.ini
collected 2320 items   
...
2055 passed, 264 skipped, 1 xfailed in 224.43s (0:03:44)

Description of previous solution:
Updating max_q_len to cum_seq_lens_q[-1].item() within the trtllm_ragged_attention_deepseek or trtllm_batch_context_with_kv_cache functions are not a viable option because the CPU-side synchronization breaks the deterministic and fully device-side execution required during CUDA graph capture. The workaround was thus to update the test & benchmark codes that call the trtllm prefill functions, and clearly state in the docstring that when batch_size == 1, max_q_len must match the query size.

🔍 Related Issues

#1898

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Bug Fixes

    • Enabled an additional native attention backend for batch_size=1 where other constraints allow, expanding supported scenarios.
  • New Features

    • Added a kernel parameter to enable future block-sparse attention support (no change to current behavior).
  • Documentation

    • Updated benchmarks docs to note support for an additional attention optimization/backend.
  • Tests

    • Expanded test coverage with configurable sequence lengths and a dedicated batch-size-1 test; removed a previous expected-failure for that path.

@bkryu bkryu self-assigned this Oct 10, 2025
kv_cache = torch.cat([k_fp8, v_fp8], dim=1)

if batch_size == 1:
# trtllm kernel requires max_q_len to be the same as the seqlen of the query when batch_size=1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why qo_indptr[-1] could be different to s_qo, is it because we want to be compatible with cudagraphs and s_qo will always be the maximum length?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Short answer is yes.

Longer answer: In a batch_size > 1 situation, the CUDA graph containing prefill.trtllm_batch_context_with_kv_cache() can be reused with multiple sequence lengths but not when batch_size==1. For example,

  • If batch_size is 3 and we have two batches with query lengths [100, 200, 300] and [16, 500, 1024], we can set s_qo=1024, when we construct the CUDA graph and use the same CUDA graph for the two batches.
  • However for batch_size=1, where we have batches of query lengths [100] and [1024], a CUDA graph must be constructed each time -- first with s_qo=100 and second with s_qo=1024.

Not sure whether the above is a real concern at the framework level. Nevertheless, s_qo goes in as the max_q_len input argument where it is the max sequence length for query. We may at least want to consider whether the wording in the documentation is clear 😄

@bkryu bkryu force-pushed the trtllm-attention-debug branch from 4dade1b to 197a7a0 Compare October 16, 2025 17:23
@yzh119
Copy link
Collaborator

yzh119 commented Oct 16, 2025

Hi @bkryu does upgrading to latest trtllm-gen fixing the issue?

@bkryu
Copy link
Collaborator Author

bkryu commented Oct 16, 2025

Hi @bkryu does upgrading to latest trtllm-gen fixing the issue?

Hi @yzh119, I'm currently checking. Upgrading to the latest trtllm-gen does fix the batch size 1 unit test, but I am seeing some errors in otherplaces. Will verify what is happening before marking the PR as ready

@bkryu
Copy link
Collaborator Author

bkryu commented Oct 16, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !83 has been created, and the CI pipeline #36750562 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #36750562: 1/17 passed

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 17, 2025

Walkthrough

Re-enables trtllm-gen-native for batch_size==1 in benchmark routines, updates three TRTLLM_GEN_FMHA artifact hash constants, adds mUseBlockSparseAttention to KernelParams, extends attention tests with sequence-length parameters and a bs1 test, and updates benchmark documentation to mention deepseek support.

Changes

Cohort / File(s) Summary
Backend constraint change
benchmarks/routines/attention.py
Removed the automatic skip that excluded trtllm-gen-native when batch_size == 1 in testBatchPrefillWithPagedKVCacheWrapper and testBatchPrefillWithRaggedKVCacheWrapper; other constraints remain.
Documentation
benchmarks/README.md
Added a bullet noting BatchPrefillWithRaggedKVCacheWrapper also supports trtllm_ragged_attention_deepseek (in addition to existing options).
Artifact hash updates
flashinfer/artifacts.py
Updated three string constants for TRTLLM_GEN_FMHA: ArtifactPath.TRTLLM_GEN_FMHA, MetaInfoHash.TRTLLM_GEN_FMHA, and CheckSumHash.TRTLLM_GEN_FMHA (values changed).
Kernel parameters
include/flashinfer/trtllm/fmha/kernelParams.h
Added public data member bool mUseBlockSparseAttention to KernelParams and set it to false in setKernelParams; included a TODO for future block-sparse integration.
Tests
tests/attention/test_trtllm_gen_attention.py
Parameterized test_trtllm_batch_prefill with max_q_len and max_kv_len, added test_trtllm_batch_prefill_bs1 (delegates to the parametrized test), and removed an explicit pytest.xfail for a bs1 deepseek case.

Sequence Diagram(s)

sequenceDiagram
  participant Test as Test Runner
  participant Bench as Benchmark Routine
  participant Selector as Backend Selector
  participant Backend as trtllm-gen-native
  participant Kernel as KernelParams

  Note over Test,Bench: Parametrized batch-prefill tests (including bs==1)

  Test->>Bench: invoke testBatchPrefill(max_q_len, max_kv_len, ...)
  Bench->>Selector: request eligible backends (batch_size considered)
  Note right of Selector: bs==1 no longer auto-skipped
  Selector-->>Backend: select trtllm-gen-native when constraints met
  Backend->>Kernel: build KernelParams (mUseBlockSparseAttention = false)
  Kernel-->>Backend: return params
  Backend-->>Bench: execute prefill using returned params
  Bench-->>Test: return results
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related issues

Suggested reviewers

  • yzh119

Poem

🐇 I hopped through hashes, nudged a test to try,
Reopened a path where a lone batch may fly,
A flag set false, a TODO kept near,
Sequences extend — the benches cheer,
Hop on, small kernel, the trail is clear!

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 40.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title Check ✅ Passed The title "fix: Fix trtllm-gen prefill IMA when batch_size==1" directly and clearly describes the main objective of the pull request. The changes focus on resolving an issue where trtllm-gen prefill functions failed when batch_size==1, with fixes applied at both the kernel level (updated FMHA cubins and kernel metadata) and the test/benchmark code level (updated parameterization to remove the batch_size==1 constraint). The title is concise, specific, and accurately reflects the primary purpose without vague terminology.
Description Check ✅ Passed The pull request description comprehensively follows the provided template structure. The 📌 Description section clearly explains what the PR does, including the root cause, kernel-side resolution, and specific updates to FMHA cubins and kernel metadata, along with test results. The 🔍 Related Issues section properly links to issue #1898. The 🚀 Pull Request Checklist is fully completed with both Pre-commit Checks and Tests sections marked as done [x], and the optional Reviewer Notes section is present. All major required sections are present with substantive information rather than placeholder text.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Comment @coderabbitai help to get the list of available commands and usage tips.

@bkryu
Copy link
Collaborator Author

bkryu commented Oct 17, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !83 has been updated with latest changes, and the CI pipeline #36805526 is currently running. I'll report back once the pipeline job completes.

@bkryu bkryu marked this pull request as ready for review October 17, 2025 16:59
@bkryu bkryu changed the title fix: Fix test and benchmark for trtllm-gen prefill batch size 1 fix: Fix trtllm-gen prefill IMA when batch_size==1 Oct 17, 2025
@bkryu bkryu force-pushed the trtllm-attention-debug branch from 57e47ea to 003ef55 Compare October 17, 2025 21:36
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
benchmarks/README.md (1)

19-19: LGTM! Documentation correctly updated.

The documentation now accurately reflects that BatchPrefillWithRaggedKVCacheWrapper supports trtllm_ragged_attention_deepseek for ragged attention operations.

Optional: Fix list indentation for consistency.

The static analysis tool flags that this line uses 8 spaces for indentation instead of the expected 4 for nested list items.

-        - Also supports computationally similar `cudnn_batch_prefill_with_kv_cache` and  `trtllm_ragged_attention_deepseek`.
+    - Also supports computationally similar `cudnn_batch_prefill_with_kv_cache` and  `trtllm_ragged_attention_deepseek`.
tests/attention/test_trtllm_gen_attention.py (1)

348-361: LGTM! Function signature correctly updated.

The new max_q_len and max_kv_len parameters are properly integrated into the function signature and correctly passed to generate_seq_lens_prefill.

Optional: Prefix unused variable with underscore.

Line 360 unpacks in_kv_lens from generate_seq_lens_prefill, but the variable is never used in the function body. Consider prefixing it with an underscore to indicate it's intentionally unused:

-    q_lens, in_kv_lens, seq_lens = generate_seq_lens_prefill(
+    q_lens, _in_kv_lens, seq_lens = generate_seq_lens_prefill(
         batch_size, max_q_len, max_kv_len
     )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 57e47ea and 003ef55.

📒 Files selected for processing (5)
  • benchmarks/README.md (1 hunks)
  • benchmarks/routines/attention.py (0 hunks)
  • flashinfer/artifacts.py (3 hunks)
  • include/flashinfer/trtllm/fmha/kernelParams.h (2 hunks)
  • tests/attention/test_trtllm_gen_attention.py (3 hunks)
💤 Files with no reviewable changes (1)
  • benchmarks/routines/attention.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • include/flashinfer/trtllm/fmha/kernelParams.h
  • flashinfer/artifacts.py
🧰 Additional context used
🧬 Code graph analysis (1)
tests/attention/test_trtllm_gen_attention.py (1)
flashinfer/utils.py (1)
  • get_compute_capability (251-254)
🪛 markdownlint-cli2 (0.18.1)
benchmarks/README.md

19-19: Unordered list indentation
Expected: 4; Actual: 8

(MD007, ul-indent)

🪛 Ruff (0.14.0)
tests/attention/test_trtllm_gen_attention.py

360-360: Unpacked variable in_kv_lens is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (2)
tests/attention/test_trtllm_gen_attention.py (2)

334-335: LGTM! Parameterization enhances test flexibility.

Adding max_q_len and max_kv_len as test parameters allows testing different sequence length combinations, which is essential for validating the batch_size==1 fix across various configurations.


530-578: LGTM! Dedicated batch_size=1 test addresses PR objective.

The new test_trtllm_batch_prefill_bs1 function specifically tests the batch_size==1 scenario with large sequence lengths (8192), which directly addresses the issue described in #1898. The test properly delegates to the main test function with appropriate parameters and minimal configuration to focus on the batch_size==1 edge case.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #36805526: 1/17 passed

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@yzh119
Copy link
Collaborator

yzh119 commented Oct 20, 2025

Failed UT are not relevant and are scheduled to be fixed in #1950 and #1953

@yzh119 yzh119 enabled auto-merge (squash) October 20, 2025 07:11
@bkryu
Copy link
Collaborator Author

bkryu commented Oct 20, 2025

@nvmbreughe , can I get a review on the PR? Zihao and Perkz already approved, but due to code owner review requirements, it seems like I need a review from you

Copy link
Contributor

@nvmbreughe nvmbreughe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@bkryu bkryu force-pushed the trtllm-attention-debug branch from fa62171 to 1b7f9e8 Compare October 20, 2025 22:43
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
tests/attention/test_trtllm_gen_attention.py (1)

360-362: Optional: Consider using underscore for unused unpacked variable.

The in_kv_lens variable is unpacked but never used. Consider using _ to signal this is intentional:

-    q_lens, in_kv_lens, seq_lens = generate_seq_lens_prefill(
-        batch_size, max_q_len, max_kv_len
-    )
+    q_lens, _, seq_lens = generate_seq_lens_prefill(
+        batch_size, max_q_len, max_kv_len
+    )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 003ef55 and 1b7f9e8.

📒 Files selected for processing (5)
  • benchmarks/README.md (1 hunks)
  • benchmarks/routines/attention.py (0 hunks)
  • flashinfer/artifacts.py (3 hunks)
  • include/flashinfer/trtllm/fmha/kernelParams.h (2 hunks)
  • tests/attention/test_trtllm_gen_attention.py (3 hunks)
💤 Files with no reviewable changes (1)
  • benchmarks/routines/attention.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • flashinfer/artifacts.py
  • include/flashinfer/trtllm/fmha/kernelParams.h
🧰 Additional context used
🧬 Code graph analysis (1)
tests/attention/test_trtllm_gen_attention.py (1)
flashinfer/utils.py (1)
  • get_compute_capability (251-254)
🪛 markdownlint-cli2 (0.18.1)
benchmarks/README.md

19-19: Unordered list indentation
Expected: 4; Actual: 8

(MD007, ul-indent)

🪛 Ruff (0.14.1)
tests/attention/test_trtllm_gen_attention.py

360-360: Unpacked variable in_kv_lens is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🔇 Additional comments (3)
benchmarks/README.md (1)

19-19: LGTM! Documentation correctly reflects deepseek support.

The addition of trtllm_ragged_attention_deepseek to the supported operations for BatchPrefillWithRaggedKVCacheWrapper is accurate and aligns with the PR's objective to fix and enable trtllm-gen prefill with deepseek attention.

tests/attention/test_trtllm_gen_attention.py (2)

334-335: Good parameterization for flexible sequence length testing.

Adding max_q_len and max_kv_len parameters allows testing different sequence length scenarios while preserving existing test behavior with sensible defaults. This enables the new batch_size=1 test to use longer sequences.

Also applies to: 348-349, 361-361


530-578: Well-targeted regression test for batch_size=1 with large sequences.

This test specifically validates the kernel fix for batch_size==1 by using large sequence lengths (8192+8192=16384 total), which was the failing scenario described in issue #1898. The narrow parameter space is appropriate for a focused regression test.

@yzh119 yzh119 merged commit c3f2596 into flashinfer-ai:main Oct 21, 2025
4 checks passed
@bkryu bkryu deleted the trtllm-attention-debug branch October 23, 2025 20:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants