Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/nightly-release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install build twine wheel
pip install setuptools>=61.0 requests filelock torch tqdm numpy apache-tvm-ffi==0.1.0b15
pip install setuptools>=61.0 requests filelock torch tqdm numpy "apache-tvm-ffi>=0.1,<0.2"

- name: Build flashinfer-cubin wheel
env:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install build twine wheel
pip install setuptools>=61.0 requests filelock torch tqdm numpy apache-tvm-ffi==0.1.0b15
pip install setuptools>=61.0 requests filelock torch tqdm numpy "apache-tvm-ffi>=0.1,<0.2"

- name: Build flashinfer-cubin wheel
run: |
Expand Down
62 changes: 31 additions & 31 deletions csrc/batch_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,21 @@ Array<int64_t> BatchPagedAttentionPlan(TensorView float_workspace_buffer,
TensorView kv_len, int64_t batch_size, int64_t num_qo_heads,
int64_t num_kv_heads, int64_t head_dim_o, bool causal) {
size_t float_workspace_size_in_bytes =
float_workspace_buffer->shape[0] * get_element_size(float_workspace_buffer);
float_workspace_buffer.size(0) * get_element_size(float_workspace_buffer);
size_t int_workspace_size_in_bytes =
int_workspace_buffer->shape[0] * get_element_size(int_workspace_buffer);
int_workspace_buffer.size(0) * get_element_size(int_workspace_buffer);

HolisticPlanInfo<2> plan_info;

cudaSetDevice(float_workspace_buffer->device.device_id);
const cudaStream_t stream = get_stream(float_workspace_buffer->device);
cudaSetDevice(float_workspace_buffer.device().device_id);
const cudaStream_t stream = get_stream(float_workspace_buffer.device());

cudaError_t status = TwoStageHolisticPlan<IdType>(
float_workspace_buffer->data, float_workspace_size_in_bytes, int_workspace_buffer->data,
page_locked_int_workspace_buffer->data, int_workspace_size_in_bytes, plan_info,
static_cast<IdType*>(qo_indptr->data), static_cast<IdType*>(kv_indptr->data),
static_cast<IdType*>(kv_len->data), batch_size, num_qo_heads, num_kv_heads, head_dim_o,
causal, stream);
float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes,
int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(),
int_workspace_size_in_bytes, plan_info, static_cast<IdType*>(qo_indptr.data_ptr()),
static_cast<IdType*>(kv_indptr.data_ptr()), static_cast<IdType*>(kv_len.data_ptr()),
batch_size, num_qo_heads, num_kv_heads, head_dim_o, causal, stream);

