Skip to content

Commit 20901f6

Browse files
authored
fix: remove kv padding from flash attention wrapper (#1453)
1 parent 0982807 commit 20901f6

1 file changed

Lines changed: 0 additions & 28 deletions

File tree

src/ggml_extend.hpp

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1329,50 +1329,26 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_attention_ext(ggml_context* ctx,
13291329

13301330
float scale = (1.0f / sqrt((float)d_head));
13311331

1332-
int kv_pad = 0;
13331332
ggml_tensor* kqv = nullptr;
13341333

13351334
auto build_kqv = [&](ggml_tensor* q_in, ggml_tensor* k_in, ggml_tensor* v_in, ggml_tensor* mask_in) -> ggml_tensor* {
1336-
if (kv_pad != 0) {
1337-
k_in = ggml_pad(ctx, k_in, 0, kv_pad, 0, 0);
1338-
}
13391335
if (kv_scale != 1.0f) {
13401336
k_in = ggml_ext_scale(ctx, k_in, kv_scale);
13411337
}
13421338
k_in = ggml_cast(ctx, k_in, GGML_TYPE_F16);
13431339

13441340
v_in = ggml_ext_cont(ctx, ggml_permute(ctx, v_in, 0, 2, 1, 3));
13451341
v_in = ggml_reshape_3d(ctx, v_in, d_head, L_k, n_kv_head * N);
1346-
if (kv_pad != 0) {
1347-
v_in = ggml_pad(ctx, v_in, 0, kv_pad, 0, 0);
1348-
}
13491342
if (kv_scale != 1.0f) {
13501343
v_in = ggml_ext_scale(ctx, v_in, kv_scale);
13511344
}
13521345
v_in = ggml_cast(ctx, v_in, GGML_TYPE_F16);
13531346

13541347
if (mask_in != nullptr) {
13551348
mask_in = ggml_transpose(ctx, mask_in);
1356-
} else {
1357-
if (kv_pad > 0) {
1358-
mask_in = ggml_ext_zeros(ctx, L_k, L_q, 1, 1);
1359-
auto pad_tensor = ggml_ext_full(ctx, -INFINITY, kv_pad, L_q, 1, 1);
1360-
mask_in = ggml_concat(ctx, mask_in, pad_tensor, 0);
1361-
}
13621349
}
13631350

13641351
if (mask_in != nullptr) {
1365-
// the need for padding got removed in ggml 4767bda
1366-
// ensure we can still use the old version for now
1367-
#ifdef GGML_KQ_MASK_PAD
1368-
int mask_pad = 0;
1369-
if (mask_in->ne[1] % GGML_KQ_MASK_PAD != 0) {
1370-
mask_pad = GGML_PAD(L_q, GGML_KQ_MASK_PAD) - mask_in->ne[1];
1371-
}
1372-
if (mask_pad > 0) {
1373-
mask_in = ggml_pad(ctx, mask_in, 0, mask_pad, 0, 0);
1374-
}
1375-
#endif
13761352
mask_in = ggml_cast(ctx, mask_in, GGML_TYPE_F16);
13771353
}
13781354

@@ -1387,10 +1363,6 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_attention_ext(ggml_context* ctx,
13871363
if (flash_attn) {
13881364
// LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
13891365
bool can_use_flash_attn = true;
1390-
if (can_use_flash_attn && L_k % 256 != 0) {
1391-
kv_pad = GGML_PAD(L_k, 256) - static_cast<int>(L_k);
1392-
}
1393-
13941366
if (mask != nullptr) {
13951367
// TODO: figure out if we can bend t5 to work too
13961368
can_use_flash_attn = can_use_flash_attn && mask->ne[3] == 1;

0 commit comments

Comments
 (0)