Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 184 additions & 2 deletions dev/cuda/attention_forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,109 @@ __global__ void attention_value_kernel1(float* out, const float* att, const floa
}
}


__global__
void relu_forward_kernel(
const float* __restrict__ Q, // Query matrix
const float* __restrict__ K, // Key matrix
const float* __restrict__ V, // Value matrix
const int N, // Sequence length
const int d, // Embedding dimension
const int Tc, // Number of column tiles
const int Tr, // Number of row tiles
const int Bc, // Number of columns per block
const int Br, // Number of rows per block
const float relu_threshold, // Threshold for ReLU (usually 0)
float* __restrict__ O // Output matrix
) {
// Thread and block indices
int tx = threadIdx.x;
int bx = blockIdx.x;
int by = blockIdx.y; // Batch and head indices

// Offset into Q, K, V, and O
int qkv_offset = (bx * gridDim.y * N * d) + (by * N * d); // gridDim.y = number of heads

// Dynamic shared memory for Q, K, V, and intermediate values
extern __shared__ float sram[];
float* Qi = sram; // Shared memory for Q tile
float* Kj = &sram[Br * d]; // Shared memory for K tile
float* Vj = &sram[(Br + Bc) * d]; // Shared memory for V tile
float* S = &sram[(Br + Bc + Bc) * d]; // Shared memory for intermediate attention scores

for (int j = 0; j < Tc; j++) {
// Load Kj and Vj into shared memory (SRAM in algorithm)
int kj_vj_offset = qkv_offset + (j * Bc * d);
for (int x = tx; x < Bc * d; x += blockDim.x) {
if (x < Bc * d && (j * Bc + x / d) < N) {
Kj[x] = K[kj_vj_offset + x];
Vj[x] = V[kj_vj_offset + x];
} else {
Kj[x] = 0.0f;
Vj[x] = 0.0f;
}
}
__syncthreads();
for (int i = 0; i < Tr; i++) {
int row_idx = i * Br + tx;
if (row_idx >= N) {
break;
}

// Load Qi into shared memory
int qi_offset = qkv_offset + (row_idx * d);
for (int x = 0; x < d; x++) {
Qi[tx * d + x] = Q[qi_offset + x];
}

float O_partial[64];
for (int x = 0; x < d; x++) {
O_partial[x] = 0.0f;
}

// Compute S = QK^T
for (int y = 0; y < Bc; y++) {
int col_idx = j * Bc + y;
if (col_idx >= N) {
break;
}

// Causal masking: only compute if row_idx >= col_idx
if (row_idx >= col_idx) {
// Compute dot product between Qi[tx] and Kj[y]
float dot_product = 0.0f;
for (int x = 0; x < d; x++) {
dot_product += Qi[tx * d + x] * Kj[y * d + x];
}

// Apply ReLU activation
float activated = fmaxf(dot_product, relu_threshold);

// Multiply by Vj[y] and accumulate
for (int x = 0; x < d; x++) {
O_partial[x] += activated * Vj[y * d + x];
}
}
}

// Write the output O_partial back to global memory
int o_offset = qkv_offset + (row_idx * d);
for (int x = 0; x < d; x++) {
O[o_offset + x] += O_partial[x];
}
}
__syncthreads();
}
}









