diff --git a/src/kernels.cu b/src/kernels.cu index 8df8130..c81b460 100644 --- a/src/kernels.cu +++ b/src/kernels.cu @@ -1,4 +1,5 @@ #include +#include #include "../tester/utils.h" @@ -15,10 +16,69 @@ * @note For invalid cases, return T(-100). * @note Handles device memory management (allocate/copy/free) internally. Errors should be thrown. */ + +template +__global__ void bitonic_sort_kernel(T* input, int n, int desc, int k, int j) { + int x = blockDim.x * blockIdx.x + threadIdx.x; + int y = x ^ j; + if (x >= n || x > y) return; + T valx = input[x]; + T valy = input[y]; + T val; + if ((valx < valy) == (desc ^ ((x & k) != 0))) { + input[x] = valy; + input[y] = valx; + } +} + +template +__global__ void init_input(T* input, int n, int mx) { + int x = blockDim.x * blockIdx.x + threadIdx.x; + if (x < n && x >= mx) { + input[x] = T(-1e9); + } +} + template T kthLargest(const std::vector& h_input, size_t k) { // TODO: Implement the kthLargest function - return T(-1000); + int n = h_input.size(); + + if (k < 1 || k > n) + return T(-100); + + while (__builtin_popcount(n) != 1) n += (n & (-n)); + + T *d_input; + + size_t size_input = h_input.size() * sizeof(T); + size_t size_malloc = n * sizeof(T); + + CUDA_CHECK(cudaMalloc(&d_input, size_malloc)); + + CUDA_CHECK(cudaMemcpy(d_input, h_input.data(), size_input, cudaMemcpyHostToDevice)); + + + int blockSize = 256; + int gridSize = (n + 255) / 256; + + init_input<<>>(d_input, n, h_input.size()); + CUDA_CHECK(cudaDeviceSynchronize()); + + + for (int k = 2; k <= n; k <<= 1) { + for (int j = (k >> 1); j > 0; j >>= 1) { + bitonic_sort_kernel<<>>(d_input, n, 1, k, j); + CUDA_CHECK(cudaDeviceSynchronize()); + } + } + + T result; + CUDA_CHECK(cudaMemcpy(&result, d_input + (k - 1), sizeof(T), cudaMemcpyDeviceToHost)); + + CUDA_CHECK(cudaFree(d_input)); + + return result; } /** @@ -37,11 +97,177 @@ T kthLargest(const std::vector& h_input, size_t k) { * @param[in] head_dim Dimension size of each attention head * @param[in] is_causal Whether to apply causal masking */ + template -void flashAttention(const std::vector& h_q, const std::vector& h_k, - const std::vector& h_v, std::vector& h_o, - int batch_size, int target_seq_len, int src_seq_len, - int query_heads, int kv_heads, int head_dim, bool is_causal) { +__device__ T block_reduce_max(T* smem, T val) { + int tid = threadIdx.x; + int lane = tid % 32; + int warp_id = tid / 32; + + for (int offset = 16; offset > 0; offset >>= 1) { + T other = __shfl_xor_sync(0xFFFFFFFF, val, offset); + val = fmaxf(val, other); + } + + if (lane == 0) + smem[warp_id] = val; + __syncthreads(); + + if (warp_id == 0) { + val = lane >= 8 ? -1e9 : smem[lane]; + for (int offset = 16; offset > 0; offset >>= 1) { + T other = __shfl_xor_sync(0xFFFFFFFF, val, offset); + val = fmaxf(val, other); + } + smem[0] = val; + } + __syncthreads(); + + return smem[0]; +} + +template +__device__ T block_reduce_sum(T* smem, T val) { + int tid = threadIdx.x; + int lane = tid % 32; + int warp_id = tid / 32; + + for (int offset = 16; offset > 0; offset >>= 1) { + T other = __shfl_xor_sync(0xFFFFFFFF, val, offset); + val = val + other; + } + + if (lane == 0) + smem[warp_id] = val; + __syncthreads(); + + if (warp_id == 0) { + val = lane >= 8 ? 0 : smem[lane]; + for (int offset = 16; offset > 0; offset >>= 1) { + T other = __shfl_xor_sync(0xFFFFFFFF, val, offset); + val = val + other; + } + smem[0] = val; + } + __syncthreads(); + + return smem[0]; +} + +template +__global__ void flashAttentionKernel( + const T* query, const T* key, + const T* value, T* output, + int batch_size, int target_seq_len, int src_seq_len, + int query_heads, int kv_heads, int head_dim, bool is_causal +) { + + + int batch_id = blockIdx.x; + int heads_id = blockIdx.y; + int kv_id = heads_id * kv_heads / query_heads; + + if (batch_id >= batch_size || heads_id >= query_heads) return; + + extern __shared__ unsigned char shared_mem[]; + T* scores = reinterpret_cast(shared_mem); + T* smem = scores + src_seq_len; + + for (int tgt = 0; tgt < target_seq_len; tgt ++) { + float mx = -1e9; + for (int src = threadIdx.x; src < src_seq_len; src += blockDim.x) { + if (is_causal && tgt < src) { + scores[src] = -1e9f; + } else { + float sum = 0.; + for (int dim = 0; dim < head_dim; dim ++) { + // query[batch_id][tgt][heads_id][dim] + int qid = dim + head_dim * (heads_id + query_heads * (tgt + target_seq_len * batch_id)); + // key[batch_id][src][kv_id][dim] + int kid = dim + head_dim * (kv_id + kv_heads * (src + src_seq_len * batch_id)); + sum += query[qid] * key[kid]; + } + scores[src] = sum / sqrtf(float(head_dim)); + mx = fmaxf(mx, scores[src]); + } + } + __syncthreads(); + + mx = block_reduce_max(smem, mx); + + T sum = 0.; + for (int src = threadIdx.x; src < src_seq_len; src += blockDim.x) { + scores[src] = expf(scores[src] - mx); + sum += scores[src]; + } + __syncthreads(); + + sum = block_reduce_sum(smem, sum); + + for (int src = threadIdx.x; src < src_seq_len; src += blockDim.x) { + scores[src] = scores[src] / (sum + 1e-8f); + } + __syncthreads(); + + for (int dim = threadIdx.x; dim < head_dim; dim += blockDim.x) { + float sum = 0.; + for (int src = 0; src < src_seq_len; src ++) { + // value[batch_id][src][kv_id][dim] + int vidx = dim + head_dim * (kv_id + kv_heads * (src + src_seq_len * batch_id)); + sum += scores[src] * value[vidx]; + } + // output[batch_id][tgt][heads_id][dim] + int oidx = dim + head_dim * (heads_id + query_heads * (tgt + target_seq_len * batch_id)); + output[oidx] = sum; + } + __syncthreads(); + } +} + + +template +void flashAttention( + const std::vector& h_q, const std::vector& h_k, + const std::vector& h_v, std::vector& h_o, + int batch_size, int target_seq_len, int src_seq_len, + int query_heads, int kv_heads, int head_dim, bool is_causal +) { + + h_o.resize(batch_size * target_seq_len * query_heads * head_dim); + + T* d_q, *d_k, *d_v, *d_o; + size_t size_q = h_q.size() * sizeof(T); + size_t size_k = h_k.size() * sizeof(T); + size_t size_v = h_v.size() * sizeof(T); + size_t size_o = h_o.size() * sizeof(T); + + CUDA_CHECK(cudaMalloc(&d_q, size_q)); + CUDA_CHECK(cudaMalloc(&d_k, size_k)); + CUDA_CHECK(cudaMalloc(&d_v, size_v)); + CUDA_CHECK(cudaMalloc(&d_o, size_o)); + + CUDA_CHECK(cudaMemcpy(d_q, h_q.data(), size_q, cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(d_k, h_k.data(), size_k, cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(d_v, h_v.data(), size_v, cudaMemcpyHostToDevice)); + + dim3 gridSize(batch_size, query_heads); + int blockSize = 256; + size_t shared_mem = src_seq_len * sizeof(T) + 8 * sizeof(T); + flashAttentionKernel<<>>( + d_q, d_k, d_v, d_o, + batch_size, target_seq_len, src_seq_len, + query_heads, kv_heads, head_dim, is_causal + ); + + CUDA_CHECK(cudaDeviceSynchronize()); + + CUDA_CHECK(cudaMemcpy(h_o.data(), d_o, size_o, cudaMemcpyDeviceToHost)); + + CUDA_CHECK(cudaFree(d_q)); + CUDA_CHECK(cudaFree(d_k)); + CUDA_CHECK(cudaFree(d_v)); + CUDA_CHECK(cudaFree(d_o)); + } // *********************************************************************