@@ -954,7 +954,7 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
954954}
955955
956956ggml_tensor * llm_graph_context::build_inp_s_copy () const {
957- const auto * kv_state = static_cast < const llama_kv_cache_recurrent_state *>(mstate );
957+ const auto * kv_state = get_state_recurrent ( );
958958
959959 auto inp = std::make_unique<llm_graph_input_s_copy>(kv_state);
960960
@@ -971,7 +971,7 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const {
971971}
972972
973973ggml_tensor * llm_graph_context::build_inp_s_mask () const {
974- const auto * kv_state = static_cast < const llama_kv_cache_recurrent_state *>(mstate );
974+ const auto * kv_state = get_state_recurrent ( );
975975
976976 auto inp = std::make_unique<llm_graph_input_s_mask>(kv_state);
977977
@@ -1025,7 +1025,7 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
10251025}
10261026
10271027ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec () const {
1028- const auto * kv_state = static_cast < const llama_kv_cache_unified_state *>(mstate );
1028+ const auto * kv_state = get_state_unified ( );
10291029
10301030 auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_state);
10311031
@@ -1056,6 +1056,30 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
10561056 return pos_bias;
10571057}
10581058
1059+ const llama_kv_cache_unified_state * llm_graph_context::get_state_unified () const {
1060+ const auto * umstate = dynamic_cast <const llama_kv_cache_unified_state *>(mstate);
1061+ if (!umstate) {
1062+ const auto hmstate = dynamic_cast <const llama_kv_cache_hybrid_recurrent_state *>(mstate);
1063+ if (hmstate) {
1064+ umstate = hmstate->get_state_attn ();
1065+ }
1066+ }
1067+ GGML_ASSERT (umstate);
1068+ return umstate;
1069+ }
1070+
1071+ const llama_kv_cache_recurrent_state * llm_graph_context::get_state_recurrent () const {
1072+ const auto * rmstate = dynamic_cast <const llama_kv_cache_recurrent_state *>(mstate);
1073+ if (!rmstate) {
1074+ const auto hmstate = dynamic_cast <const llama_kv_cache_hybrid_recurrent_state *>(mstate);
1075+ if (hmstate) {
1076+ rmstate = hmstate->get_state_recurrent ();
1077+ }
1078+ }
1079+ GGML_ASSERT (rmstate);
1080+ return rmstate;
1081+ }
1082+
10591083ggml_tensor * llm_graph_context::build_attn_mha (
10601084 ggml_cgraph * gf,
10611085 ggml_tensor * q,
@@ -1231,7 +1255,7 @@ ggml_tensor * llm_graph_context::build_attn(
12311255}
12321256
12331257llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified () const {
1234- const auto * kv_state = static_cast < const llama_kv_cache_unified_state *>(mstate );
1258+ const auto * kv_state = get_state_unified ( );
12351259
12361260 auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_state);
12371261
@@ -1268,7 +1292,7 @@ ggml_tensor * llm_graph_context::build_attn(
12681292 ggml_build_forward_expand (gf, k_cur);
12691293 ggml_build_forward_expand (gf, v_cur);
12701294
1271- const auto * kv_state = static_cast < const llama_kv_cache_unified_state *>(mstate );
1295+ const auto * kv_state = get_state_unified ( );
12721296
12731297 // store to KV cache
12741298 {
@@ -1446,7 +1470,7 @@ ggml_tensor * llm_graph_context::build_copy_mask_state(
14461470 ggml_tensor * state_mask,
14471471 int32_t n_state,
14481472 int32_t n_seqs) const {
1449- const auto * kv_state = static_cast < const llama_kv_cache_recurrent_state *>(mstate );
1473+ const auto * kv_state = get_state_recurrent ( );
14501474
14511475 const auto n_kv = kv_state->get_n_kv ();
14521476 const auto kv_head = kv_state->get_head ();
@@ -1478,7 +1502,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
14781502 ggml_tensor * state_mask,
14791503 const llama_ubatch & ubatch,
14801504 int il) const {
1481- const auto * kv_state = static_cast < const llama_kv_cache_recurrent_state *>(mstate );
1505+ const auto * kv_state = get_state_recurrent ( );
14821506
14831507 const auto token_shift_count = hparams.token_shift_count ;
14841508
@@ -1499,7 +1523,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
14991523 ggml_tensor * token_shift,
15001524 const llama_ubatch & ubatch,
15011525 int il) const {
1502- const auto * kv_state = static_cast < const llama_kv_cache_recurrent_state *>(mstate );
1526+ const auto * kv_state = get_state_recurrent ( );
15031527
15041528 const auto token_shift_count = hparams.token_shift_count ;
15051529 const auto n_embd = hparams.n_embd ;
0 commit comments