diff --git a/CMakeLists.txt b/CMakeLists.txt index 4f4b20d3515b..59e16a39f6fc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -677,6 +677,7 @@ if(VLLM_GPU_LANG STREQUAL "HIP") # _rocm_C extension # set(VLLM_ROCM_EXT_SRC + "csrc/rocm/skinny_gemms.cu" "csrc/rocm/torch_bindings.cpp" "csrc/rocm/attention.cu") diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h index afb735450e0c..179d711fdc8a 100644 --- a/csrc/rocm/ops.h +++ b/csrc/rocm/ops.h @@ -13,3 +13,9 @@ void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums, const std::optional& alibi_slopes, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale); + +void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, + at::Tensor& scale_a, at::Tensor& scale_b, const int64_t CuCount); + +torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b, + const int64_t CuCount); \ No newline at end of file diff --git a/csrc/rocm/skinny_gemms.cu b/csrc/rocm/skinny_gemms.cu new file mode 100644 index 000000000000..a5dcf3bf0094 --- /dev/null +++ b/csrc/rocm/skinny_gemms.cu @@ -0,0 +1,1714 @@ +#include +#include +#include + +#include +#include +#include + +#include +#include + +#include "cuda_compat.h" +#include "dispatch_utils.h" +#include "quantization/fp8/common.cuh" + +#if defined(__HIPCC__) && \ + (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) + #define __HIP__GFX9__ +#endif + +#if defined(__HIPCC__) && (defined(__gfx942__) || defined(__gfx950__)) + #define __HIP__MI3XX__ +#endif + +#if defined(__gfx950__) + #define LDS_SIZE 160 * 1024 +#else + #define LDS_SIZE 64 * 1024 +#endif + +int get_lds_size() { + static bool is_cached = false; + static int result; + if (is_cached == false) { + auto dprops = at::cuda::getCurrentDeviceProperties(); + std::string device_arch = dprops->gcnArchName; + size_t substring = device_arch.find("gfx95"); + result = (substring == std::string::npos ? 64 * 1024 : 160 * 1024); + is_cached = true; + } + return result; +} + +#if defined(NDEBUG) + #undef NDEBUG + #include + #define UNREACHABLE_CODE assert(false); + #define NDEBUG +#else + #define UNREACHABLE_CODE assert(false); +#endif + +template +struct scalar {}; + +template +struct scalar2 {}; + +template +__device__ __forceinline__ float2 __s22float2(T v); + +template +__device__ __forceinline__ T __float2s(float v); + +template +__device__ __forceinline__ T __float22s2_rn(float2 v); + +// Definitions and cvt functions for fp16 +template <> +struct scalar { + using type = half; +}; + +template <> +struct scalar2 { + using type = __half2; +}; + +template <> +__device__ __forceinline__ half __float2s(float v) { + return __float2half(v); +} + +template <> +__device__ __forceinline__ float2 __s22float2(__half2 v) { + return __half22float2(v); +} + +template <> +__device__ __forceinline__ __half2 __float22s2_rn(float2 v) { + return __float22half2_rn(v); +} + +// Definitions and cvt functions for bf16 +template <> +struct scalar { + using type = __hip_bfloat16; +}; + +template <> +struct scalar2 { + using type = __hip_bfloat162; +}; + +template <> +__device__ __forceinline__ __hip_bfloat16 __float2s(float v) { + return __float2bfloat16(v); +} + +template <> +__device__ __forceinline__ float2 __s22float2(__hip_bfloat162 v) { + return __bfloat1622float2(v); +} + +template <> +__device__ __forceinline__ __hip_bfloat162 __float22s2_rn(float2 v) { + return __float22bfloat162_rn(v); +} + +template +__device__ __forceinline__ T loadnt(T* addr) { + return __builtin_nontemporal_load(addr); +} + +__device__ __forceinline__ float4 load_ntmprl(const float4* addr) { + auto addr_alias = reinterpret_cast(addr); + auto dat0 = loadnt(addr_alias); + auto dat1 = loadnt(addr_alias + 1); + auto dat2 = loadnt(addr_alias + 2); + auto dat3 = loadnt(addr_alias + 3); + return make_float4(dat0, dat1, dat2, dat3); +} + +// TBlock fetches entire rows of A, and entire col of B (K dimension); assume +// N=1 for time being grid is M/A_NUM_ROWS blocks +template +__global__ void LLGemm1_kernel(const scalar_t* in_a, const scalar_t* in_b, + scalar_t* out_c, const int K) { + using scalar2_t = typename scalar2::type; + auto af4 = reinterpret_cast(in_a); + auto bf4 = reinterpret_cast(in_b); + auto c = reinterpret_cast(out_c); + __shared__ float red_smem[NUM_A_ROWS_PER_BLOCK][WARP_SIZE]; + const int row_addr = blockIdx.x * NUM_A_ROWS_PER_BLOCK * K / 8; + const int threadid = threadIdx.x; + const int warp = threadIdx.x / WARP_SIZE; + const int lane = threadIdx.x % WARP_SIZE; + const int num_warps = blockDim.x / WARP_SIZE; + const int qwarpid = threadid / 16; + const int qthreadid = threadid % 16; + float4 rowA_elem4[NUM_A_ROWS_PER_BLOCK]; + scalar2_t colB_elem4x, colB_elem4y, colB_elem4z, colB_elem4w; + float acc[NUM_A_ROWS_PER_BLOCK]; + scalar2_t acch2; + scalar2_t oval; + + // As we later use warp shuffle operations, we may have more threads in the + // block than the actual available data, hence the if guard here. + if (threadid * 8 < K) { +#pragma unroll + for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) { + // rowA_elem4[i] holds 8 * half numbers seen as a single float4. + rowA_elem4[i] = load_ntmprl(&af4[row_addr + threadid + K / 8 * i]); + } + colB_elem4x = bf4[threadid * 4 + 0]; + colB_elem4y = bf4[threadid * 4 + 1]; + colB_elem4z = bf4[threadid * 4 + 2]; + colB_elem4w = bf4[threadid * 4 + 3]; + } + + scalar2_t Af2; + float2 S; + + auto Ah2ptr = reinterpret_cast(&rowA_elem4); + scalar2_t* ah2lptr; + +#pragma unroll + for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) { + // Multiply-add on 8 scalar_t. + ah2lptr = Ah2ptr + i * 4; + Af2 = *(ah2lptr); + acch2 = __hmul2(Af2, colB_elem4x); + Af2 = *(ah2lptr + 1); + acch2 = __hfma2(Af2, colB_elem4y, acch2); + Af2 = *(ah2lptr + 2); + acch2 = __hfma2(Af2, colB_elem4z, acch2); + Af2 = *(ah2lptr + 3); + acch2 = __hfma2(Af2, colB_elem4w, acch2); + S = __s22float2(acch2); + + // See comment above concerning the if guard. + acc[i] = (threadid * 8 < K ? S.x + S.y : 0.f); + } + +// all reduce across warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { +#pragma unroll + for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) { + acc[i] += __shfl_xor(acc[i], mask); + } + } + + // Warp leaders store the data to shared memory. + if (lane < NUM_A_ROWS_PER_BLOCK) { + red_smem[lane][warp] = acc[lane]; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + if (qwarpid < NUM_A_ROWS_PER_BLOCK) { + acc[qwarpid] = qthreadid < num_warps ? red_smem[qwarpid][qthreadid] : 0.f; +#pragma unroll + for (int mask = 16 / 2; mask >= 1; mask /= 2) { + acc[qwarpid] += __shfl_xor(acc[qwarpid], mask); + } + float oval2 = __shfl_xor(acc[qwarpid], 16); + + if (lane % 32 == 0) { + oval = __float22s2_rn(make_float2(acc[qwarpid], oval2)); + c[blockIdx.x * NUM_A_ROWS_PER_BLOCK / 2 + qwarpid / 2] = oval; + } + } +} + +torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b, + const int64_t rows_per_block) { + auto M = in_a.size(0); + auto K = in_a.size(1); + auto N = in_b.size(0); + + TORCH_CHECK(N == 1, "Row number of activation tensor must be 1."); + TORCH_CHECK(in_a.dtype() == in_b.dtype()); + TORCH_CHECK(in_b.dtype() == torch::kFloat16 || + in_b.dtype() == torch::kBFloat16); + + auto out_c = torch::empty( + {N, M}, torch::TensorOptions().dtype(in_b.dtype()).device(in_b.device())); + + // NUM_TREADS need to be a multiple of WARP_SIZE, as we are using warp shuffle + // operations. + const int NUM_THREADS = + max(rows_per_block * 16, + K * 2 / 16 % WARP_SIZE == 0 + ? K * 2 / 16 + : K * 2 / 16 + (WARP_SIZE - K * 2 / 16 % WARP_SIZE)); + + int NUM_BLOCKS = M / rows_per_block; + + const at::cuda::OptionalCUDAGuard device_guard(device_of(in_b)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // call the kernel function... + AT_DISPATCH_REDUCED_FLOATING_TYPES(in_b.scalar_type(), "LLGemm1", [&] { + auto a_ptr = in_a.data_ptr(); + auto b_ptr = in_b.data_ptr(); + auto c_ptr = out_c.data_ptr(); + if (rows_per_block == 2) { + LLGemm1_kernel + <<>>(a_ptr, b_ptr, c_ptr, K); + } else if (rows_per_block == 4) { + LLGemm1_kernel + <<>>(a_ptr, b_ptr, c_ptr, K); + } else if (rows_per_block == 8) { + LLGemm1_kernel + <<>>(a_ptr, b_ptr, c_ptr, K); + } else if (rows_per_block == 16) { + LLGemm1_kernel + <<>>(a_ptr, b_ptr, c_ptr, K); + } else { + NUM_BLOCKS = M / 4; + LLGemm1_kernel + <<>>(a_ptr, b_ptr, c_ptr, K); + } + }); + + return out_c; +} + +#define DOT2C(V0, V2, V3) \ + if constexpr (std::is_same_v) { \ + asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(V0) : "0"(V0), "v"(V2), "v"(V3)); \ + } else if constexpr (std::is_same_v) { \ + float2 s = __bfloat1622float2(*((__hip_bfloat162*)(&(V2)))) * \ + __bfloat1622float2(*((__hip_bfloat162*)(&(V3)))); \ + V0 += (s.x + s.y); \ + } + +#if defined(__HIP__GFX9__) // TODO: Add NAVI support +// This version targets cases where A[] fits LDS capacity +template +__global__ void __launch_bounds__(WvPrGrp* THRDS) + wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B, + const scalar_t* __restrict__ A, scalar_t* C, + const int _WvPrGrp, const int CuCount) { + constexpr int max_lds_len = LDS_SIZE / 2; + #if defined(__HIP__MI3XX__) + constexpr bool use_mfma = (std::is_same_v); + #else + constexpr bool use_mfma = false; + #endif + + using scalar8 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; + using half4 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(__bf16)))) __bf16; + union bigType { + scalar_t h[A_CHUNK]; + float f[A_CHUNK / 2]; + float2 f2[A_CHUNK / 4]; + double d[A_CHUNK / 4]; + half4 h4[A_CHUNK / 4]; + scalar8 h8; + }; + + //---------------------------------------------------- + // Reserving 64/160 KB of LDS to have 1 WG / CU + // Goal is to bring the activation matrix A to the LDS + // and use it across the lifetime of the work group + // TODO: When activation matrix is larger than 64 KB + // then this is not going to work! + //---------------------------------------------------- + __shared__ scalar_t s[max_lds_len]; + + //---------------------------------------------------- + // Fetch the activation matrix to LDS + // Loop iteration: + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements + // - Each WG will fetch 512 * 16 => 8K elements + // - Then the WG will move to another 8 K elements + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k = 0; k < min(K * N, max_lds_len); + k += THRDS * WvPrGrp * A_CHUNK) { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + + if (k_in >= min(K * N, max_lds_len)) break; + + *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); + } + __syncthreads(); + + if (threadIdx.y >= _WvPrGrp) return; + + uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE; + + float sum[N][YTILE]; + scalar8 sum4[N][YTILE]; + + //---------------------------------------------------- + // Each wave works on a single column of weight matrix. + // There are 16 waves per WG, and hence, each WG is + // working on 16 columns of weight matrix. Moreover, + // we tile in column direction by YTILE, so when YTILE=1 + // the above math is right, however, when YTILE=2 then + // each wave will be working on 2 columns and WG will + // be working on 32 columns. + // + // Top level loop that makes WGs persistent! + // - WGs iterates across columns of weight matrix + // - Each wave within WG works on a given column(s) + // - After completing first set of columns, WGs start + // working on the next set of available columns + //---------------------------------------------------- + while (m < M) { + //---------------------------------------------------- + // 'sum' accumulates the matrix A x B computation + // split across 64 lanes. + // + // YTILE represents how many column of weight matrix + // are being worked on by each wave. + //---------------------------------------------------- + for (int i = 0; i < YTILE; i++) + for (int n = 0; n < N; n++) + if constexpr (!use_mfma) + sum[n][i] = 0; + else + sum4[n][i] = {0, 0, 0, 0}; + + bigType bigA[N][UNRL]; + bigType bigB[YTILE][UNRL]; + //---------------------------------------------------- + // Fetch weight matrix B in interleaved K-split! + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements (1024B) + // - YTILE represents the number of column being serviced + // by wave + // - Loop for fetching weight matrix (B) are unrolled + // + // Fetch activation matrix A from LDS + // - Loop for fetching activation matrix (A) are unrolled + // + // Finally, do the matrix multiplication in an unrolled + // fashion. This provides lot of food for compiler + // scheduling. + // + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + // for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + // Fetch the weight matrix from memory! + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + const scalar_t* B_ = &B[(m + 0) * K + k_]; + for (int y = 0; y < YTILE; y++) + bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[y * K]))); + } + + // Fetch activation matrix from either just LDS or from both LDS / memory + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Fetch A activation matrix in interleaved fashion from LDS or memory + + for (int n = 0; n < N; n++) { + bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n]))); + } + } + + // Do the matrix multiplication in interleaved manner + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + // Do the matrix multiplication of activation and weight matrix + // - Remember the accumulation is happening for K-split of 64! + #pragma unroll + for (uint32_t n = 0; n < N; n++) { + #pragma unroll + for (int y = 0; y < YTILE; y++) { + if constexpr (!use_mfma) + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b]) + } + else + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 4; b++) + sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k( + bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0); + } + } + } + } + + //---------------------------------------------------- + // Final reduction step using shuffle + //---------------------------------------------------- + if constexpr (!use_mfma) { + for (int n = 0; n < N; n++) { + for (int y = 0; y < YTILE; y++) { + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + } + } + + if (threadIdx.x == 63) { + for (int n = 0; n < N; n++) { + for (int i = 0; i < YTILE; i++) { + // if (commitColumn[i]) C[m + i + n * M] = __float2half(sum[n][i]); + C[m + i + n * M] = __float2s(sum[n][i]); + } + } + } + } else { + #pragma unroll + for (int n = 0; n < N; n++) { + #pragma unroll + for (int y = 0; y < YTILE; y++) { + // float accm1 = 0; + // for (int i=0; i<64; i++) + // accm1 += __shfl(sum4[n][y][i%4], i); + float accm = sum4[n][y][0]; + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[n][y][1]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[n][y][2]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[n][y][3]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + + sum4[n][y][0] = accm; + } + } + if (threadIdx.x == 63) { + for (int n = 0; n < N; n++) { + for (int i = 0; i < YTILE; i++) { + // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]); + C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]); + } + } + } + } + m += CuCount * _WvPrGrp * YTILE; + } +} +#else // !defined(__HIP__GFX9__) TODO: Add NAVI support +template +__global__ void wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B, + const scalar_t* __restrict__ A, scalar_t* C, + const int _WvPrGrp, const int CuCount) { + UNREACHABLE_CODE +} +#endif // defined(__HIP__GFX9__) TODO: Add NAVI support + +#if defined(__HIP__GFX9__) // TODO: Add NAVI support +// This version targets cases where A[] marginally exceeds LDS capacity +template +__global__ void __launch_bounds__(WvPrGrp* THRDS) + wvSplitK_hf_(const int K, const int M, const scalar_t* B, + const scalar_t* __restrict__ A, scalar_t* C, + const int _WvPrGrp, const int CuCount) { + constexpr int max_lds_len = LDS_SIZE / 2; + #if defined(__HIP__MI3XX__) + constexpr bool use_mfma = (std::is_same_v); + #else + constexpr bool use_mfma = false; + #endif + + using scalar8 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; + using half4 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(__bf16)))) __bf16; + union bigType { + scalar_t h[A_CHUNK]; + float f[A_CHUNK / 2]; + float2 f2[A_CHUNK / 4]; + double d[A_CHUNK / 4]; + half4 h4[A_CHUNK / 4]; + scalar8 h8; + }; + + //---------------------------------------------------- + // Reserving 64 KB of LDS to have 1 WG / CU + // Goal is to bring the activation matrix A to the LDS + // and use it across the lifetime of the work group + // TODO: When activation matrix is larger than 64 KB + // then this is not going to work! + //---------------------------------------------------- + __shared__ scalar_t s[max_lds_len]; + + //---------------------------------------------------- + // Computation of columns that need to be committed to memory! + //---------------------------------------------------- + uint32_t commitColumn[YTILE]; + for (uint32_t i = 0; i < YTILE; i++) { + commitColumn[i] = 1; + } + + //---------------------------------------------------- + // Indexing function into the column of weight matrix B + // Algorithm does 64 lane k-splitting / wave and uses + // WG ID and Thread ID to find the index. + //---------------------------------------------------- + // int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); + uint32_t m = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE; + + // Check whether there will be fragmentation! + // This will happen only for the last wave! + if (m < M && (m + YTILE) >= M) { + uint32_t startColumn = M - YTILE; + for (uint32_t i = 0; i < (m - startColumn); i++) { + commitColumn[i] = 0; + } + m = startColumn; + } + + //---------------------------------------------------- + // Fetch the activation matrix to LDS + // Loop iteration: + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements + // - Each WG will fetch 512 * 16 => 8K elements + // - Then the WG will move to another 8 K elements + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k = 0; k < min(K * N, max_lds_len); + k += THRDS * WvPrGrp * A_CHUNK) { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + + if (k_in >= min(K * N, max_lds_len)) break; + + *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); + } + + __syncthreads(); + + if (threadIdx.y >= _WvPrGrp) return; + + float sum[N][YTILE]; + scalar8 sum4[N][YTILE]; + + //---------------------------------------------------- + // Each wave works on a single column of weight matrix. + // There are 16 waves per WG, and hence, each WG is + // working on 16 columns of weight matrix. Moreover, + // we tile in column direction by YTILE, so when YTILE=1 + // the above math is right, however, when YTILE=2 then + // each wave will be working on 2 columns and WG will + // be working on 32 columns. + // + // Top level loop that makes WGs persistent! + // - WGs iterates across columns of weight matrix + // - Each wave within WG works on a given column(s) + // - After completing first set of columns, WGs start + // working on the next set of available columns + //---------------------------------------------------- + while (m < M) { + //---------------------------------------------------- + // 'sum' accumulates the matrix A x B computation + // split across 64 lanes. + // + // YTILE represents how many column of weight matrix + // are being worked on by each wave. + //---------------------------------------------------- + for (int i = 0; i < YTILE; i++) + for (int n = 0; n < N; n++) + if constexpr (!use_mfma) + sum[n][i] = 0; + else + sum4[n][i] = {0, 0, 0, 0}; + + bigType bigA[N][UNRL]; + bigType bigB[YTILE][UNRL]; + //---------------------------------------------------- + // Fetch weight matrix B in interleaved K-split! + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements (1024B) + // - YTILE represents the number of column being serviced + // by wave + // - Loop for fetching weight matrix (B) are unrolled + // + // Fetch activation matrix A from LDS + // - Loop for fetching activation matrix (A) are unrolled + // + // Finally, do the matrix multiplication in an unrolled + // fashion. This provides lot of food for compiler + // scheduling. + // + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + // Fetch the weight matrix from memory! + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + const scalar_t* B_ = &B[(m + 0) * K + k_]; + for (int b = 0; b < YTILE; b++) + bigB[b][k2].h8 = (loadnt((scalar8*)(&B_[b * K]))); + } + + // Fetch activation matrix from either just LDS or from both LDS / memory + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Fetch A activation matrix in interleaved fashion from LDS or memory + + for (int n = 0; n < N; n++) { + if (k_ + K * n < max_lds_len) + bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n]))); + else + bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n]))); + } + } + + // Do the matrix multiplication in interleaved manner + #pragma unroll + for (uint32_t n = 0; n < N; n++) { + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + // Do the matrix multiplication of activation and weight matrix + // - Remember the accumulation is happening for K-split of 64! + #pragma unroll + for (int y = 0; y < YTILE; y++) { + if constexpr (!use_mfma) + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b]) + } + else + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 4; b++) + sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k( + bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0); + } + } + } + } + + //---------------------------------------------------- + // Final reduction step using shuffle + //---------------------------------------------------- + if constexpr (!use_mfma) { + for (int n = 0; n < N; n++) { + for (int y = 0; y < YTILE; y++) { + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + } + } + + if (threadIdx.x == 63) { + for (int n = 0; n < N; n++) { + for (int i = 0; i < YTILE; i++) { + if (commitColumn[i]) + C[m + i + n * M] = __float2s(sum[n][i]); + } + } + } + } else { + #pragma unroll + for (int n = 0; n < N; n++) { + #pragma unroll + for (int y = 0; y < YTILE; y++) { + // float accm1 = 0; + // for (int i=0; i<64; i++) + // accm1 += __shfl(sum4[n][y][i%4], i); + + float accm = sum4[n][y][0]; + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[n][y][1]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[n][y][2]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[n][y][3]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + + sum4[n][y][0] = accm; + } + } + if (threadIdx.x == 63) { + for (int n = 0; n < N; n++) { + for (int i = 0; i < YTILE; i++) { + // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]); + C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]); + } + } + } + } + + m += CuCount * _WvPrGrp * YTILE; + + // Check whether there will be fragmentation! + // This will happen only for the last wave! + if (m < M && (m + YTILE) >= M) { + uint32_t startColumn = M - YTILE; + for (uint32_t i = 0; i < (m - startColumn); i++) { + commitColumn[i] = 0; + } + m = startColumn; + } + } +} + +#else // !defined(__HIP__GFX9__) TODO: Add NAVI support +template +__global__ void wvSplitK_hf_(const int K, const int M, const scalar_t* B, + const scalar_t* __restrict__ A, scalar_t* C, + const int _WvPrGrp, const int CuCount) { + UNREACHABLE_CODE +} +#endif // defined(__HIP__GFX9__) TODO: Add NAVI support + +#if defined(__HIP__GFX9__) // TODO: Add NAVI support +// This version targets big A[] cases, where it is much larger than LDS capacity +template +__global__ void __launch_bounds__(WvPrGrp* THRDS) + wvSplitK_hf_big_(const int K, const int M, const scalar_t* B, + const scalar_t* __restrict__ A, scalar_t* C, + const int _WvPrGrp, const int CuCount) { + constexpr int max_lds_len = LDS_SIZE / 2; + #if defined(__HIP__MI3XX__) + constexpr bool use_mfma = (std::is_same_v); + #else + constexpr bool use_mfma = false; + #endif + + using scalar8 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; + using half4 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(__bf16)))) __bf16; + union bigType { + scalar_t h[A_CHUNK]; + float f[A_CHUNK / 2]; + float2 f2[A_CHUNK / 4]; + double d[A_CHUNK / 4]; + half4 h4[A_CHUNK / 4]; + scalar8 h8; + }; + + //---------------------------------------------------- + // Reserving 64/160 KB of LDS to have 1 WG / CU + // Goal is to bring the activation matrix A to the LDS + // and use it across the lifetime of the work group + // TODO: When activation matrix is larger than 64 KB + // then this is not going to work! + //---------------------------------------------------- + __shared__ scalar_t s[max_lds_len]; + + //---------------------------------------------------- + // Computation of columns that need to be committed to memory! + //---------------------------------------------------- + uint32_t commitColumn[YTILE]; + for (uint32_t i = 0; i < YTILE; i++) { + commitColumn[i] = 1; + } + + // int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); + if (threadIdx.y >= _WvPrGrp) return; + + //---------------------------------------------------- + // Indexing function into the column of weight matrix B + // Algorithm does 64 lane k-splitting / wave and uses + // WG ID and Thread ID to find the index. + //---------------------------------------------------- + uint32_t m = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE; + + // Check whether there will be fragmentation! + // This will happen only for the last wave! + if (m < M && (m + YTILE) >= M) { + uint32_t startColumn = M - YTILE; + for (uint32_t i = 0; i < (m - startColumn); i++) { + commitColumn[i] = 0; + } + m = startColumn; + } + + //---------------------------------------------------- + // Fetch the activation matrix to LDS + // Loop iteration: + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements + // - Each WG will fetch 512 * 16 => 8K elements + // - Then the WG will move to another 8 K elements + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + #define PCML + #ifndef PCML + for (uint32_t k = 0; k < min(K * N, max_lds_len); + k += THRDS * WvPrGrp * A_CHUNK) { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + + if (k_in >= min(K * N, max_lds_len)) break; + + *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); + } + __syncthreads(); + #endif + + #define TUC (THRDS * UNRL * A_CHUNK) + uint32_t kBase = 0; + // find biggest k size that fits in LDS + uint32_t kFit = (max_lds_len) / N; + // kFit = (kFit%TWC==0) ? kFit : (kFit-kFit%TWC+TWC); //round up to multiple + // of TUC + kFit = (kFit % TUC == 0) + ? kFit + : (kFit - kFit % TUC); // round up to multiple of TUC + // if (kFit == 0) kFit = TUC; + kFit = min(kFit, K); + + float sum[N][YTILE]; + scalar8 sum4[N][YTILE]; + + //---------------------------------------------------- + // Each wave works on a single column of weight matrix. + // There are 16 waves per WG, and hence, each WG is + // working on 16 columns of weight matrix. Moreover, + // we tile in column direction by YTILE, so when YTILE=1 + // the above math is right, however, when YTILE=2 then + // each wave will be working on 2 columns and WG will + // be working on 32 columns. + // + // Top level loop that makes WGs persistent! + // - WGs iterates across columns of weight matrix + // - Each wave within WG works on a given column(s) + // - After completing first set of columns, WGs start + // working on the next set of available columns + //---------------------------------------------------- + #ifdef PCML + int YW = (YTILE * _WvPrGrp); + uint32_t Mrndp = (M % YW == 0) ? M : (M - M % YW + YW); + while (m < Mrndp) { + #else + while (m < M) { + #endif + //---------------------------------------------------- + // 'sum' accumulates the matrix A x B computation + // split across 64 lanes. + // + // YTILE represents how many column of weight matrix + // are being worked on by each wave. + //---------------------------------------------------- + for (int i = 0; i < YTILE; i++) + for (int n = 0; n < N; n++) + if constexpr (!use_mfma) + sum[n][i] = 0; + else + sum4[n][i] = {0, 0, 0, 0}; + + bigType bigA[N][UNRL]; + bigType bigB[YTILE][UNRL]; + //---------------------------------------------------- + // Fetch weight matrix B in interleaved K-split! + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements (1024B) + // - YTILE represents the number of column being serviced + // by wave + // - Loop for fetching weight matrix (B) are unrolled + // + // Fetch activation matrix A from LDS + // - Loop for fetching activation matrix (A) are unrolled + // + // Finally, do the matrix multiplication in an unrolled + // fashion. This provides lot of food for compiler + // scheduling. + // + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + #ifdef PCML + if ((k1 == 0) || (k1 == kBase + kFit)) { // load next chunk of A[] to LDS + if (k1 != 0) kBase += kFit; + __syncthreads(); + for (uint32_t k = 0; k < kFit; k += THRDS * _WvPrGrp * A_CHUNK) { + uint32_t kOff = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + if (kBase + kOff >= K) break; + if (kOff >= kFit) break; + for (uint32_t n = 0; n < N; n++) { + uint32_t k_in = kBase + n * K + kOff; + uint32_t k_ot = n * kFit + kOff; + *((bigType*)(&s[k_ot])) = *((bigType*)(&A[k_in])); + } + } + __syncthreads(); + } + if (m >= M) continue; + #endif + + // Fetch the weight matrix from memory! + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + const scalar_t* B_ = &B[(m + 0) * K + k_]; + for (int b = 0; b < YTILE; b++) + bigB[b][k2].h8 = (loadnt((scalar8*)(&B_[b * K]))); + } + + // Fetch activation matrix from either just LDS or from both LDS / memory + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Fetch A activation matrix in interleaved fashion from LDS or memory + + for (int n = 0; n < N; n++) { + #ifdef PCML + bigA[n][k2] = *((const bigType*)(&(s[k_ - kBase + kFit * n]))); + #else + if (k_ + K * n < 32 * 1024) + bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n]))); + else + bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n]))); + #endif + } + } + + // Do the matrix multiplication in interleaved manner + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + #pragma unroll + for (uint32_t n = 0; n < N; n++) { + // Do the matrix multiplication of activation and weight matrix + // - Remember the accumulation is happening for K-split of 64! + #pragma unroll + for (int y = 0; y < YTILE; y++) { + if constexpr (!use_mfma) + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b]) + } + else + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 4; b++) + sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k( + bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0); + } + } + } + } + + #ifdef PCML + if (m >= M) { + m += CuCount * _WvPrGrp * YTILE; + kBase = 0; + continue; + } + #endif + + //---------------------------------------------------- + // Final reduction step using shuffle + //---------------------------------------------------- + if constexpr (!use_mfma) { + for (int n = 0; n < N; n++) { + for (int y = 0; y < YTILE; y++) { + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + } + } + + if (threadIdx.x == 63) { + for (int n = 0; n < N; n++) { + for (int i = 0; i < YTILE; i++) { + if (commitColumn[i]) + C[m + i + n * M] = __float2s(sum[n][i]); + } + } + } + } else { + #pragma unroll + for (int n = 0; n < N; n++) { + #pragma unroll + for (int y = 0; y < YTILE; y++) { + float accm = sum4[n][y][0]; + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[n][y][1]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[n][y][2]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[n][y][3]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + + sum4[n][y][0] = accm; + } + } + if (threadIdx.x == 63) { + for (int n = 0; n < N; n++) { + for (int i = 0; i < YTILE; i++) { + // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]); + C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]); + } + } + } + } + + m += CuCount * _WvPrGrp * YTILE; + kBase = 0; + + // Check whether there will be fragmentation! + // This will happen only for the last wave! + if (m < M && (m + YTILE) >= M) { + uint32_t startColumn = M - YTILE; + for (uint32_t i = 0; i < (m - startColumn); i++) { + commitColumn[i] = 0; + } + m = startColumn; + } + } +} +#else // !defined(__HIP__GFX9__) TODO: Add NAVI support +template +__global__ void wvSplitK_hf_big_(const int K, const int M, const scalar_t* B, + const scalar_t* __restrict__ A, scalar_t* C, + const int _WvPrGrp, const int CuCount) { + UNREACHABLE_CODE +} +#endif // defined(__HIP__GFX9__) TODO: Add NAVI support + +int mindiv(int N, int div1, int div2) { + int nPrRnd = div1 * div2; + int rnds0 = N / nPrRnd; + nPrRnd -= div1 * 3; + int rnds3 = N / nPrRnd; + nPrRnd -= div1; + int rnds4 = N / nPrRnd; + nPrRnd -= div1; + int rnds5 = N / nPrRnd; + nPrRnd -= div1; + int rnds6 = N / nPrRnd; + nPrRnd -= div1; + int rnds7 = N / nPrRnd; + nPrRnd -= div1; + int rnds8 = N / nPrRnd; + nPrRnd -= div1; + int rnds9 = N / nPrRnd; + nPrRnd -= div1; + int rtn = div2; + if (rnds0 == rnds3) rtn = div2 - 3; + if (rnds0 == rnds4) rtn = div2 - 4; + if (rnds0 == rnds5) rtn = div2 - 5; + if (rnds0 == rnds6) rtn = div2 - 6; + if (rnds0 == rnds7) rtn = div2 - 7; + if (rnds0 == rnds8) rtn = div2 - 8; + if (rnds0 == rnds9) rtn = div2 - 9; + return rtn; +} + +torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b, + const int64_t CuCount) { + auto M_in = in_a.size(0); + auto K_in = in_a.size(1); + auto N_in = in_b.size(0); + + TORCH_CHECK(in_a.dtype() == in_b.dtype()); + TORCH_CHECK(K_in % 8 == 0, "k % 8 == 0"); + TORCH_CHECK(in_a.dtype() == torch::kFloat16 || + in_a.dtype() == torch::kBFloat16); + + auto out_c = torch::empty( + {N_in, M_in}, + torch::TensorOptions().dtype(in_b.dtype()).device(in_b.device())); + + dim3 grid(CuCount); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const int max_lds_len = get_lds_size() / 2; + +#define WVSPLITK(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \ + _N) \ + { \ + dim3 block(64, _WvPrGrp); \ + if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \ + int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \ + wvSplitK_hf_sml_ \ + <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ + CuCount); \ + } else if (K_in * N_in <= max_lds_len * 1.2) { \ + int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \ + wvSplitK_hf_ \ + <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ + CuCount); \ + } else { \ + int __wvPrGrp = mindiv(M_in, CuCount * _YTILEb, _WvPrGrp); \ + wvSplitK_hf_big_ \ + <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ + CuCount); \ + } \ + } + + AT_DISPATCH_REDUCED_FLOATING_TYPES(in_b.scalar_type(), "wvSplitK", [&] { + using fptype = typename scalar::type; + fptype* af4 = reinterpret_cast(in_a.data_ptr()); + const fptype* bf4 = reinterpret_cast(in_b.data_ptr()); + fptype* c = reinterpret_cast(out_c.data_ptr()); + switch (N_in) { + case 1: + WVSPLITK(16, 2, 2, 2, 2, 2, 2, 1) + break; + case 2: + WVSPLITK(16, 2, 2, 2, 2, 2, 2, 2) + break; + case 3: + WVSPLITK(16, 4, 7, 7, 1, 1, 1, 3) + break; + case 4: + WVSPLITK(16, 4, 7, 7, 1, 1, 1, 4) + break; + default: + throw std::runtime_error( + "Unsupported N value: " + std::to_string(M_in) + "," + + std::to_string(K_in) + "," + std::to_string(N_in)); + } + }); + return out_c; +} + +#if defined(__HIP__MI3XX__) // TODO: Add NAVI support +template +__global__ void __launch_bounds__(WvPrGrp* THRDS) + wvSplitKQ_hf_sml_(const int K, const int Kp, const int M, const fp8_t* B, + const fp8_t* __restrict__ A, scalar_t* C, + const float* __restrict__ s_A, + const float* __restrict__ s_B, const int _WvPrGrp, + const int CuCount) { + constexpr int max_lds_len = LDS_SIZE; + using scalar8 = + __attribute__((__vector_size__((A_CHUNK / 4) * sizeof(float)))) float; + using intx2 = __attribute__((__vector_size__(2 * sizeof(int)))) int; + using intx4 = __attribute__((__vector_size__(4 * sizeof(int)))) int; + union bigType { + char f8[A_CHUNK]; + char2 c2[A_CHUNK / 2]; + scalar_t h[A_CHUNK / 2]; + float f[A_CHUNK / 4]; + int i[A_CHUNK / 4]; + long l[A_CHUNK / 8]; + intx4 l2[A_CHUNK / 16]; + scalar8 h8; + }; + + __shared__ fp8_t s[max_lds_len]; + + for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK; + k < min(K * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) { + *((bigType*)(&s[k])) = *((bigType*)(&A[k])); + } + __syncthreads(); + + if (threadIdx.y >= _WvPrGrp) return; + + uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE; + + using floatx16 = __attribute__((__vector_size__(16 * sizeof(float)))) float; + floatx16 sum[N][YTILE]; + float sA = *s_A; + float sB = *s_B; + + while (m < M) { + for (int i = 0; i < YTILE; i++) + for (int n = 0; n < N; n++) sum[n][i] = {0.f}; + + bigType bigA[N][UNRL]; + bigType bigB[YTILE][UNRL]; + + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + #pragma unroll + for (uint32_t n = 0; n < N; ++n) bigA[n][k2].h8 = {0.f}; + #pragma unroll + for (uint32_t y = 0; y < YTILE; ++y) bigB[y][k2].h8 = {0.f}; + } + + // Fetch the weight matrix from memory! + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + const fp8_t* B_ = &B[(m + 0) * Kp + k_]; + #pragma unroll + for (uint32_t y = 0; y < YTILE; ++y) { + bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[y * Kp]))); + } + } + + // Fetch activation matrix from either just LDS or from both LDS / memory + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + for (int n = 0; n < N; n++) { + bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n]))); + } + } + + // Do the matrix multiplication in interleaved manner + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + if (k >= K) break; + + for (uint32_t n = 0; n < N; n++) { + for (int i = 0; i < A_CHUNK; i += 8) { + for (int y = 0; y < YTILE; ++y) { + sum[n][y] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( + bigA[n][k2].l[i / 8], bigB[y][k2].l[i / 8], sum[n][y], 0, 0, + 0); + } + } + } + } + } + + // Final reduction + for (int n = 0; n < N; n++) { + for (int y = 0; y < YTILE; y++) { + float accm0 = sum[n][y][0]; + float accm16 = sum[n][y][8]; + asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][1]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][9]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][2]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][10]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][3]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][11]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][4]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][12]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][5]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][13]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][6]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][14]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][7]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][15]), "v"(accm16)); + accm0 += __shfl(accm0, 36); + accm16 += __shfl(accm16, 52); + sum[n][y][0] = accm0 + __shfl(accm16, 16); + } + } + + if (threadIdx.x == 0) { + for (int n = 0; n < N; n++) { + for (int y = 0; y < YTILE; y++) { + C[m + y + n * M] = __float2s(sum[n][y][0] * sA * sB); + } + } + } + + m += CuCount * _WvPrGrp * YTILE; + } +} +#else // !defined(__HIP__MI3XX__) TODO: Add NAVI support +template +__global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M, + const fp8_t* B, const fp8_t* __restrict__ A, + scalar_t* C, const float* __restrict__ s_A, + const float* __restrict__ s_B, + const int _WvPrGrp, const int CuCount) { + UNREACHABLE_CODE +} +#endif // defined(__HIP__MI3XX__) TODO: Add NAVI support + +#if defined(__HIP__MI3XX__) // TODO: Add NAVI support +template +__global__ void __launch_bounds__(WvPrGrp* THRDS) + wvSplitKQ_hf_(const int K, const int Kp, const int M, const fp8_t* B, + const fp8_t* __restrict__ A, scalar_t* C, + const float* __restrict__ s_A, const float* __restrict__ s_B, + const int _WvPrGrp, const int CuCount) { + constexpr int max_lds_len = LDS_SIZE; + using scalar8 = + __attribute__((__vector_size__((A_CHUNK / 4) * sizeof(float)))) float; + using intx2 = __attribute__((__vector_size__(2 * sizeof(int)))) int; + using intx4 = __attribute__((__vector_size__(4 * sizeof(int)))) int; + union bigType { + char f8[A_CHUNK]; + char2 c2[A_CHUNK / 2]; + scalar_t h[A_CHUNK / 2]; + float f[A_CHUNK / 4]; + int i[A_CHUNK / 4]; + long l[A_CHUNK / 8]; + intx4 l2[A_CHUNK / 16]; + scalar8 h8; + }; + + __shared__ fp8_t s[max_lds_len]; + + for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK; + k < min(K * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) { + *((bigType*)(&s[k])) = *((bigType*)(&A[k])); + } + __syncthreads(); + + if (threadIdx.y >= _WvPrGrp) return; + + uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE; + + using floatx16 = __attribute__((__vector_size__(16 * sizeof(float)))) float; + floatx16 sum[N][YTILE]; + float sA = *s_A; + float sB = *s_B; + + while (m < M) { + for (int i = 0; i < YTILE; i++) + for (int n = 0; n < N; n++) sum[n][i] = {0}; + + bigType bigA[N][UNRL]; + bigType bigB[YTILE][UNRL]; + + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + // Fetch the weight matrix from memory! + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + const fp8_t* B_ = &B[(m + 0) * Kp + k_]; + for (int y = 0; y < YTILE; ++y) { + if (y + m >= M) break; // To avoid mem access fault. + bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[y * Kp]))); + } + } + + // Fetch activation matrix from either just LDS or from both LDS / memory + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + for (int n = 0; n < N; n++) { + if (k_ + K * n < max_lds_len) + bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n]))); + else + bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n]))); + } + } + + // Do the matrix multiplication in interleaved manner + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + for (uint32_t n = 0; n < N; n++) { + for (int i = 0; i < A_CHUNK; i += 8) { + for (int y = 0; y < YTILE; ++y) { + sum[n][y] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( + bigA[n][k2].l[i / 8], bigB[y][k2].l[i / 8], sum[n][y], 0, 0, + 0); + } + } + } + } + } + + // Final reduction + for (int n = 0; n < N; n++) { + for (int y = 0; y < YTILE; y++) { + float accm0 = sum[n][y][0]; + float accm16 = sum[n][y][8]; + asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][1]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][9]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][2]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][10]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][3]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][11]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][4]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][12]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][5]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][13]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][6]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][14]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[n][y][7]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[n][y][15]), "v"(accm16)); + accm0 += __shfl(accm0, 36); + accm16 += __shfl(accm16, 52); + sum[n][y][0] = accm0 + __shfl(accm16, 16); + } + } + + if (threadIdx.x == 0) { + for (int n = 0; n < N; n++) { + for (int y = 0; y < YTILE; y++) { + if (y + m >= M) break; // To avoid mem access fault. + C[m + y + n * M] = __float2s(sum[n][y][0] * sA * sB); + } + } + } + + m += CuCount * _WvPrGrp * YTILE; + } +} +#else // !defined(__HIP__MI3XX__) TODO: Add NAVI support +template +__global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M, + const fp8_t* B, const fp8_t* __restrict__ A, + scalar_t* C, const float* __restrict__ s_A, + const float* __restrict__ s_B, const int _WvPrGrp, + const int CuCount) { + UNREACHABLE_CODE +} +#endif // defined(__HIP__MI3XX__) TODO: Add NAVI support + +void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, + at::Tensor& scale_a, at::Tensor& scale_b, + const int64_t CuCount) { + static c10::ScalarType kFp8Type = is_fp8_ocp() + ? c10::ScalarType::Float8_e4m3fn + : c10::ScalarType::Float8_e4m3fnuz; + auto M_in = in_a.size(0); + auto K_in = in_a.size(1); + auto N_in = in_b.size(0); + auto Kp_in = in_a.stride(0); + TORCH_CHECK(K_in % 16 == 0, "k % 16 == 0"); + TORCH_CHECK(in_a.dtype() == in_b.dtype() && in_a.dtype() == kFp8Type); + TORCH_CHECK(out_c.dtype() == torch::kFloat16 || + out_c.dtype() == torch::kBFloat16); + + dim3 grid(CuCount); + const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const int max_lds_len = get_lds_size(); + +#define WVSPLITKQ(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \ + _N) \ + { \ + dim3 block(64, _WvPrGrp); \ + if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \ + int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \ + wvSplitKQ_hf_sml_ \ + <<>>(K_in, Kp_in, M_in, a_ptr, b_ptr, c_ptr, \ + s_a, s_b, __wvPrGrp, CuCount); \ + } else { \ + int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \ + wvSplitKQ_hf_ \ + <<>>(K_in, Kp_in, M_in, a_ptr, b_ptr, c_ptr, \ + s_a, s_b, __wvPrGrp, CuCount); \ + } \ + } + + AT_DISPATCH_REDUCED_FLOATING_TYPES(out_c.scalar_type(), "wvSplitKQ", [&] { + using fptype = typename scalar::type; + auto c_ptr = reinterpret_cast(out_c.data_ptr()); + auto s_a = scale_a.data_ptr(); + auto s_b = scale_b.data_ptr(); + VLLM_DISPATCH_FP8_TYPES(in_a.scalar_type(), "wvSplitKQ", [&] { + auto a_ptr = in_a.data_ptr(); + auto b_ptr = in_b.data_ptr(); + switch (N_in) { + case 1: + WVSPLITKQ(16, 2, 2, 2, 2, 2, 2, 1) + break; + case 2: + WVSPLITKQ(16, 2, 2, 2, 2, 2, 2, 2) + break; + case 3: + WVSPLITKQ(16, 4, 7, 7, 1, 1, 1, 3) + break; + case 4: + WVSPLITKQ(16, 4, 7, 7, 1, 1, 1, 4) + break; + default: + throw std::runtime_error( + "Unsupported N value: " + std::to_string(M_in) + "," + + std::to_string(K_in) + "," + std::to_string(N_in)); + } + }); + }); +} \ No newline at end of file diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp index 537e9357d52b..4d626cefaedb 100644 --- a/csrc/rocm/torch_bindings.cpp +++ b/csrc/rocm/torch_bindings.cpp @@ -12,7 +12,19 @@ // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { + // Custom gemm op for skinny matrix-matrix multiplication + rocm_ops.def( + "wvSplitK(Tensor in_a, Tensor in_b, int CuCount) -> " + "Tensor"); + + rocm_ops.impl("wvSplitK", torch::kCUDA, &wvSplitK); + // vLLM custom ops for rocm + // wvSplitK for fp8 + rocm_ops.def( + "wvSplitKQ(Tensor in_a, Tensor in_b, Tensor! out_c, Tensor scale_a, " + " Tensor scale_b, int CuCount) -> ()"); + rocm_ops.impl("wvSplitKQ", torch::kCUDA, &wvSplitKQ); // Custom attention op // Compute the attention between an input query and the cached diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py new file mode 100644 index 000000000000..c084f410c573 --- /dev/null +++ b/vllm/_aiter_ops.py @@ -0,0 +1,76 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional + +import torch + +from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op + + +def rocm_aiter_tuned_gemm_impl( + input: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + scale_a: Optional[torch.Tensor] = None, + scale_b: Optional[torch.Tensor] = None) -> torch.Tensor: + + # This AITER function can be used for + # - BF16 and FP16 matmul + # e.g. vllm/model_executor/layers/linear.py + # - per-tensor activations + per-tensor weights + # e.g. vllm/model_executor/layers/quantization/utils/w8a8_utils.py + from aiter.tuned_gemm import tgemm as aiter_tgemm + + return aiter_tgemm.mm(input, + weight, + otype=out_dtype, + scale_a=scale_a, + scale_b=scale_b, + bias=bias) + + +def rocm_aiter_tuned_gemm_fake( + input: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + scale_a: Optional[torch.Tensor] = None, + scale_b: Optional[torch.Tensor] = None) -> torch.Tensor: + + m = input.shape[0] + n = weight.shape[0] + if out_dtype is None: + out_dtype = input.dtype + return torch.empty((m, n), dtype=out_dtype, device=input.device) + + +if current_platform.is_rocm(): + direct_register_custom_op( + op_name="rocm_aiter_tuned_gemm", + op_func=rocm_aiter_tuned_gemm_impl, + mutates_args=[], + fake_impl=rocm_aiter_tuned_gemm_fake, + dispatch_key=current_platform.dispatch_key, + ) + + +class aiter_ops: + + @staticmethod + def rocm_aiter_tuned_gemm( + input: torch.Tensor, # [M, K] + weight: torch.Tensor, # [N, K] + bias: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + scale_a: Optional[torch.Tensor] = None, + scale_b: Optional[torch.Tensor] = None) -> torch.Tensor: + + return torch.ops.vllm.rocm_aiter_tuned_gemm( + input, + weight, + bias=bias, + out_dtype=out_dtype, + scale_a=scale_a, + scale_b=scale_b, + ) \ No newline at end of file diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index bd930bb90653..f1cbaf1fed26 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -36,6 +36,20 @@ def register_fake(fn): from torch.library import impl_abstract as register_fake +def wvSplitKQ(a: torch.Tensor, b: torch.Tensor, out_dtype: torch.dtype, + scale_a: torch.Tensor, scale_b: torch.Tensor, + cu_count: int) -> torch.Tensor: + out = torch.empty((b.shape[0], a.shape[0]), + dtype=out_dtype, + device=b.device) + torch.ops._rocm_C.wvSplitKQ(a, b, out, scale_a, scale_b, cu_count) + return out + + +def wvSplitK(a: torch.Tensor, b: torch.Tensor, cu_count: int) -> torch.Tensor: + return torch.ops._rocm_C.wvSplitK(a, b, cu_count) + + # page attention ops def paged_attention_v1( out: torch.Tensor, diff --git a/vllm/attention/backends/rocm_aiter_mla.py b/vllm/attention/backends/rocm_aiter_mla.py new file mode 100644 index 000000000000..0fe702ee70ee --- /dev/null +++ b/vllm/attention/backends/rocm_aiter_mla.py @@ -0,0 +1,435 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from contextlib import contextmanager +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional, Type, Union + +import torch + +import vllm._custom_ops as ops +import vllm.envs as envs +from vllm.attention.backends.mla.common import (MLACommonBackend, + MLACommonImpl, + MLACommonMetadata, + MLACommonMetadataBuilder, + MLACommonState) +from vllm.attention.backends.utils import (compute_slot_mapping, + compute_slot_mapping_start_idx, + is_block_tables_empty) +from vllm.attention.ops.rocm_aiter_mla import (aiter_mla_decode_fwd, + get_aiter_mla_metadata) + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUBuilder + + +def is_aiter_mla_enabled() -> bool: + return envs.VLLM_ROCM_USE_AITER \ + and envs.VLLM_ROCM_USE_AITER_MLA + + +class AiterMLABackend(MLACommonBackend): + + @staticmethod + def get_name() -> str: + return "ROCM_AITER_MLA" + + @staticmethod + def get_impl_cls() -> Type["AiterMLAImpl"]: + return AiterMLAImpl + + @staticmethod + def get_metadata_cls() -> Type["AiterMLAMetadata"]: + return AiterMLAMetadata + + @staticmethod + def get_builder_cls() -> Type["AiterMLAMetadataBuilder"]: + return AiterMLAMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["AiterMLAState"]: + return AiterMLAState + + +@dataclass +class AiterMLAMetadata(MLACommonMetadata): + # The following 5 tensors are for current version of AITER MLA + block_table_bound: Optional[torch.Tensor] = None + # The indptr of the paged kv cache, shape: [batch_size + 1] + paged_kv_indptr: Optional[torch.Tensor] = None + # The page indices of the paged kv cache + paged_kv_indices: Optional[torch.Tensor] = None + # The number of entries in the last page of each request in + # the paged kv cache, shape: [batch_size] + paged_kv_last_page_lens: Optional[torch.Tensor] = None + + # This is just to make new AITER MLA API work + # -- MTP support is not added yet. + qo_indptr: Optional[torch.Tensor] = None + + @property + def prefill_metadata(self): + prefill_metadata = super().prefill_metadata + self._cached_prefill_metadata = prefill_metadata + + if prefill_metadata is not None: + prefill_metadata.paged_kv_indptr = self.paged_kv_indptr + prefill_metadata.paged_kv_indices = self.paged_kv_indices + prefill_metadata\ + .paged_kv_last_page_lens = self.paged_kv_last_page_lens + prefill_metadata.block_table_bound = self.block_table_bound + prefill_metadata.qo_indptr = self.qo_indptr + + # update the cache + self._cached_prefill_metadata = self.__class__( + **prefill_metadata.__dict__) + + return self._cached_prefill_metadata + + @property + def decode_metadata(self): + decode_metadata = super().decode_metadata + + self._cached_decode_metadata = decode_metadata + + if decode_metadata is not None: + decode_metadata.paged_kv_indptr = self.paged_kv_indptr + decode_metadata.paged_kv_indices = self.paged_kv_indices + decode_metadata\ + .paged_kv_last_page_lens = self.paged_kv_last_page_lens + decode_metadata.block_table_bound = self.block_table_bound + decode_metadata.qo_indptr = self.qo_indptr + + # update the cache + self._cached_decode_metadata = self.__class__( + **decode_metadata.__dict__) + + return self._cached_decode_metadata + + def _ops_advance_step(self, num_seqs: int, num_queries: int, + block_size: int, input_tokens: torch.Tensor, + sampled_token_ids: torch.Tensor, + input_positions: torch.Tensor) -> None: + + ops.advance_step_flashinfer( + num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=input_tokens, + sampled_token_ids=sampled_token_ids, + input_positions=input_positions, + seq_lens=self.seq_lens_tensor, + slot_mapping=self.slot_mapping, + block_tables=self.block_tables, + paged_kv_indices=self.paged_kv_indices, + paged_kv_indptr=self.paged_kv_indptr, + paged_kv_last_page_lens=self.paged_kv_last_page_lens, + block_table_bound=self.block_table_bound) + + +class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): + BLOCK_TABLE_EXTENDER: list[list[int]] = [[]] + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + super().__init__(input_builder) + assert self.block_size == 1, "AITER MLA requires only block size 1." + + def prepare(self): + super().prepare() + self.paged_kv_indices: list[int] = [] + self.paged_kv_indptr: list[int] = [0] + self.paged_kv_last_page_lens: list[int] = [] + self.total_blocks = 0 + self.qo_indptr: list[int] = [0] + + def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool, + prefix_cache_hit: bool): + """Add a sequence group to the metadata. Specifically update/append + 1. context length. + 2. block table. + 3. slot mapping. + """ + is_prompt = inter_data.is_prompt + block_tables = inter_data.block_tables + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block) in zip( + inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, inter_data.seq_lens, + inter_data.query_lens, inter_data.context_lens, + inter_data.curr_sliding_window_blocks): + self.context_lens.append(context_len) + if is_prompt: + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + block_table = [] + if prefix_cache_hit: + # NOTE(woosuk): For flash-attn, the block table should + # include the entries for the incoming prefill tokens. + block_table = block_tables[seq_id] + elif ((chunked_prefill_enabled or not is_prompt) + and block_tables is not None): + if curr_sliding_window_block == 0: + block_table = block_tables[seq_id] + else: + block_table = block_tables[seq_id][ + -curr_sliding_window_block:] + self.block_tables.append(block_table) + + # Compute slot mapping. + is_profile_run = is_block_tables_empty(block_tables) + start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, + context_len, + self.sliding_window) + compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, + seq_len, context_len, start_idx, + self.block_size, inter_data.block_tables) + if is_profile_run: + return + + # Update paged_kv_* tensors only for non-profile run + block_table = block_tables[seq_id] + self._update_paged_kv_tensors(block_table, seq_len) + + def _update_paged_kv_tensors(self, block_table: list[int], seq_len: int): + # Get the number of valid blocks based on sequence length. + # If seq_len = 16, block_size = 16, + # block_table_bound is 1 with 1 valid block. + # If seq_len = 15, block_size = 16, + # block_table_bound is 0 + 1 with 1 valid block. + self.total_blocks += len(block_table) + block_table_bound = seq_len // self.block_size + 1 \ + if seq_len % self.block_size != 0 \ + else seq_len // self.block_size + self.paged_kv_indices.extend(block_table[:block_table_bound]) + self.paged_kv_indptr.append(self.paged_kv_indptr[-1] + + block_table_bound) + self.qo_indptr.append(self.qo_indptr[-1] + 1) + + last_page_len = seq_len % self.block_size + if last_page_len == 0: + last_page_len = self.block_size + self.paged_kv_last_page_lens.append(last_page_len) + + def build(self, seq_lens: list[int], query_lens: list[int], + cuda_graph_pad_size: int, batch_size: int) -> AiterMLAMetadata: + metadata = super().build(seq_lens, query_lens, cuda_graph_pad_size, + batch_size) + device = self.runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + if use_captured_graph: + last_paged_kv_indptr = self.paged_kv_indptr[-1] + self.paged_kv_indptr.extend([last_paged_kv_indptr] * + cuda_graph_pad_size) + self.paged_kv_last_page_lens.extend([0] * cuda_graph_pad_size) + last_qo_indptr = self.qo_indptr[-1] + self.qo_indptr.extend([last_qo_indptr] * cuda_graph_pad_size) + + # For current version of AITER MLA + if len(self.paged_kv_indptr) > 0: + # extend to the maximum number of blocks as returned by the + # scheduler + self.paged_kv_indices.extend( + [0] * (self.total_blocks - len(self.paged_kv_indices))) + paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices, + device=device, + dtype=torch.int) + paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr, + device=device, + dtype=torch.int) + paged_kv_last_page_lens_tensor = torch.tensor( + self.paged_kv_last_page_lens, device=device, dtype=torch.int) + block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) - + 1, + device=device, + dtype=torch.int) + + qo_indptr = torch.tensor(self.qo_indptr, + device=device, + dtype=torch.int) + else: + paged_kv_indices_tensor = None + paged_kv_indptr_tensor = None + paged_kv_last_page_lens_tensor = None + block_table_bound_tensor = None + qo_indptr = None + + metadata.paged_kv_indptr = paged_kv_indptr_tensor + metadata.paged_kv_indices = paged_kv_indices_tensor + metadata.paged_kv_last_page_lens = paged_kv_last_page_lens_tensor + metadata.block_table_bound = block_table_bound_tensor + metadata.qo_indptr = qo_indptr + + return metadata + + +class AiterMLAState(MLACommonState[AiterMLAMetadata]): + + @contextmanager + def graph_capture(self, max_batch_size: int): + kv_indices, kv_indptr, last_page_lens, qo_indptr = \ + get_aiter_mla_metadata( + max_batch_size=max_batch_size, + block_size=self.runner.block_size, + max_block_per_batch=\ + self.runner.get_max_block_per_batch(), + device=self.runner.device) + self._paged_kv_indices_tensor = kv_indices + self._paged_kv_indptr_tensor = kv_indptr + self._paged_kv_last_page_lens_tensor = last_page_lens + self._qo_indptr_tensor = qo_indptr + + with super().graph_capture(max_batch_size): + yield + + del self._paged_kv_indices_tensor + del self._paged_kv_indptr_tensor + del self._paged_kv_last_page_lens_tensor + del self._qo_indptr_tensor + + def graph_capture_get_metadata_for_batch( + self, + batch_size: int, + is_encoder_decoder_model: bool = False) -> AiterMLAMetadata: + + metadata = super().graph_capture_get_metadata_for_batch( + batch_size, is_encoder_decoder_model) + + paged_kv_indptr = self._paged_kv_indptr_tensor[:batch_size + 1] + paged_kv_indices = self._paged_kv_indices_tensor + paged_kv_last_page_lens = self._paged_kv_last_page_lens_tensor[: + batch_size] + qo_indptr = self._qo_indptr_tensor[:batch_size + 1] + + metadata.paged_kv_indptr = paged_kv_indptr + metadata.paged_kv_indices = paged_kv_indices + metadata.paged_kv_last_page_lens = paged_kv_last_page_lens + metadata.qo_indptr = qo_indptr + + return metadata + + def get_graph_input_buffers(self, + attn_metadata: AiterMLAMetadata, + is_encoder_decoder_model: bool = False): + input_buffers = super().get_graph_input_buffers( + attn_metadata, is_encoder_decoder_model) + input_buffers[ + 'paged_kv_indptr'] = attn_metadata.decode_metadata.paged_kv_indptr + input_buffers[ + "paged_kv_indices"] = attn_metadata.\ + decode_metadata.paged_kv_indices + input_buffers[ + "paged_kv_last_page_lens"] = attn_metadata.\ + decode_metadata.paged_kv_last_page_lens + input_buffers['qo_indptr'] = attn_metadata.qo_indptr + + return input_buffers + + def prepare_graph_input_buffers(self, + input_buffers, + attn_metadata: AiterMLAMetadata, + is_encoder_decoder_model: bool = False): + super().prepare_graph_input_buffers(input_buffers, attn_metadata, + is_encoder_decoder_model) + + num_total_blocks = attn_metadata.decode_metadata.paged_kv_indices.shape[ + 0] + input_buffers["paged_kv_indptr"].copy_( + attn_metadata.decode_metadata.paged_kv_indptr, non_blocking=True) + input_buffers["paged_kv_indices"][:num_total_blocks].copy_( + attn_metadata.decode_metadata.paged_kv_indices, non_blocking=True) + input_buffers["paged_kv_last_page_lens"].copy_( + attn_metadata.decode_metadata.paged_kv_last_page_lens, + non_blocking=True) + input_buffers["qo_indptr"].copy_( + attn_metadata.decode_metadata.qo_indptr, non_blocking=True) + + +class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + # MLA Specific Arguments + **mla_args) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params, logits_soft_cap, attn_type, + kv_sharing_target_layer_name, **mla_args) + + unsupported_features = [ + alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap + ] + if any(unsupported_features): + raise NotImplementedError( + "Aiter MLA does not support one of the following: " + "alibi_slopes, sliding_window, blocksparse_params, " + "logits_soft_cap") + + from aiter import flash_attn_varlen_func + self.flash_attn_varlen_func = flash_attn_varlen_func + + def _flash_attn_varlen_diff_headdims( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + softmax_scale: float, return_softmax_lse: bool, + **kwargs) -> Union[tuple[torch.Tensor, ...], torch.Tensor]: + output = self.flash_attn_varlen_func( + q, + k, + v, + **kwargs, + ) + + return output + + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: AiterMLAMetadata, + ) -> torch.Tensor: + assert kv_c_and_k_pe_cache.numel() > 0 + + decode_meta = attn_metadata.decode_metadata + assert decode_meta is not None + B = q_nope.shape[0] + + q = torch.cat([q_nope, q_pe], dim=-1) + o = torch.empty(B, + self.num_heads, + self.kv_lora_rank, + dtype=q.dtype, + device=q.device) + + kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) + + aiter_mla_decode_fwd(q, kv_buffer, o, self.scale, + attn_metadata.qo_indptr, + attn_metadata.max_query_len, + attn_metadata.paged_kv_indptr, + attn_metadata.paged_kv_indices, + attn_metadata.paged_kv_last_page_lens) + + return self._v_up_proj_and_o_proj(o) diff --git a/vllm/attention/ops/rocm_aiter_mla.py b/vllm/attention/ops/rocm_aiter_mla.py new file mode 100644 index 000000000000..cce6b4639460 --- /dev/null +++ b/vllm/attention/ops/rocm_aiter_mla.py @@ -0,0 +1,100 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op + + +def get_aiter_mla_metadata(max_batch_size: int, block_size: int, + max_block_per_batch: int, + device: torch.device) -> tuple[torch.Tensor, ...]: + paged_kv_indices = torch.zeros(max_batch_size * max_block_per_batch, + dtype=torch.int32, + device=device) + paged_kv_indptr = torch.zeros(max_batch_size + 1, + dtype=torch.int32, + device=device) + paged_kv_last_page_lens = torch.full((max_batch_size, ), + block_size, + dtype=torch.int32) + qo_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int, device=device) + return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens, qo_indptr + + +def aiter_mla_decode_fwd( + q: torch.Tensor, + kv_buffer: torch.Tensor, + o: torch.Tensor, + sm_scale: float, + qo_indptr: torch.Tensor, + max_seqlen_qo: int, + kv_indptr: Optional[torch.Tensor] = None, + kv_indices: Optional[torch.Tensor] = None, + kv_last_page_lens: Optional[torch.Tensor] = None, + logit_cap: float = 0.0, +): + + torch.ops.vllm.rocm_aiter_mla_decode_fwd(q, + kv_buffer.view( + -1, 1, 1, q.shape[-1]), + o, + qo_indptr, + max_seqlen_qo, + kv_indptr, + kv_indices, + kv_last_page_lens, + sm_scale=sm_scale, + logit_cap=logit_cap) + + +def mla_decode_fwd_impl( + q: torch.Tensor, + kv_buffer: torch.Tensor, + o: torch.Tensor, + qo_indptr: torch.Tensor, + max_seqlen_qo: int, + kv_indptr: Optional[torch.Tensor] = None, + kv_indices: Optional[torch.Tensor] = None, + kv_last_page_lens: Optional[torch.Tensor] = None, + sm_scale: float = 1.0, + logit_cap: float = 0.0, +) -> None: + from aiter.mla import mla_decode_fwd + + mla_decode_fwd(q, + kv_buffer.view(-1, 1, 1, q.shape[-1]), + o, + qo_indptr, + kv_indptr, + kv_indices, + kv_last_page_lens, + max_seqlen_qo, + sm_scale=sm_scale, + logit_cap=logit_cap) + + +def mla_decode_fwd_fake( + q: torch.Tensor, + kv_buffer: torch.Tensor, + o: torch.Tensor, + qo_indptr: torch.Tensor, + max_seqlen_qo: int, + kv_indptr: Optional[torch.Tensor] = None, + kv_indices: Optional[torch.Tensor] = None, + kv_last_page_lens: Optional[torch.Tensor] = None, + sm_scale: float = 1.0, + logit_cap: float = 0.0, +) -> None: + pass + + +if current_platform.is_rocm(): + direct_register_custom_op(op_name="rocm_aiter_mla_decode_fwd", + op_func=mla_decode_fwd_impl, + mutates_args=["o"], + fake_impl=mla_decode_fwd_fake, + tags=[torch.Tag.needs_fixed_stride_order]) diff --git a/vllm/attention/utils/fa_utils.py b/vllm/attention/utils/fa_utils.py new file mode 100644 index 000000000000..69cde06fd72e --- /dev/null +++ b/vllm/attention/utils/fa_utils.py @@ -0,0 +1,55 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +from vllm import envs +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]: + # import here to avoid circular dependencies + from vllm.platforms import current_platform + try: + from vllm.vllm_flash_attn.flash_attn_interface import ( + fa_version_unsupported_reason, is_fa_version_supported) + device_capability = current_platform.get_device_capability() + + assert device_capability is not None + + # 1. default version depending on platform + fa_version = 3 if (device_capability.major == 9 + and is_fa_version_supported(3)) else 2 + + # 2. override if passed by environment + if envs.VLLM_FLASH_ATTN_VERSION is not None: + assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3] + fa_version = envs.VLLM_FLASH_ATTN_VERSION + + # 3. fallback for unsupported combinations + if device_capability.major == 10 and fa_version == 3: + logger.warning_once( + "Cannot use FA version 3 on Blackwell platform " + "defaulting to FA version 2.") + fa_version = 2 + + if requires_alibi and fa_version == 3: + logger.warning_once("Cannot use FA version 3 with ALiBi, " + "defaulting to FA version 2.") + fa_version = 2 + + if not is_fa_version_supported(fa_version): + logger.error("Cannot use FA version %d is not supported due to %s", + fa_version, fa_version_unsupported_reason(fa_version)) + + assert is_fa_version_supported(fa_version) + return fa_version + except (ImportError, AssertionError): + return None + + +def flash_attn_supports_fp8() -> bool: + from vllm.platforms import current_platform + return get_flash_attn_version() == 3 and \ + current_platform.get_device_capability().major == 9 diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 1f719392bd9f..3a3436b02ee5 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -548,7 +548,7 @@ def get_kwargs(cls: type[Config]) -> dict[str, Any]: parser.add_argument('--block-size', type=int, default=EngineArgs.block_size, - choices=[8, 16, 32, 64, 128], + choices=[1, 8, 16, 32, 64, 128], help='Token block size for contiguous chunks of ' 'tokens. This is ignored on neuron devices and ' 'set to ``--max-model-len``. On CUDA devices, ' @@ -1522,8 +1522,14 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: # No FlashInfer or XFormers so far. V1_BACKENDS = [ - "FLASH_ATTN_VLLM_V1", "FLASH_ATTN", "PALLAS", "PALLAS_VLLM_V1", - "TRITON_ATTN_VLLM_V1", "TRITON_MLA", "FLASHMLA" + "FLASH_ATTN_VLLM_V1", + "FLASH_ATTN", + "PALLAS", + "PALLAS_VLLM_V1", + "TRITON_ATTN_VLLM_V1", + "TRITON_MLA", + "FLASHMLA", + "ROCM_AITER_MLA", ] if (envs.is_set("VLLM_ATTENTION_BACKEND") and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS): diff --git a/vllm/envs.py b/vllm/envs.py index 412034a43bd6..b0d88d613930 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -79,9 +79,11 @@ VLLM_ROCM_USE_AITER_MOE: bool = True VLLM_ROCM_USE_AITER_ASMMOE: bool = False VLLM_ROCM_USE_AITER_RMSNORM: bool = True + VLLM_ROCM_USE_AITER_MLA: bool = True VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_MOE_PADDING: bool = True VLLM_ROCM_CUSTOM_PAGED_ATTN: bool = True + VLLM_ROCM_USE_SKINNY_GEMM: bool = True VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 VLLM_DISABLE_COMPILE_CACHE: bool = False @@ -546,19 +548,23 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: lambda: (os.getenv("VLLM_ROCM_USE_AITER_MOE", "True").lower() in ("true", "1")), - # Whether to use aiter asm moe ops. # By default is enabled. "VLLM_ROCM_USE_AITER_ASMMOE": lambda: (os.getenv("VLLM_ROCM_USE_AITER_ASMMOE", "False").lower() in ("true", "1")), - # use aiter rms norm op if aiter ops are enabled. "VLLM_ROCM_USE_AITER_RMSNORM": lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in ("true", "1")), + # Whether to use aiter mla ops. + # By default is enabled. + "VLLM_ROCM_USE_AITER_MLA": + lambda: (os.getenv("VLLM_ROCM_USE_AITER_MLA", "True").lower() in + ("true", "1")), + # Pad the fp8 weights to 256 bytes for ROCm "VLLM_ROCM_FP8_PADDING": lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))), @@ -572,6 +578,11 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: lambda: (os.getenv("VLLM_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in ("true", "1")), + # Use skinny gemm for FP8 kernels + "VLLM_ROCM_USE_SKINNY_GEMM": + lambda: (os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in + ("true", "1")), + # Divisor for dynamic query scale factor calculation for FP8 KV Cache "Q_SCALE_CONSTANT": lambda: int(os.getenv("Q_SCALE_CONSTANT", "200")), diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 21035a9e5dbe..584db4da3280 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -2,13 +2,15 @@ import itertools from abc import abstractmethod -from typing import Any, Literal, Optional, Union +from typing import Any, Callable, Literal, Optional, Union import torch import torch.nn as nn -import torch.nn.functional as F from torch.nn.parameter import Parameter, UninitializedParameter +from vllm import _custom_ops as ops +from vllm import envs +from vllm._aiter_ops import aiter_ops from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, split_tensor_along_last_dim, @@ -26,6 +28,7 @@ RowvLLMParameter) # yapf: enable from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform logger = init_logger(__name__) @@ -50,6 +53,66 @@ ] +def rocm_unquantized_gemm_wrapper(): + """Creates a wrapper function with the signature (x, weight, bias)""" + # Get configuration from environment variables + use_skinny = envs.VLLM_ROCM_USE_SKINNY_GEMM + GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName + ON_MI300 = any(arch in GPU_ARCH for arch in ["gfx942"]) + use_aiter = (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_LINEAR + and ON_MI300) + + def inner_function(x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None): + k = weight.shape[1] + _use_skinny = (use_skinny and \ + x.dtype in [torch.float16, torch.bfloat16] \ + and k % 8 == 0 and bias is None) + + if _use_skinny is not True: + if use_aiter: + return aiter_ops.rocm_aiter_tuned_gemm(x, weight, bias) + return torch.nn.functional.linear(x, weight, bias) + + x_view = x.view(-1, x.size(-1)) + n = x_view.shape[0] + m = weight.shape[0] + cu_count = current_platform.get_cu_count() + + if m > 8 and 0 < n <= 4: + out = ops.wvSplitK(weight, x_view, cu_count) + return out.view(*x.shape[:-1], weight.shape[0]) + elif m % 4 == 0 and n == 1 and k <= 8192: + out = ops.LLMM1(weight, x_view, 4) + return out.view(*x.shape[:-1], weight.shape[0]) + + if use_aiter: + return aiter_ops.rocm_aiter_tuned_gemm(x, weight, bias) + + return torch.nn.functional.linear(x, weight, bias) + + return inner_function + + +def dispatch_unquantized_gemm() -> Callable[ + [torch.Tensor, torch.Tensor, Optional[torch.Tensor]], torch.Tensor]: + """ + Dispatcher function that returns a function with signature (x, weight, bias) + based on the current platform and environment variables. + """ + if current_platform.is_rocm(): + return rocm_unquantized_gemm_wrapper() + + # Return a simple wrapper around linear to maintain the same signature + def linear_wrapper(x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None): + return torch.nn.functional.linear(x, weight, bias) + + return linear_wrapper + + def adjust_marlin_shard(param, shard_size, shard_offset): marlin_tile_size = getattr(param, "marlin_tile_size", None) if marlin_tile_size is None: @@ -170,6 +233,10 @@ def apply(self, class UnquantizedLinearMethod(LinearMethodBase): """Linear method without quantization.""" + def __init__(self): + super().__init__() + self._gemm_func = dispatch_unquantized_gemm() + def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: list[int], input_size: int, @@ -187,8 +254,7 @@ def apply(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - - return F.linear(x, layer.weight, bias) + return self._gemm_func(x, layer.weight, bias) class LinearBase(torch.nn.Module): diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index da92d8288215..f94eb37463a5 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -576,14 +576,11 @@ def process_weights_after_loading(self, layer: Module) -> None: # Lazy import to avoid importing triton too early. from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( is_rocm_aiter_moe_enabled, shuffle_weights) - + self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() self.rocm_aiter_use_asm = (self.rocm_aiter_moe_enabled and envs.VLLM_ROCM_USE_AITER_ASMMOE) - print(f"rocm_aiter_moe_enabled: {self.rocm_aiter_moe_enabled}") - print(f"rocm_aiter_use_asm: {self.rocm_aiter_use_asm}") - # TODO (rob): refactor block quant into separate class. if self.block_quant: assert self.quant_config.activation_scheme == "dynamic" @@ -780,7 +777,6 @@ def apply( e_score_correction_bias=e_score_correction_bias, ) - if self.rocm_aiter_moe_enabled: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 rocm_aiter_fused_experts) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index b8e6384d7359..3b2c058502c7 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -1,10 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 +from functools import cache from typing import List, Optional, Tuple, Union import torch from vllm import _custom_ops as ops +from vllm import envs +from vllm._aiter_ops import aiter_ops from vllm.config import CompilationLevel, get_current_vllm_config from vllm.platforms import current_platform @@ -20,6 +23,58 @@ and current_platform.has_device_capability(94)) +@cache +def on_mi3xx() -> bool: + GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName + return any(arch in GPU_ARCH for arch in ["gfx942", "gfx950"]) + + +@cache +def use_skinny_gemm() -> bool: + return envs.VLLM_ROCM_USE_SKINNY_GEMM + + +def rocm_aiter_per_tensor_w8a8_scaled_mm(qinput: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + bias: torch.Tensor, + input_2d: torch.Tensor, + output_shape: list) -> torch.Tensor: + + output = aiter_ops.rocm_aiter_tuned_gemm(qinput, + weight.t(), + out_dtype=out_dtype, + scale_a=scale_a, + scale_b=scale_b, + bias=bias) + + return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape) + + +def rocm_per_tensor_w8a8_scaled_mm(qinput: torch.Tensor, weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, bias: torch.Tensor, + input_2d: torch.Tensor, + output_shape: list) -> torch.Tensor: + + if use_skinny_gemm() and on_mi3xx( + ) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0: + output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b, + current_platform.get_cu_count()) + else: + output = torch._scaled_mm(qinput, + weight, + out_dtype=out_dtype, + scale_a=scale_a, + scale_b=scale_b, + bias=bias) + + return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape) + + def sparse_cutlass_supported() -> bool: if not current_platform.is_cuda(): return False @@ -158,6 +213,10 @@ def __init__(self, pad_output = config.level < CompilationLevel.PIECEWISE self.output_padding = 17 if pad_output else None + self.use_aiter_and_is_supported = (envs.VLLM_ROCM_USE_AITER + and envs.VLLM_ROCM_USE_AITER_LINEAR + and current_platform.is_fp8_fnuz()) + def apply( self, input: torch.Tensor, @@ -218,10 +277,23 @@ def apply( else: qinput, x_scale = input_2d, input_scale - per_tensor_weights = (weight_scale.numel() == 1) - per_tensor_activations = (x_scale.numel() == 1) + per_tensor_weights = (weight_scale.numel() + == 1) and weight_scale.dim() < 2 + per_tensor_activations = (x_scale.numel() + == 1) and x_scale.dim() < 2 if per_tensor_weights and per_tensor_activations: + if current_platform.is_rocm(): + if self.use_aiter_and_is_supported: + print("Using ROCm AITER for per-tensor W8A8 scaled mm") + return rocm_aiter_per_tensor_w8a8_scaled_mm( + qinput, weight, out_dtype, x_scale, weight_scale, + bias, input_2d, output_shape) + + return rocm_per_tensor_w8a8_scaled_mm( + qinput, weight, out_dtype, x_scale, weight_scale, bias, + input_2d, output_shape) + # Fused GEMM_DQ output = torch._scaled_mm(qinput, weight, diff --git a/vllm/model_executor/layers/tuned_gemm.py b/vllm/model_executor/layers/tuned_gemm.py new file mode 100644 index 000000000000..76f5092aac2c --- /dev/null +++ b/vllm/model_executor/layers/tuned_gemm.py @@ -0,0 +1,200 @@ +# SPDX-License-Identifier: Apache-2.0 +import os +from pathlib import Path +from typing import Optional + +import pandas as pd +import torch +import torch.nn.functional as F + +from vllm import _custom_ops as ops +from vllm import envs +from vllm.platforms import current_platform +from vllm.platforms.rocm import on_gfx9 +from vllm.utils import aiter_linear_enabled, is_navi + +if aiter_linear_enabled(): + from aiter.tuned_gemm import tgemm as aiter_tgemm + +support_tuned_gemms = False +if current_platform.is_rocm(): + # import vllm._gradlib_C # noqa: F401 + support_tuned_gemms = True + + +def hipb_mm(inp, weights, solidx, bias=None): + return torch.ops._gradlib_C.hipb_mm(inp, weights, solidx, bias, None, None, + None, None) + + +def rocb_mm(inp, weights, solidx): + return torch.ops._gradlib_C.rocb_mm(inp, weights, solidx) + + +class TunedGemm: + + def __init__(self): + self.extensions_created = False + self.save_gemm = int(os.environ.get('VLLM_TUNE_GEMM', 0)) + self.untune_path = os.environ.get('VLLM_UNTUNE_FILE', + "/tmp/vllm_untuned.csv") + self.tune_path = os.environ.get('VLLM_TUNE_FILE', "tuned.csv") + self.bestsols = {} + self.load_best_sols() + self.create_ds() + self.cu_count = torch.cuda.get_device_properties( + device='cuda').multi_processor_count + + self.use_skinny = (current_platform.is_rocm() + and envs.VLLM_ROCM_USE_SKINNY_GEMM + and not is_navi()) + + if (self.save_gemm == 1): + self.tuned_df = pd.DataFrame( + columns=['M', 'N', 'K', 'bias', 'dtype']) + else: + self.tuned_df = None + + def load_best_sols(self): + if self.tune_path is not None and Path(self.tune_path).is_file(): + self.bestsols = pd.read_csv(self.tune_path) + + def create_ds(self): + df: pd.DataFrame = self.bestsols + solds = {} + for i in range(len(df)): + ds = df.iloc[i] + key = (ds['M'], ds['N'], ds['K'], ds['bias'], ds['dtype']) + if ds['libtype'] == 'hipblaslt': + soltype = 1 + elif ds['libtype'] == 'rocblas': + soltype = 2 + solds[key] = (soltype, int(ds['solidx'])) + self.solids = solds + + def query_sol(self, m, n, k, bias, dtype): + if envs.VLLM_USE_V1: + return 0, 0 + return self.solids.get((m, n, k, bias, str(dtype)), (0, 0)) + + def apply_skinny(self, m, n, k, inp_view, weights): + if not self.use_skinny: + return None + if inp_view.dtype != torch.float16 or k % 8 != 0: + return None + if m > 8 and 0 < n <= 4: + out = torch.empty(inp_view.shape[0], + weights.shape[0], + dtype=inp_view.dtype, + device='cuda') + ops.wvSpltK(weights, inp_view, out, n, self.cu_count) + return out + elif m % 4 == 0 and n == 1 and k <= 8192: + out = torch.empty(inp_view.shape[0], + weights.shape[0], + dtype=inp_view.dtype, + device='cuda') + ops.LLMM1(weights, inp_view, out, 4) + return out + else: + return None + + def scaled_mm( + self, + inp: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + bias: Optional[torch.Tensor], + ) -> torch.Tensor: + if aiter_linear_enabled(): + return aiter_tgemm.mm(inp, + weight.t(), + otype=out_dtype, + scale_a=scale_a, + scale_b=scale_b, + bias=bias) + n = inp.shape[0] + if (not envs.VLLM_ROCM_USE_SKINNY_GEMM or n != 1 + or not current_platform.is_rocm() or on_gfx9() or is_navi()): + return torch._scaled_mm(inp, + weight, + out_dtype=out_dtype, + scale_a=scale_a, + scale_b=scale_b, + bias=bias) + weightT = weight.t() + out = torch.empty(inp.shape[0], + weightT.shape[0], + dtype=out_dtype, + device='cuda') + + Otp = 1 #default bfloat16 + if out_dtype == torch.float16: + Otp = 0 + ops.wvSpltKQ(weightT, inp, out, scale_a, scale_b, n, Otp, + self.cu_count) + return out + + def mm(self, inp, weights, bias=None): + if not support_tuned_gemms: + return F.linear(inp, weights, bias) + # F.Linear can take a 3 dimensional input. vllm + # uses this for linear units. However, sampler + # will use torch.matmul with 2 dimensions only + if inp.dim() == 3: + try: + inp_view = inp.view(-1, inp.size(-1)) + batched = True + except RuntimeError: + return F.linear(inp, weights, bias) + else: + inp_view = inp + batched = False + if self.extensions_created is False: + torch.ops._gradlib_C.rocb_create_extension() + torch.ops._gradlib_C.hipb_create_extension() + self.extensions_created = True + m = weights.shape[0] + n = inp_view.shape[0] + k = inp_view.shape[1] + use_bias = bias is not None + soltype, solidx = self.query_sol(m=m, + n=n, + k=k, + bias=use_bias, + dtype=inp.dtype) + out = self.apply_skinny(m, n, k, inp_view, weights) + if out is not None: + if batched: + out = out.view(inp.shape[0], inp.shape[1], weights.shape[0]) + if bias is not None: + return out + bias + return out + elif soltype == 1: + out = hipb_mm(inp_view, weights.t(), solidx, bias) + elif soltype == 2: + out = rocb_mm(inp_view, weights.t(), solidx) + if bias is not None: + out = out + bias + else: + if (self.save_gemm == 1): + self.tuned_df = pd.concat([ + self.tuned_df, + pd.DataFrame({ + 'M': [m], + 'N': [n], + 'K': [k], + 'bias': [bias is not None], + 'dtype': [inp.dtype], + }) + ]).drop_duplicates() + self.tuned_df.to_csv(self.untune_path, index=False) + return F.linear(inp, weights, bias) + if batched: + out = out.view(inp.shape[0], inp.shape[1], weights.shape[0]) + return out + + +tgemm = TunedGemm() \ No newline at end of file diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 23b450aeddac..79934cafb5a8 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -837,4 +837,4 @@ def get_spec_layer_idx_from_weight_name(config: PretrainedConfig, for i in range(config.num_nextn_predict_layers): if weight_name.startswith(f"model.layers.{layer_idx+i}."): return layer_idx + i - return None + return None \ No newline at end of file diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 8c099b9531c5..5931a620dba7 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -39,6 +39,8 @@ class _Backend(enum.Enum): TRITON_ATTN_VLLM_V1 = enum.auto() XFORMERS = enum.auto() ROCM_FLASH = enum.auto() + ROCM_AITER_MLA = enum.auto() # Supported by V1 + ROCM_AITER_MLA_VLLM_V1 = enum.auto() TORCH_SDPA = enum.auto() FLASHINFER = enum.auto() TRITON_MLA = enum.auto() # Supported by V1 diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index ca6528313a19..5603587d7bb1 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -121,6 +121,12 @@ def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) +@cache +def on_gfx9() -> bool: + GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName + return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"]) + + class RocmPlatform(Platform): _enum = PlatformEnum.ROCM device_name: str = "rocm" @@ -140,8 +146,40 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla) -> str: if use_mla: - logger.info("Using Triton MLA backend.") - return "vllm.attention.backends.triton_mla.TritonMLABackend" + from vllm.attention.backends.rocm_aiter_mla import ( + is_aiter_mla_enabled) + + if selected_backend is None: + selected_backend = (_Backend.ROCM_AITER_MLA if + is_aiter_mla_enabled() or block_size == 1 + else _Backend.TRITON_MLA) + + if selected_backend == _Backend.TRITON_MLA: + if block_size != 1: + logger.info("Using Triton MLA backend.") + return "vllm.attention.backends.triton_mla.TritonMLABackend" # noqa: E501 + else: + raise ValueError( + f" The selected backend, {selected_backend.name}," + f"does not support block size {block_size}.") + elif selected_backend == _Backend.ROCM_AITER_MLA \ + or selected_backend == _Backend.ROCM_AITER_MLA_VLLM_V1: + if block_size == 1: + if use_v1: + logger.info("Using AITER MLA backend on V1 engine.") + return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501 + else: + raise ValueError( + "AITER MLA backend is not ported on V0 engine.") + else: + raise ValueError( + f" The selected backend, {selected_backend.name}," + f"does not support block size {block_size}." + "(currently only supports block size 1)") + else: + raise ValueError( + f" The selected backend, {selected_backend.name}," + f"is not MLA type while requested for MLA backend.") selected_backend = (_Backend.ROCM_FLASH if selected_backend == _Backend.FLASH_ATTN else selected_backend) if envs.VLLM_USE_V1: @@ -157,6 +195,11 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, logger.info("Using ROCmFlashAttention backend.") return "vllm.attention.backends.rocm_flash_attn.ROCmFlashAttentionBackend" # noqa: E501 + @classmethod + def get_cu_count(cls, device_id: int = 0) -> int: + return torch.cuda.get_device_properties( + device_id).multi_processor_count + @classmethod @lru_cache(maxsize=8) def get_device_capability(cls, diff --git a/vllm/utils.py b/vllm/utils.py index c6e2afff72d7..f1a9812665f7 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -2662,3 +2662,30 @@ def is_torch_equal_or_newer(target: str) -> bool: except Exception: # Fallback to PKG-INFO to load the package info, needed by the doc gen. return Version(importlib.metadata.version('torch')) >= Version(target) + + +@cache +def is_navi() -> bool: + from vllm.platforms import current_platform + if not current_platform.is_rocm() or not torch.cuda.is_available(): + return False + # All (visible) GPUs must be of the same type, + # otherwise FP8 results can't be guaranteed. + archName = torch.cuda.get_device_properties('cuda').gcnArchName + return archName is not None and "gfx1" in archName + + +@cache +def is_rocm() -> bool: + from vllm.platforms import current_platform + return current_platform.is_rocm() + + +@cache +def aiter_enabled() -> bool: + return is_rocm() and envs.VLLM_ROCM_USE_AITER + + +@cache +def aiter_linear_enabled() -> bool: + return aiter_enabled() and envs.VLLM_ROCM_USE_AITER_LINEAR diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index c0a6bd29623e..f28dd02de547 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -500,12 +500,12 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, # longer context lengths max_context_chunk = (self.chunked_prefill_workspace_size // num_prefills_with_context_cpu) - - # align max_context_chunk to page_size by rounding down, - # currently the `gather_cache` kernel cannot handle - # `context_chunk_starts` that are not aligned to page_size - max_context_chunk = round_down(max_context_chunk, - self.page_size) + if self.aot_schedule: + # align max_context_chunk to page_size by rounding down, + # currently the `gather_cache` kernel cannot handle + # `context_chunk_starts` that are not aligned to page_size + max_context_chunk = round_down(max_context_chunk, + self.page_size) assert max_context_chunk > 0 num_chunks = cdiv(max_context_len_cpu, max_context_chunk) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py new file mode 100644 index 000000000000..68245913ee15 --- /dev/null +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -0,0 +1,208 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from typing import Any, Optional + +import torch + +import vllm.envs as envs +from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd +# yapf conflicts with isort for this docstring +# yapf: disable +from vllm.v1.attention.backends.mla.common import (MLACommonBackend, + MLACommonDecodeMetadata, + MLACommonImpl, + MLACommonMetadata, + MLACommonMetadataBuilder) + +# yapf: enable + + +def is_aiter_mla_enabled() -> bool: + return envs.VLLM_ROCM_USE_AITER \ + and envs.VLLM_ROCM_USE_AITER_MLA + + +class AiterMLABackend(MLACommonBackend): + + @staticmethod + def get_name() -> str: + return "ROCM_AITER_MLA_VLLM_V1" + + @staticmethod + def get_impl_cls() -> type["AiterMLAImpl"]: + return AiterMLAImpl + + @staticmethod + def get_metadata_cls() -> type["AiterMLAMetadata"]: + return AiterMLAMetadata + + @staticmethod + def get_builder_cls() -> type["AiterMLAMetadataBuilder"]: + return AiterMLAMetadataBuilder + + +@dataclass +class AiterMLADecodeMetadata(MLACommonDecodeMetadata): + # The indptr of the paged kv cache, shape: [batch_size + 1] + paged_kv_indptr: Optional[torch.Tensor] = None + # The page indices of the paged kv cache + paged_kv_indices: Optional[torch.Tensor] = None + # The number of entries in the last page of each request in + # the paged kv cache, shape: [batch_size] + paged_kv_last_page_len: Optional[torch.Tensor] = None + # The query indptr, shape : [num_decode + 1] + qo_indptr: Optional[torch.Tensor] = None + + +class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): + pass + + +class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): + + def __init__(self, runner): + super().__init__(runner) + assert self.runner.block_size == 1, "AITER MLA" \ + "only supports block size 1." + + def _get_paged_kv_tensors( + self, block_table: torch.Tensor, + seq_lens: torch.Tensor) -> tuple[torch.Tensor, ...]: + page_size = self.runner.block_size + block_table_bounds = (seq_lens + page_size - 1) // page_size + + mask = (torch.arange(block_table.size(1), + dtype=block_table.dtype, + device=block_table.device).unsqueeze(0) + < block_table_bounds.unsqueeze(1)) + paged_kv_indices = block_table[mask] + + paged_kv_indptr = torch.cat([ + torch.zeros(1, + dtype=block_table_bounds.dtype, + device=block_table_bounds.device), + block_table_bounds.cumsum(dim=0, dtype=torch.int32) + ]) + + paged_kv_last_page_len = seq_lens % page_size + paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, + page_size, paged_kv_last_page_len) + qo_indptr = torch.arange(0, + self._num_decodes + 1, + step=1, + dtype=torch.int32, + device=block_table_bounds.device) + return ( + paged_kv_indices, + paged_kv_indptr, + paged_kv_last_page_len, + qo_indptr, + ) + + def _build_decode(self, input_positions: torch.Tensor, + block_table: torch.Tensor, + seq_lens: torch.Tensor) -> AiterMLADecodeMetadata: + + ( + paged_kv_indices, + paged_kv_indptr, + paged_last_page_len, + qo_indptr, + ) = self._get_paged_kv_tensors(block_table, seq_lens) + + attn_metadata = AiterMLADecodeMetadata( + input_positions=input_positions, + block_table=block_table, + seq_lens=seq_lens, + paged_kv_indptr=paged_kv_indptr, + paged_kv_indices=paged_kv_indices, + paged_kv_last_page_len=paged_last_page_len, + qo_indptr=qo_indptr) + + return attn_metadata + + +class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + # MLA Specific Arguments + **mla_args) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params, logits_soft_cap, attn_type, + **mla_args) + + unsupported_features = [ + alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap + ] + if any(unsupported_features): + raise NotImplementedError( + "Aiter MLA does not support one of the following: " + "alibi_slopes, sliding_window, blocksparse_params, " + "logits_soft_cap") + + from aiter import flash_attn_varlen_func + self.flash_attn_varlen_func = flash_attn_varlen_func + + def _flash_attn_varlen_diff_headdims(self, + q, + k, + v, + return_softmax_lse=False, + softmax_scale=None, + **kwargs): + output = self.flash_attn_varlen_func( + q=q, + k=k, + v=v, + softmax_scale=softmax_scale, + return_lse=return_softmax_lse, + **kwargs, + ) + + return output + + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: AiterMLAMetadata, + ) -> torch.Tensor: + assert kv_c_and_k_pe_cache.numel() > 0 + assert attn_metadata.decode is not None + + B = q_nope.shape[0] + + q = torch.cat([q_nope, q_pe], dim=-1) + o = torch.zeros(B, + self.num_heads, + self.kv_lora_rank, + dtype=q.dtype, + device=q.device) + + kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) + + # max_seqlen_qo must be 1 except for MTP + # TODO: Find the best value for MTP + max_seqlen_qo = 1 + + aiter_mla_decode_fwd(q, kv_buffer, o, self.scale, + attn_metadata.decode.qo_indptr, max_seqlen_qo, + attn_metadata.decode.paged_kv_indptr, + attn_metadata.decode.paged_kv_indices, + attn_metadata.decode.paged_kv_last_page_len) + + return self._v_up_proj_and_o_proj(o) diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index cff6181fa3ad..be9371c907eb 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -38,7 +38,7 @@ POLLING_TIMEOUT_MS = 5000 POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000 -EXECUTE_MODEL_TIMEOUT_S = 30 +EXECUTE_MODEL_TIMEOUT_S = 300 class MultiprocExecutor(Executor): @@ -151,7 +151,7 @@ def execute_model( def collective_rpc(self, method: Union[str, Callable], - timeout: Optional[float] = 180.0, + timeout: Optional[float] = None, args: tuple = (), kwargs: Optional[dict] = None, rank0_reply_only: bool = False) -> list[Any]: