Skip to content

Commit accdc20

Browse files
committed
update and pass existing tests
1 parent 18004a8 commit accdc20

File tree

3 files changed

+65
-17
lines changed

3 files changed

+65
-17
lines changed

csrc/trtllm_fmha_kernel_launcher.cu

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include <sstream>
2727
#include <unordered_map>
2828

29+
#include "tvm/ffi/error.h"
2930
#include "tvm_ffi_utils.h"
3031

3132
using tvm::ffi::Optional;
@@ -163,6 +164,9 @@ void trtllm_paged_attention_launcher(
163164
use_multi_block ? TileScheduler::Static : TileScheduler::Persistent;
164165
runner_params.mMultiCtasKvMode = use_multi_block;
165166

167+
runner_params.cumSeqLensQPtr = cum_seq_lens_q;
168+
runner_params.cumSeqLensKvPtr = cum_seq_lens_kv;
169+
166170
size_t max_batch_size = 8192; // todo(Yingyi): get from dlfw
167171
size_t max_num_qo_heads = 256; // todo(Yingyi): get from dlfw, in total 8MB
168172
size_t num_semaphores =
@@ -213,22 +217,49 @@ void trtllm_paged_attention_decode(
213217
TensorView seq_lens, int64_t max_kv_len, Variant<double, ffi::Tensor> bmm1_scale,
214218
Variant<double, ffi::Tensor> bmm2_scale, double o_sf_scale, int64_t o_sf_vec_size,
215219
int64_t o_sf_start_index, int64_t window_left, int64_t sparse_mla_top_k, int64_t sm_count,
216-
bool enable_pdl, int64_t workspace_size, Optional<TensorView> attention_sinks) {
220+
bool enable_pdl, int64_t workspace_size, Optional<TensorView> attention_sinks,
221+
Optional<int64_t> optional_max_q_len,
222+
Optional<TensorView> cum_seq_lens_q,
223+
Optional<TensorView> cum_seq_lens_kv
224+
) {
217225
auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype());
218226
auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype());
219227
TVM_FFI_ICHECK_EQ(key_cache.ndim(), value_cache.ndim());
220228
for (int i = 0; i < key_cache.ndim(); i++) {
221229
TVM_FFI_ICHECK_EQ(key_cache.size(i), value_cache.size(i));
222230
}
223231
auto o_data_type = dl_dtype_to_tllm_data_type(out.dtype());
224-
// NOTE(Zihao): query is [B, Q, H, D]
225-
// where Q is the number of query tokens per request, used in MTP
226-
// based on profiled results, always use decode mode for MTP (q_len is small)
227-
// example: when kv_len = 10000, q < 200, decode mode is faster
228-
int batch_size = query.size(0);
229-
int q_len_per_request = query.size(1);
230-
int sum_seq_q = batch_size * q_len_per_request;
231-
int num_qo_heads = query.size(2);
232+
int batch_size;
233+
int max_q_len;
234+
int sum_seq_q;
235+
int num_qo_heads;
236+
int* cum_seq_lens_q_ptr = nullptr;
237+
int* cum_seq_lens_kv_ptr = nullptr;
238+
if (!optional_max_q_len.has_value()) {
239+
// each request has the same length
240+
241+
// NOTE(Zihao): query is [B, Q, H, D]
242+
// where Q is the number of query tokens per request, used in MTP
243+
// based on profiled results, always use decode mode for MTP (q_len is small)
244+
// example: when kv_len = 10000, q < 200, decode mode is faster
245+
int q_len_per_request = query.size(1);
246+
batch_size = query.size(0);
247+
sum_seq_q = batch_size * q_len_per_request;
248+
num_qo_heads = query.size(2);
249+
max_q_len = q_len_per_request;
250+
} else {
251+
// each request has different length
252+
TVM_FFI_CHECK(cum_seq_lens_q.has_value(), "cum_seq_lens_q must be provided when max_q_len is provided");
253+
TVM_FFI_CHECK(cum_seq_lens_kv.has_value(), "cum_seq_lens_kv must be provided when max_q_len is provided");
254+
// the shape of query: [sum_seq_q, num_qo_heads, head_dim_q]
255+
// the shape of cum_seq_lens_q: [batch_size + 1]
256+
batch_size = cum_seq_lens_q.value().size(0) - 1;
257+
sum_seq_q = query.size(0);
258+
num_qo_heads = query.size(1);
259+
max_q_len = optional_max_q_len.value();
260+
cum_seq_lens_q_ptr = static_cast<int*>(cum_seq_lens_q.value().data_ptr());
261+
cum_seq_lens_kv_ptr = static_cast<int*>(cum_seq_lens_kv.value().data_ptr());
262+
}
232263
// Multiply by two for FP4 tensor as it is stored as UINT8 dtype. Assume the dim is even.
233264
int head_dim_k = is_4bit(kv_data_type) ? key_cache.size(-1) * 2 : key_cache.size(-1);
234265
int head_dim_q = is_4bit(q_data_type) ? query.size(-1) * 2 : query.size(-1);
@@ -285,9 +316,9 @@ void trtllm_paged_attention_decode(
285316
out.data_ptr(), output_sf_ptr, query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(),
286317
workspace_buffer.data_ptr(), static_cast<int*>(block_tables.data_ptr()),
287318
static_cast<int*>(seq_lens.data_ptr()),
288-
/*cum_seq_lens_q=*/nullptr,
289-
/*cum_seq_lens_kv=*/nullptr, attention_sinks_ptr, q_data_type, kv_data_type, o_data_type,
290-
TllmPagedAttentionMode::ForGen, batch_size, /*max_q_len=*/q_len_per_request, max_kv_len,
319+
cum_seq_lens_q_ptr,
320+
cum_seq_lens_kv_ptr, attention_sinks_ptr, q_data_type, kv_data_type, o_data_type,
321+
TllmPagedAttentionMode::ForGen, batch_size, max_q_len, max_kv_len,
291322
num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size,
292323
kv_stride_keys_values, kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq,
293324
bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale,

flashinfer/decode.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1927,6 +1927,9 @@ def _paged_run(
19271927
enable_pdl,
19281928
workspace_size,
19291929
sinks,
1930+
None, # max_q_len
1931+
None, # cum_seq_lens_q
1932+
None # cum_seq_lens_kv
19301933
)
19311934
return out
19321935

@@ -2073,7 +2076,7 @@ def trtllm_batch_decode_with_kv_cache(
20732076
workspace_buffer: torch.Tensor,
20742077
block_tables: torch.Tensor,
20752078
seq_lens: torch.Tensor,
2076-
max_seq_len: int,
2079+
max_kv_len: int,
20772080
bmm1_scale: Union[float, torch.Tensor] = 1.0,
20782081
bmm2_scale: Union[float, torch.Tensor] = 1.0,
20792082
window_left: int = -1,
@@ -2088,12 +2091,15 @@ def trtllm_batch_decode_with_kv_cache(
20882091
q_len_per_req: Optional[int] = 1,
20892092
o_scale: Optional[float] = 1.0,
20902093
mask: Optional[torch.Tensor] = None,
2094+
max_q_len: Optional[int] = None,
2095+
cum_seq_lens_q: Optional[torch.Tensor] = None,
2096+
cum_seq_lens_kv: Optional[torch.Tensor] = None,
20912097
) -> Union[torch.Tensor, FP4Tensor]:
20922098
"""
20932099
Parameters
20942100
----------
20952101
query : torch.Tensor
2096-
query tensor with shape [num_tokens, num_heads, head_dim], num_tokens = batch_size * q_len_per_request
2102+
query tensor with shape [num_tokens, num_heads, head_dim], num_tokens = total query tokens in the batch.
20972103
20982104
kv_cache : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
20992105
If kv_cache is a single tensor, it should be a tensor with shape [num_pages, 1 or 2, num_kv_heads, page_size, head_dim] if :attr:`kv_layout` is ``HND``,
@@ -2192,6 +2198,10 @@ def trtllm_batch_decode_with_kv_cache(
21922198
raise ValueError("xqa backend does not support nvfp4 output")
21932199
if o_sf_scale is not None or o_sf_vec_size is not None:
21942200
raise ValueError("xqa backend does not support o_sf_scale or o_sf_vec_size")
2201+
if max_q_len is not None or cum_seq_lens_q is not None or cum_seq_lens_kv is not None:
2202+
raise ValueError(
2203+
"xqa backend does not support cum_seq_lens_q or cum_seq_lens_kv"
2204+
)
21952205

21962206
# Handle out and out_dtype
21972207
if out_dtype is None:
@@ -2206,7 +2216,7 @@ def trtllm_batch_decode_with_kv_cache(
22062216
workspace_buffer=workspace_buffer,
22072217
block_tables=block_tables,
22082218
seq_lens=seq_lens,
2209-
max_seq_len=max_seq_len,
2219+
max_seq_len=max_kv_len,
22102220
bmm1_scale=bmm1_scale,
22112221
bmm2_scale=bmm2_scale,
22122222
window_left=window_left,
@@ -2316,13 +2326,13 @@ def trtllm_batch_decode_with_kv_cache(
23162326
q_len_per_req,
23172327
query.size(1),
23182328
query.size(2),
2319-
),
2329+
) if q_len_per_req is not None else query,
23202330
k_cache,
23212331
v_cache,
23222332
workspace_buffer,
23232333
block_tables,
23242334
seq_lens,
2325-
max_seq_len,
2335+
max_kv_len,
23262336
bmm1_scale,
23272337
bmm2_scale,
23282338
o_sf_scale or -1.0,
@@ -2334,6 +2344,9 @@ def trtllm_batch_decode_with_kv_cache(
23342344
enable_pdl,
23352345
workspace_buffer.numel() * workspace_buffer.element_size(),
23362346
sinks,
2347+
max_q_len,
2348+
cum_seq_lens_q,
2349+
cum_seq_lens_kv,
23372350
)
23382351

23392352
return (

tests/attention/test_trtllm_gen_attention.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1434,3 +1434,7 @@ def test_trtllm_gen_prefill_deepseek_bs1(
14341434
test_trtllm_gen_prefill_deepseek(
14351435
batch_size, s_qo, s_kv, num_kv_heads, head_grp_size, causal
14361436
)
1437+
1438+
1439+
if __name__ == "__main__":
1440+
pytest.main([__file__])

0 commit comments

Comments
 (0)