Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 34 additions & 28 deletions csrc/day17_persistent_matmul.cu
Original file line number Diff line number Diff line change
@@ -1,34 +1,32 @@
#include <cuda_runtime.h>
#define ceil(x, y) (((x) + (y) - 1) / (y))

// TODO: Persistent Matmul kernel 구현
// Persistent 커널 패턴을 사용한 행렬 곱셈
//
// 힌트:
// 1. Persistent 커널: 커널이 여러 작업을 반복 처리
// 2. 메모리 관리 최적화
// 3. 워프 레벨 최적화
//
// 입력: A (M, K), B (K, N)
// 출력: C (M, N)

// Persistent Matmul: 고정 블록 수로 여러 행을 순회 처리
// A(M,K) @ B(K,N) = C(M,N)

__global__ void persistent_matmul_kernel(
const float* A,
const float* B,
float* C,
int M,
int N,
int K
int M, int K, int N,
int NUM_SMS
) {
// TODO: 구현하세요
// Persistent 커널 패턴 사용
int row_idx = blockIdx.x * blockDim.x + threadIdx.x;
int col_idx = blockIdx.y * blockDim.y + threadIdx.y;

if (row_idx < M && col_idx < N) {
// TODO: Persistent Matmul 계산
int c_idx = row_idx * N + col_idx;
C[c_idx] = 0.0f;
int sm_pid = blockIdx.x; // 블록 인덱스 (0 ~ NUM_SMS-1)
int col_idx = threadIdx.x; // 열 인덱스 (0 ~ N-1)

// Persistent 루프: 각 블록이 여러 행을 순회
for (int row_idx = sm_pid; row_idx < M; row_idx += NUM_SMS) {
if (col_idx < N) {
float accumulator = 0.0f;

// K 축 순회하며 내적 계산
for (int k = 0; k < K; k++) {
float a_val = A[row_idx * K + k];
float b_val = B[k * N + col_idx];
accumulator += a_val * b_val;
}

C[row_idx * N + col_idx] = accumulator;
}
}
}

Expand All @@ -40,12 +38,20 @@ extern "C" void day17_persistent_matmul(
int N,
int K
) {
// TODO: kernel launch configuration 설정
dim3 threadsPerBlock(16, 16);
dim3 blocksPerGrid(ceil(M, 16), ceil(N, 16));
// SM 수 가져오기
int device;
cudaGetDevice(&device);
cudaDeviceProp props;
cudaGetDeviceProperties(&props, device);
int NUM_SMS = props.multiProcessorCount;
if (NUM_SMS > M) NUM_SMS = M;

// Persistent: 고정 블록 수만 런칭
dim3 threadsPerBlock(N); // 각 스레드가 한 열 담당
dim3 blocksPerGrid(NUM_SMS); // SM 수만큼만 블록 런칭

persistent_matmul_kernel<<<blocksPerGrid, threadsPerBlock>>>(
A, B, C, M, N, K
A, B, C, M, K, N, NUM_SMS
);
cudaDeviceSynchronize();
}
60 changes: 36 additions & 24 deletions src/gpu_20days/day17_persistent_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,36 +8,48 @@


@triton.jit
def day17_persistent_matmul_kernel(A_ptr, B_ptr, C_ptr, M, N, K, BLOCK_SIZE: tl.constexpr):
"""
TODO: Persistent Matmul kernel 구현

Persistent 커널 패턴을 사용한 행렬 곱셈

힌트:
1. Persistent 커널: 커널이 여러 작업을 반복 처리
2. 메모리 관리 최적화
3. 워프 레벨 최적화
"""
# TODO: 구현하세요
# row_idx = tl.program_id(0)
# col_idx = tl.program_id(1)
pass
def day17_persistent_matmul_kernel(
A_ptr,
B_ptr,
C_ptr,
M,
K,
N, # A(M,K) @ B(K,N) = C(M,N)
NUM_SMS: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
):
sm_pid = tl.program_id(0)
col_offsets = tl.arange(0, BLOCK_SIZE_N)
col_mask = col_offsets < N

for row_idx in range(sm_pid, M, NUM_SMS):
accumulator = tl.zeros([BLOCK_SIZE_N], dtype=tl.float32)
for k in range(K):
a_val = tl.load(A_ptr + row_idx * K + k)
b_vals = tl.load(B_ptr + k * N + col_offsets, mask=col_mask, other=0.0)
accumulator += a_val * b_vals
tl.store(C_ptr + row_idx * N + col_offsets, accumulator, mask=col_mask)


def day17_persistent_matmul(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
"""Day 17: Persistent matrix multiplication"""
# TODO: 구현하세요
BLOCK_SIZE = 256
M, K = A.shape
_, N = B.shape

C = torch.zeros(M, N, device=A.device, dtype=A.dtype)

def grid(meta):
return (M, N)

# day17_persistent_matmul_kernel[grid](
# A, B, C, M, N, K, BLOCK_SIZE=BLOCK_SIZE
# )
# 256 단위로 연산되므로 2의 거듭제곱 형태
BLOCK_SIZE_N = triton.next_power_of_2(N)
# 런칭할 블록이 SM수보다 작아야 함
NUM_SMS = min(M, torch.cuda.get_device_properties(0).multi_processor_count)

day17_persistent_matmul_kernel[(NUM_SMS,)](
A,
B,
C,
M,
K,
N,
NUM_SMS=NUM_SMS,
BLOCK_SIZE_N=BLOCK_SIZE_N,
)
return C
Loading