Skip to content

Commit e2df126

Browse files
committed
wip
wip
1 parent 3a23405 commit e2df126

File tree

4 files changed

+121
-53
lines changed

4 files changed

+121
-53
lines changed

benchmarks/routines/attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,7 @@ def run_backend_wrapper(backend):
526526
workspace_buffer=workspace_buffer,
527527
block_tables=block_tables,
528528
seq_lens=actual_seq_lens_kv,
529-
max_seq_len=s_kv,
529+
max_kv_len=s_kv,
530530
bmm1_scale=scale if k_scale is None else k_scale * scale,
531531
bmm2_scale=1.0 if v_scale is None else v_scale,
532532
)

csrc/trtllm_fmha_kernel_launcher.cu

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,9 @@ void trtllm_paged_attention_launcher(
157157
use_multi_block ? TileScheduler::Static : TileScheduler::Persistent;
158158
runner_params.mMultiCtasKvMode = use_multi_block;
159159

160+
runner_params.cumSeqLensQPtr = cum_seq_lens_q;
161+
runner_params.cumSeqLensKvPtr = cum_seq_lens_kv;
162+
160163
size_t max_batch_size = 8192; // todo(Yingyi): get from dlfw
161164
size_t max_num_qo_heads = 256; // todo(Yingyi): get from dlfw, in total 8MB
162165
size_t num_semaphores =
@@ -204,26 +207,26 @@ inline bool is_4bit(Data_type data_type) { return data_type == Data_type::DATA_T
204207
void trtllm_paged_attention_decode(TensorView out, Optional<TensorView> out_scale_factor,
205208
TensorView query, TensorView key_cache, TensorView value_cache,
206209
TensorView workspace_buffer, TensorView block_tables,
207-
TensorView seq_lens, int64_t max_kv_len,
208-
Variant<double, ffi::Tensor> bmm1_scale,
209-
Variant<double, ffi::Tensor> bmm2_scale, double o_sf_scale,
210-
int64_t o_sf_vec_size, int64_t o_sf_start_index,
211-
int64_t window_left, int64_t sm_count, bool enable_pdl,
212-
int64_t workspace_size, Optional<TensorView> attention_sinks) {
210+
TensorView seq_lens, int64_t max_q_len, int64_t max_kv_len,
211+
Variant<double, ffi::Tensor> bmm1_scale, Variant<double, ffi::Tensor> bmm2_scale,
212+
double o_sf_scale, int64_t o_sf_vec_size, int64_t o_sf_start_index, int64_t batch_size,
213+
int64_t window_left, TensorView cum_seq_lens_q, TensorView cum_seq_lens_kv, int64_t sm_count,
214+
bool enable_pdl, int64_t workspace_size, Optional<TensorView> attention_sinks) {
213215
auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype());
214216
auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype());
215217
TVM_FFI_ICHECK_EQ(key_cache.ndim(), value_cache.ndim());
216218
for (int i = 0; i < key_cache.ndim(); i++) {
217219
TVM_FFI_ICHECK_EQ(key_cache.size(i), value_cache.size(i));
218220
}
219221
auto o_data_type = dl_dtype_to_tllm_data_type(out.dtype());
220-
// NOTE(Zihao): query is [B, Q, H, D]
221-
// where Q is the number of query tokens per request, used in MTP
222+
// NOTE(Zihao): query is [S, H, D]
223+
// where S is the sum of query tokens for all requests, used in MTP
222224
// based on profiled results, always use decode mode for MTP (q_len is small)
223225
// example: when kv_len = 10000, q < 200, decode mode is faster
224-
int batch_size = query.size(0);
225-
int q_len_per_request = query.size(1);
226-
int sum_seq_q = batch_size * q_len_per_request;
226+
// int batch_size = query.size(0);
227+
// int q_len_per_request = query.size(1);
228+
// int sum_seq_q = batch_size * q_len_per_request;
229+
int sum_seq_q = query.size(0);
227230
int num_qo_heads = query.size(2);
228231
// Multiply by two for FP4 tensor as it is stored as UINT8 dtype. Assume the dim is even.
229232
int head_dim_k = is_4bit(kv_data_type) ? key_cache.size(-1) * 2 : key_cache.size(-1);
@@ -281,14 +284,14 @@ void trtllm_paged_attention_decode(TensorView out, Optional<TensorView> out_scal
281284
out.data_ptr(), output_sf_ptr, query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(),
282285
workspace_buffer.data_ptr(), static_cast<int*>(block_tables.data_ptr()),
283286
static_cast<int*>(seq_lens.data_ptr()),
284-
/*cum_seq_lens_q=*/nullptr,
285-
/*cum_seq_lens_kv=*/nullptr, attention_sinks_ptr, q_data_type, kv_data_type, o_data_type,
286-
TllmPagedAttentionMode::ForGen, batch_size, /*max_q_len=*/q_len_per_request, max_kv_len,
287-
num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size,
288-
kv_stride_keys_values, kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq,
289-
bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale,
290-
o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, sm_count, enable_pdl, workspace_size,
291-
stream);
287+
/*cum_seq_lens_q=*/static_cast<int*>(cum_seq_lens_q.data_ptr()),
288+
/*cum_seq_lens_kv=*/static_cast<int*>(cum_seq_lens_kv.data_ptr()), attention_sinks_ptr,
289+
q_data_type, kv_data_type, o_data_type, TllmPagedAttentionMode::ForGen, batch_size,
290+
max_q_len, max_kv_len, num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q,
291+
head_dim_o, page_size, kv_stride_keys_values, kv_stride_heads, kv_stride_batch,
292+
max_num_blocks_per_seq, bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr,
293+
bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, sm_count,
294+
enable_pdl, workspace_size, stream);
292295
}
293296

