diff --git a/common/arg.cpp b/common/arg.cpp
index 40af7e574830f..f52e7fe756d16 100644
--- a/common/arg.cpp
+++ b/common/arg.cpp
@@ -2748,6 +2748,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.embedding = true;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS"));
+ add_opt(common_arg(
+ {"--truncate-embed"},
+ string_format("allow truncation for embedding tasks to handle large inputs (default: %s)", params.truncate_embed ? "enabled" : "disabled"),
+ [](common_params & params) {
+ params.truncate_embed = true;
+ }
+ ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_TRUNCATE_EMBED"));
add_opt(common_arg(
{"--reranking", "--rerank"},
string_format("enable reranking endpoint on server (default: %s)", "disabled"),
diff --git a/common/common.h b/common/common.h
index 8922090e7b10d..5fa63e569b65d 100644
--- a/common/common.h
+++ b/common/common.h
@@ -356,6 +356,7 @@ struct common_params {
// embedding
bool embedding = false; // get only sentence embedding
+ bool truncate_embed = false; // allow truncation for embedding tasks to handle large inputs
int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
std::string embd_sep = "\n"; // separator of embeddings
diff --git a/tools/server/README.md b/tools/server/README.md
index 6f962664f6774..214afae5aec0a 100644
--- a/tools/server/README.md
+++ b/tools/server/README.md
@@ -159,6 +159,7 @@ The project is under active development, and we are [looking for feedback and co
| `--path PATH` | path to serve static files from (default: )
(env: LLAMA_ARG_STATIC_PATH) |
| `--no-webui` | Disable the Web UI (default: enabled)
(env: LLAMA_ARG_NO_WEBUI) |
| `--embedding, --embeddings` | restrict to only support embedding use case; use only with dedicated embedding models (default: disabled)
(env: LLAMA_ARG_EMBEDDINGS) |
+| `--truncate-embed` | allow truncation for embedding tasks to handle large inputs (default: disabled)
(env: LLAMA_ARG_TRUNCATE_EMBED) |
| `--reranking, --rerank` | enable reranking endpoint on server (default: disabled)
(env: LLAMA_ARG_RERANKING) |
| `--api-key KEY` | API key to use for authentication (default: none)
(env: LLAMA_API_KEY) |
| `--api-key-file FNAME` | path to file containing API keys (default: none) |
@@ -636,6 +637,8 @@ Returns a JSON object with a field `prompt` containing a string of the input mes
The same as [the embedding example](../embedding) does.
+**Note**: By default, embedding tasks cannot be split across multiple batches for safety. For large inputs that exceed the batch size, use the `--truncate-embed` flag to enable automatic truncation. When truncation occurs, the `truncated` field in the response will indicate this.
+
*Options:*
`content`: Set the text to process.
@@ -1175,6 +1178,8 @@ curl http://localhost:8080/v1/chat/completions \
This endpoint requires that the model uses a pooling different than type `none`. The embeddings are normalized using the Eucledian norm.
+**Note**: By default, embedding tasks cannot be split across multiple batches for safety. For large inputs that exceed the batch size, use the `--truncate-embed` flag to enable automatic truncation. When truncation occurs, the `truncated` field in the response will indicate this.
+
*Options:*
See [OpenAI Embeddings API documentation](https://platform.openai.com/docs/api-reference/embeddings).
diff --git a/tools/server/server.cpp b/tools/server/server.cpp
index d3f6271931f62..0f6e307c6263f 100644
--- a/tools/server/server.cpp
+++ b/tools/server/server.cpp
@@ -1034,6 +1034,7 @@ struct server_task_result_embd : server_task_result {
std::vector> embedding;
int32_t n_tokens;
+ bool truncated = false;
// OAI-compat fields
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
@@ -1052,6 +1053,7 @@ struct server_task_result_embd : server_task_result {
return json {
{"index", index},
{"embedding", embedding},
+ {"truncated", truncated},
};
}
@@ -1060,6 +1062,7 @@ struct server_task_result_embd : server_task_result {
{"index", index},
{"embedding", embedding[0]},
{"tokens_evaluated", n_tokens},
+ {"truncated", truncated},
};
}
};
@@ -1360,10 +1363,15 @@ struct server_slot {
// if the context does not have a memory module then all embeddings have to be computed within a single ubatch
// also we cannot split if the pooling would require any past tokens
- bool can_split() const {
+ // @param truncate_embed: if true, allows splitting for embedding tasks to handle large inputs
+ // with automatic truncation. If false, uses original conservative logic.
+ // NOTE: When embedding inputs are truncated, the resulting embedding may not fully represent
+ // the original input. The 'truncated' field in the response indicates when this occurs.
+ bool can_split(bool truncate_embed = false) const {
return
!need_embd() ||
- (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST);
+ (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST) ||
+ (need_embd() && truncate_embed); // allow splitting for embedding tasks only if truncate_embed is enabled
}
bool can_batch_with(server_slot & other_slot) const {
@@ -2570,12 +2578,15 @@ struct server_context {
res->id = slot.id_task;
res->index = slot.index;
res->n_tokens = slot.n_prompt_tokens;
+ res->truncated = slot.truncated;
res->oaicompat = slot.params.oaicompat;
const int n_embd = llama_model_n_embd(model);
std::vector embd_res(n_embd, 0.0f);
+ // Note: If the input was truncated (slot.truncated == true), this embedding
+ // represents only the processed portion of the original input
for (int i = 0; i < batch.n_tokens; ++i) {
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
continue;
@@ -3129,7 +3140,7 @@ struct server_context {
continue;
}
- if (!slot.can_split()) {
+ if (!slot.can_split(params_base.truncate_embed)) {
if (slot.n_prompt_tokens > n_ubatch) {
slot.release();
send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
@@ -3146,7 +3157,8 @@ struct server_context {
// if context shift is disabled, we make sure prompt size is smaller than KV size
// TODO: there should be a separate parameter that control prompt truncation
// context shift should be applied only during the generation phase
- if (slot.n_prompt_tokens >= slot.n_ctx) {
+ // For embedding tasks, allow truncation even when context shift is disabled
+ if (slot.n_prompt_tokens >= slot.n_ctx && !slot.need_embd()) {
slot.release();
send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST);
continue;
@@ -3186,6 +3198,11 @@ struct server_context {
SLT_WRN(slot, "input truncated, n_ctx = %d, n_keep = %d, n_left = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, n_left, slot.n_prompt_tokens);
+ // Warn specifically for embedding tasks about potential quality impact
+ if (slot.need_embd()) {
+ SLT_WRN(slot, "%s", "WARNING: Embedding input was truncated. The resulting embedding may not fully represent the original input. Consider increasing context size or reducing input length for better embedding quality.");
+ }
+
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
}
@@ -3272,7 +3289,7 @@ struct server_context {
slot.n_prompt_tokens_processed = 0;
}
- if (!slot.can_split()) {
+ if (!slot.can_split(params_base.truncate_embed)) {
// cannot fit the prompt in the current batch - will try next iter
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
continue;