You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
fix: Fix trtllm-gen prefill IMA when batch_size==1 (#1912)
<!-- .github/pull_request_template.md -->
## 📌 Description
<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->
Current PR fixes the test and benchmark codes IMAs when running
trtllm-gen paged & ragged prefill with batch size 1 -- the issue was
described in #1898
Root cause of the issue:
`flashinfer.prefill.trtllm_ragged_attention_deepseek` and
`flashinfer.prefill.trtllm_batch_context_with_kv_cache` both require
`max_q_len` to match the length of the query when batch size is 1.
**Updated PR:**
Issue has been addressed from the kernel-side so that the "*`max_q_len`
to match the length of the query when batch size is 1*" is no longer
required.
Current PR updates trtllm-gen FMHA cubins to latest and brings minor
updates to kernel metadata.
Unit test results after PR:
```
$ pytest tests/attention/test_trtllm_gen_attention.py
...
platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0
rootdir: /flashinfer
configfile: pytest.ini
collected 2320 items
...
2055 passed, 264 skipped, 1 xfailed in 224.43s (0:03:44)
```
**Description of previous solution:**
~~Updating `max_q_len` to `cum_seq_lens_q[-1].item()` within the
`trtllm_ragged_attention_deepseek` or
`trtllm_batch_context_with_kv_cache` functions are not a viable option
because the CPU-side synchronization breaks the deterministic and fully
device-side execution required during CUDA graph capture. The workaround
was thus to update the test & benchmark codes that call the trtllm
prefill functions, and clearly state in the docstring that when
batch_size == 1, max_q_len must match the query size.~~
## 🔍 Related Issues
#1898
<!-- Link any related issues here -->
## 🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.
### ✅ Pre-commit Checks
- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.
> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).
## 🧪 Tests
- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).
## Reviewer Notes
<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **Bug Fixes**
* Removed the automatic batch_size=1 restriction for a native backend,
enabling its use in more scenarios while other constraints remain.
* **New Features**
* Added configurable block-sparse attention support to kernel
parameters.
* **Documentation**
* Clarified supported attention optimizations and backend capabilities
in the benchmarks docs.
* **Tests**
* Expanded tests with configurable sequence lengths and added dedicated
batch-size-1 test coverage.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
---------
Co-authored-by: Zihao Ye <[email protected]>
0 commit comments