Skip to content

Commit c3f2596

Browse files
bkryuyzh119
andauthored
fix: Fix trtllm-gen prefill IMA when batch_size==1 (#1912)
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> Current PR fixes the test and benchmark codes IMAs when running trtllm-gen paged & ragged prefill with batch size 1 -- the issue was described in #1898 Root cause of the issue: `flashinfer.prefill.trtllm_ragged_attention_deepseek` and `flashinfer.prefill.trtllm_batch_context_with_kv_cache` both require `max_q_len` to match the length of the query when batch size is 1. **Updated PR:** Issue has been addressed from the kernel-side so that the "*`max_q_len` to match the length of the query when batch size is 1*" is no longer required. Current PR updates trtllm-gen FMHA cubins to latest and brings minor updates to kernel metadata. Unit test results after PR: ``` $ pytest tests/attention/test_trtllm_gen_attention.py ... platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0 rootdir: /flashinfer configfile: pytest.ini collected 2320 items ... 2055 passed, 264 skipped, 1 xfailed in 224.43s (0:03:44) ``` **Description of previous solution:** ~~Updating `max_q_len` to `cum_seq_lens_q[-1].item()` within the `trtllm_ragged_attention_deepseek` or `trtllm_batch_context_with_kv_cache` functions are not a viable option because the CPU-side synchronization breaks the deterministic and fully device-side execution required during CUDA graph capture. The workaround was thus to update the test & benchmark codes that call the trtllm prefill functions, and clearly state in the docstring that when batch_size == 1, max_q_len must match the query size.~~ ## 🔍 Related Issues #1898 <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Removed the automatic batch_size=1 restriction for a native backend, enabling its use in more scenarios while other constraints remain. * **New Features** * Added configurable block-sparse attention support to kernel parameters. * **Documentation** * Clarified supported attention optimizations and backend capabilities in the benchmarks docs. * **Tests** * Expanded tests with configurable sequence lengths and added dedicated batch-size-1 test coverage. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Zihao Ye <[email protected]>
1 parent dad0682 commit c3f2596

File tree

5 files changed

+63
-15
lines changed

5 files changed

+63
-15
lines changed

benchmarks/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ Currently supports testing most attention, gemm, and fused MOE APIs:
1616
- `BatchPrefillWithPagedKVCacheWrapper` - Prefill attention with paged KV cache.
1717
- Also supports computationally similar `cudnn_batch_prefill_with_kv_cache` and `trtllm_batch_context_with_kv_cache`.
1818
- `BatchPrefillWithRaggedKVCacheWrapper` - Prefill attention with ragged KV cache.
19+
- Also supports computationally similar `cudnn_batch_prefill_with_kv_cache` and `trtllm_ragged_attention_deepseek`.
1920
- `BatchMLAPagedAttentionWrapper` - MLA attention proposed in DeepSeek series of models.
2021
- Also supports computationally similar `trtllm_batch_decode_with_kv_cache_mla`.
2122
- GEMM:

benchmarks/routines/attention.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -696,10 +696,6 @@ def testBatchPrefillWithPagedKVCacheWrapper(args):
696696
backends.remove("trtllm-gen")
697697
if "trtllm-gen-native" in backends:
698698
remove_trtllm_native = False
699-
if batch_size == 1:
700-
# TO-DO: trtllm-gen-native hits IMA on batch size 1. Investigate and fix.
701-
print("[INFO] trtllm-gen-native backend currently requires batch size > 1")
702-
remove_trtllm_native = True
703699
if not causal:
704700
print("[INFO] trtllm-gen-native backend currently requires causal = True")
705701
remove_trtllm_native = True
@@ -1184,10 +1180,6 @@ def testBatchPrefillWithRaggedKVCacheWrapper(args):
11841180
]:
11851181
print("[INFO] trtllm-gen-native backend does not support FP8. Skipping.")
11861182
remove_trtllm_native = True
1187-
if batch_size == 1:
1188-
# TO-DO: trtllm-gen-native hits IMA on batch size 1. Investigate and fix.
1189-
print("[INFO] trtllm-gen-native backend currently requires batch size > 1")
1190-
remove_trtllm_native = True
11911183
if not (head_dim_qk == 192 and head_dim_vo == 128):
11921184
print(
11931185
"[INFO] trtllm-gen-native backend requires head_dim_qk == 192 and head_dim_vo == 128"

flashinfer/artifacts.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def get_available_cubin_files(
8080

8181

8282
class ArtifactPath:
83-
TRTLLM_GEN_FMHA: str = "a72d85b019dc125b9f711300cb989430f762f5a6/fmha/trtllm-gen/"
83+
TRTLLM_GEN_FMHA: str = "463def7494c9fc6792b5aa5b5beef34025e247ac/fmha/trtllm-gen/"
8484
TRTLLM_GEN_BMM: str = (
8585
"56fea80cb22f8b2ef2a2c6a822a075fb20b36803/batched_gemm-074aec4-cc00b23"
8686
)
@@ -95,7 +95,7 @@ class ArtifactPath:
9595
class MetaInfoHash:
9696
DEEPGEMM: str = "b4374f857c3066089c4ec6b5e79e785559fa2c05ce2623710b0b04bf86414a48"
9797
TRTLLM_GEN_FMHA: str = (
98-
"d26dbf837f40ff2dcd964094ab6e1b3f2424edda5979c313f5262655161fce98"
98+
"2b8a485f2af84768bc769e678eb6014a8181ad95a7ea9e699de5efca4b18ec6a"
9999
)
100100
TRTLLM_GEN_BMM: str = (
101101
"4a8ceeb356fc5339021acf884061e97e49e01da5c75dbf0f7cf4932c37a70152"
@@ -107,7 +107,7 @@ class MetaInfoHash:
107107

108108
class CheckSumHash:
109109
TRTLLM_GEN_FMHA: str = (
110-
"b2d9d40db550ef85585e980bee651ac19d3e416f10b0c8bf9de0a7f9d0bee3d4"
110+
"639c534614e9fdf5a9cfa91f7ea8f53989613019c0e1f8b755f461e1fcc7546f"
111111
)
112112
TRTLLM_GEN_BMM: str = (
113113
"8df2aae8f3aa39d64d2c723e775640beb4ac602a6cbb02e497c2a7316e349934"

include/flashinfer/trtllm/fmha/kernelParams.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,8 @@ struct KernelParams {
152152
int32_t mStartTokenIdxSfO;
153153
// The sum of sequence lengths for Q and K/V.
154154
int32_t mSumOfSeqLensQ, mSumOfSeqLensKv;
155+
// The flag to use block sparse attention.
156+
bool mUseBlockSparseAttention;
155157

156158
// Create the TMA shape/stride for Q.
157159
template <class FmhaOptions>
@@ -699,6 +701,8 @@ struct KernelParams {
699701
params.mStartTokenIdxSfO = options.mSfStartTokenIdx;
700702
params.mScaleSfKv = options.mScaleSfKv;
701703
params.ptrSoftmaxStats = options.softmaxStatsPtr;
704+
// TODO: Integrate trtllm block-sparse attention kernels when needed.
705+
params.mUseBlockSparseAttention = false;
702706
return params;
703707
}
704708
};

tests/attention/test_trtllm_gen_attention.py

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,8 @@ def unpack_compare_nvfp4(
331331
)
332332
@pytest.mark.parametrize("enable_pdl", [True, False, None])
333333
@pytest.mark.parametrize("enable_sink", [True, False])
334+
@pytest.mark.parametrize("max_q_len", [511])
335+
@pytest.mark.parametrize("max_kv_len", [2047])
334336
def test_trtllm_batch_prefill(
335337
kv_layout,
336338
batch_size,
@@ -343,20 +345,20 @@ def test_trtllm_batch_prefill(
343345
kv_dtype,
344346
enable_pdl,
345347
enable_sink,
348+
max_q_len,
349+
max_kv_len,
346350
):
347351
compute_capability = get_compute_capability(torch.device(device="cuda"))
348352
if compute_capability[0] in [11, 12]:
349353
pytest.skip("trtllm-gen does not support SM110/SM120/SM121 GPUs.")
350354
# Set up test parameters
351355
torch.manual_seed(0)
352356
head_dim = 128
353-
MAX_Q_LEN = 511
354-
MAX_IN_KV_LEN = 2047
355357

356358
# Generate random sequence lengths
357359
num_qo_heads = num_kv_heads * head_grp_size
358360
q_lens, in_kv_lens, seq_lens = generate_seq_lens_prefill(
359-
batch_size, MAX_Q_LEN, MAX_IN_KV_LEN
361+
batch_size, max_q_len, max_kv_len
360362
)
361363

362364
# Create query tensor and related data
@@ -525,6 +527,56 @@ def test_trtllm_batch_prefill(
525527
assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all()
526528

527529

530+
@pytest.mark.parametrize("kv_layout", ["HND"]) # trtllm-gen only support HND
531+
@pytest.mark.parametrize(
532+
"batch_size,page_size,num_kv_heads,head_grp_size",
533+
[
534+
(1, 16, 8, 8),
535+
],
536+
)
537+
@pytest.mark.parametrize("window_left", [-1]) # todo(Siyuan): add 127 window_left
538+
@pytest.mark.parametrize(
539+
"q_dtype,kv_dtype,o_dtype",
540+
[
541+
("bf16", "bf16", "bf16"),
542+
],
543+
)
544+
@pytest.mark.parametrize("enable_pdl", [None])
545+
@pytest.mark.parametrize("enable_sink", [False])
546+
@pytest.mark.parametrize("max_q_len", [8192])
547+
@pytest.mark.parametrize("max_kv_len", [8192])
548+
def test_trtllm_batch_prefill_bs1(
549+
kv_layout,
550+
batch_size,
551+
page_size,
552+
num_kv_heads,
553+
head_grp_size,
554+
window_left,
555+
q_dtype,
556+
o_dtype,
557+
kv_dtype,
558+
enable_pdl,
559+
enable_sink,
560+
max_q_len,
561+
max_kv_len,
562+
):
563+
test_trtllm_batch_prefill(
564+
kv_layout,
565+
batch_size,
566+
page_size,
567+
num_kv_heads,
568+
head_grp_size,
569+
window_left,
570+
q_dtype,
571+
o_dtype,
572+
kv_dtype,
573+
enable_pdl,
574+
enable_sink,
575+
max_q_len,
576+
max_kv_len,
577+
)
578+
579+
528580
@pytest.mark.parametrize("kv_layout", ["HND"]) # trtllm-gen only support HND
529581
@pytest.mark.parametrize(
530582
"batch_size,q_len_per_req,page_size,num_kv_heads,head_grp_size",
@@ -998,7 +1050,6 @@ def test_trtllm_gen_prefill_deepseek(
9981050
def test_trtllm_gen_prefill_deepseek_bs1(
9991051
batch_size, s_qo, s_kv, num_kv_heads, head_grp_size, causal
10001052
):
1001-
pytest.xfail("trtllm-gen prefill triggers an IMA with bs1")
10021053
test_trtllm_gen_prefill_deepseek(
10031054
batch_size, s_qo, s_kv, num_kv_heads, head_grp_size, causal
10041055
)

0 commit comments

Comments
 (0)