Skip to content

Commit 84f60cb

Browse files
committed
Model-specific tests for flash attention decode
1 parent 0e24202 commit 84f60cb

File tree

4 files changed

+426
-6
lines changed

4 files changed

+426
-6
lines changed

test/unit/flash_attention/flash_attention_decode/CMakeLists.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,20 @@ cutlass_test_unit_add_executable(
5858
xe_flash_decode_fp16_fp32_fp32_h192_1024_nonpaged.cpp
5959
)
6060

61+
cutlass_test_unit_add_executable(
62+
cutlass_test_unit_flash_attention_decode_models_xe
63+
xe_flash_decode_models_fp16_nonpaged.cpp
64+
xe_flash_decode_models_bf16_nonpaged.cpp
65+
)
66+
6167
add_custom_target(
6268
cutlass_test_unit_flash_attention_decode
6369
DEPENDS
6470
cutlass_test_unit_flash_attention_decode_h64_xe
6571
cutlass_test_unit_flash_attention_decode_h96_xe
6672
cutlass_test_unit_flash_attention_decode_h128_xe
6773
cutlass_test_unit_flash_attention_decode_h192_xe
74+
cutlass_test_unit_flash_attention_decode_models_xe
6875
)
6976

7077
add_custom_target(
@@ -74,4 +81,5 @@ add_custom_target(
7481
test_unit_flash_attention_decode_h96_xe
7582
test_unit_flash_attention_decode_h128_xe
7683
test_unit_flash_attention_decode_h192_xe
84+
test_unit_flash_attention_decode_models_xe
7785
)

test/unit/flash_attention/flash_attention_decode/flash_decode_testbed_3x.hpp

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -798,14 +798,56 @@ struct Testbed3x {
798798
};
799799

800800
template <typename FlashDecode>
801-
bool TestFlashDecodeAll(int head_size) {
801+
bool TestFlashDecodeAll(int head_size, std::string config="default") {
802802
Testbed3x<FlashDecode> testbed;
803803

804-
std::vector<int> problem_size_batch{16};
805-
std::vector<int> problem_size_num_heads{32};
806-
std::vector<int> problem_size_seq_len{1024};
807-
std::vector<int> problem_size_seq_len_cache{0, 1024};
808-
std::vector<int> cache_page_size{64, 128};
804+
std::vector<int> problem_size_batch;
805+
std::vector<int> problem_size_num_heads;
806+
std::vector<int> problem_size_seq_len;
807+
std::vector<int> problem_size_seq_len_cache;
808+
std::vector<int> cache_page_size;
809+
if(config == "whisper_v3_large"){
810+
problem_size_batch = {1, 2, 4};
811+
problem_size_num_heads = {20};
812+
problem_size_seq_len = {512, 1024};
813+
problem_size_seq_len_cache = {0, 1024};
814+
cache_page_size = {64, 128};
815+
}
816+
else if(config == "llama3_8b"){
817+
problem_size_batch = {1, 2, 4};
818+
problem_size_num_heads = {32};
819+
problem_size_seq_len = {512, 1024};
820+
problem_size_seq_len_cache = {0, 1024};
821+
cache_page_size = {64, 128};
822+
}
823+
else if(config == "llama3_405b"){
824+
problem_size_batch = {1, 2};
825+
problem_size_num_heads = {128};
826+
problem_size_seq_len = {512, 1024};
827+
problem_size_seq_len_cache = {0, 1024};
828+
cache_page_size = {64, 128};
829+
}
830+
else if(config == "qwen2_5_72b"){
831+
problem_size_batch = {1, 2};
832+
problem_size_num_heads = {64};
833+
problem_size_seq_len = {512, 1024};
834+
problem_size_seq_len_cache = {0, 1024};
835+
cache_page_size = {64, 128};
836+
}
837+
else if(config == "deepseek_r1"){
838+
problem_size_batch = {1, 2};
839+
problem_size_num_heads = {64};
840+
problem_size_seq_len = {512, 1024};
841+
problem_size_seq_len_cache = {0, 1024};
842+
cache_page_size = {64, 128};
843+
}
844+
else{
845+
problem_size_batch = {16};
846+
problem_size_num_heads = {32};
847+
problem_size_seq_len = {1024};
848+
problem_size_seq_len_cache = {0, 1024};
849+
cache_page_size = {64, 128};
850+
}
809851
std::vector<float> problem_size_softmax_scale{ 1.f / sqrt(static_cast<float>(head_size)) };
810852
bool passed = true;
811853

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
/****************************************************************************
2+
* Copyright (C) 2025 Intel Corporation. All rights reserved.
3+
* SPDX-License-Identifier: BSD-3-Clause
4+
*
5+
* COMPLETE BF16 exhaustive tests - ALL 320 combinations non-paged
6+
* Coverage: 4 heads × 2 KV × 2 causal × 2 varlen × 5 models = 160 BF16 tests
7+
* Total Matrix: 2×4×2×5×2×2 = 320 combinations (FP16 + BF16)
8+
***************************************************************************/
9+
10+
#include "flash_decode_testbed_3x.hpp"
11+
12+
namespace cutlass {
13+
14+
using MMAOperationBF16 = test::flash_attention::MMAOperationBF16;
15+
using GmemTiledCopyQ = test::flash_attention::GmemTiledCopyQU16;
16+
using GmemTiledCopyK = test::flash_attention::GmemTiledCopyKU16;
17+
using GmemTiledCopyV = test::flash_attention::GmemTiledCopyVU16;
18+
using GmemTiledCopyStore = test::flash_attention::GmemTiledCopyStoreU32;
19+
20+
// 20 tests: 5 models × 4 head sizes, KV512, causal, varlen
21+
22+
// h64 × KV512 × Causal × VarLen
23+
TEST(XE_Flash_Attention_Decode_BF16_Short, bf16_h64_kv512_causal_varlen_whisper) {
24+
using Shape_h = test::flash_attention::Shape_h64<512, 8>;
25+
using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,
26+
typename Shape_h::ShapeQK, typename Shape_h::ShapePV,
27+
typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,
28+
MMAOperationBF16, true, true,
29+
GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;
30+
EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(64, "whisper_v3_large"));
31+
}
32+
TEST(XE_Flash_Attention_Decode_BF16_Short, bf16_h64_kv512_causal_varlen_llama8b) {
33+
using Shape_h = test::flash_attention::Shape_h64<512, 8>;
34+
using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,
35+
typename Shape_h::ShapeQK, typename Shape_h::ShapePV,
36+
typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,
37+
MMAOperationBF16, true, true,
38+
GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;
39+
EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(64, "llama3_8b"));
40+
}
41+
TEST(XE_Flash_Attention_Decode_BF16_Short, bf16_h64_kv512_causal_varlen_llama405b) {
42+
using Shape_h = test::flash_attention::Shape_h64<512, 8>;
43+
using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,
44+
typename Shape_h::ShapeQK, typename Shape_h::ShapePV,
45+
typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,
46+
MMAOperationBF16, true, true,
47+
GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;
48+
EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(64, "llama3_405b"));
49+
}
50+
TEST(XE_Flash_Attention_Decode_BF16_Short, bf16_h64_kv512_causal_varlen_qwen25) {
51+
using Shape_h = test::flash_attention::Shape_h64<512, 8>;
52+
using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,
53+
typename Shape_h::ShapeQK, typename Shape_h::ShapePV,
54+
typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,
55+
MMAOperationBF16, true, true,
56+
GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;
57+
EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(64, "qwen2_5_72b"));
58+
}
59+
TEST(XE_Flash_Attention_Decode_BF16_Short, bf16_h64_kv512_causal_varlen_deepseek) {
60+
using Shape_h = test::flash_attention::Shape_h64<512, 8>;
61+
using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,
62+
typename Shape_h::ShapeQK, typename Shape_h::ShapePV,
63+
typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,
64+
MMAOperationBF16, true, true,
65+
GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;
66+
EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(64, "deepseek_r1"));
67+
}
68+
69+
// h96 × KV512 × Causal × VarLen
70+
TEST(XE_Flash_Attention_Decode_BF16_Short, bf16_h96_kv512_causal_varlen_whisper) {
71+
using Shape_h = test::flash_attention::Shape_h96<512, 8>;
72+
using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,
73+
typename Shape_h::ShapeQK, typename Shape_h::ShapePV,
74+
typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,
75+
MMAOperationBF16, true, true,
76+
GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;
77+
EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(96, "whisper_v3_large"));
78+
}
79+
TEST(XE_Flash_Attention_Decode_BF16_Short, bf16_h96_kv512_causal_varlen_llama8b) {
80+
using Shape_h = test::flash_attention::Shape_h96<512, 8>;
81+
using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,
82+
typename Shape_h::ShapeQK, typename Shape_h::ShapePV,
83+
typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,
84+
MMAOperationBF16, true, true,
85+
GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;
86+
EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(96, "llama3_8b"));
87+
}
88+
TEST(XE_Flash_Attention_Decode_BF16_Short, bf16_h96_kv512_causal_varlen_llama405b) {
89+
using Shape_h = test::flash_attention::Shape_h96<512, 8>;
90+
using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,
91+
typename Shape_h::ShapeQK, typename Shape_h::ShapePV,
92+
typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,
93+
MMAOperationBF16, true, true,
94+
GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;
95+
EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(96, "llama3_405b"));
96+
}
97+
TEST(XE_Flash_Attention_Decode_BF16_Short, bf16_h96_kv512_causal_varlen_qwen25) {
98+
using Shape_h = test::flash_attention::Shape_h96<512, 8>;
99+
using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,
100+
typename Shape_h::ShapeQK, typename Shape_h::ShapePV,
101+
typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,
102+
MMAOperationBF16, true, true,
103+
GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;
104+
EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(96, "qwen2_5_72b"));
105+
}
106+
TEST(XE_Flash_Attention_Decode_BF16_Short, bf16_h96_kv512_causal_varlen_deepseek) {
107+
using Shape_h = test::flash_attention::Shape_h96<512, 8>;
108+
using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,
109+
typename Shape_h::ShapeQK, typename Shape_h::ShapePV,
110+
typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,
111+
MMAOperationBF16, true, true,
112+
GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;
113+
EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(96, "deepseek_r1"));
114+
}
115+
116+
// h128 × KV512 × Causal × VarLen
117+
TEST(XE_Flash_Attention_Decode_BF16_Short, bf16_h128_kv512_causal_varlen_whisper) {
118+
using Shape_h = test::flash_attention::Shape_h128<512, 8>;
119+
using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,
120+
typename Shape_h::ShapeQK, typename Shape_h::ShapePV,
121+
typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,
122+
MMAOperationBF16, true, true,
123+
GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;
124+
EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(128, "whisper_v3_large"));
125+
}
126+
TEST(XE_Flash_Attention_Decode_BF16_Short, bf16_h128_kv512_causal_varlen_llama8b) {
127+
using Shape_h = test::flash_attention::Shape_h128<512, 8>;
128+
using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,
129+
typename Shape_h::ShapeQK, typename Shape_h::ShapePV,
130+
typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,
131+
MMAOperationBF16, true, true,
132+
GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;
133+
EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(128, "llama3_8b"));
134+
}
135+
TEST(XE_Flash_Attention_Decode_BF16_Short, bf16_h128_kv512_causal_varlen_llama405b) {
136+
using Shape_h = test::flash_attention::Shape_h128<512, 8>;
137+
using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,
138+
typename Shape_h::ShapeQK, typename Shape_h::ShapePV,
139+
typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,
140+
MMAOperationBF16, true, true,
141+
GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;
142+
EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(128, "llama3_405b"));
143+
}
144+
TEST(XE_Flash_Attention_Decode_BF16_Short, bf16_h128_kv512_causal_varlen_qwen25) {
145+
using Shape_h = test::flash_attention::Shape_h128<512, 8>;
146+
using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,
147+
typename Shape_h::ShapeQK, typename Shape_h::ShapePV,
148+
typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,
149+
MMAOperationBF16, true, true,
150+
GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;
151+
EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(128, "qwen2_5_72b"));
152+
}
153+
TEST(XE_Flash_Attention_Decode_BF16_Short, bf16_h128_kv512_causal_varlen_deepseek) {
154+
using Shape_h = test::flash_attention::Shape_h128<512, 8>;
155+
using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,
156+
typename Shape_h::ShapeQK, typename Shape_h::ShapePV,
157+
typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,
158+
MMAOperationBF16, true, true,
159+
GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;
160+
EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(128, "deepseek_r1"));
161+
}
162+
163+
// h192 × KV512 × Causal × VarLen
164+
TEST(XE_Flash_Attention_Decode_BF16_Short, bf16_h192_kv512_causal_varlen_whisper) {
165+
using Shape_h = test::flash_attention::Shape_h192<512, 8>;
166+
using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,
167+
typename Shape_h::ShapeQK, typename Shape_h::ShapePV,
168+
typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,
169+
MMAOperationBF16, true, true,
170+
GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;
171+
EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(192, "whisper_v3_large"));
172+
}
173+
TEST(XE_Flash_Attention_Decode_BF16_Short, bf16_h192_kv512_causal_varlen_llama8b) {
174+
using Shape_h = test::flash_attention::Shape_h192<512, 8>;
175+
using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,
176+
typename Shape_h::ShapeQK, typename Shape_h::ShapePV,
177+
typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,
178+
MMAOperationBF16, true, true,
179+
GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;
180+
EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(192, "llama3_8b"));
181+
}
182+
TEST(XE_Flash_Attention_Decode_BF16_Short, bf16_h192_kv512_causal_varlen_llama405b) {
183+
using Shape_h = test::flash_attention::Shape_h192<512, 8>;
184+
using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,
185+
typename Shape_h::ShapeQK, typename Shape_h::ShapePV,
186+
typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,
187+
MMAOperationBF16, true, true,
188+
GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;
189+
EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(192, "llama3_405b"));
190+
}
191+
TEST(XE_Flash_Attention_Decode_BF16_Short, bf16_h192_kv512_causal_varlen_qwen25) {
192+
using Shape_h = test::flash_attention::Shape_h192<512, 8>;
193+
using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,
194+
typename Shape_h::ShapeQK, typename Shape_h::ShapePV,
195+
typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,
196+
MMAOperationBF16, true, true,
197+
GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;
198+
EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(192, "qwen2_5_72b"));
199+
}
200+
TEST(XE_Flash_Attention_Decode_BF16_Short, bf16_h192_kv512_causal_varlen_deepseek) {
201+
using Shape_h = test::flash_attention::Shape_h192<512, 8>;
202+
using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,
203+
typename Shape_h::ShapeQK, typename Shape_h::ShapePV,
204+
typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,
205+
MMAOperationBF16, true, true,
206+
GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;
207+
EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(192, "deepseek_r1"));
208+
}
209+
210+
} // namespace cutlass

0 commit comments

Comments
 (0)