Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 43 additions & 12 deletions csrc/trtllm_fmha_kernel_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <sstream>
#include <unordered_map>

#include "tvm/ffi/error.h"
#include "tvm_ffi_utils.h"

using tvm::ffi::Optional;
Expand Down Expand Up @@ -163,6 +164,9 @@ void trtllm_paged_attention_launcher(
use_multi_block ? TileScheduler::Static : TileScheduler::Persistent;
runner_params.mMultiCtasKvMode = use_multi_block;

runner_params.cumSeqLensQPtr = cum_seq_lens_q;
runner_params.cumSeqLensKvPtr = cum_seq_lens_kv;

size_t max_batch_size = 8192; // todo(Yingyi): get from dlfw
size_t max_num_qo_heads = 256; // todo(Yingyi): get from dlfw, in total 8MB
size_t num_semaphores =
Expand Down Expand Up @@ -213,22 +217,49 @@ void trtllm_paged_attention_decode(
TensorView seq_lens, int64_t max_kv_len, Variant<double, ffi::Tensor> bmm1_scale,
Variant<double, ffi::Tensor> bmm2_scale, double o_sf_scale, int64_t o_sf_vec_size,
int64_t o_sf_start_index, int64_t window_left, int64_t sparse_mla_top_k, int64_t sm_count,
bool enable_pdl, int64_t workspace_size, Optional<TensorView> attention_sinks) {
bool enable_pdl, int64_t workspace_size, Optional<TensorView> attention_sinks,
Optional<int64_t> optional_max_q_len,
Optional<TensorView> cum_seq_lens_q,
Optional<TensorView> cum_seq_lens_kv
) {
auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype());
auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype());
TVM_FFI_ICHECK_EQ(key_cache.ndim(), value_cache.ndim());
for (int i = 0; i < key_cache.ndim(); i++) {
TVM_FFI_ICHECK_EQ(key_cache.size(i), value_cache.size(i));
}
auto o_data_type = dl_dtype_to_tllm_data_type(out.dtype());
// NOTE(Zihao): query is [B, Q, H, D]
// where Q is the number of query tokens per request, used in MTP
// based on profiled results, always use decode mode for MTP (q_len is small)
// example: when kv_len = 10000, q < 200, decode mode is faster
int batch_size = query.size(0);
int q_len_per_request = query.size(1);
int sum_seq_q = batch_size * q_len_per_request;
int num_qo_heads = query.size(2);
int batch_size;
int max_q_len;
int sum_seq_q;
int num_qo_heads;
int* cum_seq_lens_q_ptr = nullptr;
int* cum_seq_lens_kv_ptr = nullptr;
if (!optional_max_q_len.has_value()) {
// each request has the same length

// NOTE(Zihao): query is [B, Q, H, D]
// where Q is the number of query tokens per request, used in MTP
// based on profiled results, always use decode mode for MTP (q_len is small)
// example: when kv_len = 10000, q < 200, decode mode is faster
int q_len_per_request = query.size(1);
batch_size = query.size(0);
sum_seq_q = batch_size * q_len_per_request;
num_qo_heads = query.size(2);
max_q_len = q_len_per_request;
} else {
// each request has different length
TVM_FFI_CHECK(cum_seq_lens_q.has_value(), "cum_seq_lens_q must be provided when max_q_len is provided");
TVM_FFI_CHECK(cum_seq_lens_kv.has_value(), "cum_seq_lens_kv must be provided when max_q_len is provided");
// the shape of query: [sum_seq_q, num_qo_heads, head_dim_q]
// the shape of cum_seq_lens_q: [batch_size + 1]
batch_size = cum_seq_lens_q.value().size(0) - 1;
sum_seq_q = query.size(0);
num_qo_heads = query.size(1);
max_q_len = optional_max_q_len.value();
cum_seq_lens_q_ptr = static_cast<int*>(cum_seq_lens_q.value().data_ptr());
cum_seq_lens_kv_ptr = static_cast<int*>(cum_seq_lens_kv.value().data_ptr());
}
// Multiply by two for FP4 tensor as it is stored as UINT8 dtype. Assume the dim is even.
int head_dim_k = is_4bit(kv_data_type) ? key_cache.size(-1) * 2 : key_cache.size(-1);
int head_dim_q = is_4bit(q_data_type) ? query.size(-1) * 2 : query.size(-1);
Expand Down Expand Up @@ -285,9 +316,9 @@ void trtllm_paged_attention_decode(
out.data_ptr(), output_sf_ptr, query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(),
workspace_buffer.data_ptr(), static_cast<int*>(block_tables.data_ptr()),
static_cast<int*>(seq_lens.data_ptr()),
/*cum_seq_lens_q=*/nullptr,
/*cum_seq_lens_kv=*/nullptr, attention_sinks_ptr, q_data_type, kv_data_type, o_data_type,
TllmPagedAttentionMode::ForGen, batch_size, /*max_q_len=*/q_len_per_request, max_kv_len,
cum_seq_lens_q_ptr,
cum_seq_lens_kv_ptr, attention_sinks_ptr, q_data_type, kv_data_type, o_data_type,
TllmPagedAttentionMode::ForGen, batch_size, max_q_len, max_kv_len,
num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size,
kv_stride_keys_values, kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq,
bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale,
Expand Down
23 changes: 18 additions & 5 deletions flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -1927,6 +1927,9 @@ def _paged_run(
enable_pdl,
workspace_size,
sinks,
None, # max_q_len
None, # cum_seq_lens_q
None # cum_seq_lens_kv
)
return out

Expand Down Expand Up @@ -2073,7 +2076,7 @@ def trtllm_batch_decode_with_kv_cache(
workspace_buffer: torch.Tensor,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
max_seq_len: int,
max_kv_len: int,
bmm1_scale: Union[float, torch.Tensor] = 1.0,
bmm2_scale: Union[float, torch.Tensor] = 1.0,
window_left: int = -1,
Expand All @@ -2088,12 +2091,15 @@ def trtllm_batch_decode_with_kv_cache(
q_len_per_req: Optional[int] = 1,
o_scale: Optional[float] = 1.0,
mask: Optional[torch.Tensor] = None,
max_q_len: Optional[int] = None,
cum_seq_lens_q: Optional[torch.Tensor] = None,
cum_seq_lens_kv: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, FP4Tensor]:
"""
Parameters
----------
query : torch.Tensor
query tensor with shape [num_tokens, num_heads, head_dim], num_tokens = batch_size * q_len_per_request
query tensor with shape [num_tokens, num_heads, head_dim], num_tokens = total query tokens in the batch.

kv_cache : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
If kv_cache is a single tensor, it should be a tensor with shape [num_pages, 1 or 2, num_kv_heads, page_size, head_dim] if :attr:`kv_layout` is ``HND``,
Expand Down Expand Up @@ -2192,6 +2198,10 @@ def trtllm_batch_decode_with_kv_cache(
raise ValueError("xqa backend does not support nvfp4 output")
if o_sf_scale is not None or o_sf_vec_size is not None:
raise ValueError("xqa backend does not support o_sf_scale or o_sf_vec_size")
if max_q_len is not None or cum_seq_lens_q is not None or cum_seq_lens_kv is not None:
raise ValueError(
"xqa backend does not support cum_seq_lens_q or cum_seq_lens_kv"
)

# Handle out and out_dtype
if out_dtype is None:
Expand All @@ -2206,7 +2216,7 @@ def trtllm_batch_decode_with_kv_cache(
workspace_buffer=workspace_buffer,
block_tables=block_tables,
seq_lens=seq_lens,
max_seq_len=max_seq_len,
max_seq_len=max_kv_len,
bmm1_scale=bmm1_scale,
bmm2_scale=bmm2_scale,
window_left=window_left,
Expand Down Expand Up @@ -2316,13 +2326,13 @@ def trtllm_batch_decode_with_kv_cache(
q_len_per_req,
query.size(1),
query.size(2),
),
) if q_len_per_req is not None else query,
k_cache,
v_cache,
workspace_buffer,
block_tables,
seq_lens,
max_seq_len,
max_kv_len,
bmm1_scale,
bmm2_scale,
o_sf_scale or -1.0,
Expand All @@ -2334,6 +2344,9 @@ def trtllm_batch_decode_with_kv_cache(
enable_pdl,
workspace_buffer.numel() * workspace_buffer.element_size(),
sinks,
max_q_len,
cum_seq_lens_q,
cum_seq_lens_kv,
)

return (
Expand Down
114 changes: 108 additions & 6 deletions tests/attention/test_trtllm_gen_attention.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import math

import sys
sys.path.append("./")

import pytest
import torch
from tests.test_helpers.utils_fp4 import (
Expand Down Expand Up @@ -54,8 +57,13 @@ def generate_seq_lens_prefill(batch_size, max_q_len, max_in_kv_len):
return q_lens, in_kv_lens, seq_lens


def generate_seq_lens_decode(batch_size, q_len_per_req, max_in_kv_len):
q_lens = torch.full((batch_size,), q_len_per_req, dtype=torch.int32)
def generate_seq_lens_decode(batch_size, q_len_per_req, max_in_kv_len, max_q_len):
if q_len_per_req is not None:
assert max_q_len is None, "Can not specify both q_len_per_req and max_q_len."
q_lens = torch.full((batch_size,), q_len_per_req, dtype=torch.int32)
else:
assert max_q_len is not None, "Must specify either q_len_per_req or max_q_len."
q_lens = torch.randint(1, max_q_len + 1, (batch_size,), dtype=torch.int32)
in_kv_lens = torch.randint(0, max_in_kv_len + 1, (batch_size,), dtype=torch.int)
in_kv_lens[-1] = max_in_kv_len
seq_lens = q_lens + in_kv_lens
Expand Down Expand Up @@ -746,6 +754,7 @@ def _test_trtllm_batch_decode(
max_in_kv_len,
head_dim,
device_scale=False,
max_q_len=None,
):
"""
Common function for testing trtllm-gen decode.
Expand Down Expand Up @@ -780,7 +789,7 @@ def _test_trtllm_batch_decode(
# Generate random sequence lengths
num_qo_heads = num_kv_heads * head_grp_size
q_lens, in_kv_lens, seq_lens = generate_seq_lens_decode(
batch_size, q_len_per_req, max_in_kv_len
batch_size, q_len_per_req, max_in_kv_len, max_q_len
)

# Create query tensor and related data
Expand Down Expand Up @@ -835,7 +844,7 @@ def _test_trtllm_batch_decode(
"window_left": window_left,
}
if not enable_sink:
if q_len_per_req == 1:
if q_len_per_req is not None and q_len_per_req == 1:
wrapper_ref = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer_ref, kv_layout, use_tensor_cores=True
)
Expand Down Expand Up @@ -886,7 +895,7 @@ def _test_trtllm_batch_decode(
kv_indptr=kv_indptr_tokens,
)

if q_len_per_req > 1:
if (q_len_per_req and q_len_per_req > 1):
mask = generate_causal_mask(batch_size, q_len_per_req, GPU_DEVICE)
else:
mask = None
Expand Down Expand Up @@ -923,6 +932,9 @@ def _test_trtllm_batch_decode(
q_len_per_req=q_len_per_req,
o_scale=o_scale,
mask=mask,
max_q_len=max_q_len if max_q_len is not None else None,
cum_seq_lens_q=q_indptr if max_q_len is not None else None,
cum_seq_lens_kv=kv_indptr if max_q_len is not None else None,
)
if backend == "trtllm-gen":
# check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero
Expand All @@ -948,7 +960,7 @@ def _test_trtllm_batch_decode(

# convert to float32 for fp8 is not supported by assert_close
# relax rtol and atol for speculative decoding test
if q_len_per_req > 1:
if (q_len_per_req and q_len_per_req > 1) or (max_q_len and max_q_len > 1):
rtol, atol = rtol * 2, atol * 2

# Arbitary small mismatch rate
Expand Down Expand Up @@ -1434,3 +1446,93 @@ def test_trtllm_gen_prefill_deepseek_bs1(
test_trtllm_gen_prefill_deepseek(
batch_size, s_qo, s_kv, num_kv_heads, head_grp_size, causal
)


def test_trtllm_batch_decode_spec(
kv_layout,
batch_size,
max_q_len,
page_size,
num_kv_heads,
head_grp_size,
window_left,
q_dtype,
o_dtype,
kv_dtype,
enable_pdl,
enable_sink,
max_in_kv_len,
head_dim,
):
_test_trtllm_batch_decode(
"trtllm-gen",
kv_layout,
batch_size,
None, # q_len_per_req
page_size,
num_kv_heads,
head_grp_size,
window_left,
q_dtype,
o_dtype,
kv_dtype,
enable_pdl,
enable_sink,
max_in_kv_len,
head_dim,
max_q_len=max_q_len,
)


if __name__ == "__main__":
# pytest.main([__file__])
test_trtllm_batch_decode_spec(
kv_layout="HND",
batch_size=4,
max_q_len=12,
page_size=64,
num_kv_heads=4,
head_grp_size=1,
window_left=-1,
q_dtype="bf16",
kv_dtype="bf16",
o_dtype="bf16",
enable_pdl=None,
enable_sink=False,
max_in_kv_len=110,
head_dim=128,
)
_test_trtllm_batch_decode(
backend='trtllm-gen',
kv_layout="HND",
batch_size=4,
q_len_per_req=3,
page_size=64,
num_kv_heads=4,
head_grp_size=1,
window_left=-1,
q_dtype="bf16",
kv_dtype="bf16",
o_dtype="bf16",
enable_pdl=None,
enable_sink=False,
max_in_kv_len=110,
head_dim=128,
)
_test_trtllm_batch_decode(
backend='trtllm-gen',
kv_layout="HND",
batch_size=4,
q_len_per_req=1,
page_size=64,
num_kv_heads=4,
head_grp_size=1,
window_left=-1,
q_dtype="fp8",
kv_dtype="fp8",
o_dtype="nvfp4",
enable_pdl=None,
enable_sink=False,
max_in_kv_len=110,
head_dim=128,
)
Loading