Skip to content

Commit cd93a28

Browse files
CUDA: fix FA out-of-bounds reads (ggml-org#7479)
1 parent 1e37436 commit cd93a28

File tree

4 files changed

+8
-8
lines changed

4 files changed

+8
-8
lines changed

ggml-cuda/fattn-tile-f16.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ static __global__ void flash_attn_tile_ext_f16(
8383
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
8484
const int i = i0 + threadIdx.x;
8585

86-
const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i];
86+
const float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i] : make_float2(0.0f, 0.0f);
8787
Q_h2[j][i] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
8888
}
8989
}

ggml-cuda/fattn-tile-f32.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ static __global__ void flash_attn_tile_ext_f32(
7979

8080
#pragma unroll
8181
for (int i0 = 0; i0 < D; i0 += 2*WARP_SIZE) {
82-
float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i0/2 + threadIdx.x];
82+
float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i0/2 + threadIdx.x] : make_float2(0.0f, 0.0f);
8383
Q_f[j][i0 + 0*WARP_SIZE + threadIdx.x] = tmp.x * scale;
8484
Q_f[j][i0 + 1*WARP_SIZE + threadIdx.x] = tmp.y * scale;
8585
}

ggml-cuda/fattn-vec-f16.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ static __global__ void flash_attn_vec_ext_f16(
9494
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
9595
const int i = i0 + threadIdx.x;
9696

97-
const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i];
97+
const float2 tmp = ncols <= 2 || ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i] : make_float2(0.0f, 0.0f);
9898
Q_h2[j][i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
9999
}
100100
}
@@ -212,7 +212,7 @@ static __global__ void flash_attn_vec_ext_f16(
212212

213213
#pragma unroll
214214
for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
215-
if (ic0 + j_VKQ >= ne01) {
215+
if (ncols > 2 && ic0 + j_VKQ >= ne01) {
216216
break;
217217
}
218218

@@ -227,7 +227,7 @@ static __global__ void flash_attn_vec_ext_f16(
227227
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
228228
}
229229

230-
if (parallel_blocks != 1 && tid < ncols && ic0 + tid < ne01) {
230+
if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
231231
dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
232232
}
233233
#else

ggml-cuda/fattn-vec-f32.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ static __global__ void flash_attn_vec_ext_f32(
9191
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
9292
const int i = i0 + threadIdx.x;
9393

94-
Q_h2[j][i0/WARP_SIZE] = Q_f2[j*(nb01/sizeof(float2)) + i];
94+
Q_h2[j][i0/WARP_SIZE] = ncols <= 2 || ic0 + j ? Q_f2[j*(nb01/sizeof(float2)) + i] : make_float2(0.0f, 0.0f);
9595
Q_h2[j][i0/WARP_SIZE].x *= scale;
9696
Q_h2[j][i0/WARP_SIZE].y *= scale;
9797
}
@@ -200,7 +200,7 @@ static __global__ void flash_attn_vec_ext_f32(
200200

201201
#pragma unroll
202202
for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
203-
if (ic0 + j_VKQ >= ne01) {
203+
if (ncols > 2 && ic0 + j_VKQ >= ne01) {
204204
break;
205205
}
206206

@@ -215,7 +215,7 @@ static __global__ void flash_attn_vec_ext_f32(
215215
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
216216
}
217217

218-
if (parallel_blocks != 1 && tid < ncols && ic0 + tid < ne01) {
218+
if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
219219
dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
220220
}
221221
}

0 commit comments

Comments
 (0)