Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 43 additions & 31 deletions csrc/day19_rope.cu
Original file line number Diff line number Diff line change
@@ -1,20 +1,8 @@
#include <cuda_runtime.h>
#include <cmath>
#define ceil(x, y) (((x) + (y) - 1) / (y))

// TODO: RoPE (Rotary Position Embedding) kernel 구현
// 회전 위치 임베딩을 적용합니다
//
// 힌트:
// 1. 삼각함수 연산 (sin, cos) 사용
// 2. 위치 정보를 이용한 회전 변환
// 3. Query와 Key에 각각 적용
//
// 입력: query (num_heads, seq_len, head_dim) - batch_size는 항상 1
// key (num_heads, seq_len, head_dim)
// cos_cache (seq_len, head_dim / 2)
// sin_cache (seq_len, head_dim / 2)
// 출력: rotated_query, rotated_key
// RoPE: (x_2i, x_2i+1) 쌍을 cos/sin으로 회전
// x'_2i = x_2i*cos - x_2i+1*sin, x'_2i+1 = x_2i*sin + x_2i+1*cos

__global__ void rope_kernel(
const float* query,
Expand All @@ -27,19 +15,42 @@ __global__ void rope_kernel(
int seq_len,
int head_dim
) {
// TODO: 구현하세요
int head_idx = blockIdx.y;
int seq_idx = blockIdx.x;
int dim_idx = threadIdx.x;
// 1차원 그리드: 스레드 하나가 한 (head, seq, pair) 담당
int total_pairs = num_heads * seq_len * (head_dim / 2);
int idx = blockIdx.x * blockDim.x + threadIdx.x;

if (head_idx < num_heads && seq_idx < seq_len && dim_idx < head_dim) {
// TODO: RoPE 계산
int q_idx = head_idx * seq_len * head_dim +
seq_idx * head_dim +
dim_idx;
rotated_query[q_idx] = query[q_idx];
rotated_key[q_idx] = key[q_idx];
}
if (idx >= total_pairs)
return;

int half_dim = head_dim / 2;
int pairs_per_head = seq_len * half_dim;

int head_idx = idx / pairs_per_head;
int rest = idx % pairs_per_head;
int seq_idx = rest / half_dim;
int pair_idx = rest % half_dim;

// query/key에서 이 쌍의 인덱스
int base = head_idx * seq_len * head_dim + seq_idx * head_dim;
int d0 = base + 2 * pair_idx;
int d1 = base + 2 * pair_idx + 1;

// cos, sin 캐시에서 읽기 (캐시 shape: seq_len, half_dim)
int cache_idx = seq_idx * half_dim + pair_idx;
float cos_val = cos_cache[cache_idx];
float sin_val = sin_cache[cache_idx];

// query 회전
float q0 = query[d0];
float q1 = query[d1];
rotated_query[d0] = q0 * cos_val - q1 * sin_val;
rotated_query[d1] = q0 * sin_val + q1 * cos_val;

// key 회전
float k0 = key[d0];
float k1 = key[d1];
rotated_key[d0] = k0 * cos_val - k1 * sin_val;
rotated_key[d1] = k0 * sin_val + k1 * cos_val;
}

extern "C" void day19_rope(
Expand All @@ -53,12 +64,13 @@ extern "C" void day19_rope(
int seq_len,
int head_dim
) {
// TODO: kernel launch configuration 설정
// batch_size는 항상 1이므로 제거
dim3 threadsPerBlock(head_dim);
dim3 blocksPerGrid(seq_len, num_heads);
int half_dim = head_dim / 2;
int total_pairs = num_heads * seq_len * half_dim;

int threads = 256;
int blocks = ceil(total_pairs, threads);

rope_kernel<<<blocksPerGrid, threadsPerBlock>>>(
rope_kernel<<<blocks, threads>>>(
query, key, rotated_query, rotated_key,
cos_cache, sin_cache,
num_heads, seq_len, head_dim
Expand Down
64 changes: 39 additions & 25 deletions src/gpu_20days/day19_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,48 +18,62 @@ def day19_rope_kernel(
num_heads,
seq_len,
head_dim,
BLOCK_SIZE: tl.constexpr,
):
"""
TODO: RoPE kernel 구현
"""RoPE: 한 프로그램이 (head, seq, pair) 하나만 처리. x'_2i = x_2i*cos - x_2i+1*sin 등"""
half_dim = head_dim // 2

회전 위치 임베딩을 적용합니다
# 그리드: (num_heads, seq_len, half_dim) → 프로그램 하나당 pair 하나
head_idx = tl.program_id(0)
seq_idx = tl.program_id(1)
pair_idx = tl.program_id(2)

힌트:
1. 삼각함수 연산 (sin, cos) 사용
2. 위치 정보를 이용한 회전 변환
3. Query와 Key에 각각 적용
base = head_idx * seq_len * head_dim + seq_idx * head_dim
d0 = base + 2 * pair_idx
d1 = base + 2 * pair_idx + 1

batch_size는 항상 1입니다.
"""
# TODO: 구현하세요
pass
cache_idx = seq_idx * half_dim + pair_idx
cos_val = tl.load(cos_cache_ptr + cache_idx)
sin_val = tl.load(sin_cache_ptr + cache_idx)

q0 = tl.load(query_ptr + d0)
q1 = tl.load(query_ptr + d1)
tl.store(rotated_query_ptr + d0, q0 * cos_val - q1 * sin_val)
tl.store(rotated_query_ptr + d1, q0 * sin_val + q1 * cos_val)

k0 = tl.load(key_ptr + d0)
k1 = tl.load(key_ptr + d1)
tl.store(rotated_key_ptr + d0, k0 * cos_val - k1 * sin_val)
tl.store(rotated_key_ptr + d1, k0 * sin_val + k1 * cos_val)


def day19_rope(
query: torch.Tensor, key: torch.Tensor, cos_cache: torch.Tensor, sin_cache: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""Day 19: Rotary position embedding (batch_size is always 1)"""
# TODO: 구현하세요
BLOCK_SIZE = 256
# batch_size는 항상 1이므로 입력은 3D
"""Day 19: Rotary position embedding (batch_size는 항상 1)"""
if query.dim() != 3:
raise ValueError(
"day19_rope expects 3D tensor (num_heads, seq_len, head_dim), batch_size is always 1"
)

num_heads, seq_len, head_dim = query.shape
half_dim = head_dim // 2

rotated_query = torch.zeros_like(query)
rotated_key = torch.zeros_like(key)
rotated_query = torch.empty_like(query)
rotated_key = torch.empty_like(key)

# 그리드: (num_heads, seq_len, half_dim) → 프로그램 하나당 pair 하나
def grid(meta):
return (num_heads, triton.cdiv(seq_len, BLOCK_SIZE))
return (num_heads, seq_len, half_dim)

# day19_rope_kernel[grid](
# query, key, rotated_query, rotated_key,
# cos_cache, sin_cache,
# num_heads, seq_len, head_dim,
# BLOCK_SIZE=BLOCK_SIZE
# )
day19_rope_kernel[grid](
query,
key,
rotated_query,
rotated_key,
cos_cache,
sin_cache,
num_heads=num_heads,
seq_len=seq_len,
head_dim=head_dim,
)
return rotated_query, rotated_key
Loading