Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 52 additions & 40 deletions csrc/day15_fused_attention.cu
Original file line number Diff line number Diff line change
@@ -1,51 +1,63 @@
#include <cuda_runtime.h>
#define ceil(x, y) (((x) + (y) - 1) / (y))

// TODO: Fused Attention kernel 구현
// Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) * V
//
// 힌트:
// 1. Q @ K^T 계산 (matrix multiplication)
// 2. Scale by sqrt(d_k)
// 3. Apply mask (optional)
// 4. Softmax
// 5. @ V (matrix multiplication)
//
// 모든 연산을 하나의 커널로 융합하여 성능 최적화
//
// 입력: Q (num_heads, seq_len, head_dim) - batch_size는 항상 1
// K (num_heads, seq_len, head_dim)
// V (num_heads, seq_len, head_dim)
// mask (optional) - attention mask
// 출력: output (num_heads, seq_len, head_dim)

// Fused Attention: softmax(Q @ K^T * scale) @ V
// Each block handles one (head, query_row) pair
__global__ void fused_attention_kernel(
const float* Q,
const float* K,
const float* V,
float* output,
const float* mask,
int num_heads,
int seq_len,
int head_dim,
float scale
) {
// TODO: 구현하세요
// Fused Attention: QK^T -> scale -> mask -> softmax -> @V
int head_idx = blockIdx.y;
int seq_idx = blockIdx.x;
int head = blockIdx.y;
int row = blockIdx.x;
int d = threadIdx.x;

if (d >= head_dim) return;

int head_offset = head * seq_len * head_dim;
int q_base = head_offset + row * head_dim;

float q_val = Q[q_base + d];

// Online softmax variables
float max_score = -INFINITY;
float sum_exp = 0.0f;
float out_val = 0.0f;

// TODO: Attention 계산
// 각 thread는 하나의 head_dim element를 처리할 수 있습니다
int feature_idx = threadIdx.x;
// Process each key-value pair
for (int s = 0; s < seq_len; s++) {
int kv_base = head_offset + s * head_dim;

if (head_idx < num_heads && seq_idx < seq_len && feature_idx < head_dim) {
int idx = head_idx * seq_len * head_dim +
seq_idx * head_dim +
feature_idx;
// TODO: Fused Attention 계산
output[idx] = Q[idx];
// Compute Q · K (need reduction across threads)
__shared__ float dot_buffer[256];
dot_buffer[d] = q_val * K[kv_base + d];
__syncthreads();

// Parallel reduction for dot product
for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
if (d < stride && d + stride < head_dim) {
dot_buffer[d] += dot_buffer[d + stride];
}
__syncthreads();
}

float score = dot_buffer[0] * scale;

// Online softmax update
float new_max = fmaxf(max_score, score);
float correction = expf(max_score - new_max);
float exp_score = expf(score - new_max);

out_val = out_val * correction + exp_score * V[kv_base + d];
sum_exp = sum_exp * correction + exp_score;
max_score = new_max;
}

// Normalize and store
output[q_base + d] = out_val / sum_exp;
}

extern "C" void day15_fused_attention(
Expand All @@ -59,13 +71,13 @@ extern "C" void day15_fused_attention(
int head_dim,
float scale
) {
// TODO: kernel launch configuration 설정
// batch_size는 항상 1이므로 제거
dim3 threadsPerBlock(head_dim);
dim3 blocksPerGrid(seq_len, num_heads);
int threads = head_dim;
if (threads > 256) threads = 256;

dim3 blocks(seq_len, num_heads);

fused_attention_kernel<<<blocksPerGrid, threadsPerBlock>>>(
Q, K, V, output, mask, num_heads, seq_len, head_dim, scale
fused_attention_kernel<<<blocks, threads>>>(
Q, K, V, output, seq_len, head_dim, scale
);
cudaDeviceSynchronize();
}
69 changes: 36 additions & 33 deletions src/gpu_20days/day15_fused_attention.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""
Day 15: Fused Attention
Attention(Q, K, V) = softmax(Q @ K^T * scale) @ V
"""

from typing import Optional
Expand All @@ -15,31 +16,42 @@ def day15_fused_attention_kernel(
K_ptr,
V_ptr,
output_ptr,
mask_ptr,
num_heads,
seq_len,
head_dim,
scale,
BLOCK_SIZE: tl.constexpr,
BLOCK_D: tl.constexpr,
):
"""
TODO: Fused Attention kernel 구현
head = tl.program_id(0)
row = tl.program_id(1)

Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) * V
head_offset = head * seq_len * head_dim
q_base = head_offset + row * head_dim
d_idx = tl.arange(0, BLOCK_D)

힌트:
1. Q @ K^T 계산 (matrix multiplication)
2. Scale by sqrt(d_k)
3. Apply mask (optional)
4. Softmax
5. @ V (matrix multiplication)
q = tl.load(Q_ptr + q_base + d_idx, mask=d_idx < head_dim, other=0.0)

모든 연산을 하나의 커널로 융합하여 성능 최적화
scores = tl.zeros([1], dtype=tl.float32)
max_score = tl.zeros([1], dtype=tl.float32) - float("inf")
output = tl.zeros([BLOCK_D], dtype=tl.float32)
sum_exp = tl.zeros([1], dtype=tl.float32)

batch_size는 항상 1입니다.
"""
# TODO: 구현하세요
pass
for s in range(seq_len):
k = tl.load(K_ptr + head_offset + s * head_dim + d_idx, mask=d_idx < head_dim, other=0.0)
score = tl.sum(q * k) * scale

new_max = tl.maximum(max_score, score)
correction = tl.exp(max_score - new_max)
exp_score = tl.exp(score - new_max)

output = output * correction
sum_exp = sum_exp * correction + exp_score

v = tl.load(V_ptr + head_offset + s * head_dim + d_idx, mask=d_idx < head_dim, other=0.0)
output += exp_score * v
max_score = new_max

output = output / sum_exp
tl.store(output_ptr + q_base + d_idx, output, mask=d_idx < head_dim)


def day15_fused_attention(
Expand All @@ -49,26 +61,17 @@ def day15_fused_attention(
mask: Optional[torch.Tensor] = None,
scale: Optional[float] = None,
) -> torch.Tensor:
"""Day 15: Fused attention mechanism (batch_size is always 1)"""
# TODO: 구현하세요
BLOCK_SIZE = 256
# batch_size는 항상 1이므로 입력은 3D
if Q.dim() != 3:
raise ValueError(
"day15_fused_attention expects 3D tensor (num_heads, seq_len, head_dim), batch_size is always 1"
)

"""Day 15: Fused attention (batch_size is always 1)"""
assert Q.dim() == 3, "Expected 3D tensor (num_heads, seq_len, head_dim)"
num_heads, seq_len, head_dim = Q.shape

if scale is None:
scale = 1.0 / (head_dim**0.5)

output = torch.zeros_like(Q)

def grid(meta):
return (num_heads, triton.cdiv(seq_len, BLOCK_SIZE))
output = torch.empty_like(Q)
BLOCK_D = triton.next_power_of_2(head_dim)

# day15_fused_attention_kernel[grid](
# Q, K, V, output, mask, num_heads, seq_len, head_dim, scale, BLOCK_SIZE=BLOCK_SIZE
# )
day15_fused_attention_kernel[(num_heads, seq_len)](
Q, K, V, output, seq_len, head_dim, scale, BLOCK_D=BLOCK_D
)
return output
Loading