diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 12afe0f25a87b..a99c55820041e 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -158,6 +158,7 @@ option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp on option(GGML_HIP "ggml: use HIP" OFF) option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF) option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON) +option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF) option(GGML_HIP_UMA "ggml: use HIP unified memory architecture" OFF) option(GGML_VULKAN "ggml: use Vulkan" OFF) option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks" OFF) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index adf0d3ecb566c..1832314ec133b 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -62,6 +62,7 @@ #define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a #define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA +#define GGML_CUDA_CC_IS_AMD(cc) (cc >= GGML_CUDA_CC_OFFSET_AMD) #define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1) #define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2) #define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3) @@ -196,6 +197,10 @@ typedef float2 dfloat2; #define FP16_MMA_AVAILABLE #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA +#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3)) +#define FP16_MMA_AVAILABLE +#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3)) + #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING #define NEW_MMA_AVAILABLE #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING @@ -223,12 +228,18 @@ static bool fast_fp16_hardware_available(const int cc) { // Any FP16 tensor core instructions are available for ggml code. static bool fp16_mma_available(const int cc) { - return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA; +#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN) + return false; +#else + return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA || + GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3; +#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN) } // To be used for feature selection of external libraries, e.g. cuBLAS. static bool fp16_mma_hardware_available(const int cc) { - return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA; + return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA || + GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3; } // Volta technically had FP16 tensor cores but they work very differently compared to Turing and later. diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 7b9566fb4be32..46de14093545c 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -57,12 +57,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c; + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); GGML_UNUSED(Q_v); T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; @@ -70,7 +71,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0( const int shift = k_KQ & (QI8_1/2); const int v = (get_int_b2(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; - const int u = Q_q8[k_KQ_0/WARP_SIZE]; + const int u = Q_q8[k_KQ_0/warp_size]; const int sumi = ggml_cuda_dp4a(v, u, 0); @@ -78,14 +79,14 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0( if (std::is_same::value) { const half2 * Q_ds = (const half2 *) Q_ds_v; - const half2 sum2 = __half2half2(K_q4_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE]; + const half2 sum2 = __half2half2(K_q4_0[ib].d) * Q_ds[k_KQ_0/warp_size]; sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2) /* *8/QI8_1 == 1 */); } else #endif // FP16_AVAILABLE { const float2 * Q_ds = (const float2 *) Q_ds_v; - sum += (T) (__half2float(K_q4_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (8/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y)); + sum += (T) (__half2float(K_q4_0[ib].d) * (sumi*Q_ds[k_KQ_0/warp_size].x - (8/QI8_1)*Q_ds[k_KQ_0/warp_size].y)); } } @@ -97,12 +98,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c; + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); GGML_UNUSED(Q_v); T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; @@ -110,7 +112,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1( const int shift = k_KQ & (QI8_1/2); const int v = (get_int_b4(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; - const int u = Q_q8[k_KQ_0/WARP_SIZE]; + const int u = Q_q8[k_KQ_0/warp_size]; const int sumi = ggml_cuda_dp4a(v, u, 0); @@ -118,7 +120,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1( if (std::is_same::value) { const half2 * Q_ds = (const half2 *) Q_ds_v; - const half2 d4d8_m4s8 = K_q4_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE]; + const half2 d4d8_m4s8 = K_q4_1[ib].dm * Q_ds[k_KQ_0/warp_size]; const half2 sumid4d8_m4s8scaled = d4d8_m4s8 * make_half2(sumi, 1.0f/QI8_1); sum += (T) (__low2half(sumid4d8_m4s8scaled) + __high2half(sumid4d8_m4s8scaled)); } else @@ -126,8 +128,8 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1( { const float2 * Q_ds = (const float2 *) Q_ds_v; - const float sumid4d8 = __low2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].x * sumi; - const float m4s8scaled = __high2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].y / QI8_1; + const float sumid4d8 = __low2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/warp_size].x * sumi; + const float m4s8scaled = __high2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/warp_size].y / QI8_1; sum += (T) (sumid4d8 + m4s8scaled); } @@ -141,12 +143,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c; + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); GGML_UNUSED(Q_v); T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; @@ -161,7 +164,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0( v |= (vh << 18) & 0x00100000; // 2 -> 20 v |= (vh << 25) & 0x10000000; // 3 -> 28 - const int u = Q_q8[k_KQ_0/WARP_SIZE]; + const int u = Q_q8[k_KQ_0/warp_size]; const int sumi = ggml_cuda_dp4a(v, u, 0); @@ -169,14 +172,14 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0( if (std::is_same::value) { const half2 * Q_ds = (const half2 *) Q_ds_v; - const half2 sum2 = __half2half2(K_q5_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE]; + const half2 sum2 = __half2half2(K_q5_0[ib].d) * Q_ds[k_KQ_0/warp_size]; sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2)*__float2half(2.0f)) /* *16/QI8_1 == 2 */; } else #endif // FP16_AVAILABLE { const float2 * Q_ds = (const float2 *) Q_ds_v; - sum += (T) (__half2float(K_q5_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (16/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y)); + sum += (T) (__half2float(K_q5_0[ib].d) * (sumi*Q_ds[k_KQ_0/warp_size].x - (16/QI8_1)*Q_ds[k_KQ_0/warp_size].y)); } } @@ -188,12 +191,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c; + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); GGML_UNUSED(Q_v); T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; @@ -208,7 +212,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( v |= (vh << 18) & 0x00100000; // 2 -> 20 v |= (vh << 25) & 0x10000000; // 3 -> 28 - const int u = Q_q8[k_KQ_0/WARP_SIZE]; + const int u = Q_q8[k_KQ_0/warp_size]; const int sumi = ggml_cuda_dp4a(v, u, 0); @@ -216,7 +220,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( if (std::is_same::value) { const half2 * Q_ds = (const half2 *) Q_ds_v; - const half2 d5d8_m5s8 = K_q5_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE]; + const half2 d5d8_m5s8 = K_q5_1[ib].dm * Q_ds[k_KQ_0/warp_size]; const half2 sumid5d8_m5s8scaled = d5d8_m5s8 * make_half2(sumi, 1.0f/QI8_1); sum += (T) (__low2half(sumid5d8_m5s8scaled) + __high2half(sumid5d8_m5s8scaled)); } else @@ -224,8 +228,8 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( { const float2 * Q_ds = (const float2 *) Q_ds_v; - const float sumid5d8 = __low2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].x * sumi; - const float m5s8scaled = __high2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].y / QI8_1; + const float sumid5d8 = __low2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/warp_size].x * sumi; + const float m5s8scaled = __high2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/warp_size].y / QI8_1; sum += (T) (sumid5d8 + m5s8scaled); } @@ -239,12 +243,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c; + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); GGML_UNUSED(Q_v); T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_0; @@ -255,13 +260,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0( T Q_d; if (std::is_same::value) { const half2 * Q_ds = (const half2 *) Q_ds_v; - Q_d = __low2half(Q_ds[k_KQ_0/WARP_SIZE]); + Q_d = __low2half(Q_ds[k_KQ_0/warp_size]); } else { const float2 * Q_ds = (const float2 *) Q_ds_v; - Q_d = Q_ds[k_KQ_0/WARP_SIZE].x; + Q_d = Q_ds[k_KQ_0/warp_size].x; } - sum += vec_dot_q8_0_q8_1_impl(&v, &Q_q8[k_KQ_0/WARP_SIZE], K_q8_0[ib].d, Q_d); + sum += vec_dot_q8_0_q8_1_impl(&v, &Q_q8[k_KQ_0/warp_size], K_q8_0[ib].d, Q_d); } return sum; @@ -272,6 +277,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) { const half2 * K_h2 = (const half2 *) K_c; + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); GGML_UNUSED(Q_q8); GGML_UNUSED(Q_ds_v); @@ -282,11 +288,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16( half2 sum2 = make_half2(0.0f, 0.0f); #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += warp_size) { const int k_KQ = k_KQ_0 + threadIdx.x; const half2 K_ik = K_h2[k_KQ]; - sum2 += K_ik * Q_h2[k_KQ_0/WARP_SIZE]; + sum2 += K_ik * Q_h2[k_KQ_0/warp_size]; } return __low2half(sum2) + __high2half(sum2); @@ -298,12 +304,12 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16( float sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += warp_size) { const int k_KQ = k_KQ_0 + threadIdx.x; const half2 K_ik = K_h2[k_KQ]; - sum += __low2float(K_ik) * Q_f2[k_KQ_0/WARP_SIZE].x; - sum += __high2float(K_ik) * Q_f2[k_KQ_0/WARP_SIZE].y; + sum += __low2float(K_ik) * Q_f2[k_KQ_0/warp_size].x; + sum += __high2float(K_ik) * Q_f2[k_KQ_0/warp_size].y; } return sum; @@ -698,6 +704,8 @@ void launch_fattn( GGML_ASSERT(Q->ne[3] == 1); + const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size; + ggml_cuda_pool & pool = ctx.pool(); cudaStream_t main_stream = ctx.stream(); const int id = ggml_cuda_get_device(); @@ -750,7 +758,7 @@ void launch_fattn( const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1); const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3]; - const dim3 block_dim(WARP_SIZE, nwarps, 1); + const dim3 block_dim(warp_size, nwarps, 1); dim3 blocks_num; if (parallel_blocks == 0) { // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup. @@ -796,6 +804,8 @@ void launch_fattn( const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + GGML_ASSERT(block_dim.x % warp_size == 0); + GGML_ASSERT(!GGML_CUDA_CC_IS_AMD(cc) || block_dim.x * block_dim.y <= 4 * (unsigned int)warp_size); fattn_kernel<<>>( (const char *) Q->data, K_data, diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index 8828652fb5e7f..622cf28576d29 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -7,14 +7,19 @@ #include "fattn-wmma-f16.cuh" #ifdef FP16_MMA_AVAILABLE +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) #include +namespace wmma = nvcuda::wmma; +#elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE) +#undef HIP_ENABLE_WARP_SYNC_BUILTINS // conflicts with rocWMMA headers +#include +namespace wmma = rocwmma; +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) #endif // FP16_MMA_AVAILABLE // D == head size, VKQ_stride == num VKQ rows calculated in parallel: template -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) -__launch_bounds__(nwarps*WARP_SIZE, 1) -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +__launch_bounds__(nwarps*ggml_cuda_get_physical_warp_size(), 1) static __global__ void flash_attn_ext_f16( const char * __restrict__ Q, const char * __restrict__ K, @@ -51,7 +56,7 @@ static __global__ void flash_attn_ext_f16( const int ne1, const int ne2, const int ne3) { -#if defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE))) // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { NO_DEVICE_CODE; @@ -60,6 +65,8 @@ static __global__ void flash_attn_ext_f16( //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on. const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. @@ -68,11 +75,11 @@ static __global__ void flash_attn_ext_f16( constexpr int frag_m = ncols == 8 ? 32 : 16; constexpr int frag_n = ncols == 8 ? 8 : 16; static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0."); - typedef nvcuda::wmma::fragment frag_a_K; - typedef nvcuda::wmma::fragment frag_a_V; - typedef nvcuda::wmma::fragment frag_b; - typedef nvcuda::wmma::fragment frag_c_KQ; - typedef nvcuda::wmma::fragment frag_c_VKQ; + typedef wmma::fragment frag_a_K; + typedef wmma::fragment frag_a_V; + typedef wmma::fragment frag_b; + typedef wmma::fragment frag_c_KQ; + typedef wmma::fragment frag_c_VKQ; constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel. constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy. @@ -132,9 +139,9 @@ static __global__ void flash_attn_ext_f16( for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < D/2; i0 += warp_size) { const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D/2 && i >= D/2) { + if (i0 + warp_size > D/2 && i >= D/2) { break; } VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f); @@ -146,9 +153,9 @@ static __global__ void flash_attn_ext_f16( for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < D; i0 += warp_size) { const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D && i >= D) { + if (i0 + warp_size > D && i >= D) { break; } KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; @@ -162,7 +169,7 @@ static __global__ void flash_attn_ext_f16( for (int i0 = 0; i0 < D; i0 += 16) { #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { - nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded); + wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded); } } @@ -176,20 +183,20 @@ static __global__ void flash_attn_ext_f16( frag_c_KQ KQ_c[ncols/frag_n]; #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { - nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f); + wmma::fill_fragment(KQ_c[j], static_cast(0.0f)); } #pragma unroll for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { frag_a_K K_a; - nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); + wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { - nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); + wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); } } #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { - nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major); + wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, wmma::mem_col_major); } } @@ -202,27 +209,27 @@ static __global__ void flash_attn_ext_f16( const int j = j0 + threadIdx.y; if (std::is_same::value) { - float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE]; + float KQ_f_tmp[FATTN_KQ_STRIDE / warp_size]; #pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) { const int k = k0 + threadIdx.x; - KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k]; + KQ_f_tmp[k0/warp_size] = KQ_f[j*kqs_padded + k]; if (use_logit_softcap) { - KQ_f_tmp[k0/WARP_SIZE] = logit_softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]); + KQ_f_tmp[k0/warp_size] = logit_softcap*tanhf(KQ_f_tmp[k0/warp_size]); } } float KQ_max_new = KQ_max_f[j0/nwarps]; #pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) { const int k = k0 + threadIdx.x; - KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; - KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]); + KQ_f_tmp[k0/warp_size] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; + KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/warp_size]); } - KQ_max_new = warp_reduce_max(KQ_max_new); + KQ_max_new = warp_reduce_max(KQ_max_new); const float diff = KQ_max_f[j0/nwarps] - KQ_max_new; KQ_max_scale_f[j0/nwarps] = expf(diff); @@ -233,48 +240,48 @@ static __global__ void flash_attn_ext_f16( float KQ_rowsum_add = 0.0f; #pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) { const int k = k0 + threadIdx.x; - const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps]; - KQ_f_tmp[k0/WARP_SIZE] = expf(diff); + const float diff = KQ_f_tmp[k0/warp_size] - KQ_max_f[j0/nwarps]; + KQ_f_tmp[k0/warp_size] = expf(diff); if (diff <= SOFTMAX_FTZ_THRESHOLD) { - KQ_f_tmp[k0/WARP_SIZE] = 0.0f; + KQ_f_tmp[k0/warp_size] = 0.0f; } - KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE]; - KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE]; + KQ_rowsum_add += KQ_f_tmp[k0/warp_size]; + KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/warp_size]; } - KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); + KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); // Scale previous KQ_rowsum to account for a potential increase in KQ_max: KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add; } else { - half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)]; + half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*warp_size)]; #pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) { const int k = k0 + threadIdx.x; - KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k]; + KQ2_tmp[k0/warp_size] = KQ2[j*(kqs_padded/2) + k]; if (use_logit_softcap) { // There is no dedicated tangens hyperbolicus function for half2. - KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f)); - KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f)) - /(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f)); + KQ2_tmp[k0/warp_size] = h2exp(KQ2_tmp[k0/warp_size]*make_half2(2.0f, 2.0f)); + KQ2_tmp[k0/warp_size] = (KQ2_tmp[k0/warp_size] - make_half2(1.0f, 1.0f)) + /(KQ2_tmp[k0/warp_size] + make_half2(1.0f, 1.0f)); - KQ2_tmp[k0/WARP_SIZE] *= logit_softcap_2; + KQ2_tmp[k0/warp_size] *= logit_softcap_2; } } half2 KQ_max_new = KQ_max_h2[j0/nwarps]; #pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) { const int k = k0 + threadIdx.x; - KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); - KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]); + KQ2_tmp[k0/warp_size] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); + KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/warp_size]); } - KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); + KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new; KQ_max_scale_h2[j0/nwarps] = h2exp(diff); const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); @@ -283,17 +290,17 @@ static __global__ void flash_attn_ext_f16( half2 KQ_rowsum_add = make_half2(0.0f, 0.0f); #pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) { const int k = k0 + threadIdx.x; - const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps]; - KQ2_tmp[k0/WARP_SIZE] = h2exp(diff); + const half2 diff = KQ2_tmp[k0/warp_size] - KQ_max_h2[j0/nwarps]; + KQ2_tmp[k0/warp_size] = h2exp(diff); const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); - *((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask; - KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE]; - KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE]; + *((uint32_t *) &KQ2_tmp[k0/warp_size]) &= ftz_mask; + KQ_rowsum_add += KQ2_tmp[k0/warp_size]; + KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/warp_size]; } - KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); + KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); // Scale previous KQ_rowsum to account for a potential increase in KQ_max: KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add; @@ -308,7 +315,7 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { const int k = k0 + (threadIdx.y % VKQ_ratio)*16; - nvcuda::wmma::load_matrix_sync( + wmma::load_matrix_sync( KQ_b[k0/(VKQ_ratio*16)][j0/frag_n], KQ + j0*(kqar*kqs_padded) + k, kqar*kqs_padded); @@ -320,7 +327,7 @@ static __global__ void flash_attn_ext_f16( for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) { #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { - nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f); + wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], static_cast(0.0f)); } #pragma unroll @@ -328,10 +335,10 @@ static __global__ void flash_attn_ext_f16( const int k = k0 + (threadIdx.y % VKQ_ratio)*16; frag_a_V v_a; - nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); + wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { - nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); + wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); } } } @@ -343,10 +350,10 @@ static __global__ void flash_attn_ext_f16( for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) { #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { - nvcuda::wmma::store_matrix_sync( + wmma::store_matrix_sync( KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n], - D_padded, nvcuda::wmma::mem_col_major); + D_padded, wmma::mem_col_major); } } @@ -364,9 +371,9 @@ static __global__ void flash_attn_ext_f16( } #pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < D/2; i0 += warp_size) { const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D/2 && i >= D/2) { + if (i0 + warp_size > D/2 && i >= D/2) { break; } @@ -398,9 +405,9 @@ static __global__ void flash_attn_ext_f16( } #pragma unroll - for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < D; i0 += warp_size) { const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D && i >= D) { + if (i0 + warp_size > D && i >= D) { break; } float dst_val = VKQ[j_VKQ*D_padded + i]; @@ -425,7 +432,7 @@ static __global__ void flash_attn_ext_f16( } #else NO_DEVICE_CODE; -#endif // defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE))) } constexpr int get_max_power_of_2(int x) { @@ -515,6 +522,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten const ggml_tensor * Q = dst->src[0]; const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV); + const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size; if (prec != GGML_PREC_DEFAULT) { if (Q->ne[1] <= 32 || Q->ne[0] > 128) { @@ -571,7 +579,8 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten return; } - if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) { +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) + if (Q->ne[1] <= 8 && Q->ne[0] % warp_size == 0) { constexpr int cols_per_block = 8; switch (Q->ne[0]) { case 64: @@ -592,6 +601,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten } return; } +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) if (Q->ne[1] <= 32) { constexpr int cols_per_block = 16; diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index b1becccb4de72..24f973056aa9a 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -250,10 +250,18 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst ggml_cuda_set_device(ctx.device); const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size; const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV); - // On AMD the tile kernels perform poorly, use the vec kernel instead: if (cc >= GGML_CUDA_CC_OFFSET_AMD) { +#if defined(GGML_HIP_ROCWMMA_FATTN) + if (fp16_mma_available(cc)) { + ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); + return; + } +#endif // defined(GGML_HIP_ROCWMMA_FATTN) + + // On AMD the tile kernels perform poorly, use the vec kernel instead: if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) { ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); } else { @@ -291,7 +299,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst const int gqa_ratio = Q->ne[2] / K->ne[2]; const bool mma_fast_for_bs1 = fp16_mma_available(cc) && gqa_ratio % 2 == 0 && K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && mask; - if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0 && !mma_fast_for_bs1) { + if (Q->ne[1] == 1 && Q->ne[0] % (2*warp_size) == 0 && !mma_fast_for_bs1) { if (prec == GGML_PREC_DEFAULT) { ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); return; diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index 4a0384dd47654..e3762649fd275 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -39,6 +39,12 @@ endif() find_package(hip REQUIRED) find_package(hipblas REQUIRED) find_package(rocblas REQUIRED) +if (GGML_HIP_ROCWMMA_FATTN) + CHECK_INCLUDE_FILE_CXX("rocwmma/rocwmma.hpp" FOUND_ROCWMMA) + if (NOT ${FOUND_ROCWMMA}) + message(FATAL_ERROR "rocwmma has not been found") + endif() +endif() if (${hip_VERSION} VERSION_LESS 5.5) message(FATAL_ERROR "At least ROCM/HIP V5.5 is required") @@ -107,6 +113,10 @@ if (GGML_HIP_NO_VMM) add_compile_definitions(GGML_HIP_NO_VMM) endif() +if (GGML_HIP_ROCWMMA_FATTN) + add_compile_definitions(GGML_HIP_ROCWMMA_FATTN) +endif() + if (NOT GGML_CUDA_FA) add_compile_definitions(GGML_CUDA_NO_FA) endif()