diff --git a/csrc/day17_persistent_matmul.cu b/csrc/day17_persistent_matmul.cu index ab05956..84decdb 100644 --- a/csrc/day17_persistent_matmul.cu +++ b/csrc/day17_persistent_matmul.cu @@ -1,34 +1,32 @@ #include -#define ceil(x, y) (((x) + (y) - 1) / (y)) - -// TODO: Persistent Matmul kernel 구현 -// Persistent 커널 패턴을 사용한 행렬 곱셈 -// -// 힌트: -// 1. Persistent 커널: 커널이 여러 작업을 반복 처리 -// 2. 메모리 관리 최적화 -// 3. 워프 레벨 최적화 -// -// 입력: A (M, K), B (K, N) -// 출력: C (M, N) + +// Persistent Matmul: 고정 블록 수로 여러 행을 순회 처리 +// A(M,K) @ B(K,N) = C(M,N) __global__ void persistent_matmul_kernel( const float* A, const float* B, float* C, - int M, - int N, - int K + int M, int K, int N, + int NUM_SMS ) { - // TODO: 구현하세요 - // Persistent 커널 패턴 사용 - int row_idx = blockIdx.x * blockDim.x + threadIdx.x; - int col_idx = blockIdx.y * blockDim.y + threadIdx.y; - - if (row_idx < M && col_idx < N) { - // TODO: Persistent Matmul 계산 - int c_idx = row_idx * N + col_idx; - C[c_idx] = 0.0f; + int sm_pid = blockIdx.x; // 블록 인덱스 (0 ~ NUM_SMS-1) + int col_idx = threadIdx.x; // 열 인덱스 (0 ~ N-1) + + // Persistent 루프: 각 블록이 여러 행을 순회 + for (int row_idx = sm_pid; row_idx < M; row_idx += NUM_SMS) { + if (col_idx < N) { + float accumulator = 0.0f; + + // K 축 순회하며 내적 계산 + for (int k = 0; k < K; k++) { + float a_val = A[row_idx * K + k]; + float b_val = B[k * N + col_idx]; + accumulator += a_val * b_val; + } + + C[row_idx * N + col_idx] = accumulator; + } } } @@ -40,12 +38,20 @@ extern "C" void day17_persistent_matmul( int N, int K ) { - // TODO: kernel launch configuration 설정 - dim3 threadsPerBlock(16, 16); - dim3 blocksPerGrid(ceil(M, 16), ceil(N, 16)); + // SM 수 가져오기 + int device; + cudaGetDevice(&device); + cudaDeviceProp props; + cudaGetDeviceProperties(&props, device); + int NUM_SMS = props.multiProcessorCount; + if (NUM_SMS > M) NUM_SMS = M; + + // Persistent: 고정 블록 수만 런칭 + dim3 threadsPerBlock(N); // 각 스레드가 한 열 담당 + dim3 blocksPerGrid(NUM_SMS); // SM 수만큼만 블록 런칭 persistent_matmul_kernel<<>>( - A, B, C, M, N, K + A, B, C, M, K, N, NUM_SMS ); cudaDeviceSynchronize(); } diff --git a/src/gpu_20days/day17_persistent_matmul.py b/src/gpu_20days/day17_persistent_matmul.py index 77917bc..150f1a8 100644 --- a/src/gpu_20days/day17_persistent_matmul.py +++ b/src/gpu_20days/day17_persistent_matmul.py @@ -8,36 +8,48 @@ @triton.jit -def day17_persistent_matmul_kernel(A_ptr, B_ptr, C_ptr, M, N, K, BLOCK_SIZE: tl.constexpr): - """ - TODO: Persistent Matmul kernel 구현 - - Persistent 커널 패턴을 사용한 행렬 곱셈 - - 힌트: - 1. Persistent 커널: 커널이 여러 작업을 반복 처리 - 2. 메모리 관리 최적화 - 3. 워프 레벨 최적화 - """ - # TODO: 구현하세요 - # row_idx = tl.program_id(0) - # col_idx = tl.program_id(1) - pass +def day17_persistent_matmul_kernel( + A_ptr, + B_ptr, + C_ptr, + M, + K, + N, # A(M,K) @ B(K,N) = C(M,N) + NUM_SMS: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + sm_pid = tl.program_id(0) + col_offsets = tl.arange(0, BLOCK_SIZE_N) + col_mask = col_offsets < N + + for row_idx in range(sm_pid, M, NUM_SMS): + accumulator = tl.zeros([BLOCK_SIZE_N], dtype=tl.float32) + for k in range(K): + a_val = tl.load(A_ptr + row_idx * K + k) + b_vals = tl.load(B_ptr + k * N + col_offsets, mask=col_mask, other=0.0) + accumulator += a_val * b_vals + tl.store(C_ptr + row_idx * N + col_offsets, accumulator, mask=col_mask) def day17_persistent_matmul(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: """Day 17: Persistent matrix multiplication""" - # TODO: 구현하세요 - BLOCK_SIZE = 256 M, K = A.shape _, N = B.shape C = torch.zeros(M, N, device=A.device, dtype=A.dtype) - - def grid(meta): - return (M, N) - - # day17_persistent_matmul_kernel[grid]( - # A, B, C, M, N, K, BLOCK_SIZE=BLOCK_SIZE - # ) + # 256 단위로 연산되므로 2의 거듭제곱 형태 + BLOCK_SIZE_N = triton.next_power_of_2(N) + # 런칭할 블록이 SM수보다 작아야 함 + NUM_SMS = min(M, torch.cuda.get_device_properties(0).multi_processor_count) + + day17_persistent_matmul_kernel[(NUM_SMS,)]( + A, + B, + C, + M, + K, + N, + NUM_SMS=NUM_SMS, + BLOCK_SIZE_N=BLOCK_SIZE_N, + ) return C