@@ -274,13 +274,16 @@ llama_context::llama_context(
274274 // simulate full KV cache
275275 llama_kv_cache * kv_self = static_cast <llama_kv_cache *>(memory.get ());
276276
277- kv_self->set_full ();
277+ const auto kv_state = kv_self->init_full ();
278+ if (!kv_state) {
279+ throw std::runtime_error (" failed to initialize KV cache" );
280+ }
278281
279282 cross.v_embd .clear ();
280283
281284 // reserve pp graph first so that buffers are only allocated once
282285 {
283- auto * gf = graph_reserve (n_tokens, n_seqs, n_tokens);
286+ auto * gf = graph_reserve (n_tokens, n_seqs, n_tokens, kv_state. get () );
284287 if (!gf) {
285288 throw std::runtime_error (" failed to allocate compute pp buffers" );
286289 }
@@ -291,7 +294,7 @@ llama_context::llama_context(
291294
292295 // reserve with tg graph to get the number of splits and nodes
293296 {
294- auto * gf = graph_reserve (1 , 1 , 1 );
297+ auto * gf = graph_reserve (1 , 1 , 1 , kv_state. get () );
295298 if (!gf) {
296299 throw std::runtime_error (" failed to allocate compute tg buffers" );
297300 }
@@ -302,7 +305,7 @@ llama_context::llama_context(
302305
303306 // reserve again with pp graph to avoid ggml-alloc reallocations during inference
304307 {
305- auto * gf = graph_reserve (n_tokens, n_seqs, n_tokens);
308+ auto * gf = graph_reserve (n_tokens, n_seqs, n_tokens, kv_state. get () );
306309 if (!gf) {
307310 throw std::runtime_error (" failed to allocate compute pp buffers" );
308311 }
@@ -430,12 +433,15 @@ void llama_context::kv_self_update() {
430433
431434 if (kv_self->update (*this )) {
432435 // if the KV cache did any computation, we have to reserve a new worst-case graph
433- kv_self->set_full ();
436+ const auto kv_state = kv_self->init_full ();
437+ if (!kv_state) {
438+ throw std::runtime_error (" failed to initialize KV cache" );
439+ }
434440
435441 const uint32_t n_seqs = cparams.n_seq_max ;
436442 const uint32_t n_tokens = std::min (cparams.n_ctx , cparams.n_ubatch );
437443
438- auto * gf = graph_reserve (n_tokens, n_seqs, n_tokens);
444+ auto * gf = graph_reserve (n_tokens, n_seqs, n_tokens, kv_state. get () );
439445 if (!gf) {
440446 LLAMA_LOG_ERROR (" %s: failed to reserve graph after the KV cache update\n " , __func__);
441447 }
@@ -633,32 +639,32 @@ bool llama_context::apply_adapter_cvec(
633639 return cvec.apply (model, data, len, n_embd, il_start, il_end);
634640}
635641
636- llm_graph_result_ptr llama_context::process (const llama_ubatch & ubatch, llm_graph_type gtype, ggml_status * ret) {
642+ llm_graph_result_ptr llama_context::process_ubatch (const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_state_i * mstate, ggml_status & ret) {
643+ if (mstate && !mstate->apply ()) {
644+ LLAMA_LOG_ERROR (" %s: failed to apply memory state\n " , __func__);
645+ ret = GGML_STATUS_FAILED;
646+ return nullptr ;
647+ }
648+
637649 auto * gf = graph_init ();
638650 if (!gf) {
639651 LLAMA_LOG_ERROR (" %s: failed to initialize graph\n " , __func__);
640- if (ret) {
641- *ret = GGML_STATUS_FAILED;
642- }
652+ ret = GGML_STATUS_FAILED;
643653 return nullptr ;
644654 }
645655
646- auto res = graph_build (ctx_compute.get (), gf, ubatch, gtype);
656+ auto res = graph_build (ctx_compute.get (), gf, ubatch, gtype, mstate );
647657 if (!res) {
648658 LLAMA_LOG_ERROR (" %s: failed to build graph\n " , __func__);
649- if (ret) {
650- *ret = GGML_STATUS_FAILED;
651- }
659+ ret = GGML_STATUS_FAILED;
652660 return nullptr ;
653661 }
654662
655663 // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
656664
657665 if (!ggml_backend_sched_alloc_graph (sched.get (), gf)) {
658666 LLAMA_LOG_ERROR (" %s: failed to allocate graph\n " , __func__);
659- if (ret) {
660- *ret = GGML_STATUS_ALLOC_FAILED;
661- }
667+ ret = GGML_STATUS_ALLOC_FAILED;
662668 return nullptr ;
663669 }
664670
@@ -667,12 +673,12 @@ llm_graph_result_ptr llama_context::process(const llama_ubatch & ubatch, llm_gra
667673 const auto status = graph_compute (gf, ubatch.n_tokens > 1 );
668674 if (status != GGML_STATUS_SUCCESS) {
669675 LLAMA_LOG_ERROR (" %s: failed to compute graph, compute status: %d\n " , __func__, status);
670- if (ret) {
671- *ret = status;
672- }
676+ ret = status;
673677 return nullptr ;
674678 }
675679
680+ ret = GGML_STATUS_SUCCESS;
681+
676682 return res;
677683}
678684
@@ -748,7 +754,7 @@ int llama_context::encode(llama_batch & inp_batch) {
748754 cparams.causal_attn = false ;
749755
750756 ggml_status status;
751- auto res = process (ubatch, LLM_GRAPH_TYPE_ENCODER, & status);
757+ const auto res = process_ubatch (ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr , status);
752758
753759 cparams.causal_attn = causal_attn_org;
754760
@@ -927,12 +933,12 @@ int llama_context::decode(llama_batch & inp_batch) {
927933 // handle any pending defrags/shifts
928934 kv_self_update ();
929935
930- auto decode_state = kv_self->init (batch, cparams.n_ubatch , embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
931- if (!decode_state ) {
936+ auto kv_state = kv_self->init_batch (batch, cparams.n_ubatch , embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
937+ if (!kv_state ) {
932938 return -2 ;
933939 }
934940
935- switch (decode_state ->get_status ()) {
941+ switch (kv_state ->get_status ()) {
936942 case LLAMA_MEMORY_STATUS_SUCCESS:
937943 {
938944 } break ;
@@ -955,8 +961,8 @@ int llama_context::decode(llama_batch & inp_batch) {
955961
956962 int64_t n_outputs_prev = 0 ;
957963
958- while ( const auto * ubatch_ptr = decode_state-> next ()) {
959- const auto & ubatch = *ubatch_ptr ;
964+ do {
965+ const auto & ubatch = kv_state-> get_ubatch () ;
960966
961967 // count the outputs in this u_batch
962968 {
@@ -979,7 +985,7 @@ int llama_context::decode(llama_batch & inp_batch) {
979985 ggml_backend_sched_set_eval_callback (sched.get (), cparams.cb_eval , cparams.cb_eval_user_data );
980986
981987 ggml_status status;
982- auto res = process (ubatch, LLM_GRAPH_TYPE_DECODER, & status);
988+ const auto res = process_ubatch (ubatch, LLM_GRAPH_TYPE_DECODER, kv_state. get (), status);
983989
984990 if (!res) {
985991 // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
@@ -1092,7 +1098,7 @@ int llama_context::decode(llama_batch & inp_batch) {
10921098 }
10931099
10941100 n_outputs_prev += n_outputs;
1095- }
1101+ } while (kv_state-> next ());
10961102
10971103 // set to total number of outputs in the batch, for use in llama_get_logits_ith
10981104 n_outputs = n_outputs_all;
@@ -1101,7 +1107,7 @@ int llama_context::decode(llama_batch & inp_batch) {
11011107 {
11021108 bool sorted_output = true ;
11031109
1104- auto & out_ids = decode_state ->out_ids ();
1110+ auto & out_ids = kv_state ->out_ids ();
11051111
11061112 GGML_ASSERT (out_ids.size () == (size_t ) n_outputs_all);
11071113
@@ -1261,7 +1267,7 @@ ggml_cgraph * llama_context::graph_init() {
12611267 return ggml_new_graph_custom (ctx_compute.get (), graph_max_nodes (), false );
12621268}
12631269
1264- ggml_cgraph * llama_context::graph_reserve (uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs) {
1270+ ggml_cgraph * llama_context::graph_reserve (uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate ) {
12651271 LLAMA_LOG_DEBUG (" %s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n " , __func__, n_tokens, n_seqs, n_outputs);
12661272
12671273 if (n_tokens % n_seqs != 0 ) {
@@ -1281,7 +1287,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
12811287 llama_ubatch ubatch = { true , n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr , nullptr , nullptr , nullptr , nullptr };
12821288
12831289 auto * gf = graph_init ();
1284- auto res = graph_build (ctx_compute.get (), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
1290+ auto res = graph_build (ctx_compute.get (), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate );
12851291
12861292 this ->n_outputs = save_n_outputs;
12871293
@@ -1302,10 +1308,11 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
13021308}
13031309
13041310llm_graph_result_ptr llama_context::graph_build (
1305- ggml_context * ctx,
1306- ggml_cgraph * gf,
1307- const llama_ubatch & ubatch,
1308- llm_graph_type gtype) {
1311+ ggml_context * ctx,
1312+ ggml_cgraph * gf,
1313+ const llama_ubatch & ubatch,
1314+ llm_graph_type gtype,
1315+ const llama_memory_state_i * mstate) {
13091316 return model.build_graph (
13101317 {
13111318 /* .ctx =*/ ctx,
@@ -1317,7 +1324,7 @@ llm_graph_result_ptr llama_context::graph_build(
13171324 /* .backend_cpu =*/ backend_cpu,
13181325 /* .cvec =*/ &cvec,
13191326 /* .loras =*/ &loras,
1320- /* .memory =*/ memory. get () ,
1327+ /* .mstate =*/ mstate ,
13211328 /* .cross =*/ &cross,
13221329 /* .n_outputs =*/ n_outputs,
13231330 /* .cb =*/ graph_get_cb (),
@@ -2020,8 +2027,8 @@ void llama_context::opt_epoch_iter(
20202027
20212028 int64_t n_outputs_all = n_tokens_all;
20222029
2023- auto decode_state = kv_self->init (batch, cparams.n_ubatch , embd_pooled, /* logits_all */ true );
2024- if (!decode_state || decode_state ->get_status () != LLAMA_MEMORY_STATUS_SUCCESS) {
2030+ auto kv_state = kv_self->init_batch (batch, cparams.n_ubatch , embd_pooled, /* logits_all */ true );
2031+ if (!kv_state || kv_state ->get_status () != LLAMA_MEMORY_STATUS_SUCCESS) {
20252032 LLAMA_LOG_ERROR (" %s: could not initialize batch\n " , __func__);
20262033 break ;
20272034 }
@@ -2033,13 +2040,13 @@ void llama_context::opt_epoch_iter(
20332040 };
20342041
20352042 uint32_t pos_batch = 0 ;
2036- while ( const auto * ubatch_ptr = decode_state-> next ()) {
2037- const auto & ubatch = *ubatch_ptr ;
2043+ do {
2044+ const auto & ubatch = kv_state-> get_ubatch () ;
20382045
20392046 n_outputs = ubatch.n_tokens ;
20402047
20412048 auto * gf = graph_init ();
2042- auto res = graph_build (ctx_compute.get (), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
2049+ auto res = graph_build (ctx_compute.get (), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, kv_state. get () );
20432050
20442051 struct ggml_context * ctx_compute_opt;
20452052 {
@@ -2073,7 +2080,7 @@ void llama_context::opt_epoch_iter(
20732080 ggml_free (ctx_compute_opt);
20742081
20752082 pos_batch += ubatch.n_tokens ;
2076- }
2083+ } while (kv_state-> next ());
20772084 }
20782085}
20792086
0 commit comments