Skip to content

Allow truncation when embedding #14493

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2748,6 +2748,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.embedding = true;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS"));
add_opt(common_arg(
{"--truncate-embed"},
string_format("allow truncation for embedding tasks to handle large inputs (default: %s)", params.truncate_embed ? "enabled" : "disabled"),
[](common_params & params) {
params.truncate_embed = true;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_TRUNCATE_EMBED"));
add_opt(common_arg(
{"--reranking", "--rerank"},
string_format("enable reranking endpoint on server (default: %s)", "disabled"),
Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ struct common_params {

// embedding
bool embedding = false; // get only sentence embedding
bool truncate_embed = false; // allow truncation for embedding tasks to handle large inputs
int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
std::string embd_sep = "\n"; // separator of embeddings
Expand Down
5 changes: 5 additions & 0 deletions tools/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ The project is under active development, and we are [looking for feedback and co
| `--path PATH` | path to serve static files from (default: )<br/>(env: LLAMA_ARG_STATIC_PATH) |
| `--no-webui` | Disable the Web UI (default: enabled)<br/>(env: LLAMA_ARG_NO_WEBUI) |
| `--embedding, --embeddings` | restrict to only support embedding use case; use only with dedicated embedding models (default: disabled)<br/>(env: LLAMA_ARG_EMBEDDINGS) |
| `--truncate-embed` | allow truncation for embedding tasks to handle large inputs (default: disabled)<br/>(env: LLAMA_ARG_TRUNCATE_EMBED) |
| `--reranking, --rerank` | enable reranking endpoint on server (default: disabled)<br/>(env: LLAMA_ARG_RERANKING) |
| `--api-key KEY` | API key to use for authentication (default: none)<br/>(env: LLAMA_API_KEY) |
| `--api-key-file FNAME` | path to file containing API keys (default: none) |
Expand Down Expand Up @@ -636,6 +637,8 @@ Returns a JSON object with a field `prompt` containing a string of the input mes

The same as [the embedding example](../embedding) does.

**Note**: By default, embedding tasks cannot be split across multiple batches for safety. For large inputs that exceed the batch size, use the `--truncate-embed` flag to enable automatic truncation. When truncation occurs, the `truncated` field in the response will indicate this.

*Options:*

`content`: Set the text to process.
Expand Down Expand Up @@ -1175,6 +1178,8 @@ curl http://localhost:8080/v1/chat/completions \

This endpoint requires that the model uses a pooling different than type `none`. The embeddings are normalized using the Eucledian norm.

**Note**: By default, embedding tasks cannot be split across multiple batches for safety. For large inputs that exceed the batch size, use the `--truncate-embed` flag to enable automatic truncation. When truncation occurs, the `truncated` field in the response will indicate this.

*Options:*

See [OpenAI Embeddings API documentation](https://platform.openai.com/docs/api-reference/embeddings).
Expand Down
27 changes: 22 additions & 5 deletions tools/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1034,6 +1034,7 @@ struct server_task_result_embd : server_task_result {
std::vector<std::vector<float>> embedding;

int32_t n_tokens;
bool truncated = false;

// OAI-compat fields
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
Expand All @@ -1052,6 +1053,7 @@ struct server_task_result_embd : server_task_result {
return json {
{"index", index},
{"embedding", embedding},
{"truncated", truncated},
};
}

Expand All @@ -1060,6 +1062,7 @@ struct server_task_result_embd : server_task_result {
{"index", index},
{"embedding", embedding[0]},
{"tokens_evaluated", n_tokens},
{"truncated", truncated},
};
}
};
Expand Down Expand Up @@ -1360,10 +1363,15 @@ struct server_slot {

// if the context does not have a memory module then all embeddings have to be computed within a single ubatch
// also we cannot split if the pooling would require any past tokens
bool can_split() const {
// @param truncate_embed: if true, allows splitting for embedding tasks to handle large inputs
// with automatic truncation. If false, uses original conservative logic.
// NOTE: When embedding inputs are truncated, the resulting embedding may not fully represent
// the original input. The 'truncated' field in the response indicates when this occurs.
bool can_split(bool truncate_embed = false) const {
return
!need_embd() ||
(llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST);
(llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST) ||
(need_embd() && truncate_embed); // allow splitting for embedding tasks only if truncate_embed is enabled
}

bool can_batch_with(server_slot & other_slot) const {
Expand Down Expand Up @@ -2570,12 +2578,15 @@ struct server_context {
res->id = slot.id_task;
res->index = slot.index;
res->n_tokens = slot.n_prompt_tokens;
res->truncated = slot.truncated;
res->oaicompat = slot.params.oaicompat;

const int n_embd = llama_model_n_embd(model);

std::vector<float> embd_res(n_embd, 0.0f);

// Note: If the input was truncated (slot.truncated == true), this embedding
// represents only the processed portion of the original input
for (int i = 0; i < batch.n_tokens; ++i) {
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
continue;
Expand Down Expand Up @@ -3129,7 +3140,7 @@ struct server_context {
continue;
}

if (!slot.can_split()) {
if (!slot.can_split(params_base.truncate_embed)) {
if (slot.n_prompt_tokens > n_ubatch) {
slot.release();
send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
Expand All @@ -3146,7 +3157,8 @@ struct server_context {
// if context shift is disabled, we make sure prompt size is smaller than KV size
// TODO: there should be a separate parameter that control prompt truncation
// context shift should be applied only during the generation phase
if (slot.n_prompt_tokens >= slot.n_ctx) {
// For embedding tasks, allow truncation even when context shift is disabled
if (slot.n_prompt_tokens >= slot.n_ctx && !slot.need_embd()) {
slot.release();
send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST);
continue;
Expand Down Expand Up @@ -3186,6 +3198,11 @@ struct server_context {

SLT_WRN(slot, "input truncated, n_ctx = %d, n_keep = %d, n_left = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, n_left, slot.n_prompt_tokens);

// Warn specifically for embedding tasks about potential quality impact
if (slot.need_embd()) {
SLT_WRN(slot, "%s", "WARNING: Embedding input was truncated. The resulting embedding may not fully represent the original input. Consider increasing context size or reducing input length for better embedding quality.");
}

GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
}

Expand Down Expand Up @@ -3272,7 +3289,7 @@ struct server_context {
slot.n_prompt_tokens_processed = 0;
}

if (!slot.can_split()) {
if (!slot.can_split(params_base.truncate_embed)) {
// cannot fit the prompt in the current batch - will try next iter
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
continue;
Expand Down
Loading