From 52d70fb4de1fed7e780c23a7665c1979338dcb53 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 3 Feb 2026 21:21:49 +0900 Subject: [PATCH] CUDA challenge day 19 --- csrc/day19_rope.cu | 74 +++++++++++++++++++++--------------- src/gpu_20days/day19_rope.py | 64 +++++++++++++++++++------------ 2 files changed, 82 insertions(+), 56 deletions(-) diff --git a/csrc/day19_rope.cu b/csrc/day19_rope.cu index 0f0d708..97cf966 100644 --- a/csrc/day19_rope.cu +++ b/csrc/day19_rope.cu @@ -1,20 +1,8 @@ #include -#include #define ceil(x, y) (((x) + (y) - 1) / (y)) -// TODO: RoPE (Rotary Position Embedding) kernel 구현 -// 회전 위치 임베딩을 적용합니다 -// -// 힌트: -// 1. 삼각함수 연산 (sin, cos) 사용 -// 2. 위치 정보를 이용한 회전 변환 -// 3. Query와 Key에 각각 적용 -// -// 입력: query (num_heads, seq_len, head_dim) - batch_size는 항상 1 -// key (num_heads, seq_len, head_dim) -// cos_cache (seq_len, head_dim / 2) -// sin_cache (seq_len, head_dim / 2) -// 출력: rotated_query, rotated_key +// RoPE: (x_2i, x_2i+1) 쌍을 cos/sin으로 회전 +// x'_2i = x_2i*cos - x_2i+1*sin, x'_2i+1 = x_2i*sin + x_2i+1*cos __global__ void rope_kernel( const float* query, @@ -27,19 +15,42 @@ __global__ void rope_kernel( int seq_len, int head_dim ) { - // TODO: 구현하세요 - int head_idx = blockIdx.y; - int seq_idx = blockIdx.x; - int dim_idx = threadIdx.x; + // 1차원 그리드: 스레드 하나가 한 (head, seq, pair) 담당 + int total_pairs = num_heads * seq_len * (head_dim / 2); + int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (head_idx < num_heads && seq_idx < seq_len && dim_idx < head_dim) { - // TODO: RoPE 계산 - int q_idx = head_idx * seq_len * head_dim + - seq_idx * head_dim + - dim_idx; - rotated_query[q_idx] = query[q_idx]; - rotated_key[q_idx] = key[q_idx]; - } + if (idx >= total_pairs) + return; + + int half_dim = head_dim / 2; + int pairs_per_head = seq_len * half_dim; + + int head_idx = idx / pairs_per_head; + int rest = idx % pairs_per_head; + int seq_idx = rest / half_dim; + int pair_idx = rest % half_dim; + + // query/key에서 이 쌍의 인덱스 + int base = head_idx * seq_len * head_dim + seq_idx * head_dim; + int d0 = base + 2 * pair_idx; + int d1 = base + 2 * pair_idx + 1; + + // cos, sin 캐시에서 읽기 (캐시 shape: seq_len, half_dim) + int cache_idx = seq_idx * half_dim + pair_idx; + float cos_val = cos_cache[cache_idx]; + float sin_val = sin_cache[cache_idx]; + + // query 회전 + float q0 = query[d0]; + float q1 = query[d1]; + rotated_query[d0] = q0 * cos_val - q1 * sin_val; + rotated_query[d1] = q0 * sin_val + q1 * cos_val; + + // key 회전 + float k0 = key[d0]; + float k1 = key[d1]; + rotated_key[d0] = k0 * cos_val - k1 * sin_val; + rotated_key[d1] = k0 * sin_val + k1 * cos_val; } extern "C" void day19_rope( @@ -53,12 +64,13 @@ extern "C" void day19_rope( int seq_len, int head_dim ) { - // TODO: kernel launch configuration 설정 - // batch_size는 항상 1이므로 제거 - dim3 threadsPerBlock(head_dim); - dim3 blocksPerGrid(seq_len, num_heads); + int half_dim = head_dim / 2; + int total_pairs = num_heads * seq_len * half_dim; + + int threads = 256; + int blocks = ceil(total_pairs, threads); - rope_kernel<<>>( + rope_kernel<<>>( query, key, rotated_query, rotated_key, cos_cache, sin_cache, num_heads, seq_len, head_dim diff --git a/src/gpu_20days/day19_rope.py b/src/gpu_20days/day19_rope.py index b6c9be4..cdd287e 100644 --- a/src/gpu_20days/day19_rope.py +++ b/src/gpu_20days/day19_rope.py @@ -18,48 +18,62 @@ def day19_rope_kernel( num_heads, seq_len, head_dim, - BLOCK_SIZE: tl.constexpr, ): - """ - TODO: RoPE kernel 구현 + """RoPE: 한 프로그램이 (head, seq, pair) 하나만 처리. x'_2i = x_2i*cos - x_2i+1*sin 등""" + half_dim = head_dim // 2 - 회전 위치 임베딩을 적용합니다 + # 그리드: (num_heads, seq_len, half_dim) → 프로그램 하나당 pair 하나 + head_idx = tl.program_id(0) + seq_idx = tl.program_id(1) + pair_idx = tl.program_id(2) - 힌트: - 1. 삼각함수 연산 (sin, cos) 사용 - 2. 위치 정보를 이용한 회전 변환 - 3. Query와 Key에 각각 적용 + base = head_idx * seq_len * head_dim + seq_idx * head_dim + d0 = base + 2 * pair_idx + d1 = base + 2 * pair_idx + 1 - batch_size는 항상 1입니다. - """ - # TODO: 구현하세요 - pass + cache_idx = seq_idx * half_dim + pair_idx + cos_val = tl.load(cos_cache_ptr + cache_idx) + sin_val = tl.load(sin_cache_ptr + cache_idx) + + q0 = tl.load(query_ptr + d0) + q1 = tl.load(query_ptr + d1) + tl.store(rotated_query_ptr + d0, q0 * cos_val - q1 * sin_val) + tl.store(rotated_query_ptr + d1, q0 * sin_val + q1 * cos_val) + + k0 = tl.load(key_ptr + d0) + k1 = tl.load(key_ptr + d1) + tl.store(rotated_key_ptr + d0, k0 * cos_val - k1 * sin_val) + tl.store(rotated_key_ptr + d1, k0 * sin_val + k1 * cos_val) def day19_rope( query: torch.Tensor, key: torch.Tensor, cos_cache: torch.Tensor, sin_cache: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: - """Day 19: Rotary position embedding (batch_size is always 1)""" - # TODO: 구현하세요 - BLOCK_SIZE = 256 - # batch_size는 항상 1이므로 입력은 3D + """Day 19: Rotary position embedding (batch_size는 항상 1)""" if query.dim() != 3: raise ValueError( "day19_rope expects 3D tensor (num_heads, seq_len, head_dim), batch_size is always 1" ) num_heads, seq_len, head_dim = query.shape + half_dim = head_dim // 2 - rotated_query = torch.zeros_like(query) - rotated_key = torch.zeros_like(key) + rotated_query = torch.empty_like(query) + rotated_key = torch.empty_like(key) + # 그리드: (num_heads, seq_len, half_dim) → 프로그램 하나당 pair 하나 def grid(meta): - return (num_heads, triton.cdiv(seq_len, BLOCK_SIZE)) + return (num_heads, seq_len, half_dim) - # day19_rope_kernel[grid]( - # query, key, rotated_query, rotated_key, - # cos_cache, sin_cache, - # num_heads, seq_len, head_dim, - # BLOCK_SIZE=BLOCK_SIZE - # ) + day19_rope_kernel[grid]( + query, + key, + rotated_query, + rotated_key, + cos_cache, + sin_cache, + num_heads=num_heads, + seq_len=seq_len, + head_dim=head_dim, + ) return rotated_query, rotated_key