__global__
void attention_forward_kernel2(
const float* Q,
Expand Down Expand Up @@ -696,6 +799,82 @@ void attention_forward1(float* out, float* preatt, float* att,
attention_value_kernel1<<<num_blocks, block_size>>>(out, att, inp, B, T, C, NH);
}

// relu_forward modified from attention_forward2 below
void relu_forward(float* out,
const float* inp,
int B, int T, int C, int NH,
const int block_size) {
const int Bc = 32;
const int Br = 32;
// renaming these to be consistent with the kernel
// const int B = B;
const int nh = NH;
const int N = T;
const int d = C / NH;
// more
const int Tc = ceil((float) N / Bc);
const int Tr = ceil((float) N / Br);
const float softmax_scale = 1.0 / sqrt(d);
// create some temporary memory
float* l;
float* m;
cudaCheck(cudaMalloc(&l, B * nh * N * sizeof(float)));
cudaCheck(cudaMalloc(&m, B * nh * N * sizeof(float)));
cudaCheck(cudaMemset(l, 0, B * nh * N * sizeof(float)));
cudaCheck(cudaMemset(m, -10000.0f, B * nh * N * sizeof(float)));

// calculate SRAM size needed per block, ensure we have enough shared memory
int col_tile_size = Bc * d; // size of Kj, Vj
int row_tile_size = Br * d; // size of Qi
const int sram_size =
(2 * col_tile_size * sizeof(float)) // SRAM size for Kj, Vj
+ (row_tile_size * sizeof(float)) // SRAM size for Qi
+ (Bc * Br * sizeof(float)); // SRAM size for S
int max_sram_size;
cudaDeviceGetAttribute(&max_sram_size, cudaDevAttrMaxSharedMemoryPerBlock, 0);
if (sram_size > max_sram_size) {
printf("Max shared memory: %d, requested shared memory: %d \n", max_sram_size, sram_size);
printf("SRAM size exceeds maximum shared memory per block\n");
printf("Try decreasing col_tile_size or row_tile_size further\n");
exit(1);
}

// grid and block dims
dim3 grid_dim(B, nh); // batch_size x num_heads
dim3 block_dim(Br); // Br threads per block

// okay so now, this kernel wants Q,K,V to all be of shape (B, nh, N, d)
// but instead, we have a single tensor QKV (inp) of shape (B, N, 3, nh, d)
// so we have to permute the tensor using a kernel with block_size
float *q, *k, *v;
cudaCheck(cudaMalloc(&q, B * T * C * sizeof(float)));
cudaCheck(cudaMalloc(&k, B * T * C * sizeof(float)));
cudaCheck(cudaMalloc(&v, B * T * C * sizeof(float)));
int total_threads = B * N * nh * d;
int num_blocks = ceil_div(total_threads, block_size);
permute_kernel<<<num_blocks, block_size>>>(q, k, v, inp, B, N, nh, d);

// now actually call the flash attention kernel
relu_forward_kernel<<<grid_dim, block_dim, sram_size>>>(
q, k, v,
N, d, Tc, Tr, Bc, Br, 0,
out
);

// out has shape (B, nh, N, d) but we need to unpermute it to (B, N, nh, d)
unpermute_kernel<<<num_blocks, block_size>>>(out, q, B, N, nh, d);
cudaCheck(cudaMemcpy(out, q, B * T * C * sizeof(float), cudaMemcpyDeviceToDevice));

// free memory
cudaCheck(cudaFree(l));
cudaCheck(cudaFree(m));
cudaCheck(cudaFree(q));
cudaCheck(cudaFree(k));
cudaCheck(cudaFree(v));
}




void attention_forward2(float* out,
const float* inp,
Expand Down Expand Up @@ -1263,6 +1442,9 @@ void attention_forward(int kernel_num,
attention_forward_cudnn((floatX*)vaccum, stats, (floatX*)qkvr, inp, out, B, T, C, NH);
break;
#endif
case 11:
relu_forward(out, inp, B, T, C, NH, block_size);
break;
default:
printf("Invalid kernel number\n");
exit(1);
Expand Down Expand Up @@ -1328,7 +1510,7 @@ int main(int argc, char **argv) {

// Lower accuracy requirements for FP16 (1e-4f also too much for TF32 on kernels 3 & 4)
float accuracy_threshold = (kernel_num <= 4) ? 1e-3f : 1e-2f;

/*
// first check the correctness of the kernel
attention_forward_cpu(out, preatt, att, inp, B, T, C, NH);
for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {
Expand All @@ -1351,7 +1533,7 @@ int main(int argc, char **argv) {
validate_result(d_preatt, preatt, "preatt", B * NH * T * T, accuracy_threshold);
}
}
printf("All results match. Starting benchmarks.\n\n");
printf("All results match. Starting benchmarks.\n\n"); */
first_run_validation = false;

// benchmark speed of the kernel
Expand Down
Loading