diff --git a/benchmarks/README.md b/benchmarks/README.md index 6b1a30c9d1..f41d695cdc 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -16,6 +16,7 @@ Currently supports testing most attention, gemm, and fused MOE APIs: - `BatchPrefillWithPagedKVCacheWrapper` - Prefill attention with paged KV cache. - Also supports computationally similar `cudnn_batch_prefill_with_kv_cache` and `trtllm_batch_context_with_kv_cache`. - `BatchPrefillWithRaggedKVCacheWrapper` - Prefill attention with ragged KV cache. + - Also supports computationally similar `cudnn_batch_prefill_with_kv_cache` and `trtllm_ragged_attention_deepseek`. - `BatchMLAPagedAttentionWrapper` - MLA attention proposed in DeepSeek series of models. - Also supports computationally similar `trtllm_batch_decode_with_kv_cache_mla`. - GEMM: diff --git a/benchmarks/routines/attention.py b/benchmarks/routines/attention.py index a75dea0928..acdf9ce7ab 100644 --- a/benchmarks/routines/attention.py +++ b/benchmarks/routines/attention.py @@ -696,10 +696,6 @@ def testBatchPrefillWithPagedKVCacheWrapper(args): backends.remove("trtllm-gen") if "trtllm-gen-native" in backends: remove_trtllm_native = False - if batch_size == 1: - # TO-DO: trtllm-gen-native hits IMA on batch size 1. Investigate and fix. - print("[INFO] trtllm-gen-native backend currently requires batch size > 1") - remove_trtllm_native = True if not causal: print("[INFO] trtllm-gen-native backend currently requires causal = True") remove_trtllm_native = True @@ -1184,10 +1180,6 @@ def testBatchPrefillWithRaggedKVCacheWrapper(args): ]: print("[INFO] trtllm-gen-native backend does not support FP8. Skipping.") remove_trtllm_native = True - if batch_size == 1: - # TO-DO: trtllm-gen-native hits IMA on batch size 1. Investigate and fix. - print("[INFO] trtllm-gen-native backend currently requires batch size > 1") - remove_trtllm_native = True if not (head_dim_qk == 192 and head_dim_vo == 128): print( "[INFO] trtllm-gen-native backend requires head_dim_qk == 192 and head_dim_vo == 128" diff --git a/flashinfer/artifacts.py b/flashinfer/artifacts.py index 1b5cde7542..89b458e8d0 100644 --- a/flashinfer/artifacts.py +++ b/flashinfer/artifacts.py @@ -80,7 +80,7 @@ def get_available_cubin_files( class ArtifactPath: - TRTLLM_GEN_FMHA: str = "a72d85b019dc125b9f711300cb989430f762f5a6/fmha/trtllm-gen/" + TRTLLM_GEN_FMHA: str = "463def7494c9fc6792b5aa5b5beef34025e247ac/fmha/trtllm-gen/" TRTLLM_GEN_BMM: str = ( "56fea80cb22f8b2ef2a2c6a822a075fb20b36803/batched_gemm-074aec4-cc00b23" ) @@ -95,7 +95,7 @@ class ArtifactPath: class MetaInfoHash: DEEPGEMM: str = "b4374f857c3066089c4ec6b5e79e785559fa2c05ce2623710b0b04bf86414a48" TRTLLM_GEN_FMHA: str = ( - "d26dbf837f40ff2dcd964094ab6e1b3f2424edda5979c313f5262655161fce98" + "2b8a485f2af84768bc769e678eb6014a8181ad95a7ea9e699de5efca4b18ec6a" ) TRTLLM_GEN_BMM: str = ( "4a8ceeb356fc5339021acf884061e97e49e01da5c75dbf0f7cf4932c37a70152" @@ -107,7 +107,7 @@ class MetaInfoHash: class CheckSumHash: TRTLLM_GEN_FMHA: str = ( - "b2d9d40db550ef85585e980bee651ac19d3e416f10b0c8bf9de0a7f9d0bee3d4" + "639c534614e9fdf5a9cfa91f7ea8f53989613019c0e1f8b755f461e1fcc7546f" ) TRTLLM_GEN_BMM: str = ( "8df2aae8f3aa39d64d2c723e775640beb4ac602a6cbb02e497c2a7316e349934" diff --git a/include/flashinfer/trtllm/fmha/kernelParams.h b/include/flashinfer/trtllm/fmha/kernelParams.h index 0d592c63e0..57adc57914 100644 --- a/include/flashinfer/trtllm/fmha/kernelParams.h +++ b/include/flashinfer/trtllm/fmha/kernelParams.h @@ -152,6 +152,8 @@ struct KernelParams { int32_t mStartTokenIdxSfO; // The sum of sequence lengths for Q and K/V. int32_t mSumOfSeqLensQ, mSumOfSeqLensKv; + // The flag to use block sparse attention. + bool mUseBlockSparseAttention; // Create the TMA shape/stride for Q. template @@ -699,6 +701,8 @@ struct KernelParams { params.mStartTokenIdxSfO = options.mSfStartTokenIdx; params.mScaleSfKv = options.mScaleSfKv; params.ptrSoftmaxStats = options.softmaxStatsPtr; + // TODO: Integrate trtllm block-sparse attention kernels when needed. + params.mUseBlockSparseAttention = false; return params; } }; diff --git a/tests/attention/test_trtllm_gen_attention.py b/tests/attention/test_trtllm_gen_attention.py index 80853c7dbf..6bd2065b3d 100755 --- a/tests/attention/test_trtllm_gen_attention.py +++ b/tests/attention/test_trtllm_gen_attention.py @@ -331,6 +331,8 @@ def unpack_compare_nvfp4( ) @pytest.mark.parametrize("enable_pdl", [True, False, None]) @pytest.mark.parametrize("enable_sink", [True, False]) +@pytest.mark.parametrize("max_q_len", [511]) +@pytest.mark.parametrize("max_kv_len", [2047]) def test_trtllm_batch_prefill( kv_layout, batch_size, @@ -343,6 +345,8 @@ def test_trtllm_batch_prefill( kv_dtype, enable_pdl, enable_sink, + max_q_len, + max_kv_len, ): compute_capability = get_compute_capability(torch.device(device="cuda")) if compute_capability[0] in [11, 12]: @@ -350,13 +354,11 @@ def test_trtllm_batch_prefill( # Set up test parameters torch.manual_seed(0) head_dim = 128 - MAX_Q_LEN = 511 - MAX_IN_KV_LEN = 2047 # Generate random sequence lengths num_qo_heads = num_kv_heads * head_grp_size q_lens, in_kv_lens, seq_lens = generate_seq_lens_prefill( - batch_size, MAX_Q_LEN, MAX_IN_KV_LEN + batch_size, max_q_len, max_kv_len ) # Create query tensor and related data @@ -525,6 +527,56 @@ def test_trtllm_batch_prefill( assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all() +@pytest.mark.parametrize("kv_layout", ["HND"]) # trtllm-gen only support HND +@pytest.mark.parametrize( + "batch_size,page_size,num_kv_heads,head_grp_size", + [ + (1, 16, 8, 8), + ], +) +@pytest.mark.parametrize("window_left", [-1]) # todo(Siyuan): add 127 window_left +@pytest.mark.parametrize( + "q_dtype,kv_dtype,o_dtype", + [ + ("bf16", "bf16", "bf16"), + ], +) +@pytest.mark.parametrize("enable_pdl", [None]) +@pytest.mark.parametrize("enable_sink", [False]) +@pytest.mark.parametrize("max_q_len", [8192]) +@pytest.mark.parametrize("max_kv_len", [8192]) +def test_trtllm_batch_prefill_bs1( + kv_layout, + batch_size, + page_size, + num_kv_heads, + head_grp_size, + window_left, + q_dtype, + o_dtype, + kv_dtype, + enable_pdl, + enable_sink, + max_q_len, + max_kv_len, +): + test_trtllm_batch_prefill( + kv_layout, + batch_size, + page_size, + num_kv_heads, + head_grp_size, + window_left, + q_dtype, + o_dtype, + kv_dtype, + enable_pdl, + enable_sink, + max_q_len, + max_kv_len, + ) + + @pytest.mark.parametrize("kv_layout", ["HND"]) # trtllm-gen only support HND @pytest.mark.parametrize( "batch_size,q_len_per_req,page_size,num_kv_heads,head_grp_size", @@ -998,7 +1050,6 @@ def test_trtllm_gen_prefill_deepseek( def test_trtllm_gen_prefill_deepseek_bs1( batch_size, s_qo, s_kv, num_kv_heads, head_grp_size, causal ): - pytest.xfail("trtllm-gen prefill triggers an IMA with bs1") test_trtllm_gen_prefill_deepseek( batch_size, s_qo, s_kv, num_kv_heads, head_grp_size, causal )