diff --git a/csrc/day15_fused_attention.cu b/csrc/day15_fused_attention.cu index 06f3ff7..78234c9 100644 --- a/csrc/day15_fused_attention.cu +++ b/csrc/day15_fused_attention.cu @@ -1,51 +1,63 @@ #include -#define ceil(x, y) (((x) + (y) - 1) / (y)) - -// TODO: Fused Attention kernel 구현 -// Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) * V -// -// 힌트: -// 1. Q @ K^T 계산 (matrix multiplication) -// 2. Scale by sqrt(d_k) -// 3. Apply mask (optional) -// 4. Softmax -// 5. @ V (matrix multiplication) -// -// 모든 연산을 하나의 커널로 융합하여 성능 최적화 -// -// 입력: Q (num_heads, seq_len, head_dim) - batch_size는 항상 1 -// K (num_heads, seq_len, head_dim) -// V (num_heads, seq_len, head_dim) -// mask (optional) - attention mask -// 출력: output (num_heads, seq_len, head_dim) +// Fused Attention: softmax(Q @ K^T * scale) @ V +// Each block handles one (head, query_row) pair __global__ void fused_attention_kernel( const float* Q, const float* K, const float* V, float* output, - const float* mask, - int num_heads, int seq_len, int head_dim, float scale ) { - // TODO: 구현하세요 - // Fused Attention: QK^T -> scale -> mask -> softmax -> @V - int head_idx = blockIdx.y; - int seq_idx = blockIdx.x; + int head = blockIdx.y; + int row = blockIdx.x; + int d = threadIdx.x; + + if (d >= head_dim) return; + + int head_offset = head * seq_len * head_dim; + int q_base = head_offset + row * head_dim; + + float q_val = Q[q_base + d]; + + // Online softmax variables + float max_score = -INFINITY; + float sum_exp = 0.0f; + float out_val = 0.0f; - // TODO: Attention 계산 - // 각 thread는 하나의 head_dim element를 처리할 수 있습니다 - int feature_idx = threadIdx.x; + // Process each key-value pair + for (int s = 0; s < seq_len; s++) { + int kv_base = head_offset + s * head_dim; - if (head_idx < num_heads && seq_idx < seq_len && feature_idx < head_dim) { - int idx = head_idx * seq_len * head_dim + - seq_idx * head_dim + - feature_idx; - // TODO: Fused Attention 계산 - output[idx] = Q[idx]; + // Compute Q · K (need reduction across threads) + __shared__ float dot_buffer[256]; + dot_buffer[d] = q_val * K[kv_base + d]; + __syncthreads(); + + // Parallel reduction for dot product + for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) { + if (d < stride && d + stride < head_dim) { + dot_buffer[d] += dot_buffer[d + stride]; + } + __syncthreads(); + } + + float score = dot_buffer[0] * scale; + + // Online softmax update + float new_max = fmaxf(max_score, score); + float correction = expf(max_score - new_max); + float exp_score = expf(score - new_max); + + out_val = out_val * correction + exp_score * V[kv_base + d]; + sum_exp = sum_exp * correction + exp_score; + max_score = new_max; } + + // Normalize and store + output[q_base + d] = out_val / sum_exp; } extern "C" void day15_fused_attention( @@ -59,13 +71,13 @@ extern "C" void day15_fused_attention( int head_dim, float scale ) { - // TODO: kernel launch configuration 설정 - // batch_size는 항상 1이므로 제거 - dim3 threadsPerBlock(head_dim); - dim3 blocksPerGrid(seq_len, num_heads); + int threads = head_dim; + if (threads > 256) threads = 256; + + dim3 blocks(seq_len, num_heads); - fused_attention_kernel<<>>( - Q, K, V, output, mask, num_heads, seq_len, head_dim, scale + fused_attention_kernel<<>>( + Q, K, V, output, seq_len, head_dim, scale ); cudaDeviceSynchronize(); } diff --git a/src/gpu_20days/day15_fused_attention.py b/src/gpu_20days/day15_fused_attention.py index 4d8f586..fdbfb99 100644 --- a/src/gpu_20days/day15_fused_attention.py +++ b/src/gpu_20days/day15_fused_attention.py @@ -1,5 +1,6 @@ """ Day 15: Fused Attention +Attention(Q, K, V) = softmax(Q @ K^T * scale) @ V """ from typing import Optional @@ -15,31 +16,42 @@ def day15_fused_attention_kernel( K_ptr, V_ptr, output_ptr, - mask_ptr, - num_heads, seq_len, head_dim, scale, - BLOCK_SIZE: tl.constexpr, + BLOCK_D: tl.constexpr, ): - """ - TODO: Fused Attention kernel 구현 + head = tl.program_id(0) + row = tl.program_id(1) - Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) * V + head_offset = head * seq_len * head_dim + q_base = head_offset + row * head_dim + d_idx = tl.arange(0, BLOCK_D) - 힌트: - 1. Q @ K^T 계산 (matrix multiplication) - 2. Scale by sqrt(d_k) - 3. Apply mask (optional) - 4. Softmax - 5. @ V (matrix multiplication) + q = tl.load(Q_ptr + q_base + d_idx, mask=d_idx < head_dim, other=0.0) - 모든 연산을 하나의 커널로 융합하여 성능 최적화 + scores = tl.zeros([1], dtype=tl.float32) + max_score = tl.zeros([1], dtype=tl.float32) - float("inf") + output = tl.zeros([BLOCK_D], dtype=tl.float32) + sum_exp = tl.zeros([1], dtype=tl.float32) - batch_size는 항상 1입니다. - """ - # TODO: 구현하세요 - pass + for s in range(seq_len): + k = tl.load(K_ptr + head_offset + s * head_dim + d_idx, mask=d_idx < head_dim, other=0.0) + score = tl.sum(q * k) * scale + + new_max = tl.maximum(max_score, score) + correction = tl.exp(max_score - new_max) + exp_score = tl.exp(score - new_max) + + output = output * correction + sum_exp = sum_exp * correction + exp_score + + v = tl.load(V_ptr + head_offset + s * head_dim + d_idx, mask=d_idx < head_dim, other=0.0) + output += exp_score * v + max_score = new_max + + output = output / sum_exp + tl.store(output_ptr + q_base + d_idx, output, mask=d_idx < head_dim) def day15_fused_attention( @@ -49,26 +61,17 @@ def day15_fused_attention( mask: Optional[torch.Tensor] = None, scale: Optional[float] = None, ) -> torch.Tensor: - """Day 15: Fused attention mechanism (batch_size is always 1)""" - # TODO: 구현하세요 - BLOCK_SIZE = 256 - # batch_size는 항상 1이므로 입력은 3D - if Q.dim() != 3: - raise ValueError( - "day15_fused_attention expects 3D tensor (num_heads, seq_len, head_dim), batch_size is always 1" - ) - + """Day 15: Fused attention (batch_size is always 1)""" + assert Q.dim() == 3, "Expected 3D tensor (num_heads, seq_len, head_dim)" num_heads, seq_len, head_dim = Q.shape if scale is None: scale = 1.0 / (head_dim**0.5) - output = torch.zeros_like(Q) - - def grid(meta): - return (num_heads, triton.cdiv(seq_len, BLOCK_SIZE)) + output = torch.empty_like(Q) + BLOCK_D = triton.next_power_of_2(head_dim) - # day15_fused_attention_kernel[grid]( - # Q, K, V, output, mask, num_heads, seq_len, head_dim, scale, BLOCK_SIZE=BLOCK_SIZE - # ) + day15_fused_attention_kernel[(num_heads, seq_len)]( + Q, K, V, output, seq_len, head_dim, scale, BLOCK_D=BLOCK_D + ) return output