-
Notifications
You must be signed in to change notification settings - Fork 542
fix: Fix trtllm-gen prefill IMA when batch_size==1 #1912
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
benchmarks/routines/attention.py
Outdated
| kv_cache = torch.cat([k_fp8, v_fp8], dim=1) | ||
|
|
||
| if batch_size == 1: | ||
| # trtllm kernel requires max_q_len to be the same as the seqlen of the query when batch_size=1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why qo_indptr[-1] could be different to s_qo, is it because we want to be compatible with cudagraphs and s_qo will always be the maximum length?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Short answer is yes.
Longer answer: In a batch_size > 1 situation, the CUDA graph containing prefill.trtllm_batch_context_with_kv_cache() can be reused with multiple sequence lengths but not when batch_size==1. For example,
- If batch_size is 3 and we have two batches with query lengths
[100, 200, 300]and[16, 500, 1024], we can sets_qo=1024, when we construct the CUDA graph and use the same CUDA graph for the two batches. - However for batch_size=1, where we have batches of query lengths
[100]and[1024], a CUDA graph must be constructed each time -- first withs_qo=100and second withs_qo=1024.
Not sure whether the above is a real concern at the framework level. Nevertheless, s_qo goes in as the max_q_len input argument where it is the max sequence length for query. We may at least want to consider whether the wording in the documentation is clear 😄
4dade1b to
197a7a0
Compare
|
Hi @bkryu does upgrading to latest trtllm-gen fixing the issue? |
|
/bot run |
|
[FAILED] Pipeline #36750562: 1/17 passed |
WalkthroughRe-enables trtllm-gen-native for batch_size==1 in benchmark routines, updates three TRTLLM_GEN_FMHA artifact hash constants, adds Changes
Sequence Diagram(s)sequenceDiagram
participant Test as Test Runner
participant Bench as Benchmark Routine
participant Selector as Backend Selector
participant Backend as trtllm-gen-native
participant Kernel as KernelParams
Note over Test,Bench: Parametrized batch-prefill tests (including bs==1)
Test->>Bench: invoke testBatchPrefill(max_q_len, max_kv_len, ...)
Bench->>Selector: request eligible backends (batch_size considered)
Note right of Selector: bs==1 no longer auto-skipped
Selector-->>Backend: select trtllm-gen-native when constraints met
Backend->>Kernel: build KernelParams (mUseBlockSparseAttention = false)
Kernel-->>Backend: return params
Backend-->>Bench: execute prefill using returned params
Bench-->>Test: return results
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related issues
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Comment |
|
/bot run |
57e47ea to
003ef55
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
benchmarks/README.md (1)
19-19: LGTM! Documentation correctly updated.The documentation now accurately reflects that
BatchPrefillWithRaggedKVCacheWrappersupportstrtllm_ragged_attention_deepseekfor ragged attention operations.Optional: Fix list indentation for consistency.
The static analysis tool flags that this line uses 8 spaces for indentation instead of the expected 4 for nested list items.
- - Also supports computationally similar `cudnn_batch_prefill_with_kv_cache` and `trtllm_ragged_attention_deepseek`. + - Also supports computationally similar `cudnn_batch_prefill_with_kv_cache` and `trtllm_ragged_attention_deepseek`.tests/attention/test_trtllm_gen_attention.py (1)
348-361: LGTM! Function signature correctly updated.The new
max_q_lenandmax_kv_lenparameters are properly integrated into the function signature and correctly passed togenerate_seq_lens_prefill.Optional: Prefix unused variable with underscore.
Line 360 unpacks
in_kv_lensfromgenerate_seq_lens_prefill, but the variable is never used in the function body. Consider prefixing it with an underscore to indicate it's intentionally unused:- q_lens, in_kv_lens, seq_lens = generate_seq_lens_prefill( + q_lens, _in_kv_lens, seq_lens = generate_seq_lens_prefill( batch_size, max_q_len, max_kv_len )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
benchmarks/README.md(1 hunks)benchmarks/routines/attention.py(0 hunks)flashinfer/artifacts.py(3 hunks)include/flashinfer/trtllm/fmha/kernelParams.h(2 hunks)tests/attention/test_trtllm_gen_attention.py(3 hunks)
💤 Files with no reviewable changes (1)
- benchmarks/routines/attention.py
🚧 Files skipped from review as they are similar to previous changes (2)
- include/flashinfer/trtllm/fmha/kernelParams.h
- flashinfer/artifacts.py
🧰 Additional context used
🧬 Code graph analysis (1)
tests/attention/test_trtllm_gen_attention.py (1)
flashinfer/utils.py (1)
get_compute_capability(251-254)
🪛 markdownlint-cli2 (0.18.1)
benchmarks/README.md
19-19: Unordered list indentation
Expected: 4; Actual: 8
(MD007, ul-indent)
🪛 Ruff (0.14.0)
tests/attention/test_trtllm_gen_attention.py
360-360: Unpacked variable in_kv_lens is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (2)
tests/attention/test_trtllm_gen_attention.py (2)
334-335: LGTM! Parameterization enhances test flexibility.Adding
max_q_lenandmax_kv_lenas test parameters allows testing different sequence length combinations, which is essential for validating the batch_size==1 fix across various configurations.
530-578: LGTM! Dedicated batch_size=1 test addresses PR objective.The new
test_trtllm_batch_prefill_bs1function specifically tests the batch_size==1 scenario with large sequence lengths (8192), which directly addresses the issue described in #1898. The test properly delegates to the main test function with appropriate parameters and minimal configuration to focus on the batch_size==1 edge case.
|
[FAILED] Pipeline #36805526: 1/17 passed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
@nvmbreughe , can I get a review on the PR? Zihao and Perkz already approved, but due to code owner review requirements, it seems like I need a review from you |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
…special treatments
fa62171 to
1b7f9e8
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tests/attention/test_trtllm_gen_attention.py (1)
360-362: Optional: Consider using underscore for unused unpacked variable.The
in_kv_lensvariable is unpacked but never used. Consider using_to signal this is intentional:- q_lens, in_kv_lens, seq_lens = generate_seq_lens_prefill( - batch_size, max_q_len, max_kv_len - ) + q_lens, _, seq_lens = generate_seq_lens_prefill( + batch_size, max_q_len, max_kv_len + )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
benchmarks/README.md(1 hunks)benchmarks/routines/attention.py(0 hunks)flashinfer/artifacts.py(3 hunks)include/flashinfer/trtllm/fmha/kernelParams.h(2 hunks)tests/attention/test_trtllm_gen_attention.py(3 hunks)
💤 Files with no reviewable changes (1)
- benchmarks/routines/attention.py
🚧 Files skipped from review as they are similar to previous changes (2)
- flashinfer/artifacts.py
- include/flashinfer/trtllm/fmha/kernelParams.h
🧰 Additional context used
🧬 Code graph analysis (1)
tests/attention/test_trtllm_gen_attention.py (1)
flashinfer/utils.py (1)
get_compute_capability(251-254)
🪛 markdownlint-cli2 (0.18.1)
benchmarks/README.md
19-19: Unordered list indentation
Expected: 4; Actual: 8
(MD007, ul-indent)
🪛 Ruff (0.14.1)
tests/attention/test_trtllm_gen_attention.py
360-360: Unpacked variable in_kv_lens is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🔇 Additional comments (3)
benchmarks/README.md (1)
19-19: LGTM! Documentation correctly reflects deepseek support.The addition of
trtllm_ragged_attention_deepseekto the supported operations forBatchPrefillWithRaggedKVCacheWrapperis accurate and aligns with the PR's objective to fix and enable trtllm-gen prefill with deepseek attention.tests/attention/test_trtllm_gen_attention.py (2)
334-335: Good parameterization for flexible sequence length testing.Adding
max_q_lenandmax_kv_lenparameters allows testing different sequence length scenarios while preserving existing test behavior with sensible defaults. This enables the new batch_size=1 test to use longer sequences.Also applies to: 348-349, 361-361
530-578: Well-targeted regression test for batch_size=1 with large sequences.This test specifically validates the kernel fix for batch_size==1 by using large sequence lengths (8192+8192=16384 total), which was the failing scenario described in issue #1898. The narrow parameter space is appropriate for a focused regression test.
📌 Description
Current PR fixes the test and benchmark codes IMAs when running trtllm-gen paged & ragged prefill with batch size 1 -- the issue was described in #1898
Root cause of the issue:
flashinfer.prefill.trtllm_ragged_attention_deepseekandflashinfer.prefill.trtllm_batch_context_with_kv_cacheboth requiremax_q_lento match the length of the query when batch size is 1.Updated PR:
Issue has been addressed from the kernel-side so that the "
max_q_lento match the length of the query when batch size is 1" is no longer required.Current PR updates trtllm-gen FMHA cubins to latest and brings minor updates to kernel metadata.
Unit test results after PR:
Description of previous solution:
Updatingmax_q_lentocum_seq_lens_q[-1].item()within thetrtllm_ragged_attention_deepseekortrtllm_batch_context_with_kv_cachefunctions are not a viable option because the CPU-side synchronization breaks the deterministic and fully device-side execution required during CUDA graph capture. The workaround was thus to update the test & benchmark codes that call the trtllm prefill functions, and clearly state in the docstring that when batch_size == 1, max_q_len must match the query size.🔍 Related Issues
#1898
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Bug Fixes
New Features
Documentation
Tests