Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 35 additions & 29 deletions csrc/day20_conv2d.cu
Original file line number Diff line number Diff line change
@@ -1,18 +1,7 @@
#include <cuda_runtime.h>
#define ceil(x, y) (((x) + (y) - 1) / (y))

// TODO: 2D Convolution kernel 구현
// 2D 컨볼루션 연산을 수행합니다
//
// 힌트:
// 1. 2D 슬라이딩 윈도우 패턴
// 2. 메모리 타일링 최적화
// 3. Shared memory 활용
//
// 입력: input (in_channels, height, width) - batch_size는 항상 1
// kernel (out_channels, in_channels, kernel_h, kernel_w)
// padding, stride
// 출력: output (out_channels, out_height, out_width)
// 2D Convolution: 출력 한 픽셀당 스레드 하나. 입력 채널 × 커널 높이 × 커널 너비만큼 합산.

__global__ void conv2d_kernel(
const float* input,
Expand All @@ -31,21 +20,39 @@ __global__ void conv2d_kernel(
int stride_h,
int stride_w
) {
// TODO: 구현하세요
// 2D 컨볼루션 계산
int out_channel_idx = blockIdx.y;
int out_row = blockIdx.x / output_w;
int out_col = blockIdx.x % output_w;
// 1차원 그리드: 스레드 하나가 출력 한 칸 담당
int total_out = out_channels * output_h * output_w;
int idx = blockIdx.x * blockDim.x + threadIdx.x;

int thread_idx = threadIdx.x;
if (idx >= total_out)
return;

if (out_channel_idx < out_channels && out_row < output_h && out_col < output_w) {
// TODO: 2D Convolution 계산
int out_idx = out_channel_idx * output_h * output_w +
out_row * output_w +
out_col;
output[out_idx] = 0.0f;
// 출력 인덱스 (oc, oh, ow) 계산
int out_spatial = output_h * output_w;
int out_channel_idx = idx / out_spatial;
int out_linear = idx % out_spatial;
int out_row = out_linear / output_w;
int out_col = out_linear % output_w;

// 이 출력 칸에 대해서 입력 × 커널 합산
float sum = 0.0f;
for (int c = 0; c < in_channels; c++) {
for (int kh = 0; kh < kernel_h; kh++) {
for (int kw = 0; kw < kernel_w; kw++) {
int in_h = out_row * stride_h + kh - pad_h;
int in_w = out_col * stride_w + kw - pad_w;
if (in_h >= 0 && in_h < input_h && in_w >= 0 && in_w < input_w) {
int in_idx = c * input_h * input_w + in_h * input_w + in_w;
int k_idx = out_channel_idx * in_channels * kernel_h * kernel_w
+ c * kernel_h * kernel_w + kh * kernel_w + kw;
sum += input[in_idx] * kernel[k_idx];
}
}
}
}

int out_idx = out_channel_idx * output_h * output_w + out_row * output_w + out_col;
output[out_idx] = sum;
}

extern "C" void day20_conv2d(
Expand All @@ -65,12 +72,11 @@ extern "C" void day20_conv2d(
int stride_h,
int stride_w
) {
// TODO: kernel launch configuration 설정
// batch_size는 항상 1이므로 제거
dim3 threadsPerBlock(256);
dim3 blocksPerGrid(output_h * output_w, out_channels);
int total_out = output_h * output_w * out_channels;
int threads = 256;
int blocks = ceil(total_out, threads);

conv2d_kernel<<<blocksPerGrid, threadsPerBlock>>>(
conv2d_kernel<<<blocks, threads>>>(
input, kernel, output,
in_channels, out_channels,
input_h, input_w, kernel_h, kernel_w,
Expand Down
75 changes: 50 additions & 25 deletions src/gpu_20days/day20_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,39 @@ def day20_conv2d_kernel(
pad_w,
stride_h,
stride_w,
BLOCK_SIZE: tl.constexpr,
):
"""
TODO: 2D Convolution kernel 구현
"""2D Convolution: 한 프로그램이 출력 한 칸만 계산 (채널×커널H×커널W 합산)"""
idx = tl.program_id(0)
out_spatial = output_h * output_w
total_out = out_channels * out_spatial

2D 컨볼루션 연산을 수행합니다
if idx >= total_out:
return

힌트:
1. 2D 슬라이딩 윈도우 패턴
2. 메모리 타일링 최적화
3. Shared memory 활용
out_channel_idx = idx // out_spatial
out_linear = idx % out_spatial
out_row = out_linear // output_w
out_col = out_linear % output_w

batch_size는 항상 1입니다.
"""
# TODO: 구현하세요
pass
acc = 0.0
for c in range(in_channels):
for kh in range(kernel_h):
for kw in range(kernel_w):
in_row = out_row * stride_h + kh - pad_h
in_col = out_col * stride_w + kw - pad_w
in_bounds = (in_row >= 0) & (in_row < input_h) & (in_col >= 0) & (in_col < input_w)
if in_bounds:
in_idx = c * input_h * input_w + in_row * input_w + in_col
k_idx = (
out_channel_idx * in_channels * kernel_h * kernel_w
+ c * kernel_h * kernel_w
+ kh * kernel_w
+ kw
)
acc += tl.load(input_ptr + in_idx) * tl.load(kernel_ptr + k_idx)

out_idx = out_channel_idx * output_h * output_w + out_row * output_w + out_col
tl.store(output_ptr + out_idx, acc)


def day20_conv2d(
Expand All @@ -48,10 +65,7 @@ def day20_conv2d(
padding: tuple[int, int] = (0, 0),
stride: tuple[int, int] = (1, 1),
) -> torch.Tensor:
"""Day 20: Two-dimensional convolution (batch_size is always 1)"""
# TODO: 구현하세요
BLOCK_SIZE = 256
# batch_size는 항상 1이므로 입력은 3D
"""Day 20: 2D Convolution (batch_size는 항상 1)"""
if input.dim() != 3:
raise ValueError(
"day20_conv2d expects 3D tensor (in_channels, height, width), batch_size is always 1"
Expand All @@ -67,16 +81,27 @@ def day20_conv2d(
output_w = (input_w + 2 * pad_w - kernel_w) // stride_w + 1

output = torch.zeros(out_channels, output_h, output_w, device=input.device, dtype=input.dtype)
total_out = out_channels * output_h * output_w

# 그리드: 출력 원소 개수만큼 프로그램 실행 (하나당 한 픽셀)
def grid(meta):
return (out_channels, triton.cdiv(output_h * output_w, BLOCK_SIZE))
return (total_out,)

# day20_conv2d_kernel[grid](
# input, kernel, output,
# in_channels, out_channels,
# input_h, input_w, kernel_h, kernel_w,
# output_h, output_w,
# pad_h, pad_w, stride_h, stride_w,
# BLOCK_SIZE=BLOCK_SIZE
# )
day20_conv2d_kernel[grid](
input,
kernel,
output,
in_channels=in_channels,
out_channels=out_channels,
input_h=input_h,
input_w=input_w,
kernel_h=kernel_h,
kernel_w=kernel_w,
output_h=output_h,
output_w=output_w,
pad_h=pad_h,
pad_w=pad_w,
stride_h=stride_h,
stride_w=stride_w,
)
return output
Loading