TVM_FFI_ICHECK(status == cudaSuccess)
<< "Failed to plan persistent paged attention, error: " << cudaGetErrorString(status);
Expand All @@ -76,34 +76,34 @@ void BatchPagedAttentionRun(TensorView float_workspace_buffer, TensorView int_wo
HolisticPlanInfo<2> plan_info;
plan_info.FromVector(std::vector<int64_t>(plan_info_vec.begin(), plan_info_vec.end()));

void* float_buffer_ptr = float_workspace_buffer->data;
void* int_buffer_ptr = int_workspace_buffer->data;
void* float_buffer_ptr = float_workspace_buffer.data_ptr();
void* int_buffer_ptr = int_workspace_buffer.data_ptr();

const MaskMode mask_mode = static_cast<MaskMode>(mask_mode_code);

// NOTE (Yilong): assume both q and o are NHD
unsigned int q_stride_n = q->strides[0];
unsigned int q_stride_h = q->strides[1];
unsigned int q_stride_n = q.stride(0);
unsigned int q_stride_h = q.stride(1);

// layout only constraint paged KV
const QKVLayout kv_layout = static_cast<QKVLayout>(layout_code);
unsigned int k_stride_page = k_cache->strides[0];
unsigned int v_stride_page = v_cache->strides[0];
unsigned int k_stride_page = k_cache.stride(0);
unsigned int v_stride_page = v_cache.stride(0);
unsigned int k_stride_n, k_stride_h, v_stride_n, v_stride_h;
if (kv_layout == QKVLayout::kNHD) {
k_stride_h = k_cache->strides[2];
k_stride_n = k_cache->strides[1];
v_stride_h = v_cache->strides[2];
v_stride_n = v_cache->strides[1];
k_stride_h = k_cache.stride(2);
k_stride_n = k_cache.stride(1);
v_stride_h = v_cache.stride(2);
v_stride_n = v_cache.stride(1);
} else {
k_stride_h = k_cache->strides[1];
k_stride_n = k_cache->strides[2];
v_stride_h = v_cache->strides[1];
v_stride_n = v_cache->strides[2];
k_stride_h = k_cache.stride(1);
k_stride_n = k_cache.stride(2);
v_stride_h = v_cache.stride(1);
v_stride_n = v_cache.stride(2);
}

cudaSetDevice(q->device.device_id);
const cudaStream_t stream = get_stream(q->device);
cudaSetDevice(q.device().device_id);
const cudaStream_t stream = get_stream(q.device());

DISPATCH_context(
DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE,
Expand All @@ -112,17 +112,17 @@ void BatchPagedAttentionRun(TensorView float_workspace_buffer, TensorView int_wo
IdType* len_kv_chunk =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.len_kv_chunk_offset);
for (int i = 0; i < 2; i++) {
params[i].q = static_cast<DTypeQ*>(q->data);
params[i].k = static_cast<DTypeKV*>(k_cache->data);
params[i].v = static_cast<DTypeKV*>(v_cache->data);
params[i].q = static_cast<DTypeQ*>(q.data_ptr());
params[i].k = static_cast<DTypeKV*>(k_cache.data_ptr());
params[i].v = static_cast<DTypeKV*>(v_cache.data_ptr());

params[i].q_indptr =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.tasks[i].q_indptr_offset);
params[i].kv_indptr =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.tasks[i].kv_indptr_offset);
params[i].partial_indptr = GetPtrFromBaseOffset<IdType>(
int_buffer_ptr, plan_info.tasks[i].partial_indptr_offset);
params[i].kv_indices = static_cast<int*>(kv_indices->data);
params[i].kv_indices = static_cast<int*>(kv_indices.data_ptr());
params[i].q_len =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.tasks[i].q_len_offset);
params[i].kv_len =
Expand All @@ -139,9 +139,9 @@ void BatchPagedAttentionRun(TensorView float_workspace_buffer, TensorView int_wo
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.tasks[i].work_indptr_offset);
params[i].len_kv_chunk = len_kv_chunk + i;

