diff --git a/common/arg.cpp b/common/arg.cpp index 40af7e574830f..f52e7fe756d16 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2748,6 +2748,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.embedding = true; } ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS")); + add_opt(common_arg( + {"--truncate-embed"}, + string_format("allow truncation for embedding tasks to handle large inputs (default: %s)", params.truncate_embed ? "enabled" : "disabled"), + [](common_params & params) { + params.truncate_embed = true; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_TRUNCATE_EMBED")); add_opt(common_arg( {"--reranking", "--rerank"}, string_format("enable reranking endpoint on server (default: %s)", "disabled"), diff --git a/common/common.h b/common/common.h index 8922090e7b10d..5fa63e569b65d 100644 --- a/common/common.h +++ b/common/common.h @@ -356,6 +356,7 @@ struct common_params { // embedding bool embedding = false; // get only sentence embedding + bool truncate_embed = false; // allow truncation for embedding tasks to handle large inputs int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm) std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix std::string embd_sep = "\n"; // separator of embeddings diff --git a/tools/server/README.md b/tools/server/README.md index 6f962664f6774..214afae5aec0a 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -159,6 +159,7 @@ The project is under active development, and we are [looking for feedback and co | `--path PATH` | path to serve static files from (default: )
(env: LLAMA_ARG_STATIC_PATH) | | `--no-webui` | Disable the Web UI (default: enabled)
(env: LLAMA_ARG_NO_WEBUI) | | `--embedding, --embeddings` | restrict to only support embedding use case; use only with dedicated embedding models (default: disabled)
(env: LLAMA_ARG_EMBEDDINGS) | +| `--truncate-embed` | allow truncation for embedding tasks to handle large inputs (default: disabled)
(env: LLAMA_ARG_TRUNCATE_EMBED) | | `--reranking, --rerank` | enable reranking endpoint on server (default: disabled)
(env: LLAMA_ARG_RERANKING) | | `--api-key KEY` | API key to use for authentication (default: none)
(env: LLAMA_API_KEY) | | `--api-key-file FNAME` | path to file containing API keys (default: none) | @@ -636,6 +637,8 @@ Returns a JSON object with a field `prompt` containing a string of the input mes The same as [the embedding example](../embedding) does. +**Note**: By default, embedding tasks cannot be split across multiple batches for safety. For large inputs that exceed the batch size, use the `--truncate-embed` flag to enable automatic truncation. When truncation occurs, the `truncated` field in the response will indicate this. + *Options:* `content`: Set the text to process. @@ -1175,6 +1178,8 @@ curl http://localhost:8080/v1/chat/completions \ This endpoint requires that the model uses a pooling different than type `none`. The embeddings are normalized using the Eucledian norm. +**Note**: By default, embedding tasks cannot be split across multiple batches for safety. For large inputs that exceed the batch size, use the `--truncate-embed` flag to enable automatic truncation. When truncation occurs, the `truncated` field in the response will indicate this. + *Options:* See [OpenAI Embeddings API documentation](https://platform.openai.com/docs/api-reference/embeddings). diff --git a/tools/server/server.cpp b/tools/server/server.cpp index d3f6271931f62..0f6e307c6263f 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1034,6 +1034,7 @@ struct server_task_result_embd : server_task_result { std::vector> embedding; int32_t n_tokens; + bool truncated = false; // OAI-compat fields oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; @@ -1052,6 +1053,7 @@ struct server_task_result_embd : server_task_result { return json { {"index", index}, {"embedding", embedding}, + {"truncated", truncated}, }; } @@ -1060,6 +1062,7 @@ struct server_task_result_embd : server_task_result { {"index", index}, {"embedding", embedding[0]}, {"tokens_evaluated", n_tokens}, + {"truncated", truncated}, }; } }; @@ -1360,10 +1363,15 @@ struct server_slot { // if the context does not have a memory module then all embeddings have to be computed within a single ubatch // also we cannot split if the pooling would require any past tokens - bool can_split() const { + // @param truncate_embed: if true, allows splitting for embedding tasks to handle large inputs + // with automatic truncation. If false, uses original conservative logic. + // NOTE: When embedding inputs are truncated, the resulting embedding may not fully represent + // the original input. The 'truncated' field in the response indicates when this occurs. + bool can_split(bool truncate_embed = false) const { return !need_embd() || - (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST); + (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST) || + (need_embd() && truncate_embed); // allow splitting for embedding tasks only if truncate_embed is enabled } bool can_batch_with(server_slot & other_slot) const { @@ -2570,12 +2578,15 @@ struct server_context { res->id = slot.id_task; res->index = slot.index; res->n_tokens = slot.n_prompt_tokens; + res->truncated = slot.truncated; res->oaicompat = slot.params.oaicompat; const int n_embd = llama_model_n_embd(model); std::vector embd_res(n_embd, 0.0f); + // Note: If the input was truncated (slot.truncated == true), this embedding + // represents only the processed portion of the original input for (int i = 0; i < batch.n_tokens; ++i) { if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { continue; @@ -3129,7 +3140,7 @@ struct server_context { continue; } - if (!slot.can_split()) { + if (!slot.can_split(params_base.truncate_embed)) { if (slot.n_prompt_tokens > n_ubatch) { slot.release(); send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER); @@ -3146,7 +3157,8 @@ struct server_context { // if context shift is disabled, we make sure prompt size is smaller than KV size // TODO: there should be a separate parameter that control prompt truncation // context shift should be applied only during the generation phase - if (slot.n_prompt_tokens >= slot.n_ctx) { + // For embedding tasks, allow truncation even when context shift is disabled + if (slot.n_prompt_tokens >= slot.n_ctx && !slot.need_embd()) { slot.release(); send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST); continue; @@ -3186,6 +3198,11 @@ struct server_context { SLT_WRN(slot, "input truncated, n_ctx = %d, n_keep = %d, n_left = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, n_left, slot.n_prompt_tokens); + // Warn specifically for embedding tasks about potential quality impact + if (slot.need_embd()) { + SLT_WRN(slot, "%s", "WARNING: Embedding input was truncated. The resulting embedding may not fully represent the original input. Consider increasing context size or reducing input length for better embedding quality."); + } + GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx); } @@ -3272,7 +3289,7 @@ struct server_context { slot.n_prompt_tokens_processed = 0; } - if (!slot.can_split()) { + if (!slot.can_split(params_base.truncate_embed)) { // cannot fit the prompt in the current batch - will try next iter if (batch.n_tokens + slot.n_prompt_tokens > n_batch) { continue;