-
Notifications
You must be signed in to change notification settings - Fork 724
Reduce allocation overhead in quantized sdpa #15610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/kimishpatel/202/base
Are you sure you want to change the base?
Changes from all commits
ae61ab4
99902b8
72292a7
93e17be
602a3a7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -213,13 +213,13 @@ void dequant_and_gemm( | |||||
| const int64_t v_stride_n, | ||||||
| float* o_data, | ||||||
| const int64_t o_stride_m, | ||||||
| const float beta) { | ||||||
| std::vector<float> dequantized_v_data(v_data.m * v_data.n); | ||||||
| const float beta, | ||||||
| float* buf_qdq_ptr) { | ||||||
| dequantize_per_channel_optimized( | ||||||
| static_cast<const int8_t*>(v_data.data), | ||||||
| static_cast<const float*>(v_data.scales), | ||||||
| static_cast<const int8_t*>(v_data.zero_points), | ||||||
| dequantized_v_data.data(), | ||||||
| buf_qdq_ptr, | ||||||
| -128, | ||||||
| 127, | ||||||
| 1, | ||||||
|
|
@@ -237,7 +237,7 @@ void dequant_and_gemm( | |||||
| m, | ||||||
| k, | ||||||
| static_cast<float>(1), | ||||||
| dequantized_v_data.data(), | ||||||
| buf_qdq_ptr, | ||||||
| v_data.n, | ||||||
| qk_data, | ||||||
| qk_stride_m, | ||||||
|
|
@@ -257,7 +257,8 @@ void _qk_at_v_gemm( | |||||
| const int64_t v_stride_n, | ||||||
| accum_t* o_data, | ||||||
| const int64_t o_stride_m, | ||||||
| const accum_t beta) { | ||||||
| const accum_t beta, | ||||||
| accum_t* buf_qdq_ptr) { | ||||||
| if (v_data.dtype == ScalarType::Char) { | ||||||
| if constexpr (std::is_same<accum_t, float>::value) { | ||||||
| if (m > 4) { | ||||||
|
|
@@ -273,7 +274,8 @@ void _qk_at_v_gemm( | |||||
| v_stride_n, | ||||||
| o_data, | ||||||
| o_stride_m, | ||||||
| beta); | ||||||
| beta, | ||||||
| buf_qdq_ptr); | ||||||
| } else { | ||||||
| // For smaller batch sizes, use quantized gemm | ||||||
| int a_stride_m_tmp, b_stride_n_tmp; | ||||||
|
|
@@ -773,6 +775,17 @@ void cpu_flash_attention( | |||||
| // at::Tensor buf_reduced = at::empty( | ||||||
| // {num_thread, qSplitSize, is_reduced_type ? kvSplitSize : 0}, | ||||||
| // query.options()); | ||||||
| int64_t size_per_thread_qdq_vec = qSplitSize * kvSplitSize * headSize; | ||||||
| // Lets align size_per_thread_qdq_vec to 64 bytes, for coalesced cache reads, | ||||||
|
||||||
| // Lets align size_per_thread_qdq_vec to 64 bytes, for coalesced cache reads, | |
| // Lets align size_per_thread_qdq_vec to 32 elements (128 bytes for float), for coalesced cache reads, |
Copilot
AI
Nov 17, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The alignment calculation is incorrect. The formula (x + kAlignment - 1) & (-(kAlignment - 1)) uses the wrong mask.
For aligning to a power-of-2 boundary, the correct formula is:
(size_per_thread_qdq_vec + kAlignment - 1) & (-kAlignment)or equivalently:
(size_per_thread_qdq_vec + kAlignment - 1) & ~(kAlignment - 1)The current code uses -(kAlignment - 1) which equals -31 = 0xFFFFFFE1, but the correct mask should be -32 = 0xFFFFFFE0 to properly zero out the bottom 5 bits.
| (size_per_thread_qdq_vec + kAlignment - 1) & (-(kAlignment - 1)); | |
| (size_per_thread_qdq_vec + kAlignment - 1) & -kAlignment; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@copilot are you sure? Please double check again
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The buffer size calculation appears to be larger than necessary. The dequantize operation needs
kvBlockSize * headSizeelements (at mostkvSplitSize * headSize), but this allocatesqSplitSize * kvSplitSize * headSize. The extraqSplitSizefactor seems unnecessary and wastes memory per thread.Consider changing to:
int64_t size_per_thread_qdq_vec = kvSplitSize * headSize;