diff --git a/include/llama.h b/include/llama.h index a0a660bff88da..b96f2417951f2 100644 --- a/include/llama.h +++ b/include/llama.h @@ -461,6 +461,7 @@ extern "C" { LLAMA_API bool llama_supports_rpc (void); LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx); + LLAMA_API uint32_t llama_n_ctx_seq (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index f6192a36e0ee5..0190475458822 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -112,11 +112,17 @@ llama_context::llama_context( } } - const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max; + cparams.n_ctx_seq = cparams.kv_unified ? cparams.n_ctx : cparams.n_ctx / cparams.n_seq_max; + + if (cparams.n_ctx_seq > hparams.n_ctx_train) { + LLAMA_LOG_WARN("%s: capping n_ctx_seq (%u) to n_ctx_train (%u)\n", __func__, cparams.n_ctx_seq, hparams.n_ctx_train); + + cparams.n_ctx_seq = hparams.n_ctx_train; + } LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max); LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx); - LLAMA_LOG_INFO("%s: n_ctx_per_seq = %u\n", __func__, n_ctx_per_seq); + LLAMA_LOG_INFO("%s: n_ctx_seq = %u\n", __func__, cparams.n_ctx_seq); LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch); LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn); @@ -125,14 +131,14 @@ llama_context::llama_context( LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); - if (n_ctx_per_seq < hparams.n_ctx_train) { - LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n", - __func__, n_ctx_per_seq, hparams.n_ctx_train); + if (cparams.n_ctx_seq < hparams.n_ctx_train) { + LLAMA_LOG_WARN("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n", + __func__, cparams.n_ctx_seq, hparams.n_ctx_train); } - if (n_ctx_per_seq > hparams.n_ctx_train) { - LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n", - __func__, n_ctx_per_seq, hparams.n_ctx_train); + if (cparams.n_ctx_seq > hparams.n_ctx_train) { + LLAMA_LOG_WARN("%s: n_ctx_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n", + __func__, cparams.n_ctx_seq, hparams.n_ctx_train); } if (!hparams.vocab_only) { @@ -453,8 +459,8 @@ uint32_t llama_context::n_ctx() const { return cparams.n_ctx; } -uint32_t llama_context::n_ctx_per_seq() const { - return cparams.n_ctx / cparams.n_seq_max; +uint32_t llama_context::n_ctx_seq() const { + return cparams.n_ctx_seq; } uint32_t llama_context::n_batch() const { @@ -2383,6 +2389,10 @@ uint32_t llama_n_ctx(const llama_context * ctx) { return ctx->n_ctx(); } +uint32_t llama_n_ctx_seq(const llama_context * ctx) { + return ctx->n_ctx_seq(); +} + uint32_t llama_n_batch(const llama_context * ctx) { return ctx->n_batch(); } diff --git a/src/llama-context.h b/src/llama-context.h index ed6d82cb396f9..20cbd78955412 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -43,11 +43,11 @@ struct llama_context { ggml_backend_sched_t get_sched() const; - uint32_t n_ctx() const; - uint32_t n_ctx_per_seq() const; - uint32_t n_batch() const; - uint32_t n_ubatch() const; - uint32_t n_seq_max() const; + uint32_t n_ctx() const; + uint32_t n_ctx_seq() const; + uint32_t n_batch() const; + uint32_t n_ubatch() const; + uint32_t n_seq_max() const; uint32_t n_threads() const; uint32_t n_threads_batch() const; diff --git a/src/llama-cparams.h b/src/llama-cparams.h index eae7b839f4857..fcef8fa976038 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -8,6 +8,7 @@ struct llama_cparams { uint32_t n_ctx; // context size used during inference + uint32_t n_ctx_seq; // context for a single sequence uint32_t n_batch; uint32_t n_ubatch; uint32_t n_seq_max; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index ea6f59ed482bb..8d2ec4763bbdc 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -6581,14 +6581,14 @@ float llama_model::get_rope_freq_scale(const llama_cparams & cparams, int il) co } ggml_tensor * llama_model::get_rope_factors(const llama_cparams & cparams, int il) const { - const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max; + const uint32_t n_ctx_seq = cparams.n_ctx_seq; // choose long/short freq factors based on the context size if (layers[il].rope_freqs != nullptr) { return layers[il].rope_freqs; } - if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) { + if (n_ctx_seq > hparams.n_ctx_orig_yarn) { return layers[il].rope_long; } @@ -19710,12 +19710,6 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* filter_attn */ std::move(filter_attn), /* filter_recr */ std::move(filter_recr)); } else { - uint32_t n_ctx_per_stream = cparams.n_ctx; - - if (!cparams.kv_unified) { - n_ctx_per_stream = (cparams.n_ctx + cparams.n_seq_max - 1)/cparams.n_seq_max; - } - llama_memory_i::layer_reuse_cb reuse = nullptr; if (arch == LLM_ARCH_GEMMA3N) { @@ -19739,7 +19733,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, cparams.offload_kqv, params.swa_full, cparams.kv_unified, - n_ctx_per_stream, + cparams.n_ctx_seq, cparams.n_seq_max, cparams.n_ubatch, 1, @@ -19755,7 +19749,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, !cparams.flash_attn, cparams.offload_kqv, cparams.kv_unified, - n_ctx_per_stream, + cparams.n_ctx_seq, cparams.n_seq_max, 1, hparams.n_swa, diff --git a/tests/test-thread-safety.cpp b/tests/test-thread-safety.cpp index e5158fb5062f0..bcb86c35e6652 100644 --- a/tests/test-thread-safety.cpp +++ b/tests/test-thread-safety.cpp @@ -131,7 +131,14 @@ int main(int argc, char ** argv) { } batch = llama_batch_get_one(&token, 1); - if (llama_decode(ctx.get(), batch)) { + + int ret = llama_decode(ctx.get(), batch); + if (ret == 1 && i > 0) { + LOG_INF("Context full, stopping generation.\n"); + break; + } + + if (ret != 0) { LOG_ERR("Model %d/%d, Context %d/%d: failed to decode\n", m + 1, num_models, c + 1, num_contexts); failed.store(true); return; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index cb794ab647eba..4752c42a3f5eb 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -2413,7 +2413,7 @@ struct server_context { params_dft.devices = params_base.speculative.devices; params_dft.model = params_base.speculative.model; - params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx; + params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_base.speculative.n_ctx; params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers; params_dft.n_parallel = 1; params_dft.cache_type_k = params_base.speculative.cache_type_k; @@ -2501,8 +2501,6 @@ struct server_context { } void init() { - const int32_t n_ctx_slot = n_ctx / params_base.n_parallel; - SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel); for (int i = 0; i < params_base.n_parallel; i++) { @@ -2510,7 +2508,7 @@ struct server_context { slot.id = i; slot.ctx = ctx; - slot.n_ctx = n_ctx_slot; + slot.n_ctx = llama_n_ctx_seq(ctx); slot.mctx = mctx; slot.prompt.tokens.has_mtmd = mctx != nullptr; @@ -2533,7 +2531,7 @@ struct server_context { } } - SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx); + SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx); slot.callback_on_release = [this](int) { queue_tasks.pop_deferred_task(); @@ -2705,6 +2703,39 @@ struct server_context { return ret; } + // return true if at least one slot has been purged + // TODO: improve logic + // - smarter decision which slot to purge + // - move slot to level 2 cache instead of removing? + // - instead of purging, try to store and resume later? + bool try_purge_idle_slots() { + bool res = false; + + if (!params_base.kv_unified) { + return res; + } + + for (auto & slot : slots) { + if (slot.is_processing()) { + continue; + } + + if (slot.prompt.n_tokens() > 0) { + SRV_WRN("purging slot %d with %zu tokens\n", slot.id, slot.prompt.tokens.size()); + + llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1); + slot.prompt.tokens.clear(); + + res = true; + + // purge slots one by one + break; + } + } + + return res; + } + bool launch_slot_with_task(server_slot & slot, server_task && task) { slot.reset(); @@ -3640,9 +3671,10 @@ struct server_context { int32_t n_batch = llama_n_batch(ctx); int32_t n_ubatch = llama_n_ubatch(ctx); - // next, batch any pending prompts without exceeding n_batch - float alora_scale = -1.0f; + float alora_scale = -1.0f; size_t alora_disabled_id = 0; + + // next, batch any pending prompts without exceeding n_batch if (params_base.cont_batching || batch.n_tokens == 0) { for (auto & slot : slots) { // check if we can batch this slot with the previous one @@ -3666,7 +3698,7 @@ struct server_context { slot.n_past = 0; slot.state = SLOT_STATE_PROCESSING_PROMPT; - SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", + SLT_INF(slot, "new prompt, n_ctx = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.task->params.n_keep, slot.n_prompt_tokens()); // print prompt tokens (for debugging) @@ -4123,6 +4155,8 @@ struct server_context { std::string err; if (n_batch == 1 && ret == 1) { + // TODO: try to terminate only the largest active slot and continue + // need to remove the tokens from the current batch too err = "Context size has been exceeded."; } @@ -4138,17 +4172,23 @@ struct server_context { // TODO: handle ret == 2 (abort) when we start aborting if (!err.empty()) { - SRV_ERR("%s, i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret); + SRV_ERR("%s i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret); + for (auto & slot : slots) { - send_error(slot, err); - slot.release(); + if (slot.is_processing()) { + send_error(slot, err); + slot.release(); + } } + break; } } // retry with half the batch size to try to find a free slot in the KV cache - n_batch /= 2; + if (!try_purge_idle_slots()) { + n_batch /= 2; + } SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret); @@ -4942,7 +4982,7 @@ int main(int argc, char ** argv) { // Everything else, including multimodal completions. inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true); } - const size_t n_ctx_slot = ctx_server.n_ctx / ctx_server.params_base.n_parallel; + const size_t n_ctx_slot = ctx_server.slots.front().n_ctx; tasks.reserve(inputs.size()); for (size_t i = 0; i < inputs.size(); i++) { auto n_prompt_tokens = inputs[i].size(); diff --git a/tools/server/tests/unit/test_chat_completion.py b/tools/server/tests/unit/test_chat_completion.py index d56d3d5f178b8..392e0efecdbbd 100644 --- a/tools/server/tests/unit/test_chat_completion.py +++ b/tools/server/tests/unit/test_chat_completion.py @@ -433,21 +433,21 @@ def test_context_size_exceeded_stream(): @pytest.mark.parametrize( "n_batch,batch_count,reuse_cache", [ - (64, 15, False), + (64, 3, False), (64, 1, True), ] ) -def test_return_progresssss(n_batch, batch_count, reuse_cache): +def test_return_progress(n_batch, batch_count, reuse_cache): global server server.n_batch = n_batch - server.n_ctx = 2048 + server.n_ctx = 256 server.n_slots = 1 server.start() def make_cmpl_request(): return server.make_stream_request("POST", "/chat/completions", data={ "max_tokens": 10, "messages": [ - {"role": "user", "content": "This is a test" * 100}, + {"role": "user", "content": "This is a test" * 10}, ], "stream": True, "return_progress": True, diff --git a/tools/server/tests/unit/test_completion.py b/tools/server/tests/unit/test_completion.py index 00ba78cf67c09..acb893d495899 100644 --- a/tools/server/tests/unit/test_completion.py +++ b/tools/server/tests/unit/test_completion.py @@ -368,6 +368,37 @@ def check_slots_status(): # assert match_regex(re_content, res.body["content"]) +@pytest.mark.parametrize( + "n_ctx,n_slots,n_predict_vals,expected_success", + [ + (256, 4, [80, 40, 80, 80], [True, True, True, True]), + (256, 4, [70, 70, 70, 70], [False, False, False, False]), + (256, 4, [90, 90, 40, 90], [False, False, True, False]), + (256, 4, [90, 90, 40, 80], [True, True, True, True]), + ], +) +def test_completion_unified(n_ctx, n_slots, n_predict_vals, expected_success): + global server + server.n_slots = n_slots + server.kv_unified = True + server.n_ctx = n_ctx + server.start() + prompt = "A" + tasks = [] + for n_predict in n_predict_vals: + tasks.append((server.make_request, ("POST", "/completion", {"prompt": prompt, "n_predict": n_predict}))) + results = parallel_function_calls(tasks) + for res, n_predict, expect_ok in zip(results, n_predict_vals, expected_success): + if expect_ok: + assert res.status_code == 200 + assert "content" in res.body + if "timings" in res.body: + assert res.body["timings"]["predicted_n"] == n_predict + else: + assert res.status_code == 500 + assert "content" not in res.body + + @pytest.mark.parametrize( "prompt,n_predict,response_fields", [ diff --git a/tools/server/tests/unit/test_infill.py b/tools/server/tests/unit/test_infill.py index 73dacdae812b8..cd1a391b4adbc 100644 --- a/tools/server/tests/unit/test_infill.py +++ b/tools/server/tests/unit/test_infill.py @@ -18,7 +18,7 @@ def test_infill_without_input_extra(): "input_suffix": "}\n", }) assert res.status_code == 200 - assert match_regex("(Ann|small|shiny|Daddy)+", res.body["content"]) + assert match_regex("(Ann|small|shiny|Daddy|Jimmy)+", res.body["content"]) def test_infill_with_input_extra(): @@ -34,7 +34,7 @@ def test_infill_with_input_extra(): "input_suffix": "}\n", }) assert res.status_code == 200 - assert match_regex("(Dad|excited|park)+", res.body["content"]) + assert match_regex("(Dad|excited|park|Jimmy)+", res.body["content"]) @pytest.mark.parametrize("input_extra", [ diff --git a/tools/server/tests/utils.py b/tools/server/tests/utils.py index 4ba3d43c33044..da703c4c51a15 100644 --- a/tools/server/tests/utils.py +++ b/tools/server/tests/utils.py @@ -78,6 +78,7 @@ class ServerProcess: server_embeddings: bool | None = False server_reranking: bool | None = False server_metrics: bool | None = False + kv_unified: bool | None = False server_slots: bool | None = False pooling: str | None = None draft: int | None = None @@ -159,6 +160,8 @@ def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None: server_args.append("--reranking") if self.server_metrics: server_args.append("--metrics") + if self.kv_unified: + server_args.append("--kv-unified") if self.server_slots: server_args.append("--slots") else: