Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions extension/llm/custom_ops/op_sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ Tensor& flash_attention_kernel_out(
// we might consider another appraoch
if (seq_len >= 768) {
sdpa::impl::cpu_flash_attention<CTYPE, 256, 512>(
ctx,
output,
query,
key,
Expand All @@ -289,6 +290,7 @@ Tensor& flash_attention_kernel_out(
nullopt);
} else if (seq_len >= 192) {
sdpa::impl::cpu_flash_attention<CTYPE, 64, 512>(
ctx,
output,
query,
key,
Expand All @@ -305,6 +307,7 @@ Tensor& flash_attention_kernel_out(
nullopt);
} else {
sdpa::impl::cpu_flash_attention<CTYPE, 32, 512>(
ctx,
output,
query,
key,
Expand Down Expand Up @@ -418,6 +421,7 @@ Tensor& custom_sdpa_out_impl(
// we might consider another appraoch
if (seq_len >= 768) {
sdpa::impl::cpu_flash_attention<CTYPE, 256, 512>(
ctx,
output,
q,
k,
Expand All @@ -437,6 +441,7 @@ Tensor& custom_sdpa_out_impl(
num_keys_for_causal_attention);
} else if (seq_len >= 192) {
sdpa::impl::cpu_flash_attention<CTYPE, 64, 512>(
ctx,
output,
q,
k,
Expand All @@ -456,6 +461,7 @@ Tensor& custom_sdpa_out_impl(
num_keys_for_causal_attention);
} else {
sdpa::impl::cpu_flash_attention<CTYPE, 32, 512>(
ctx,
output,
q,
k,
Expand Down
39 changes: 26 additions & 13 deletions extension/llm/custom_ops/op_sdpa_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ enum class SeqDim { ONE = 1, TWO };

namespace sdpa::impl {

static std::vector<char> scratch_for_quant_dequant_vec;
Copy link

Copilot AI Nov 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This static vector scratch_for_quant_dequant_vec is declared but never used in the code. It appears to be a leftover from the refactoring where the local vector was replaced with the temp allocator approach. This should be removed.

Suggested change
static std::vector<char> scratch_for_quant_dequant_vec;

Copilot uses AI. Check for mistakes.
struct MaybeQuantizedMatrixData {
const void* data{nullptr};
const int8_t* zero_points{nullptr};
Expand Down Expand Up @@ -543,6 +544,7 @@ TODO: Just handle conversion of bool mask to float
*/
template <typename scalar_t, int64_t q_split_size, int64_t kv_split_size>
void cpu_flash_attention(
RuntimeContext& ctx,
Tensor& output,
const Tensor& query,
const Tensor& key,
Expand Down Expand Up @@ -766,26 +768,37 @@ void cpu_flash_attention(
int64_t size_of_intermediate_precision = sizeof(accum_t);
int64_t size_bytes = size_per_thread * num_thread * query.element_size() *
size_of_intermediate_precision;
std::vector<char> buf_vec(size_bytes);
void* buf = reinterpret_cast<void*>(buf_vec.data());
// Need to double check the following
size_bytes = num_thread * qSplitSize * kvSplitSize * query.element_size();
std::vector<char> buf_reduced_vec(size_bytes);
void* buf_reduced = reinterpret_cast<void*>(buf_reduced_vec.data());
// at::Tensor buf_reduced = at::empty(
// {num_thread, qSplitSize, is_reduced_type ? kvSplitSize : 0},
// query.options());
Result<void*> buff_res = ctx.allocate_temp(size_bytes);
std::unique_ptr<char[]> allocated_buf;
void* buf;
if (!buff_res.ok()) {
allocated_buf = std::make_unique<char[]>(size_bytes);
buf = reinterpret_cast<void*>(allocated_buf.get());
} else {
buf = buff_res.get();
}
void* buf_reduced = nullptr;
int64_t size_per_thread_qdq_vec = qSplitSize * kvSplitSize * headSize;
// Lets align size_per_thread_qdq_vec to 64 bytes, for coalesced cache reads,
// by padding with right number of per thread elements
constexpr int64_t kAlignment = 32;
size_per_thread_qdq_vec =
(size_per_thread_qdq_vec + kAlignment - 1) & (-(kAlignment - 1));
int64_t size_per_thread_qdq_bytes = size_per_thread_qdq_vec * sizeof(accum_t);
int64_t size_per_thread_qdq_bytes =
size_per_thread_qdq_vec * size_of_intermediate_precision;
int64_t size_qdq_bytes = size_per_thread_qdq_bytes * num_thread;
std::vector<char> scratch_for_quant_dequant_vec(size_qdq_bytes);
accum_t* scratch_for_quant_dequant =
reinterpret_cast<accum_t*>(scratch_for_quant_dequant_vec.data());
std::unique_ptr<char[]> allocated_buf_for_qdq;
Result<void*> scratch_for_quant_dequant_res =
ctx.allocate_temp(size_qdq_bytes);
accum_t* scratch_for_quant_dequant;
if (!scratch_for_quant_dequant_res.ok()) {
allocated_buf_for_qdq = std::make_unique<char[]>(size_qdq_bytes);
scratch_for_quant_dequant =
reinterpret_cast<accum_t*>(allocated_buf_for_qdq.get());
} else {
scratch_for_quant_dequant =
reinterpret_cast<accum_t*>(scratch_for_quant_dequant_res.get());
}

// Data ptrs
const scalar_t* q_data = query.const_data_ptr<scalar_t>();
Expand Down
Loading