File tree Expand file tree Collapse file tree 4 files changed +8
-8
lines changed Expand file tree Collapse file tree 4 files changed +8
-8
lines changed Original file line number Diff line number Diff line change @@ -83,7 +83,7 @@ static __global__ void flash_attn_tile_ext_f16(
83
83
for (int i0 = 0 ; i0 < D/2 ; i0 += WARP_SIZE) {
84
84
const int i = i0 + threadIdx .x ;
85
85
86
- const float2 tmp = Q_f2[j*(nb01/sizeof (float2 )) + i];
86
+ const float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof (float2 )) + i] : make_float2 ( 0 . 0f , 0 . 0f ) ;
87
87
Q_h2[j][i] = make_half2 (scale, scale) * make_half2 (tmp.x , tmp.y );
88
88
}
89
89
}
Original file line number Diff line number Diff line change @@ -79,7 +79,7 @@ static __global__ void flash_attn_tile_ext_f32(
79
79
80
80
#pragma unroll
81
81
for (int i0 = 0 ; i0 < D; i0 += 2 *WARP_SIZE) {
82
- float2 tmp = Q_f2[j*(nb01/sizeof (float2 )) + i0/2 + threadIdx .x ];
82
+ float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof (float2 )) + i0/2 + threadIdx .x ] : make_float2 ( 0 . 0f , 0 . 0f ) ;
83
83
Q_f[j][i0 + 0 *WARP_SIZE + threadIdx .x ] = tmp.x * scale;
84
84
Q_f[j][i0 + 1 *WARP_SIZE + threadIdx .x ] = tmp.y * scale;
85
85
}
Original file line number Diff line number Diff line change @@ -94,7 +94,7 @@ static __global__ void flash_attn_vec_ext_f16(
94
94
for (int i0 = 0 ; i0 < D/2 ; i0 += WARP_SIZE) {
95
95
const int i = i0 + threadIdx .x ;
96
96
97
- const float2 tmp = Q_f2[j*(nb01/sizeof (float2 )) + i];
97
+ const float2 tmp = ncols <= 2 || ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof (float2 )) + i] : make_float2 ( 0 . 0f , 0 . 0f ) ;
98
98
Q_h2[j][i0/WARP_SIZE] = make_half2 (scale, scale) * make_half2 (tmp.x , tmp.y );
99
99
}
100
100
}
@@ -212,7 +212,7 @@ static __global__ void flash_attn_vec_ext_f16(
212
212
213
213
#pragma unroll
214
214
for (int j_VKQ = 0 ; j_VKQ < ncols; ++j_VKQ) {
215
- if (ic0 + j_VKQ >= ne01) {
215
+ if (ncols > 2 && ic0 + j_VKQ >= ne01) {
216
216
break ;
217
217
}
218
218
@@ -227,7 +227,7 @@ static __global__ void flash_attn_vec_ext_f16(
227
227
dst[j_dst*D*gridDim .y + D*blockIdx .y + tid] = dst_val;
228
228
}
229
229
230
- if (parallel_blocks != 1 && tid < ncols && ic0 + tid < ne01) {
230
+ if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01) ) {
231
231
dst_meta[(ic0 + tid)*gridDim .y *parallel_blocks + blockIdx .y *parallel_blocks + ip] = make_float2 (kqmax[tid], kqsum[tid]);
232
232
}
233
233
#else
Original file line number Diff line number Diff line change @@ -91,7 +91,7 @@ static __global__ void flash_attn_vec_ext_f32(
91
91
for (int i0 = 0 ; i0 < D/2 ; i0 += WARP_SIZE) {
92
92
const int i = i0 + threadIdx .x ;
93
93
94
- Q_h2[j][i0/WARP_SIZE] = Q_f2[j*(nb01/sizeof (float2 )) + i];
94
+ Q_h2[j][i0/WARP_SIZE] = ncols <= 2 || ic0 + j ? Q_f2[j*(nb01/sizeof (float2 )) + i] : make_float2 ( 0 . 0f , 0 . 0f ) ;
95
95
Q_h2[j][i0/WARP_SIZE].x *= scale;
96
96
Q_h2[j][i0/WARP_SIZE].y *= scale;
97
97
}
@@ -200,7 +200,7 @@ static __global__ void flash_attn_vec_ext_f32(
200
200
201
201
#pragma unroll
202
202
for (int j_VKQ = 0 ; j_VKQ < ncols; ++j_VKQ) {
203
- if (ic0 + j_VKQ >= ne01) {
203
+ if (ncols > 2 && ic0 + j_VKQ >= ne01) {
204
204
break ;
205
205
}
206
206
@@ -215,7 +215,7 @@ static __global__ void flash_attn_vec_ext_f32(
215
215
dst[j_dst*D*gridDim .y + D*blockIdx .y + tid] = dst_val;
216
216
}
217
217
218
- if (parallel_blocks != 1 && tid < ncols && ic0 + tid < ne01) {
218
+ if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01) ) {
219
219
dst_meta[(ic0 + tid)*gridDim .y *parallel_blocks + blockIdx .y *parallel_blocks + ip] = make_float2 (kqmax[tid], kqsum[tid]);
220
220
}
221
221
}
You can’t perform that action at this time.
0 commit comments