@@ -69,8 +69,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
6969 continue ;
7070 }
7171
72- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s ();
73- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
72+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s (il );
73+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
7474
7575 const char * dev_name = " CPU" ;
7676
@@ -1326,7 +1326,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
13261326 for (const auto & layer : layers) {
13271327 const uint32_t il = layer.il ;
13281328
1329- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s ();
1329+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s (il );
13301330
13311331 // Write key type
13321332 const int32_t k_type_i = (int32_t )layer.k ->type ;
@@ -1348,7 +1348,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
13481348 for (const auto & layer : layers) {
13491349 const uint32_t il = layer.il ;
13501350
1351- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
1351+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
13521352
13531353 // Write value type
13541354 const int32_t v_type_i = (int32_t )layer.v ->type ;
@@ -1372,7 +1372,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
13721372 for (const auto & layer : layers) {
13731373 const uint32_t il = layer.il ;
13741374
1375- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
1375+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
13761376
13771377 // Write value type
13781378 const int32_t v_type_i = (int32_t )layer.v ->type ;
@@ -1515,7 +1515,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
15151515 for (const auto & layer : layers) {
15161516 const uint32_t il = layer.il ;
15171517
1518- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s ();
1518+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s (il );
15191519
15201520 // Read type of key
15211521 int32_t k_type_i_ref;
@@ -1545,7 +1545,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
15451545 for (const auto & layer : layers) {
15461546 const uint32_t il = layer.il ;
15471547
1548- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
1548+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
15491549
15501550 // Read type of value
15511551 int32_t v_type_i_ref;
@@ -1575,7 +1575,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
15751575 for (const auto & layer : layers) {
15761576 const uint32_t il = layer.il ;
15771577
1578- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
1578+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
15791579
15801580 // Read type of value
15811581 int32_t v_type_i_ref;
@@ -2014,8 +2014,8 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
20142014 continue ;
20152015 }
20162016
2017- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (i) + hparams.n_embd_k_s ();
2018- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (i) + hparams.n_embd_v_s ();
2017+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (i) + hparams.n_embd_k_s (i );
2018+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (i) + hparams.n_embd_v_s (i );
20192019
20202020 const char * dev_name = " CPU" ;
20212021
@@ -2717,7 +2717,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
27172717 // Iterate and write all the keys first, each row is a cell
27182718 // Get whole range at a time
27192719 for (uint32_t il = 0 ; il < n_layer; ++il) {
2720- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s ();
2720+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s (il );
27212721
27222722 // Write key type
27232723 const int32_t k_type_i = (int32_t )k_l[il]->type ;
@@ -2737,7 +2737,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
27372737
27382738 if (!v_trans) {
27392739 for (uint32_t il = 0 ; il < n_layer; ++il) {
2740- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
2740+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
27412741
27422742 // Write value type
27432743 const int32_t v_type_i = (int32_t )v_l[il]->type ;
@@ -2758,7 +2758,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
27582758 // When v is transposed, we also need the element size and get the element ranges from each row
27592759 const uint32_t kv_size = size;
27602760 for (uint32_t il = 0 ; il < n_layer; ++il) {
2761- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
2761+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
27622762
27632763 // Write value type
27642764 const int32_t v_type_i = (int32_t )v_l[il]->type ;
@@ -2905,7 +2905,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
29052905
29062906 // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
29072907 for (uint32_t il = 0 ; il < n_layer; ++il) {
2908- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s ();
2908+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s (il );
29092909
29102910 // Read type of key
29112911 int32_t k_type_i_ref;
@@ -2933,7 +2933,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
29332933
29342934 if (!v_trans) {
29352935 for (uint32_t il = 0 ; il < n_layer; ++il) {
2936- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
2936+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
29372937
29382938 // Read type of value
29392939 int32_t v_type_i_ref;
@@ -2961,7 +2961,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
29612961 } else {
29622962 // For each layer, read the values for each cell (transposed)
29632963 for (uint32_t il = 0 ; il < n_layer; ++il) {
2964- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
2964+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
29652965
29662966 // Read type of value
29672967 int32_t v_type_i_ref;
0 commit comments