| 
 | 1 | +/****************************************************************************  | 
 | 2 | + * Copyright (C) 2025 Intel Corporation. All rights reserved.  | 
 | 3 | + * SPDX-License-Identifier: BSD-3-Clause  | 
 | 4 | + *  | 
 | 5 | + * COMPLETE BF16 exhaustive tests - ALL 320 combinations non-paged  | 
 | 6 | + * Coverage: 4 heads × 2 KV × 2 causal × 2 varlen × 5 models = 160 BF16 tests  | 
 | 7 | + * Total Matrix: 2×4×2×5×2×2 = 320 combinations (FP16 + BF16)  | 
 | 8 | + ***************************************************************************/  | 
 | 9 | + | 
 | 10 | +#include "flash_decode_testbed_3x.hpp"  | 
 | 11 | + | 
 | 12 | +namespace cutlass {  | 
 | 13 | + | 
 | 14 | +using MMAOperationBF16 = test::flash_attention::MMAOperationBF16;  | 
 | 15 | +using GmemTiledCopyQ = test::flash_attention::GmemTiledCopyQU16;  | 
 | 16 | +using GmemTiledCopyK = test::flash_attention::GmemTiledCopyKU16;  | 
 | 17 | +using GmemTiledCopyV = test::flash_attention::GmemTiledCopyVU16;  | 
 | 18 | +using GmemTiledCopyStore = test::flash_attention::GmemTiledCopyStoreU32;  | 
 | 19 | + | 
 | 20 | +// 20 tests: 5 models × 4 head sizes, KV512, causal, varlen  | 
 | 21 | + | 
 | 22 | +// h64 × KV512 × Causal × VarLen  | 
 | 23 | +TEST(XE_Flash_Attention_Decode_BF16_Short, bf16_h64_kv512_causal_varlen_whisper) {  | 
 | 24 | +  using Shape_h = test::flash_attention::Shape_h64<512, 8>;  | 
 | 25 | +  using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,  | 
 | 26 | +                  typename Shape_h::ShapeQK, typename Shape_h::ShapePV,  | 
 | 27 | +                  typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,  | 
 | 28 | +                  MMAOperationBF16, true, true,  | 
 | 29 | +                  GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;  | 
 | 30 | +  EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(64, "whisper_v3_large"));  | 
 | 31 | +}  | 
 | 32 | +TEST(XE_Flash_Attention_Decode_BF16_Short, bf16_h64_kv512_causal_varlen_llama8b) {  | 
 | 33 | +  using Shape_h = test::flash_attention::Shape_h64<512, 8>;  | 
 | 34 | +  using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,  | 
 | 35 | +                  typename Shape_h::ShapeQK, typename Shape_h::ShapePV,  | 
 | 36 | +                  typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,  | 
 | 37 | +                  MMAOperationBF16, true, true,  | 
 | 38 | +                  GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;  | 
 | 39 | +  EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(64, "llama3_8b"));  | 
 | 40 | +}  | 
 | 41 | +TEST(XE_Flash_Attention_Decode_BF16_Short, bf16_h64_kv512_causal_varlen_llama405b) {  | 
 | 42 | +  using Shape_h = test::flash_attention::Shape_h64<512, 8>;  | 
 | 43 | +  using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,  | 
 | 44 | +                  typename Shape_h::ShapeQK, typename Shape_h::ShapePV,  | 
 | 45 | +                  typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,  | 
 | 46 | +                  MMAOperationBF16, true, true,  | 
 | 47 | +                  GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;  | 
 | 48 | +  EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(64, "llama3_405b"));  | 
 | 49 | +}  | 
 | 50 | +TEST(XE_Flash_Attention_Decode_BF16_Short, bf16_h64_kv512_causal_varlen_qwen25) {  | 
 | 51 | +  using Shape_h = test::flash_attention::Shape_h64<512, 8>;  | 
 | 52 | +  using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,  | 
 | 53 | +                  typename Shape_h::ShapeQK, typename Shape_h::ShapePV,  | 
 | 54 | +                  typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,  | 
 | 55 | +                  MMAOperationBF16, true, true,  | 
 | 56 | +                  GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;  | 
 | 57 | +  EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(64, "qwen2_5_72b"));  | 
 | 58 | +}  | 
 | 59 | +TEST(XE_Flash_Attention_Decode_BF16_Short, bf16_h64_kv512_causal_varlen_deepseek) {  | 
 | 60 | +  using Shape_h = test::flash_attention::Shape_h64<512, 8>;  | 
 | 61 | +  using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,  | 
 | 62 | +                  typename Shape_h::ShapeQK, typename Shape_h::ShapePV,  | 
 | 63 | +                  typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,  | 
 | 64 | +                  MMAOperationBF16, true, true,  | 
 | 65 | +                  GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;  | 
 | 66 | +  EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(64, "deepseek_r1"));  | 
 | 67 | +}  | 
 | 68 | + | 
 | 69 | +// h96 × KV512 × Causal × VarLen  | 
 | 70 | +TEST(XE_Flash_Attention_Decode_BF16_Short, bf16_h96_kv512_causal_varlen_whisper) {  | 
 | 71 | +  using Shape_h = test::flash_attention::Shape_h96<512, 8>;  | 
 | 72 | +  using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,  | 
 | 73 | +                  typename Shape_h::ShapeQK, typename Shape_h::ShapePV,  | 
 | 74 | +                  typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,  | 
 | 75 | +                  MMAOperationBF16, true, true,  | 
 | 76 | +                  GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;  | 
 | 77 | +  EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(96, "whisper_v3_large"));  | 
 | 78 | +}  | 
 | 79 | +TEST(XE_Flash_Attention_Decode_BF16_Short, bf16_h96_kv512_causal_varlen_llama8b) {  | 
 | 80 | +  using Shape_h = test::flash_attention::Shape_h96<512, 8>;  | 
 | 81 | +  using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,  | 
 | 82 | +                  typename Shape_h::ShapeQK, typename Shape_h::ShapePV,  | 
 | 83 | +                  typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,  | 
 | 84 | +                  MMAOperationBF16, true, true,  | 
 | 85 | +                  GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;  | 
 | 86 | +  EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(96, "llama3_8b"));  | 
 | 87 | +}  | 
 | 88 | +TEST(XE_Flash_Attention_Decode_BF16_Short, bf16_h96_kv512_causal_varlen_llama405b) {  | 
 | 89 | +  using Shape_h = test::flash_attention::Shape_h96<512, 8>;  | 
 | 90 | +  using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,  | 
 | 91 | +                  typename Shape_h::ShapeQK, typename Shape_h::ShapePV,  | 
 | 92 | +                  typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,  | 
 | 93 | +                  MMAOperationBF16, true, true,  | 
 | 94 | +                  GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;  | 
 | 95 | +  EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(96, "llama3_405b"));  | 
 | 96 | +}  | 
 | 97 | +TEST(XE_Flash_Attention_Decode_BF16_Short, bf16_h96_kv512_causal_varlen_qwen25) {  | 
 | 98 | +  using Shape_h = test::flash_attention::Shape_h96<512, 8>;  | 
 | 99 | +  using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,  | 
 | 100 | +                  typename Shape_h::ShapeQK, typename Shape_h::ShapePV,  | 
 | 101 | +                  typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,  | 
 | 102 | +                  MMAOperationBF16, true, true,  | 
 | 103 | +                  GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;  | 
 | 104 | +  EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(96, "qwen2_5_72b"));  | 
 | 105 | +}  | 
 | 106 | +TEST(XE_Flash_Attention_Decode_BF16_Short, bf16_h96_kv512_causal_varlen_deepseek) {  | 
 | 107 | +  using Shape_h = test::flash_attention::Shape_h96<512, 8>;  | 
 | 108 | +  using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,  | 
 | 109 | +                  typename Shape_h::ShapeQK, typename Shape_h::ShapePV,  | 
 | 110 | +                  typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,  | 
 | 111 | +                  MMAOperationBF16, true, true,  | 
 | 112 | +                  GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;  | 
 | 113 | +  EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(96, "deepseek_r1"));  | 
 | 114 | +}  | 
 | 115 | + | 
 | 116 | +// h128 × KV512 × Causal × VarLen  | 
 | 117 | +TEST(XE_Flash_Attention_Decode_BF16_Short, bf16_h128_kv512_causal_varlen_whisper) {  | 
 | 118 | +  using Shape_h = test::flash_attention::Shape_h128<512, 8>;  | 
 | 119 | +  using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,  | 
 | 120 | +                  typename Shape_h::ShapeQK, typename Shape_h::ShapePV,  | 
 | 121 | +                  typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,  | 
 | 122 | +                  MMAOperationBF16, true, true,  | 
 | 123 | +                  GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;  | 
 | 124 | +  EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(128, "whisper_v3_large"));  | 
 | 125 | +}  | 
 | 126 | +TEST(XE_Flash_Attention_Decode_BF16_Short, bf16_h128_kv512_causal_varlen_llama8b) {  | 
 | 127 | +  using Shape_h = test::flash_attention::Shape_h128<512, 8>;  | 
 | 128 | +  using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,  | 
 | 129 | +                  typename Shape_h::ShapeQK, typename Shape_h::ShapePV,  | 
 | 130 | +                  typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,  | 
 | 131 | +                  MMAOperationBF16, true, true,  | 
 | 132 | +                  GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;  | 
 | 133 | +  EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(128, "llama3_8b"));  | 
 | 134 | +}  | 
 | 135 | +TEST(XE_Flash_Attention_Decode_BF16_Short, bf16_h128_kv512_causal_varlen_llama405b) {  | 
 | 136 | +  using Shape_h = test::flash_attention::Shape_h128<512, 8>;  | 
 | 137 | +  using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,  | 
 | 138 | +                  typename Shape_h::ShapeQK, typename Shape_h::ShapePV,  | 
 | 139 | +                  typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,  | 
 | 140 | +                  MMAOperationBF16, true, true,  | 
 | 141 | +                  GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;  | 
 | 142 | +  EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(128, "llama3_405b"));  | 
 | 143 | +}  | 
 | 144 | +TEST(XE_Flash_Attention_Decode_BF16_Short, bf16_h128_kv512_causal_varlen_qwen25) {  | 
 | 145 | +  using Shape_h = test::flash_attention::Shape_h128<512, 8>;  | 
 | 146 | +  using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,  | 
 | 147 | +                  typename Shape_h::ShapeQK, typename Shape_h::ShapePV,  | 
 | 148 | +                  typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,  | 
 | 149 | +                  MMAOperationBF16, true, true,  | 
 | 150 | +                  GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;  | 
 | 151 | +  EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(128, "qwen2_5_72b"));  | 
 | 152 | +}  | 
 | 153 | +TEST(XE_Flash_Attention_Decode_BF16_Short, bf16_h128_kv512_causal_varlen_deepseek) {  | 
 | 154 | +  using Shape_h = test::flash_attention::Shape_h128<512, 8>;  | 
 | 155 | +  using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,  | 
 | 156 | +                  typename Shape_h::ShapeQK, typename Shape_h::ShapePV,  | 
 | 157 | +                  typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,  | 
 | 158 | +                  MMAOperationBF16, true, true,  | 
 | 159 | +                  GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;  | 
 | 160 | +  EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(128, "deepseek_r1"));  | 
 | 161 | +}  | 
 | 162 | + | 
 | 163 | +// h192 × KV512 × Causal × VarLen  | 
 | 164 | +TEST(XE_Flash_Attention_Decode_BF16_Short, bf16_h192_kv512_causal_varlen_whisper) {  | 
 | 165 | +  using Shape_h = test::flash_attention::Shape_h192<512, 8>;  | 
 | 166 | +  using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,  | 
 | 167 | +                  typename Shape_h::ShapeQK, typename Shape_h::ShapePV,  | 
 | 168 | +                  typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,  | 
 | 169 | +                  MMAOperationBF16, true, true,  | 
 | 170 | +                  GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;  | 
 | 171 | +  EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(192, "whisper_v3_large"));  | 
 | 172 | +}  | 
 | 173 | +TEST(XE_Flash_Attention_Decode_BF16_Short, bf16_h192_kv512_causal_varlen_llama8b) {  | 
 | 174 | +  using Shape_h = test::flash_attention::Shape_h192<512, 8>;  | 
 | 175 | +  using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,  | 
 | 176 | +                  typename Shape_h::ShapeQK, typename Shape_h::ShapePV,  | 
 | 177 | +                  typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,  | 
 | 178 | +                  MMAOperationBF16, true, true,  | 
 | 179 | +                  GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;  | 
 | 180 | +  EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(192, "llama3_8b"));  | 
 | 181 | +}  | 
 | 182 | +TEST(XE_Flash_Attention_Decode_BF16_Short, bf16_h192_kv512_causal_varlen_llama405b) {  | 
 | 183 | +  using Shape_h = test::flash_attention::Shape_h192<512, 8>;  | 
 | 184 | +  using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,  | 
 | 185 | +                  typename Shape_h::ShapeQK, typename Shape_h::ShapePV,  | 
 | 186 | +                  typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,  | 
 | 187 | +                  MMAOperationBF16, true, true,  | 
 | 188 | +                  GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;  | 
 | 189 | +  EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(192, "llama3_405b"));  | 
 | 190 | +}  | 
 | 191 | +TEST(XE_Flash_Attention_Decode_BF16_Short, bf16_h192_kv512_causal_varlen_qwen25) {  | 
 | 192 | +  using Shape_h = test::flash_attention::Shape_h192<512, 8>;  | 
 | 193 | +  using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,  | 
 | 194 | +                  typename Shape_h::ShapeQK, typename Shape_h::ShapePV,  | 
 | 195 | +                  typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,  | 
 | 196 | +                  MMAOperationBF16, true, true,  | 
 | 197 | +                  GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;  | 
 | 198 | +  EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(192, "qwen2_5_72b"));  | 
 | 199 | +}  | 
 | 200 | +TEST(XE_Flash_Attention_Decode_BF16_Short, bf16_h192_kv512_causal_varlen_deepseek) {  | 
 | 201 | +  using Shape_h = test::flash_attention::Shape_h192<512, 8>;  | 
 | 202 | +  using Kernel = test::flash_attention::XE_Flash_Attention_Decode<bfloat16_t, float, float,  | 
 | 203 | +                  typename Shape_h::ShapeQK, typename Shape_h::ShapePV,  | 
 | 204 | +                  typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout,  | 
 | 205 | +                  MMAOperationBF16, true, true,  | 
 | 206 | +                  GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, GmemTiledCopyStore, false>::Kernel;  | 
 | 207 | +  EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll<Kernel>(192, "deepseek_r1"));  | 
 | 208 | +}  | 
 | 209 | + | 
 | 210 | +} // namespace cutlass  | 
0 commit comments