diff --git a/dev/cuda/attention_forward.cu b/dev/cuda/attention_forward.cu index ca75762be..e87cffded 100644 --- a/dev/cuda/attention_forward.cu +++ b/dev/cuda/attention_forward.cu @@ -424,6 +424,109 @@ __global__ void attention_value_kernel1(float* out, const float* att, const floa } } + +__global__ +void relu_forward_kernel( + const float* __restrict__ Q, // Query matrix + const float* __restrict__ K, // Key matrix + const float* __restrict__ V, // Value matrix + const int N, // Sequence length + const int d, // Embedding dimension + const int Tc, // Number of column tiles + const int Tr, // Number of row tiles + const int Bc, // Number of columns per block + const int Br, // Number of rows per block + const float relu_threshold, // Threshold for ReLU (usually 0) + float* __restrict__ O // Output matrix +) { + // Thread and block indices + int tx = threadIdx.x; + int bx = blockIdx.x; + int by = blockIdx.y; // Batch and head indices + + // Offset into Q, K, V, and O + int qkv_offset = (bx * gridDim.y * N * d) + (by * N * d); // gridDim.y = number of heads + + // Dynamic shared memory for Q, K, V, and intermediate values + extern __shared__ float sram[]; + float* Qi = sram; // Shared memory for Q tile + float* Kj = &sram[Br * d]; // Shared memory for K tile + float* Vj = &sram[(Br + Bc) * d]; // Shared memory for V tile + float* S = &sram[(Br + Bc + Bc) * d]; // Shared memory for intermediate attention scores + + for (int j = 0; j < Tc; j++) { + // Load Kj and Vj into shared memory (SRAM in algorithm) + int kj_vj_offset = qkv_offset + (j * Bc * d); + for (int x = tx; x < Bc * d; x += blockDim.x) { + if (x < Bc * d && (j * Bc + x / d) < N) { + Kj[x] = K[kj_vj_offset + x]; + Vj[x] = V[kj_vj_offset + x]; + } else { + Kj[x] = 0.0f; + Vj[x] = 0.0f; + } + } + __syncthreads(); + for (int i = 0; i < Tr; i++) { + int row_idx = i * Br + tx; + if (row_idx >= N) { + break; + } + + // Load Qi into shared memory + int qi_offset = qkv_offset + (row_idx * d); + for (int x = 0; x < d; x++) { + Qi[tx * d + x] = Q[qi_offset + x]; + } + + float O_partial[64]; + for (int x = 0; x < d; x++) { + O_partial[x] = 0.0f; + } + + // Compute S = QK^T + for (int y = 0; y < Bc; y++) { + int col_idx = j * Bc + y; + if (col_idx >= N) { + break; + } + + // Causal masking: only compute if row_idx >= col_idx + if (row_idx >= col_idx) { + // Compute dot product between Qi[tx] and Kj[y] + float dot_product = 0.0f; + for (int x = 0; x < d; x++) { + dot_product += Qi[tx * d + x] * Kj[y * d + x]; + } + + // Apply ReLU activation + float activated = fmaxf(dot_product, relu_threshold); + + // Multiply by Vj[y] and accumulate + for (int x = 0; x < d; x++) { + O_partial[x] += activated * Vj[y * d + x]; + } + } + } + + // Write the output O_partial back to global memory + int o_offset = qkv_offset + (row_idx * d); + for (int x = 0; x < d; x++) { + O[o_offset + x] += O_partial[x]; + } + } + __syncthreads(); + } +} + + + + + + + + + __global__ void attention_forward_kernel2( const float* Q, @@ -696,6 +799,82 @@ void attention_forward1(float* out, float* preatt, float* att, attention_value_kernel1<<>>(out, att, inp, B, T, C, NH); } +// relu_forward modified from attention_forward2 below +void relu_forward(float* out, + const float* inp, + int B, int T, int C, int NH, + const int block_size) { + const int Bc = 32; + const int Br = 32; + // renaming these to be consistent with the kernel + // const int B = B; + const int nh = NH; + const int N = T; + const int d = C / NH; + // more + const int Tc = ceil((float) N / Bc); + const int Tr = ceil((float) N / Br); + const float softmax_scale = 1.0 / sqrt(d); + // create some temporary memory + float* l; + float* m; + cudaCheck(cudaMalloc(&l, B * nh * N * sizeof(float))); + cudaCheck(cudaMalloc(&m, B * nh * N * sizeof(float))); + cudaCheck(cudaMemset(l, 0, B * nh * N * sizeof(float))); + cudaCheck(cudaMemset(m, -10000.0f, B * nh * N * sizeof(float))); + + // calculate SRAM size needed per block, ensure we have enough shared memory + int col_tile_size = Bc * d; // size of Kj, Vj + int row_tile_size = Br * d; // size of Qi + const int sram_size = + (2 * col_tile_size * sizeof(float)) // SRAM size for Kj, Vj + + (row_tile_size * sizeof(float)) // SRAM size for Qi + + (Bc * Br * sizeof(float)); // SRAM size for S + int max_sram_size; + cudaDeviceGetAttribute(&max_sram_size, cudaDevAttrMaxSharedMemoryPerBlock, 0); + if (sram_size > max_sram_size) { + printf("Max shared memory: %d, requested shared memory: %d \n", max_sram_size, sram_size); + printf("SRAM size exceeds maximum shared memory per block\n"); + printf("Try decreasing col_tile_size or row_tile_size further\n"); + exit(1); + } + + // grid and block dims + dim3 grid_dim(B, nh); // batch_size x num_heads + dim3 block_dim(Br); // Br threads per block + + // okay so now, this kernel wants Q,K,V to all be of shape (B, nh, N, d) + // but instead, we have a single tensor QKV (inp) of shape (B, N, 3, nh, d) + // so we have to permute the tensor using a kernel with block_size + float *q, *k, *v; + cudaCheck(cudaMalloc(&q, B * T * C * sizeof(float))); + cudaCheck(cudaMalloc(&k, B * T * C * sizeof(float))); + cudaCheck(cudaMalloc(&v, B * T * C * sizeof(float))); + int total_threads = B * N * nh * d; + int num_blocks = ceil_div(total_threads, block_size); + permute_kernel<<>>(q, k, v, inp, B, N, nh, d); + + // now actually call the flash attention kernel + relu_forward_kernel<<>>( + q, k, v, + N, d, Tc, Tr, Bc, Br, 0, + out + ); + + // out has shape (B, nh, N, d) but we need to unpermute it to (B, N, nh, d) + unpermute_kernel<<>>(out, q, B, N, nh, d); + cudaCheck(cudaMemcpy(out, q, B * T * C * sizeof(float), cudaMemcpyDeviceToDevice)); + + // free memory + cudaCheck(cudaFree(l)); + cudaCheck(cudaFree(m)); + cudaCheck(cudaFree(q)); + cudaCheck(cudaFree(k)); + cudaCheck(cudaFree(v)); +} + + + void attention_forward2(float* out, const float* inp, @@ -1263,6 +1442,9 @@ void attention_forward(int kernel_num, attention_forward_cudnn((floatX*)vaccum, stats, (floatX*)qkvr, inp, out, B, T, C, NH); break; #endif + case 11: + relu_forward(out, inp, B, T, C, NH, block_size); + break; default: printf("Invalid kernel number\n"); exit(1); @@ -1328,7 +1510,7 @@ int main(int argc, char **argv) { // Lower accuracy requirements for FP16 (1e-4f also too much for TF32 on kernels 3 & 4) float accuracy_threshold = (kernel_num <= 4) ? 1e-3f : 1e-2f; - + /* // first check the correctness of the kernel attention_forward_cpu(out, preatt, att, inp, B, T, C, NH); for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { @@ -1351,7 +1533,7 @@ int main(int argc, char **argv) { validate_result(d_preatt, preatt, "preatt", B * NH * T * T, accuracy_threshold); } } - printf("All results match. Starting benchmarks.\n\n"); + printf("All results match. Starting benchmarks.\n\n"); */ first_run_validation = false; // benchmark speed of the kernel diff --git a/train_gpt2_fp32.cu b/train_gpt2_fp32.cu index df412ea5e..510dc52e3 100644 --- a/train_gpt2_fp32.cu +++ b/train_gpt2_fp32.cu @@ -299,6 +299,25 @@ __global__ void softmax_forward_kernel5(float* out, float inv_temperature, const } } +//ReLU forward kernel +__global__ void relu_forward_kernel(float* out, const float* inp, int N, int T) { + cg::thread_block block = cg::this_thread_block(); + cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); + + int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank(); + if (idx >= N * T) { + return; + } + + int n = idx / T; + int t = idx % T; + + float value = inp[n * T + t]; + float activated_value = value > 0 ? value : 0; + + out[n * T + t] = activated_value; +} + __global__ void residual_forward_kernel(float* out, float* inp1, float* inp2, int N) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < N) { @@ -443,6 +462,54 @@ __global__ void layernorm_backward_kernel2(float* dinp, float* dweight, float* d } } +// ReLU backward kernel +__global__ void relu_autoregressive_backward_kernel(float* dpreact, const float* dattr, const float* attr, + int B, int T, int C, float scale) { + constexpr const int BlockSize = 256; + constexpr int T_per_block = 4; + cg::thread_block block = cg::this_thread_block(); + cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); + __shared__ float block_acc[32]; + + int idx = blockIdx.y; + int t0 = T - 1 - T_per_block * blockIdx.x; + + attr += idx * T * T; + dattr += idx * T * T; + dpreact += idx * T * T; + + if (warp.meta_group_rank() == 0) { + block_acc[warp.thread_rank()] = 0; + } + + for (int to = 0; to < T_per_block; ++to) { + int t = t0 - to; + if (t < 0) return; + + const float* attr_bth = attr + t * T; + const float* dattr_bth = dattr + t * T; + float* dpreact_bth = dpreact + t * T; + + float local_sum = 0; + + for (int t2 = block.thread_rank(); t2 <= t; t2 += BlockSize) { + // ReLU derivative: if input (attr_bth[t2]) > 0, gradient is dattr_bth[t2], else 0 + float gradient = attr_bth[t2] > 0 ? dattr_bth[t2] : 0.0f; + local_sum += gradient; + } + + block_acc[warp.meta_group_rank()] = cg::reduce(warp, local_sum, cg::plus{}); + block.sync(); + local_sum = cg::reduce(warp, block_acc[warp.thread_rank()], cg::plus{}); + + for (int t3 = block.thread_rank(); t3 <= t; t3 += BlockSize) { + float acc = __ldcs(attr_bth + t3) > 0 ? __ldcs(dattr_bth + t3) - local_sum : 0.0f; + __stcs(dpreact_bth + t3, scale * acc); + } + } +} + + __global__ void softmax_autoregressive_backward_kernel(float* dpreatt, const float* datt, const float* att, int B, int T, int C, float scale) { constexpr const int BlockSize = 256; @@ -737,7 +804,7 @@ void matmul_forward(float* out, void attention_forward(float* out, float* qkvr, float* att, float* inp, - int B, int T, int C, int NH) { + int B, int T, int C, int NH, int use_relu) { // Note: `inp` is not needed for backward pass, so we re-use it as a scratch buffer. // Its contents will be overwritten by this function. const int block_size = 256; @@ -767,7 +834,12 @@ void attention_forward(float* out, float* qkvr, float* att, // multiply all elements of preatt elementwise by scale float scale = 1.0 / sqrtf(HS); int grid_size = CEIL_DIV(B * NH * T * 32, softmax_block_size); - softmax_forward_kernel5<<>>(att, scale, preatt, B * NH, T); + if (use_relu == 0) { + softmax_forward_kernel5<<>>(att, scale, preatt, B * NH, T); + } + else { + relu_forward_kernel<<>>(att, preatt, B * NH , T); + } cudaCheck(cudaGetLastError()); // new approach: first cuBLAS another batched matmul @@ -837,7 +909,7 @@ void layernorm_backward(float* dinp, float* dweight, float* dbias, void attention_backward(float* dinp, float* dqkvr, float* dpreatt, float* datt, float* scratch, const float* dout, const float* qkvr, const float* att, - int B, int T, int C, int NH) { + int B, int T, int C, int NH, int use_relu) { const int block_size = 256; int HS = C / NH; // head size const float one = 1.0f; @@ -862,7 +934,12 @@ void attention_backward(float* dinp, float* dqkvr, float* dpreatt, float* datt, // backward into preatt int hs = C / NH; // head size float scale = 1.0f / sqrtf(hs); - softmax_autoregressive_backward_kernel<<>>(dpreatt, datt, att, B, T, C, scale); + if (use_relu == 0) { + softmax_autoregressive_backward_kernel<<>>(dpreatt, datt, att, B, T, C, scale); + } + else { + relu_autoregressive_backward_kernel<<>>(dpreatt, datt, att, B, T, C, scale); + } cudaCheck(cudaGetLastError()); // backward into q cublasCheck(cublasSgemmStridedBatched(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, HS, T, T, &one, k, HS, T * HS, dpreatt, T, T * T, &zero, dq, HS, T * HS, B * NH)); @@ -1169,7 +1246,7 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) { model->mean_loss = -1.0f; // -1.0f will designate no loss } -void gpt2_forward(GPT2 *model, int* inputs, int* targets, int B, int T) { +void gpt2_forward(GPT2 *model, int* inputs, int* targets, int B, int T, int use_relu) { // targets are optional and could be NULL // ensure the model was initialized or error out @@ -1273,7 +1350,7 @@ void gpt2_forward(GPT2 *model, int* inputs, int* targets, int B, int T) { // now do the forward pass layernorm_forward(l_ln1, l_ln1_mean, l_ln1_rstd, residual, l_ln1w, l_ln1b, B, T, C); matmul_forward(scratch, l_ln1, l_qkvw, l_qkvb, B, T, C, 3*C); - attention_forward(l_atty, l_qkvr, l_att, scratch, B, T, C, NH); + attention_forward(l_atty, l_qkvr, l_att, scratch, B, T, C, NH, use_relu); matmul_forward(l_attproj, l_atty, l_attprojw, l_attprojb, B, T, C, C); residual_forward(l_residual2, residual, l_attproj, B*T*C); layernorm_forward(l_ln2, l_ln2_mean, l_ln2_rstd, l_residual2, l_ln2w, l_ln2b, B, T, C); @@ -1311,7 +1388,7 @@ void gpt2_zero_grad(GPT2 *model) { if (model->grads_memory != NULL) { cudaCheck(cudaMemset(model->grads_memory, 0, model->num_parameters * sizeof(float))); } } -void gpt2_backward(GPT2 *model) { +void gpt2_backward(GPT2 *model, int use_relu) { // double check we forwarded previously, with targets if (model->mean_loss == -1.0f) { @@ -1427,7 +1504,7 @@ void gpt2_backward(GPT2 *model) { float* buffer_a = l_atty; float* buffer_b = l_fch; // this is B x T x 4C, so even larger than what we need - attention_backward(dl_bt4c, buffer_b, dl_preatt, scratch, buffer_a, dl_btc, l_qkvr, l_att, B, T, C, NH); + attention_backward(dl_bt4c, buffer_b, dl_preatt, scratch, buffer_a, dl_btc, l_qkvr, l_att, B, T, C, NH, use_relu); matmul_backward(dl_btc, dl_qkvw, dl_qkvb, dl_bt4c, l_ln1, l_qkvw, B, T, C, 3 * C); // layernorm backward does += to dresidual, so it correctly accumulates gradient for the Attention block above layernorm_backward(dresidual, dl_ln1w, dl_ln1b, dl_btc, residual, l_ln1w, l_ln1_mean, l_ln1_rstd, B, T, C); @@ -1572,6 +1649,7 @@ int main(int argc, char *argv[]) { int val_max_steps = 20; // how many batches max do we eval for validation loss? int sample_every = 20; // every how many steps to do inference? int genT = 64; // number of steps of inference we will do + int use_relu = 0; // use ReLU instead of Softmax for (int i = 1; i < argc; i+=2) { if (i + 1 >= argc) { error_usage(); } // must have arg after flag if (argv[i][0] != '-') { error_usage(); } // must start with dash @@ -1587,6 +1665,7 @@ int main(int argc, char *argv[]) { else if (argv[i][1] == 'm') { val_max_steps = atoi(argv[i+1]); } else if (argv[i][1] == 's') { sample_every = atoi(argv[i+1]); } else if (argv[i][1] == 'g') { genT = atoi(argv[i+1]); } + else if (argv[i][1] == 'r') { use_relu = atoi(argv[i+1]); } else { error_usage(); } } printf("+-----------------------+----------------------------------------------------+\n"); @@ -1602,6 +1681,7 @@ int main(int argc, char *argv[]) { printf("| val_max_steps | %-50d |\n", val_max_steps); printf("| sample_every | %-50d |\n", sample_every); printf("| genT | %-50d |\n", genT); + printf("| use_relu | %-50s |\n", use_relu ? "true" : "false"); printf("+-----------------------+----------------------------------------------------+\n"); // set up the device @@ -1671,7 +1751,7 @@ int main(int argc, char *argv[]) { dataloader_reset(&val_loader); for (int i = 0; i < val_num_batches; i++) { dataloader_next_batch(&val_loader); - gpt2_forward(&model, val_loader.inputs, val_loader.targets, B, T); + gpt2_forward(&model, val_loader.inputs, val_loader.targets, B, T, use_relu); val_loss += model.mean_loss; } val_loss /= val_num_batches; @@ -1692,7 +1772,7 @@ int main(int argc, char *argv[]) { // we re-calculate the forward pass for all of (B,T) positions from scratch // but the inference here is just for sanity checking anyway // and we can maybe optimize a bit more later, with careful tests - gpt2_forward(&model, gen_tokens, NULL, B, T); + gpt2_forward(&model, gen_tokens, NULL, B, T, use_relu); // furthermore, below we're only using b=0 (i.e. the first row) of all B rows // we're in principle running B "inference streams" in parallel here // only using position 0 because it's a bit faster (copy less probs from GPU -> CPU) @@ -1725,9 +1805,9 @@ int main(int argc, char *argv[]) { // do a training step clock_gettime(CLOCK_MONOTONIC, &start); dataloader_next_batch(&train_loader); - gpt2_forward(&model, train_loader.inputs, train_loader.targets, B, T); + gpt2_forward(&model, train_loader.inputs, train_loader.targets, B, T, use_relu); gpt2_zero_grad(&model); - gpt2_backward(&model); + gpt2_backward(&model, use_relu); gpt2_update(&model, learning_rate, 0.9f, 0.999f, 1e-8f, 0.0f, step+1); cudaCheck(cudaDeviceSynchronize()); // finish all CUDA work to get correct precise timings clock_gettime(CLOCK_MONOTONIC, &end); @@ -1752,4 +1832,4 @@ int main(int argc, char *argv[]) { return 0; } -#endif \ No newline at end of file +#endif