@@ -1329,50 +1329,26 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_attention_ext(ggml_context* ctx,
13291329
13301330 float scale = (1 .0f / sqrt ((float )d_head));
13311331
1332- int kv_pad = 0 ;
13331332 ggml_tensor* kqv = nullptr ;
13341333
13351334 auto build_kqv = [&](ggml_tensor* q_in, ggml_tensor* k_in, ggml_tensor* v_in, ggml_tensor* mask_in) -> ggml_tensor* {
1336- if (kv_pad != 0 ) {
1337- k_in = ggml_pad (ctx, k_in, 0 , kv_pad, 0 , 0 );
1338- }
13391335 if (kv_scale != 1 .0f ) {
13401336 k_in = ggml_ext_scale (ctx, k_in, kv_scale);
13411337 }
13421338 k_in = ggml_cast (ctx, k_in, GGML_TYPE_F16 );
13431339
13441340 v_in = ggml_ext_cont (ctx, ggml_permute (ctx, v_in, 0 , 2 , 1 , 3 ));
13451341 v_in = ggml_reshape_3d (ctx, v_in, d_head, L_k, n_kv_head * N);
1346- if (kv_pad != 0 ) {
1347- v_in = ggml_pad (ctx, v_in, 0 , kv_pad, 0 , 0 );
1348- }
13491342 if (kv_scale != 1 .0f ) {
13501343 v_in = ggml_ext_scale (ctx, v_in, kv_scale);
13511344 }
13521345 v_in = ggml_cast (ctx, v_in, GGML_TYPE_F16 );
13531346
13541347 if (mask_in != nullptr ) {
13551348 mask_in = ggml_transpose (ctx, mask_in);
1356- } else {
1357- if (kv_pad > 0 ) {
1358- mask_in = ggml_ext_zeros (ctx, L_k, L_q, 1 , 1 );
1359- auto pad_tensor = ggml_ext_full (ctx, -INFINITY , kv_pad, L_q, 1 , 1 );
1360- mask_in = ggml_concat (ctx, mask_in, pad_tensor, 0 );
1361- }
13621349 }
13631350
13641351 if (mask_in != nullptr ) {
1365- // the need for padding got removed in ggml 4767bda
1366- // ensure we can still use the old version for now
1367- #ifdef GGML_KQ_MASK_PAD
1368- int mask_pad = 0 ;
1369- if (mask_in->ne [1 ] % GGML_KQ_MASK_PAD != 0 ) {
1370- mask_pad = GGML_PAD (L_q, GGML_KQ_MASK_PAD ) - mask_in->ne [1 ];
1371- }
1372- if (mask_pad > 0 ) {
1373- mask_in = ggml_pad (ctx, mask_in, 0 , mask_pad, 0 , 0 );
1374- }
1375- #endif
13761352 mask_in = ggml_cast (ctx, mask_in, GGML_TYPE_F16 );
13771353 }
13781354
@@ -1387,10 +1363,6 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_attention_ext(ggml_context* ctx,
13871363 if (flash_attn) {
13881364 // LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
13891365 bool can_use_flash_attn = true ;
1390- if (can_use_flash_attn && L_k % 256 != 0 ) {
1391- kv_pad = GGML_PAD (L_k, 256 ) - static_cast <int >(L_k);
1392- }
1393-
13941366 if (mask != nullptr ) {
13951367 // TODO: figure out if we can bend t5 to work too
13961368 can_use_flash_attn = can_use_flash_attn && mask->ne [3 ] == 1 ;
0 commit comments