|
26 | 26 | #include <sstream> |
27 | 27 | #include <unordered_map> |
28 | 28 |
|
| 29 | +#include "tvm/ffi/error.h" |
29 | 30 | #include "tvm_ffi_utils.h" |
30 | 31 |
|
31 | 32 | using tvm::ffi::Optional; |
@@ -163,6 +164,9 @@ void trtllm_paged_attention_launcher( |
163 | 164 | use_multi_block ? TileScheduler::Static : TileScheduler::Persistent; |
164 | 165 | runner_params.mMultiCtasKvMode = use_multi_block; |
165 | 166 |
|
| 167 | + runner_params.cumSeqLensQPtr = cum_seq_lens_q; |
| 168 | + runner_params.cumSeqLensKvPtr = cum_seq_lens_kv; |
| 169 | + |
166 | 170 | size_t max_batch_size = 8192; // todo(Yingyi): get from dlfw |
167 | 171 | size_t max_num_qo_heads = 256; // todo(Yingyi): get from dlfw, in total 8MB |
168 | 172 | size_t num_semaphores = |
@@ -213,22 +217,49 @@ void trtllm_paged_attention_decode( |
213 | 217 | TensorView seq_lens, int64_t max_kv_len, Variant<double, ffi::Tensor> bmm1_scale, |
214 | 218 | Variant<double, ffi::Tensor> bmm2_scale, double o_sf_scale, int64_t o_sf_vec_size, |
215 | 219 | int64_t o_sf_start_index, int64_t window_left, int64_t sparse_mla_top_k, int64_t sm_count, |
216 | | - bool enable_pdl, int64_t workspace_size, Optional<TensorView> attention_sinks) { |
| 220 | + bool enable_pdl, int64_t workspace_size, Optional<TensorView> attention_sinks, |
| 221 | + Optional<int64_t> optional_max_q_len, |
| 222 | + Optional<TensorView> cum_seq_lens_q, |
| 223 | + Optional<TensorView> cum_seq_lens_kv |
| 224 | + ) { |
217 | 225 | auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype()); |
218 | 226 | auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype()); |
219 | 227 | TVM_FFI_ICHECK_EQ(key_cache.ndim(), value_cache.ndim()); |
220 | 228 | for (int i = 0; i < key_cache.ndim(); i++) { |
221 | 229 | TVM_FFI_ICHECK_EQ(key_cache.size(i), value_cache.size(i)); |
222 | 230 | } |
223 | 231 | auto o_data_type = dl_dtype_to_tllm_data_type(out.dtype()); |
224 | | - // NOTE(Zihao): query is [B, Q, H, D] |
225 | | - // where Q is the number of query tokens per request, used in MTP |
226 | | - // based on profiled results, always use decode mode for MTP (q_len is small) |
227 | | - // example: when kv_len = 10000, q < 200, decode mode is faster |
228 | | - int batch_size = query.size(0); |
229 | | - int q_len_per_request = query.size(1); |
230 | | - int sum_seq_q = batch_size * q_len_per_request; |
231 | | - int num_qo_heads = query.size(2); |
| 232 | + int batch_size; |
| 233 | + int max_q_len; |
| 234 | + int sum_seq_q; |
| 235 | + int num_qo_heads; |
| 236 | + int* cum_seq_lens_q_ptr = nullptr; |
| 237 | + int* cum_seq_lens_kv_ptr = nullptr; |
| 238 | + if (!optional_max_q_len.has_value()) { |
| 239 | + // each request has the same length |
| 240 | + |
| 241 | + // NOTE(Zihao): query is [B, Q, H, D] |
| 242 | + // where Q is the number of query tokens per request, used in MTP |
| 243 | + // based on profiled results, always use decode mode for MTP (q_len is small) |
| 244 | + // example: when kv_len = 10000, q < 200, decode mode is faster |
| 245 | + int q_len_per_request = query.size(1); |
| 246 | + batch_size = query.size(0); |
| 247 | + sum_seq_q = batch_size * q_len_per_request; |
| 248 | + num_qo_heads = query.size(2); |
| 249 | + max_q_len = q_len_per_request; |
| 250 | + } else { |
| 251 | + // each request has different length |
| 252 | + TVM_FFI_CHECK(cum_seq_lens_q.has_value(), "cum_seq_lens_q must be provided when max_q_len is provided"); |
| 253 | + TVM_FFI_CHECK(cum_seq_lens_kv.has_value(), "cum_seq_lens_kv must be provided when max_q_len is provided"); |
| 254 | + // the shape of query: [sum_seq_q, num_qo_heads, head_dim_q] |
| 255 | + // the shape of cum_seq_lens_q: [batch_size + 1] |
| 256 | + batch_size = cum_seq_lens_q.value().size(0) - 1; |
| 257 | + sum_seq_q = query.size(0); |
| 258 | + num_qo_heads = query.size(1); |
| 259 | + max_q_len = optional_max_q_len.value(); |
| 260 | + cum_seq_lens_q_ptr = static_cast<int*>(cum_seq_lens_q.value().data_ptr()); |
| 261 | + cum_seq_lens_kv_ptr = static_cast<int*>(cum_seq_lens_kv.value().data_ptr()); |
| 262 | + } |
232 | 263 | // Multiply by two for FP4 tensor as it is stored as UINT8 dtype. Assume the dim is even. |
233 | 264 | int head_dim_k = is_4bit(kv_data_type) ? key_cache.size(-1) * 2 : key_cache.size(-1); |
234 | 265 | int head_dim_q = is_4bit(q_data_type) ? query.size(-1) * 2 : query.size(-1); |
@@ -285,9 +316,9 @@ void trtllm_paged_attention_decode( |
285 | 316 | out.data_ptr(), output_sf_ptr, query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(), |
286 | 317 | workspace_buffer.data_ptr(), static_cast<int*>(block_tables.data_ptr()), |
287 | 318 | static_cast<int*>(seq_lens.data_ptr()), |
288 | | - /*cum_seq_lens_q=*/nullptr, |
289 | | - /*cum_seq_lens_kv=*/nullptr, attention_sinks_ptr, q_data_type, kv_data_type, o_data_type, |
290 | | - TllmPagedAttentionMode::ForGen, batch_size, /*max_q_len=*/q_len_per_request, max_kv_len, |
| 319 | + cum_seq_lens_q_ptr, |
| 320 | + cum_seq_lens_kv_ptr, attention_sinks_ptr, q_data_type, kv_data_type, o_data_type, |
| 321 | + TllmPagedAttentionMode::ForGen, batch_size, max_q_len, max_kv_len, |
291 | 322 | num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size, |
292 | 323 | kv_stride_keys_values, kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq, |
293 | 324 | bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale, |
|
0 commit comments