Skip to content

Commit cf7dd4b

Browse files
committed
server : cache prompts and checkpoints only for completion tasks
1 parent d637973 commit cf7dd4b

File tree

1 file changed

+33
-18
lines changed

1 file changed

+33
-18
lines changed

tools/server/server.cpp

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1436,7 +1436,8 @@ struct server_slot_prompt {
14361436
struct server_prompt_cache {
14371437
std::list<server_slot_prompt> states;
14381438

1439-
size_t limit_size = 0; // 0 = no limit
1439+
// in bytes, 0 = no limit
1440+
size_t limit_size = 2ull*1024*1024*1024;
14401441

14411442
size_t size() const {
14421443
size_t res = 0;
@@ -1532,7 +1533,7 @@ struct server_slot {
15321533
std::vector<std::string> generated_tool_call_ids;
15331534

15341535
// stats
1535-
size_t n_sent_text = 0; // number of sent text character
1536+
size_t n_sent_text = 0; // number of sent text character
15361537

15371538
int64_t t_start_process_prompt;
15381539
int64_t t_start_generation;
@@ -1792,7 +1793,7 @@ void server_slot::prompt_save(server_prompt_cache & prompt_cache) {
17921793
const int cur_lcs_len = cached_prompt.get_common_prefix(prompt.tokens);
17931794

17941795
if (cur_lcs_len == (int) prompt.tokens.size()) {
1795-
SRV_INF("%s", " - prompt is already cached, skipping\n");
1796+
SRV_WRN("%s", " - prompt is already cached, skipping\n");
17961797
return;
17971798
}
17981799
}
@@ -1804,7 +1805,7 @@ void server_slot::prompt_save(server_prompt_cache & prompt_cache) {
18041805
const int len = cached_prompt.get_common_prefix(prompt.tokens);
18051806

18061807
if (len == (int) cached_prompt.size()) {
1807-
SRV_INF(" - removing cached prompt with length %d\n", len);
1808+
SRV_WRN(" - removing cached prompt with length %d\n", len);
18081809

18091810
it = states.erase(it);
18101811
} else {
@@ -1814,7 +1815,7 @@ void server_slot::prompt_save(server_prompt_cache & prompt_cache) {
18141815

18151816
const size_t cur_size = llama_state_seq_get_size_ext(ctx, id, 0);
18161817

1817-
SRV_INF(" - saving prompt with length %d, total cache size = %.3f MiB\n",
1818+
SRV_WRN(" - saving prompt with length %d, total cache size = %.3f MiB\n",
18181819
(int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0));
18191820

18201821
// if there is a limit, remove the oldest entries to make room
@@ -1824,6 +1825,8 @@ void server_slot::prompt_save(server_prompt_cache & prompt_cache) {
18241825
break;
18251826
}
18261827

1828+
SRV_WRN(" - cache size limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0));
1829+
18271830
states.pop_front();
18281831
}
18291832
} else {
@@ -1833,6 +1836,8 @@ void server_slot::prompt_save(server_prompt_cache & prompt_cache) {
18331836
break;
18341837
}
18351838

1839+
SRV_WRN(" - cache token limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0));
1840+
18361841
states.pop_front();
18371842
}
18381843
}
@@ -1847,15 +1852,19 @@ void server_slot::prompt_save(server_prompt_cache & prompt_cache) {
18471852

18481853
llama_state_seq_get_data_ext(ctx, cur.data.data(), cur_size, id, 0);
18491854

1850-
SRV_INF(" - cache state: %zu prompts, %.3f MiB\n", states.size(), prompt_cache.size() / (1024.0 * 1024.0));
1855+
SRV_WRN(" - cache state: %zu prompts, %.3f MiB\n", states.size(), prompt_cache.size() / (1024.0 * 1024.0));
1856+
1857+
for (const auto & state : states) {
1858+
SRV_WRN(" - prompt %p: %7d tokens, checkpoints: %2zu, %.3f MiB\n", (const void *)&state, state.n_tokens(), state.checkpoints.size(), state.size() / (1024.0 * 1024.0));
1859+
}
18511860
}
18521861

18531862
void server_slot::prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) {
18541863
auto & states = prompt_cache.states;
18551864

18561865
int lcs_len = prompt.tokens.get_common_prefix(tokens);
18571866

1858-
SRV_INF(" - looking for better prompt, base lcs_len = %d\n", lcs_len);
1867+
SRV_WRN(" - looking for better prompt, base lcs_len = %d\n", lcs_len);
18591868

18601869
auto it_best = states.end();
18611870

@@ -1872,7 +1881,7 @@ void server_slot::prompt_load(server_prompt_cache & prompt_cache, const server_t
18721881
}
18731882

18741883
if (it_best != states.end()) {
1875-
SRV_INF(" - found better prompt with lcs_len = %d\n", lcs_len);
1884+
SRV_WRN(" - found better prompt with lcs_len = %d\n", lcs_len);
18761885

18771886
const size_t size = it_best->data.size();
18781887
const size_t n = llama_state_seq_set_data_ext(ctx, it_best->data.data(), size, id, 0);
@@ -2454,7 +2463,7 @@ struct server_context {
24542463
SRV_ERR("%s", "failed to create speculator\n");
24552464
return;
24562465
}
2457-
for (auto &pair : params_base.speculative.replacements) {
2466+
for (auto & pair : params_base.speculative.replacements) {
24582467
common_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str());
24592468
}
24602469
}
@@ -2483,7 +2492,7 @@ struct server_context {
24832492
// 1. It's not explicitly disabled (reasoning_budget == 0)
24842493
// 2. The chat template supports it
24852494
const bool enable_thinking = params_base.use_jinja && params_base.reasoning_budget != 0 && common_chat_templates_support_enable_thinking(chat_templates.get());
2486-
SRV_INF("Enable thinking? %d\n", enable_thinking);
2495+
SRV_INF("thinking = %d\n", enable_thinking);
24872496

24882497
oai_parser_opt = {
24892498
/* use_jinja */ params_base.use_jinja,
@@ -2585,21 +2594,24 @@ struct server_context {
25852594
if (ret) {
25862595
const auto & tokens = ret->prompt.tokens;
25872596

2597+
// cache prompts only for completion tasks
2598+
update_cache = update_cache && task.type == SERVER_TASK_TYPE_COMPLETION;
2599+
25882600
// don't update the cache if the slot's context is empty
25892601
update_cache = update_cache && tokens.size() > 0;
25902602

25912603
// TODO: mtmd does not support prompt cache
25922604
update_cache = update_cache && (ret->mctx == nullptr);
25932605

25942606
if (update_cache) {
2595-
SRV_INF("%s", "updating prompt cache\n");
2607+
SRV_WRN("%s", "updating prompt cache\n");
25962608

25972609
const int64_t t_start = ggml_time_us();
25982610

25992611
ret->prompt_save(prompt_cache);
26002612
ret->prompt_load(prompt_cache, task.tokens);
26012613

2602-
SRV_INF("prompt cache update took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0);
2614+
SRV_WRN("prompt cache update took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0);
26032615
}
26042616
}
26052617

@@ -3734,16 +3746,16 @@ struct server_context {
37343746

37353747
if (!do_reset) {
37363748
// restore the context checkpoint
3737-
const size_t ctx_checkpoint_size = it->data.size();
3738-
const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), ctx_checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
3749+
const size_t checkpoint_size = it->data.size();
3750+
const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
37393751

3740-
if (n != ctx_checkpoint_size) {
3741-
SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) ctx_checkpoint_size / 1024 / 1024);
3752+
if (n != checkpoint_size) {
3753+
SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
37423754
do_reset = true;
37433755
//printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
37443756
} else {
37453757
slot.n_past = std::min(slot.n_past, std::max(it->pos_min + 1, it->pos_max));
3746-
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) ctx_checkpoint_size / 1024 / 1024);
3758+
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
37473759
}
37483760
}
37493761

@@ -3842,6 +3854,9 @@ struct server_context {
38423854

38433855
bool do_checkpoint = params_base.n_ctx_checkpoints > 0;
38443856

3857+
// make checkpoints only for completion tasks
3858+
do_checkpoint = do_checkpoint && slot.task->type == SERVER_TASK_TYPE_COMPLETION;
3859+
38453860
// make a checkpoint of the parts of the memory that cannot be rolled back.
38463861
// checkpoints are created only if:
38473862
// - the model uses SWA and we are not using `swa_full`
@@ -3941,7 +3956,7 @@ struct server_context {
39413956

39423957
llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
39433958

3944-
SLT_WRN(slot, "saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
3959+
SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
39453960
(int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
39463961
}
39473962
}

0 commit comments

Comments
 (0)