294297
void trtllm_paged_attention_context(

flashinfer/decode.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2065,26 +2065,30 @@ def trtllm_batch_decode_with_kv_cache(
20652065
workspace_buffer: torch.Tensor,
20662066
block_tables: torch.Tensor,
20672067
seq_lens: torch.Tensor,
2068-
max_seq_len: int,
2069-
bmm1_scale: Union[float, torch.Tensor] = 1.0,
2070-
bmm2_scale: Union[float, torch.Tensor] = 1.0,
2068+
max_q_len: int,
2069+
max_kv_len: int,
2070+
bmm1_scale: Union[float, torch.Tensor],
2071+
bmm2_scale: Union[float, torch.Tensor],
2072+
batch_size: int,
2073+
cum_seq_lens_q: torch.Tensor,
2074+
cum_seq_lens_kv: torch.Tensor,
20712075
window_left: int = -1,
20722076
out: Optional[Union[torch.Tensor, FP4Tensor]] = None,
20732077
out_dtype: Optional[Union[torch.dtype, str]] = None,
20742078
o_sf_scale: Optional[float] = None,
20752079
o_sf_vec_size: Optional[int] = None,
2076-
sinks: Optional[List[torch.Tensor]] = None,
20772080
kv_layout: str = "HND",
20782081
enable_pdl: Optional[bool] = None,
2082+
sinks: Optional[List[torch.Tensor]] = None,
20792083
backend: str = "auto",
2080-
q_len_per_req: Optional[int] = 1,
2084+
# the following args are xqa-specific
20812085
o_scale: Optional[float] = 1.0,
20822086
) -> Union[torch.Tensor, FP4Tensor]:
20832087
"""
20842088
Parameters
20852089
----------
20862090
query : torch.Tensor
2087-
query tensor with shape [num_tokens, num_heads, head_dim], num_tokens = batch_size * q_len_per_request
2091+
query tensor with shape [num_tokens, num_heads, head_dim], num_tokens = total query tokens in the batch.
20882092
20892093
kv_cache : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
20902094
If kv_cache is a single tensor, it should be a tensor with shape [num_pages, 1 or 2, num_kv_heads, page_size, head_dim] if :attr:`kv_layout` is ``HND``,
@@ -2185,6 +2189,8 @@ def trtllm_batch_decode_with_kv_cache(
21852189
raise ValueError("xqa backend does not support nvfp4 output")
21862190
if o_sf_scale is not None or o_sf_vec_size is not None:
21872191
raise ValueError("xqa backend does not support o_sf_scale or o_sf_vec_size")
2192+
if max_q_len != 1:
2193+
raise ValueError("xqa backend only supports max_q_len == 1")
21882194

21892195
# Handle out and out_dtype
21902196
if out_dtype is None:
@@ -2199,15 +2205,15 @@ def trtllm_batch_decode_with_kv_cache(
21992205
workspace_buffer=workspace_buffer,
22002206
block_tables=block_tables,
22012207
seq_lens=seq_lens,
2202-
max_seq_len=max_seq_len,
2208+
max_seq_len=max_kv_len,
22032209
bmm1_scale=bmm1_scale,
22042210
bmm2_scale=bmm2_scale,
22052211
window_left=window_left,
22062212
out=out,
22072213
sinks=sinks,
22082214
kv_layout=kv_layout,
22092215
enable_pdl=enable_pdl,
2210-
q_len_per_req=q_len_per_req,
2216+
q_len_per_req=1,
22112217
o_scale=o_scale,
22122218
)
22132219
elif backend == "trtllm-gen":
@@ -2299,31 +2305,30 @@ def trtllm_batch_decode_with_kv_cache(
22992305
bmm1_scale = bmm1_scale * log2e
23002306
if isinstance(bmm2_scale, torch.Tensor):
23012307
assert bmm2_scale.dtype == torch.float32
2302-
2308+
workspace_size = workspace_buffer.numel() * workspace_buffer.element_size()
23032309
run_func(
23042310
out,
23052311
out_scale_factor,
2306-
query.view(
2307-
query.size(0) // q_len_per_req,
2308-
q_len_per_req,
2309-
query.size(1),
2310-
query.size(2),
2311-
),
2312+
query,
23122313
k_cache,
23132314
v_cache,
23142315
workspace_buffer,
23152316
block_tables,
23162317
seq_lens,
2317-
max_seq_len,
2318+
max_q_len,
2319+
max_kv_len,
23182320
bmm1_scale,
23192321
bmm2_scale,
23202322
o_sf_scale or -1.0,
23212323
o_sf_vec_size or -1,
23222324
o_sf_start_index,
2325+
batch_size,
23232326
window_left,
2327+
cum_seq_lens_q,
2328+
cum_seq_lens_kv,
23242329
sm_count,
23252330
enable_pdl,
2326-
workspace_buffer.numel() * workspace_buffer.element_size(),
2331+
workspace_size,
23272332
sinks,
23282333
)
23292334

tests/attention/test_trtllm_gen_attention.py

Lines changed: 76 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,12 @@ def generate_seq_lens_prefill(batch_size, max_q_len, max_in_kv_len):
5454
return q_lens, in_kv_lens, seq_lens
5555

5656

57-
def generate_seq_lens_decode(batch_size, q_len_per_req, max_in_kv_len):
58-
q_lens = torch.full((batch_size,), q_len_per_req, dtype=torch.int32)
59-
in_kv_lens = torch.randint(0, max_in_kv_len + 1, (batch_size,), dtype=torch.int)
60-
in_kv_lens[-1] = max_in_kv_len
61-
seq_lens = q_lens + in_kv_lens
62-
return q_lens, in_kv_lens, seq_lens
57+
def generate_seq_lens_decode(batch_size, max_q_len, max_kv_len):
58+
q_lens = torch.full((batch_size,), max_q_len, dtype=torch.int32)
59+
kv_lens = torch.randint(0, max_kv_len + 1, (batch_size,), dtype=torch.int)
60+
kv_lens[-1] = max_kv_len
61+
seq_lens = q_lens + kv_lens
62+
return q_lens, kv_lens, seq_lens
6363

6464

6565
def generate_cumsum_lens(lens):
@@ -667,7 +667,6 @@ def _test_trtllm_batch_decode(
667667
backend,
668668
kv_layout,
669669
batch_size,
670-
q_len_per_req,
671670
page_size,
672671
num_kv_heads,
673672
head_grp_size,
@@ -677,9 +676,10 @@ def _test_trtllm_batch_decode(
677676
kv_dtype,
678677
enable_pdl,
679678
enable_sink,
680-
max_in_kv_len,
681679
head_dim,
682-
device_scale=False,
680+
max_q_len,
681+
max_kv_len,
682+
device_scale,
683683
):
684684
"""
685685
Common function for testing trtllm-gen decode.
@@ -702,12 +702,12 @@ def _test_trtllm_batch_decode(
702702
pytest.skip("xqa backend only supports fp16 and bf16 query")
703703

704704
# xqa backend doesn't support speculative decoding yet
705-
if backend == "xqa" and q_len_per_req > 1:
705+
if backend == "xqa" and max_q_len > 1:
706706
pytest.skip(
707707
"xqa backend does not support speculative decoding (q_len_per_req > 1) yet"
708708
)
709709

710-
if o_dtype == "nvfp4" and q_len_per_req > 1:
710+
if o_dtype == "nvfp4" and max_q_len > 1:
711711
# todo(Yingyi): add support for nvfp4 with speculative decoding
712712
pytest.skip("nvfp4 is not supported for q_len_per_req > 1")
713713

@@ -719,8 +719,8 @@ def _test_trtllm_batch_decode(
719719

720720
# Generate random sequence lengths
721721
num_qo_heads = num_kv_heads * head_grp_size
722-
q_lens, in_kv_lens, seq_lens = generate_seq_lens_decode(
723-
batch_size, q_len_per_req, max_in_kv_len
722+
q_lens, kv_lens, seq_lens = generate_seq_lens_decode(
723+
batch_size, max_q_len, max_kv_len
724724
)
725725

726726
# Create query tensor and related data
@@ -775,7 +775,7 @@ def _test_trtllm_batch_decode(
775775
"window_left": window_left,
776776
}
777777
if not enable_sink:
778-
if q_len_per_req == 1:
778+
if max_q_len == 1:
779779
wrapper_ref = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper(
780780
workspace_buffer_ref, kv_layout, use_tensor_cores=True
781781
)
@@ -843,6 +843,8 @@ def _test_trtllm_batch_decode(
843843
workspace_buffer,
844844
page_table,
845845
seq_lens.to(GPU_DEVICE),
846+
max_q_len,
847+
max_kv_len,
846848
torch.max(seq_lens).item(),
847849
bmm1_scale,
848850
bmm2_scale,
@@ -855,7 +857,6 @@ def _test_trtllm_batch_decode(
855857
kv_layout=kv_layout,
856858
enable_pdl=enable_pdl,
857859
backend=backend,
858-
q_len_per_req=q_len_per_req,
859860
o_scale=o_scale,
860861
)
861862
if backend == "trtllm-gen":
@@ -882,7 +883,7 @@ def _test_trtllm_batch_decode(
882883

883884
# convert to float32 for fp8 is not supported by assert_close
884885
# relax rtol and atol for speculative decoding test
885-
if q_len_per_req > 1:
886+
if max_q_len > 1:
886887
rtol, atol = rtol * 2, atol * 2
887888

888889
# Arbitary small mismatch rate
@@ -1224,6 +1225,65 @@ def test_trtllm_batch_decode_long_sequence_length(
12241225
)
12251226

12261227

1228+
@pytest.mark.parametrize("kv_layout", ["HND"]) # trtllm-gen only support HND
1229+
@pytest.mark.parametrize(
1230+
"batch_size,q_len_per_req,page_size,num_kv_heads,head_grp_size",
1231+
[
1232+
(1, 1, 16, 8, 8),
1233+
(1, 1, 32, 8, 8),
1234+
],
1235+
)
1236+
@pytest.mark.parametrize("window_left", [-1])
1237+
@pytest.mark.parametrize(
1238+
"q_dtype,kv_dtype,o_dtype",
1239+
[
1240+
("fp8", "fp8", "fp8"),
1241+
],
1242+
)
1243+
@pytest.mark.parametrize("enable_pdl", [None])
1244+
@pytest.mark.parametrize("enable_sink", [False])
1245+
@pytest.mark.parametrize("max_in_kv_len", [4096, 8192])
1246+
@pytest.mark.parametrize("head_dim", [128])
1247+
@pytest.mark.parametrize("device_scale", [True, False])
1248+
def test_trtllm_batch_decode_spec(
1249+
kv_layout,
1250+
batch_size,
1251+
q_len_per_req,
1252+
page_size,
1253+
num_kv_heads,
1254+
head_grp_size,
1255+
window_left,
1256+
q_dtype,
1257+
o_dtype,
1258+
kv_dtype,
1259+
enable_pdl,
1260+
enable_sink,
1261+
max_in_kv_len,
1262+
head_dim,
1263+
device_scale,
1264+
):
1265+
# Small number of test cases for batch size 1
1266+
_test_trtllm_batch_decode(
1267+
"trtllm-gen",
1268+
kv_layout,
1269+
batch_size,
1270+
q_len_per_req,
1271+
page_size,
1272+
num_kv_heads,
1273+
head_grp_size,
1274+
window_left,
1275+
q_dtype,
1276+
o_dtype,
1277+
kv_dtype,
1278+
enable_pdl,
1279+
enable_sink,
1280+
max_in_kv_len,
1281+
head_dim,
1282+
device_scale,
1283+
)
1284+
1285+
1286+
12271287
@pytest.mark.parametrize("batch_size", [4, 128, 256])
12281288
@pytest.mark.parametrize("s_qo", [32, 64, 87])
12291289
@pytest.mark.parametrize("s_kv", [32, 64, 87])

0 commit comments

Comments
 (0)