diff --git a/test/unit/flash_attention/flash_attention_decode/CMakeLists.txt b/test/unit/flash_attention/flash_attention_decode/CMakeLists.txt index 006c727f62..478b358ce6 100644 --- a/test/unit/flash_attention/flash_attention_decode/CMakeLists.txt +++ b/test/unit/flash_attention/flash_attention_decode/CMakeLists.txt @@ -58,6 +58,12 @@ cutlass_test_unit_add_executable( xe_flash_decode_fp16_fp32_fp32_h192_1024_nonpaged.cpp ) +cutlass_test_unit_add_executable( + cutlass_test_unit_flash_attention_decode_models_xe + xe_flash_decode_models_fp16_nonpaged.cpp + xe_flash_decode_models_bf16_nonpaged.cpp +) + add_custom_target( cutlass_test_unit_flash_attention_decode DEPENDS @@ -65,6 +71,7 @@ add_custom_target( cutlass_test_unit_flash_attention_decode_h96_xe cutlass_test_unit_flash_attention_decode_h128_xe cutlass_test_unit_flash_attention_decode_h192_xe + cutlass_test_unit_flash_attention_decode_models_xe ) add_custom_target( @@ -74,4 +81,5 @@ add_custom_target( test_unit_flash_attention_decode_h96_xe test_unit_flash_attention_decode_h128_xe test_unit_flash_attention_decode_h192_xe + test_unit_flash_attention_decode_models_xe ) \ No newline at end of file diff --git a/test/unit/flash_attention/flash_attention_decode/flash_decode_testbed_3x.hpp b/test/unit/flash_attention/flash_attention_decode/flash_decode_testbed_3x.hpp index 7180f75d5e..e0b1e8162a 100644 --- a/test/unit/flash_attention/flash_attention_decode/flash_decode_testbed_3x.hpp +++ b/test/unit/flash_attention/flash_attention_decode/flash_decode_testbed_3x.hpp @@ -798,14 +798,56 @@ struct Testbed3x { }; template -bool TestFlashDecodeAll(int head_size) { +bool TestFlashDecodeAll(int head_size, std::string config="default") { Testbed3x testbed; - std::vector problem_size_batch{16}; - std::vector problem_size_num_heads{32}; - std::vector problem_size_seq_len{1024}; - std::vector problem_size_seq_len_cache{0, 1024}; - std::vector cache_page_size{64, 128}; + std::vector problem_size_batch; + std::vector problem_size_num_heads; + std::vector problem_size_seq_len; + std::vector problem_size_seq_len_cache; + std::vector cache_page_size; + if(config == "whisper_v3_large"){ + problem_size_batch = {1, 2, 4}; + problem_size_num_heads = {20}; + problem_size_seq_len = {512, 1024}; + problem_size_seq_len_cache = {0, 1024}; + cache_page_size = {64, 128}; + } + else if(config == "llama3_8b"){ + problem_size_batch = {1, 2, 4}; + problem_size_num_heads = {32}; + problem_size_seq_len = {512, 1024}; + problem_size_seq_len_cache = {0, 1024}; + cache_page_size = {64, 128}; + } + else if(config == "llama3_405b"){ + problem_size_batch = {1, 2}; + problem_size_num_heads = {128}; + problem_size_seq_len = {512, 1024}; + problem_size_seq_len_cache = {0, 1024}; + cache_page_size = {64, 128}; + } + else if(config == "qwen2_5_72b"){ + problem_size_batch = {1, 2}; + problem_size_num_heads = {64}; + problem_size_seq_len = {512, 1024}; + problem_size_seq_len_cache = {0, 1024}; + cache_page_size = {64, 128}; + } + else if(config == "deepseek_r1"){ + problem_size_batch = {1, 2}; + problem_size_num_heads = {64}; + problem_size_seq_len = {512, 1024}; + problem_size_seq_len_cache = {0, 1024}; + cache_page_size = {64, 128}; + } + else{ + problem_size_batch = {16}; + problem_size_num_heads = {32}; + problem_size_seq_len = {1024}; + problem_size_seq_len_cache = {0, 1024}; + cache_page_size = {64, 128}; + } std::vector problem_size_softmax_scale{ 1.f / sqrt(static_cast(head_size)) }; bool passed = true; diff --git a/test/unit/flash_attention/flash_attention_decode/xe_flash_decode_bf16_fp32_fp32_h64_512_nonpaged.cpp b/test/unit/flash_attention/flash_attention_decode/xe_flash_decode_bf16_fp32_fp32_h64_512_nonpaged.cpp index 1dcd52cfc2..b2d4873786 100644 --- a/test/unit/flash_attention/flash_attention_decode/xe_flash_decode_bf16_fp32_fp32_h64_512_nonpaged.cpp +++ b/test/unit/flash_attention/flash_attention_decode/xe_flash_decode_bf16_fp32_fp32_h64_512_nonpaged.cpp @@ -29,6 +29,7 @@ * **************************************************************************************************/ + /*! \file \brief Tests for Xe flash attention decode bf16 */ diff --git a/test/unit/flash_attention/flash_attention_decode/xe_flash_decode_models_bf16_nonpaged.cpp b/test/unit/flash_attention/flash_attention_decode/xe_flash_decode_models_bf16_nonpaged.cpp new file mode 100644 index 0000000000..f50ab5876a --- /dev/null +++ b/test/unit/flash_attention/flash_attention_decode/xe_flash_decode_models_bf16_nonpaged.cpp @@ -0,0 +1,206 @@ +/**************************************************************************** + * Copyright (C) 2025 Intel Corporation. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + ***************************************************************************/ + +#include "flash_decode_testbed_3x.hpp" + +namespace cutlass { + +using MMAOperationBF16 = test::flash_attention::MMAOperationBF16; +using GmemTiledCopyQ = test::flash_attention::GmemTiledCopyQU16; +using GmemTiledCopyK = test::flash_attention::GmemTiledCopyKU16; +using GmemTiledCopyV = test::flash_attention::GmemTiledCopyVU16; +using GmemTiledCopyStore = test::flash_attention::GmemTiledCopyStoreU32; + +// 20 tests: 5 models × 4 head sizes, KV512, causal, varlen + +// h64 × KV512 × Causal × VarLen +TEST(XE_Flash_Attention_Decode_BF16, bf16_h64_kv512_causal_varlen_whisper) { + using Shape_h = test::flash_attention::Shape_h64<512, 8>; + using Kernel = test::flash_attention::XE_Flash_Attention_Decode::Kernel; + EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll(64, "whisper_v3_large")); +} +TEST(XE_Flash_Attention_Decode_BF16, bf16_h64_kv512_causal_varlen_llama8b) { + using Shape_h = test::flash_attention::Shape_h64<512, 8>; + using Kernel = test::flash_attention::XE_Flash_Attention_Decode::Kernel; + EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll(64, "llama3_8b")); +} +TEST(XE_Flash_Attention_Decode_BF16, bf16_h64_kv512_causal_varlen_llama405b) { + using Shape_h = test::flash_attention::Shape_h64<512, 8>; + using Kernel = test::flash_attention::XE_Flash_Attention_Decode::Kernel; + EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll(64, "llama3_405b")); +} +TEST(XE_Flash_Attention_Decode_BF16, bf16_h64_kv512_causal_varlen_qwen25) { + using Shape_h = test::flash_attention::Shape_h64<512, 8>; + using Kernel = test::flash_attention::XE_Flash_Attention_Decode::Kernel; + EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll(64, "qwen2_5_72b")); +} +TEST(XE_Flash_Attention_Decode_BF16, bf16_h64_kv512_causal_varlen_deepseek) { + using Shape_h = test::flash_attention::Shape_h64<512, 8>; + using Kernel = test::flash_attention::XE_Flash_Attention_Decode::Kernel; + EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll(64, "deepseek_r1")); +} + +// h96 × KV512 × Causal × VarLen +TEST(XE_Flash_Attention_Decode_BF16, bf16_h96_kv512_causal_varlen_whisper) { + using Shape_h = test::flash_attention::Shape_h96<512, 8>; + using Kernel = test::flash_attention::XE_Flash_Attention_Decode::Kernel; + EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll(96, "whisper_v3_large")); +} +TEST(XE_Flash_Attention_Decode_BF16, bf16_h96_kv512_causal_varlen_llama8b) { + using Shape_h = test::flash_attention::Shape_h96<512, 8>; + using Kernel = test::flash_attention::XE_Flash_Attention_Decode::Kernel; + EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll(96, "llama3_8b")); +} +TEST(XE_Flash_Attention_Decode_BF16, bf16_h96_kv512_causal_varlen_llama405b) { + using Shape_h = test::flash_attention::Shape_h96<512, 8>; + using Kernel = test::flash_attention::XE_Flash_Attention_Decode::Kernel; + EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll(96, "llama3_405b")); +} +TEST(XE_Flash_Attention_Decode_BF16, bf16_h96_kv512_causal_varlen_qwen25) { + using Shape_h = test::flash_attention::Shape_h96<512, 8>; + using Kernel = test::flash_attention::XE_Flash_Attention_Decode::Kernel; + EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll(96, "qwen2_5_72b")); +} +TEST(XE_Flash_Attention_Decode_BF16, bf16_h96_kv512_causal_varlen_deepseek) { + using Shape_h = test::flash_attention::Shape_h96<512, 8>; + using Kernel = test::flash_attention::XE_Flash_Attention_Decode::Kernel; + EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll(96, "deepseek_r1")); +} + +// h128 × KV512 × Causal × VarLen +TEST(XE_Flash_Attention_Decode_BF16, bf16_h128_kv512_causal_varlen_whisper) { + using Shape_h = test::flash_attention::Shape_h128<512, 8>; + using Kernel = test::flash_attention::XE_Flash_Attention_Decode::Kernel; + EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll(128, "whisper_v3_large")); +} +TEST(XE_Flash_Attention_Decode_BF16, bf16_h128_kv512_causal_varlen_llama8b) { + using Shape_h = test::flash_attention::Shape_h128<512, 8>; + using Kernel = test::flash_attention::XE_Flash_Attention_Decode::Kernel; + EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll(128, "llama3_8b")); +} +TEST(XE_Flash_Attention_Decode_BF16, bf16_h128_kv512_causal_varlen_llama405b) { + using Shape_h = test::flash_attention::Shape_h128<512, 8>; + using Kernel = test::flash_attention::XE_Flash_Attention_Decode::Kernel; + EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll(128, "llama3_405b")); +} +TEST(XE_Flash_Attention_Decode_BF16, bf16_h128_kv512_causal_varlen_qwen25) { + using Shape_h = test::flash_attention::Shape_h128<512, 8>; + using Kernel = test::flash_attention::XE_Flash_Attention_Decode::Kernel; + EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll(128, "qwen2_5_72b")); +} +TEST(XE_Flash_Attention_Decode_BF16, bf16_h128_kv512_causal_varlen_deepseek) { + using Shape_h = test::flash_attention::Shape_h128<512, 8>; + using Kernel = test::flash_attention::XE_Flash_Attention_Decode::Kernel; + EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll(128, "deepseek_r1")); +} + +// h192 × KV512 × Causal × VarLen +TEST(XE_Flash_Attention_Decode_BF16, bf16_h192_kv512_causal_varlen_whisper) { + using Shape_h = test::flash_attention::Shape_h192<512, 8>; + using Kernel = test::flash_attention::XE_Flash_Attention_Decode::Kernel; + EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll(192, "whisper_v3_large")); +} +TEST(XE_Flash_Attention_Decode_BF16, bf16_h192_kv512_causal_varlen_llama8b) { + using Shape_h = test::flash_attention::Shape_h192<512, 8>; + using Kernel = test::flash_attention::XE_Flash_Attention_Decode::Kernel; + EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll(192, "llama3_8b")); +} +TEST(XE_Flash_Attention_Decode_BF16, bf16_h192_kv512_causal_varlen_llama405b) { + using Shape_h = test::flash_attention::Shape_h192<512, 8>; + using Kernel = test::flash_attention::XE_Flash_Attention_Decode::Kernel; + EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll(192, "llama3_405b")); +} +TEST(XE_Flash_Attention_Decode_BF16, bf16_h192_kv512_causal_varlen_qwen25) { + using Shape_h = test::flash_attention::Shape_h192<512, 8>; + using Kernel = test::flash_attention::XE_Flash_Attention_Decode::Kernel; + EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll(192, "qwen2_5_72b")); +} +TEST(XE_Flash_Attention_Decode_BF16, bf16_h192_kv512_causal_varlen_deepseek) { + using Shape_h = test::flash_attention::Shape_h192<512, 8>; + using Kernel = test::flash_attention::XE_Flash_Attention_Decode::Kernel; + EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll(192, "deepseek_r1")); +} + +} // namespace cutlass diff --git a/test/unit/flash_attention/flash_attention_decode/xe_flash_decode_models_fp16_nonpaged.cpp b/test/unit/flash_attention/flash_attention_decode/xe_flash_decode_models_fp16_nonpaged.cpp new file mode 100644 index 0000000000..52d56326b6 --- /dev/null +++ b/test/unit/flash_attention/flash_attention_decode/xe_flash_decode_models_fp16_nonpaged.cpp @@ -0,0 +1,156 @@ +/**************************************************************************** + * Copyright (C) 2025 Intel Corporation. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + ***************************************************************************/ + +#include "flash_decode_testbed_3x.hpp" + +namespace cutlass { + +using MMAOperationFP16 = test::flash_attention::MMAOperationFP16; +using GmemTiledCopyQ = test::flash_attention::GmemTiledCopyQU16; +using GmemTiledCopyK = test::flash_attention::GmemTiledCopyKU16; +using GmemTiledCopyV = test::flash_attention::GmemTiledCopyVU16; +using GmemTiledCopyStore = test::flash_attention::GmemTiledCopyStoreU32; + +TEST(XE_Flash_Attention_Decode_FP16, fp16_h96_kv512_causal_varlen_whisper) { + using Shape_h = test::flash_attention::Shape_h96<512, 8>; + using Kernel = test::flash_attention::XE_Flash_Attention_Decode::Kernel; + EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll(96, "whisper_v3_large")); +} +TEST(XE_Flash_Attention_Decode_FP16, fp16_h96_kv512_causal_varlen_llama8b) { + using Shape_h = test::flash_attention::Shape_h96<512, 8>; + using Kernel = test::flash_attention::XE_Flash_Attention_Decode::Kernel; + EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll(96, "llama3_8b")); +} +TEST(XE_Flash_Attention_Decode_FP16, fp16_h96_kv512_causal_varlen_llama405b) { + using Shape_h = test::flash_attention::Shape_h96<512, 8>; + using Kernel = test::flash_attention::XE_Flash_Attention_Decode::Kernel; + EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll(96, "llama3_405b")); +} +TEST(XE_Flash_Attention_Decode_FP16, fp16_h96_kv512_causal_varlen_qwen25) { + using Shape_h = test::flash_attention::Shape_h96<512, 8>; + using Kernel = test::flash_attention::XE_Flash_Attention_Decode::Kernel; + EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll(96, "qwen2_5_72b")); +} +TEST(XE_Flash_Attention_Decode_FP16, fp16_h96_kv512_causal_varlen_deepseek) { + using Shape_h = test::flash_attention::Shape_h96<512, 8>; + using Kernel = test::flash_attention::XE_Flash_Attention_Decode::Kernel; + EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll(96, "deepseek_r1")); +} + +// h128 × KV512 × Causal × VarLen × [model] +TEST(XE_Flash_Attention_Decode_FP16, fp16_h128_kv512_causal_varlen_whisper) { + using Shape_h = test::flash_attention::Shape_h128<512, 8>; + using Kernel = test::flash_attention::XE_Flash_Attention_Decode::Kernel; + EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll(128, "whisper_v3_large")); +} +TEST(XE_Flash_Attention_Decode_FP16, fp16_h128_kv512_causal_varlen_llama8b) { + using Shape_h = test::flash_attention::Shape_h128<512, 8>; + using Kernel = test::flash_attention::XE_Flash_Attention_Decode::Kernel; + EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll(128, "llama3_8b")); +} +TEST(XE_Flash_Attention_Decode_FP16, fp16_h128_kv512_causal_varlen_llama405b) { + using Shape_h = test::flash_attention::Shape_h128<512, 8>; + using Kernel = test::flash_attention::XE_Flash_Attention_Decode::Kernel; + EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll(128, "llama3_405b")); +} +TEST(XE_Flash_Attention_Decode_FP16, fp16_h128_kv512_causal_varlen_qwen25) { + using Shape_h = test::flash_attention::Shape_h128<512, 8>; + using Kernel = test::flash_attention::XE_Flash_Attention_Decode::Kernel; + EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll(128, "qwen2_5_72b")); +} +TEST(XE_Flash_Attention_Decode_FP16, fp16_h128_kv512_causal_varlen_deepseek) { + using Shape_h = test::flash_attention::Shape_h128<512, 8>; + using Kernel = test::flash_attention::XE_Flash_Attention_Decode::Kernel; + EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll(128, "deepseek_r1")); +} + +// h192 × KV512 × Causal × VarLen × [model] +TEST(XE_Flash_Attention_Decode_FP16, fp16_h192_kv512_causal_varlen_whisper) { + using Shape_h = test::flash_attention::Shape_h192<512, 8>; + using Kernel = test::flash_attention::XE_Flash_Attention_Decode::Kernel; + EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll(192, "whisper_v3_large")); +} +TEST(XE_Flash_Attention_Decode_FP16, fp16_h192_kv512_causal_varlen_llama8b) { + using Shape_h = test::flash_attention::Shape_h192<512, 8>; + using Kernel = test::flash_attention::XE_Flash_Attention_Decode::Kernel; + EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll(192, "llama3_8b")); +} +TEST(XE_Flash_Attention_Decode_FP16, fp16_h192_kv512_causal_varlen_llama405b) { + using Shape_h = test::flash_attention::Shape_h192<512, 8>; + using Kernel = test::flash_attention::XE_Flash_Attention_Decode::Kernel; + EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll(192, "llama3_405b")); +} +TEST(XE_Flash_Attention_Decode_FP16, fp16_h192_kv512_causal_varlen_qwen25) { + using Shape_h = test::flash_attention::Shape_h192<512, 8>; + using Kernel = test::flash_attention::XE_Flash_Attention_Decode::Kernel; + EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll(192, "qwen2_5_72b")); +} +TEST(XE_Flash_Attention_Decode_FP16, fp16_h192_kv512_causal_varlen_deepseek) { + using Shape_h = test::flash_attention::Shape_h192<512, 8>; + using Kernel = test::flash_attention::XE_Flash_Attention_Decode::Kernel; + EXPECT_TRUE(test::flash_attention::TestFlashDecodeAll(192, "deepseek_r1")); +} + +} // namespace cutlass