params[i].final_o = static_cast<DTypeO*>(o->data);
params[i].final_o = static_cast<DTypeO*>(o.data_ptr());
params[i].final_lse =
maybe_lse.has_value() ? static_cast<float*>(maybe_lse.value()->data) : nullptr;
maybe_lse.has_value() ? static_cast<float*>(maybe_lse.value().data_ptr()) : nullptr;
params[i].partial_o =
GetPtrFromBaseOffset<DTypeO>(float_buffer_ptr, plan_info.partial_o_offset);
params[i].partial_lse =
Expand Down
64 changes: 33 additions & 31 deletions csrc/batch_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,29 +43,29 @@ Array<int64_t> BatchDecodeWithPagedKVCachePlan(
int64_t window_left, double logits_soft_cap, int64_t head_dim_qk, int64_t head_dim_vo,
TensorView empty_q_data, TensorView empty_kv_data) {
size_t float_workspace_size_in_bytes =
float_workspace_buffer->shape[0] * get_element_size(float_workspace_buffer);
float_workspace_buffer.size(0) * get_element_size(float_workspace_buffer);
size_t int_workspace_size_in_bytes =
int_workspace_buffer->shape[0] * get_element_size(int_workspace_buffer);
int_workspace_buffer.size(0) * get_element_size(int_workspace_buffer);

DecodePlanInfo plan_info;

TVM_FFI_ICHECK_EQ(head_dim_qk, head_dim_vo)
<< "CUDA cores template only supports equal head dim for QK and VO, please use tensor "
"cores template for different head dim";

cudaSetDevice(float_workspace_buffer->device.device_id);
const cudaStream_t stream = get_stream(float_workspace_buffer->device);
cudaSetDevice(float_workspace_buffer.device().device_id);
const cudaStream_t stream = get_stream(float_workspace_buffer.device());
DISPATCH_context(
DTypeQ, DTypeKV, DTypeO, IdType, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE,
USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, AttentionVariant, Params, [&] {
DISPATCH_GQA_GROUP_SIZE(num_qo_heads / num_kv_heads, GROUP_SIZE, {
auto work_estimation_func = BatchDecodeWithPagedKVCacheWorkEstimationDispatched<
GROUP_SIZE, HEAD_DIM_QK, POS_ENCODING_MODE, AttentionVariant, Params>;
cudaError_t status = DecodePlan<HEAD_DIM_QK, POS_ENCODING_MODE, AttentionVariant, Params>(
static_cast<void*>(float_workspace_buffer->data), float_workspace_size_in_bytes,
static_cast<void*>(int_workspace_buffer->data),
static_cast<void*>(page_locked_int_workspace_buffer->data),
int_workspace_size_in_bytes, plan_info, static_cast<IdType*>(indptr->data),
static_cast<void*>(float_workspace_buffer.data_ptr()), float_workspace_size_in_bytes,
static_cast<void*>(int_workspace_buffer.data_ptr()),
static_cast<void*>(page_locked_int_workspace_buffer.data_ptr()),
int_workspace_size_in_bytes, plan_info, static_cast<IdType*>(indptr.data_ptr()),
batch_size, num_qo_heads, page_size, enable_cuda_graph,
/*stream=*/stream, work_estimation_func);

Expand All @@ -89,36 +89,36 @@ void BatchDecodeWithPagedKVCacheRun(TensorView float_workspace_buffer,
DecodePlanInfo plan_info;
plan_info.FromVector(std::vector<int64_t>(plan_info_vec.begin(), plan_info_vec.end()));
QKVLayout kv_layout = static_cast<QKVLayout>(kv_layout_code);
int64_t batch_size = q->shape[0];
int64_t num_qo_heads = q->shape[1];
int64_t batch_size = q.size(0);
int64_t num_qo_heads = q.size(1);
int64_t num_kv_heads, page_size;

if (kv_layout == QKVLayout::kHND) {
num_kv_heads = paged_k_cache->shape[1];
page_size = paged_k_cache->shape[2];
num_kv_heads = paged_k_cache.size(1);
page_size = paged_k_cache.size(2);
} else {
page_size = paged_k_cache->shape[1];
num_kv_heads = paged_k_cache->shape[2];
page_size = paged_k_cache.size(1);
num_kv_heads = paged_k_cache.size(2);
}
uint32_t head_dim_qk = q->shape[2];
uint32_t head_dim_vo = paged_v_cache->shape[3];
uint32_t head_dim_qk = q.size(2);
uint32_t head_dim_vo = paged_v_cache.size(3);

TVM_FFI_ICHECK_EQ(head_dim_qk, head_dim_vo)
<< "CUDA cores template only supports equal head dim for QK and VO, please use tensor "
"cores template for different head dim";

if (maybe_lse.has_value()) {
const auto& lse = maybe_lse.value();
TVM_FFI_ICHECK_EQ(lse->shape[0], batch_size);
TVM_FFI_ICHECK_EQ(lse->shape[1], num_qo_heads);
TVM_FFI_ICHECK_EQ(lse.size(0), batch_size);
TVM_FFI_ICHECK_EQ(lse.size(1), num_qo_heads);
}

void* float_buffer = static_cast<void*>(float_workspace_buffer->data);
void* int_buffer = static_cast<void*>(int_workspace_buffer->data);
void* float_buffer = static_cast<void*>(float_workspace_buffer.data_ptr());
void* int_buffer = static_cast<void*>(int_workspace_buffer.data_ptr());

// get q_stride_n and q_stride_h
const auto q_stride_n = q->strides[0];
const auto q_stride_h = q->strides[1];
const auto q_stride_n = q.stride(0);
const auto q_stride_h = q.stride(1);

// get kv_cache_strides
const int64_t* kv_cache_strides = nullptr;
Expand All @@ -130,24 +130,26 @@ void BatchDecodeWithPagedKVCacheRun(TensorView float_workspace_buffer,
}
kv_cache_strides = k_strides.data();

cudaSetDevice(q->device.device_id);
const cudaStream_t stream = get_stream(q->device);
cudaSetDevice(q.device().device_id);
const cudaStream_t stream = get_stream(q.device());

DISPATCH_context(
DTypeQ, DTypeKV, DTypeO, IdType, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE,
USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, AttentionVariant, Params, [&] {
paged_kv_t<DTypeKV, IdType> paged_kv(
num_kv_heads, page_size, HEAD_DIM_QK, batch_size, kv_layout,
static_cast<DTypeKV*>(paged_k_cache->data), static_cast<DTypeKV*>(paged_v_cache->data),
kv_cache_strides, static_cast<IdType*>(paged_kv_indices->data),
static_cast<IdType*>(paged_kv_indptr->data),
static_cast<IdType*>(paged_kv_last_page_len->data));
static_cast<DTypeKV*>(paged_k_cache.data_ptr()),
static_cast<DTypeKV*>(paged_v_cache.data_ptr()), kv_cache_strides,
static_cast<IdType*>(paged_kv_indices.data_ptr()),
static_cast<IdType*>(paged_kv_indptr.data_ptr()),
static_cast<IdType*>(paged_kv_last_page_len.data_ptr()));

Params params;
params.q = static_cast<DTypeQ*>(q->data);
params.q = static_cast<DTypeQ*>(q.data_ptr());
params.paged_kv = paged_kv;
params.o = static_cast<DTypeO*>(o->data);
params.lse = maybe_lse.has_value() ? static_cast<float*>(maybe_lse.value()->data) : nullptr;
params.o = static_cast<DTypeO*>(o.data_ptr());
params.lse =
maybe_lse.has_value() ? static_cast<float*>(maybe_lse.value().data_ptr()) : nullptr;
params.padded_batch_size = 0;
params.num_qo_heads = num_qo_heads;
params.q_stride_n = q_stride_n;
Expand Down
52 changes: 27 additions & 25 deletions csrc/batch_decode_mla_cute_sm80.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,24 @@ Array<int64_t> BatchDecodeWithPagedKVCachePlanMLA(ffi::TensorView float_workspac
int64_t num_qo_heads, int64_t page_size,
bool enable_cuda_graph) {
size_t float_workspace_size_in_bytes =
float_workspace_buffer->shape[0] * get_element_size(float_workspace_buffer);
float_workspace_buffer.size(0) * get_element_size(float_workspace_buffer);
size_t int_workspace_size_in_bytes =
int_workspace_buffer->shape[0] * get_element_size(int_workspace_buffer);
int_workspace_buffer.size(0) * get_element_size(int_workspace_buffer);

DecodePlanInfo plan_info;
cudaSetDevice(float_workspace_buffer->device.device_id);
const cudaStream_t stream = get_stream(float_workspace_buffer->device);
cudaSetDevice(float_workspace_buffer.device().device_id);
const cudaStream_t stream = get_stream(float_workspace_buffer.device());

auto work_estimation_func = BatchDecodeWithPagedKVCacheWorkEstimationDispatchedMlaCuteSM80<
HEAD_DIM_CKV, HEAD_DIM_KPE, QO_TILE_LEN, AttentionVariant, Params>;
cudaError_t status =
DecodePlan<HEAD_DIM_CKV, flashinfer::PosEncodingMode::kNone, AttentionVariant, Params>(
static_cast<void*>(float_workspace_buffer->data), float_workspace_size_in_bytes,
static_cast<void*>(int_workspace_buffer->data),
static_cast<void*>(page_locked_int_workspace_buffer->data), int_workspace_size_in_bytes,
plan_info, static_cast<IdType*>(indptr->data), batch_size, num_qo_heads, page_size,
enable_cuda_graph, /*stream=*/stream, work_estimation_func);
static_cast<void*>(float_workspace_buffer.data_ptr()), float_workspace_size_in_bytes,
static_cast<void*>(int_workspace_buffer.data_ptr()),
static_cast<void*>(page_locked_int_workspace_buffer.data_ptr()),
int_workspace_size_in_bytes, plan_info, static_cast<IdType*>(indptr.data_ptr()),
batch_size, num_qo_heads, page_size, enable_cuda_graph, /*stream=*/stream,
work_estimation_func);

TVM_FFI_ICHECK(status == cudaSuccess)
<< "BatchDecodeWithPagedKVCachePlanMLA failed with error " << cudaGetErrorString(status);
Expand All @@ -55,31 +56,32 @@ void BatchDecodeWithPagedKVCacheRunMLA(
DecodePlanInfo plan_info;
plan_info.FromVector(std::vector<int64_t>(plan_info_vec.begin(), plan_info_vec.end()));

int64_t batch_size = q_nope->shape[0];
int64_t num_qo_heads = q_nope->shape[1];
int64_t page_size = paged_ckv_cache->shape[1];
int64_t batch_size = q_nope.size(0);
int64_t num_qo_heads = q_nope.size(1);
int64_t page_size = paged_ckv_cache.size(1);

if (maybe_lse.has_value()) {
const auto& lse = maybe_lse.value();
TVM_FFI_ICHECK_EQ(lse->shape[0], batch_size);
TVM_FFI_ICHECK_EQ(lse->shape[1], num_qo_heads);
TVM_FFI_ICHECK_EQ(lse.size(0), batch_size);
TVM_FFI_ICHECK_EQ(lse.size(1), num_qo_heads);
}

TVM_FFI_ICHECK_GE(logits_soft_cap, 0.f) << "logits_soft_cap must be non-negative";

void* float_buffer = static_cast<void*>(float_workspace_buffer->data);
void* int_buffer = static_cast<void*>(int_workspace_buffer->data);
void* float_buffer = static_cast<void*>(float_workspace_buffer.data_ptr());
void* int_buffer = static_cast<void*>(int_workspace_buffer.data_ptr());

paged_kv_mla_t<DTypeKV, IdType> paged_kv(
page_size, HEAD_DIM_CKV, HEAD_DIM_KPE, batch_size,
static_cast<DTypeKV*>(paged_ckv_cache->data), paged_ckv_cache.strides().data(),
static_cast<DTypeKV*>(paged_kpe_cache->data), paged_kpe_cache.strides().data(),
static_cast<IdType*>(paged_kv_indices->data), static_cast<IdType*>(paged_kv_indptr->data),
static_cast<IdType*>(paged_kv_last_page_len->data));
static_cast<DTypeKV*>(paged_ckv_cache.data_ptr()), paged_ckv_cache.strides().data(),
static_cast<DTypeKV*>(paged_kpe_cache.data_ptr()), paged_kpe_cache.strides().data(),
static_cast<IdType*>(paged_kv_indices.data_ptr()),
static_cast<IdType*>(paged_kv_indptr.data_ptr()),
static_cast<IdType*>(paged_kv_last_page_len.data_ptr()));
Params params(
static_cast<DTypeQ*>(q_nope->data), static_cast<DTypeQ*>(q_pe->data),
/*q_offset=*/nullptr, paged_kv, static_cast<DTypeO*>(o->data),
/*lse=*/(maybe_lse.has_value() ? static_cast<float*>(maybe_lse.value()->data) : nullptr),
static_cast<DTypeQ*>(q_nope.data_ptr()), static_cast<DTypeQ*>(q_pe.data_ptr()),
/*q_offset=*/nullptr, paged_kv, static_cast<DTypeO*>(o.data_ptr()),
/*lse=*/(maybe_lse.has_value() ? static_cast<float*>(maybe_lse.value().data_ptr()) : nullptr),
num_qo_heads, window_left, logits_soft_cap, sm_scale, rope_scale, rope_theta);

DTypeO* tmp_v = nullptr;
Expand All @@ -101,8 +103,8 @@ void BatchDecodeWithPagedKVCacheRunMLA(
}
params.padded_batch_size = plan_info.padded_batch_size;

cudaSetDevice(paged_ckv_cache->device.device_id);
const cudaStream_t stream = get_stream(paged_ckv_cache->device);
cudaSetDevice(paged_ckv_cache.device().device_id);
const cudaStream_t stream = get_stream(paged_ckv_cache.device());
cudaError_t status = BatchDecodeWithPagedKVCacheDispatchedMlaCuteSM80<HEAD_DIM_CKV, HEAD_DIM_KPE,
QO_TILE_LEN, Params>(
params, tmp_v, tmp_s, /*stream=*/stream);
Expand Down
Loading
Loading