Skip to content

aLoRA Support #15327

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion convert_lora_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from math import prod
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Sequence, SupportsIndex, cast
from transformers import AutoConfig
from transformers import AutoConfig, AutoTokenizer

import torch

Expand All @@ -26,6 +26,8 @@
# reuse model definitions from convert_hf_to_gguf.py
from convert_hf_to_gguf import LazyTorchTensor, ModelBase

from gguf.constants import GGUFValueType

logger = logging.getLogger("lora-to-gguf")


Expand Down Expand Up @@ -369,7 +371,31 @@ def set_type(self):
self.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora")

def set_gguf_parameters(self):
logger.debug("GGUF KV: %s = %d", gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha)
self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha)
alora_invocation_tokens = lparams.get("alora_invocation_tokens")
invocation_string = lparams.get("invocation_string")
if invocation_string and not alora_invocation_tokens:
logger.debug("Tokenizing invocation_string -> alora_invocation_tokens")
base_model_path_or_id = hparams.get("_name_or_path")
try:
tokenizer = AutoTokenizer.from_pretrained(base_model_path_or_id)
except ValueError:
logger.error("Unable to load tokenizer from %s", base_model_path_or_id)
raise
# NOTE: There's an off-by-one with the older aLoRAs where
# the invocation string includes the "<|start_of_turn|>"
# token, but the adapters themselves were trained to
# activate _after_ that first token, so we drop it here.
alora_invocation_tokens = tokenizer(invocation_string)["input_ids"][1:]
if alora_invocation_tokens:
logger.debug("GGUF KV: %s = %s", gguf.Keys.Adapter.ALORA_INVOCATION_TOKENS, alora_invocation_tokens)
self.gguf_writer.add_key_value(
gguf.Keys.Adapter.ALORA_INVOCATION_TOKENS,
alora_invocation_tokens,
GGUFValueType.ARRAY,
GGUFValueType.UINT32,
)

def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
# Never add extra tensors (e.g. rope_freqs) for LoRA adapters
Expand Down
5 changes: 3 additions & 2 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,9 @@ class Tokenizer:
MIDDLE_ID = "tokenizer.ggml.middle_token_id"

class Adapter:
TYPE = "adapter.type"
LORA_ALPHA = "adapter.lora.alpha"
TYPE = "adapter.type"
LORA_ALPHA = "adapter.lora.alpha"
ALORA_INVOCATION_TOKENS = "adapter.alora.invocation_tokens"

class IMatrix:
CHUNK_COUNT = "imatrix.chunk_count"
Expand Down
4 changes: 4 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,10 @@ extern "C" {
// Note: loaded adapters will be free when the associated model is deleted
LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter);

// Get the invocation tokens if the current lora is an alora
LLAMA_API uint64_t llama_adapter_get_alora_n_invocation_tokens(const struct llama_adapter_lora * adapter);
LLAMA_API const llama_token * llama_adapter_get_alora_invocation_tokens (const struct llama_adapter_lora * adapter);

// The following functions operate on a llama_context, hence the naming: llama_verb_...

// Add a loaded LoRA adapter to given context
Expand Down
43 changes: 43 additions & 0 deletions src/llama-adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include <map>
#include <cassert>
#include <sstream>
#include <stdexcept>

// vec
Expand Down Expand Up @@ -190,6 +191,36 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
}

adapter.alpha = get_kv_f32(llm_kv(LLM_KV_ADAPTER_LORA_ALPHA));

// parse alora invocation sequence vector
const auto & key = llm_kv(LLM_KV_ADAPTER_ALORA_INVOCATION_TOKENS);
const int kid = gguf_find_key(ctx_gguf.get(), key.c_str());
if (kid >= 0) {
if (gguf_get_kv_type(ctx_gguf.get(), kid) != GGUF_TYPE_ARRAY) {
throw std::runtime_error("invalid gguf type for " + key);
}
const auto arr_type = gguf_get_arr_type(ctx_gguf.get(), kid);
if (arr_type != GGUF_TYPE_UINT32) {
throw std::runtime_error("invalid gguf element type for " + key);
}
const size_t seq_len = gguf_get_arr_n(ctx_gguf.get(), kid);
const void * data = gguf_get_arr_data(ctx_gguf.get(), kid);
adapter.alora_invocation_tokens.resize(seq_len);
std::copy(
(const llama_token *)data,
(const llama_token *)data + seq_len,
adapter.alora_invocation_tokens.begin());
std::stringstream ss;
ss << "[";
for (size_t i = 0; i < adapter.alora_invocation_tokens.size(); ++i) {
ss << adapter.alora_invocation_tokens[i];
if (i < adapter.alora_invocation_tokens.size() - 1) {
ss << ", ";
}
}
ss << "]";
LLAMA_LOG_INFO("%s: %s = %s\n", __func__, key.c_str(), ss.str().c_str());
}
}

int n_tensors = gguf_get_n_tensors(ctx_gguf.get());
Expand Down Expand Up @@ -386,3 +417,15 @@ llama_adapter_lora * llama_adapter_lora_init(llama_model * model, const char * p
void llama_adapter_lora_free(llama_adapter_lora * adapter) {
delete adapter;
}

uint64_t llama_adapter_get_alora_n_invocation_tokens(const struct llama_adapter_lora * adapter) {
if (!adapter) {
return 0;
}
return adapter->alora_invocation_tokens.size();
}

const llama_token * llama_adapter_get_alora_invocation_tokens(const llama_adapter_lora * adapter) {
GGML_ASSERT(adapter);
return adapter->alora_invocation_tokens.data();
}
3 changes: 3 additions & 0 deletions src/llama-adapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ struct llama_adapter_lora {

float alpha;

// activated lora (aLoRA)
std::vector<llama_token> alora_invocation_tokens;

llama_adapter_lora() = default;
~llama_adapter_lora() = default;

Expand Down
5 changes: 3 additions & 2 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,9 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_TOKENIZER_FIM_REP_ID, "tokenizer.ggml.fim_rep_token_id" },
{ LLM_KV_TOKENIZER_FIM_SEP_ID, "tokenizer.ggml.fim_sep_token_id" },

{ LLM_KV_ADAPTER_TYPE, "adapter.type" },
{ LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" },
{ LLM_KV_ADAPTER_TYPE, "adapter.type" },
{ LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" },
{ LLM_KV_ADAPTER_ALORA_INVOCATION_TOKENS, "adapter.alora.invocation_tokens" },

// deprecated
{ LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" },
Expand Down
1 change: 1 addition & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ enum llm_kv {

LLM_KV_ADAPTER_TYPE,
LLM_KV_ADAPTER_LORA_ALPHA,
LLM_KV_ADAPTER_ALORA_INVOCATION_TOKENS,

LLM_KV_POSNET_EMBEDDING_LENGTH,
LLM_KV_POSNET_BLOCK_COUNT,
Expand Down
101 changes: 98 additions & 3 deletions tools/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ struct slot_params {
int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
int32_t n_predict = -1; // new tokens to predict
int32_t n_indent = 0; // mininum line indentation for the generated text in number of whitespace characters
int32_t n_indent = 0; // minimum line indentation for the generated text in number of whitespace characters

int64_t t_max_prompt_ms = -1; // TODO: implement
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
Expand Down Expand Up @@ -1322,6 +1322,7 @@ struct server_slot {
common_speculative * spec = nullptr;

std::vector<common_adapter_lora_info> lora;
int32_t alora_invocation_start = -1;

// the index relative to completion multi-task request
size_t index = 0;
Expand Down Expand Up @@ -1416,6 +1417,9 @@ struct server_slot {
// clear speculative decoding stats
n_draft_total = 0;
n_draft_accepted = 0;

// clear alora start
alora_invocation_start = -1;
}

bool need_embd() const {
Expand Down Expand Up @@ -2282,11 +2286,65 @@ struct server_context {
slot.prompt_tokens = std::move(task.prompt_tokens);

if (!are_lora_equal(slot.params.lora, slot.lora)) {
// if lora is changed, we cannot reuse cached tokens
slot.cache_tokens.clear();
// if lora has changed, check to see if the cache should be cleared
if (lora_should_clear_cache(slot.lora, slot.params.lora)) {
SLT_INF(slot, "clearing cache for lora change. %zu loras -> %zu loras\n", slot.lora.size(), slot.params.lora.size());
slot.cache_tokens.clear();
} else {
SLT_INF(slot, "keeping cache for alora. %zu target loras\n", slot.params.lora.size());
}
slot.lora = slot.params.lora;
}

// if using alora, make sure it's only a single one requested and active
size_t alora_invocation_start = slot.prompt_tokens.size();
if (lora_all_alora(slot.lora)) {

const auto & enabled_ids = lora_get_enabled_ids(slot.lora);
// TODO: This will error out if a user requests two aloras, but only
// provides the activation string for one. We could, instead search
// for all requested alora activation strings and then either keep
// only the last one, or reject if multiple are found.
if (enabled_ids.size() != 1) {
send_error(task, "Cannot run multiple aLoRAs in a single request", ERROR_TYPE_INVALID_REQUEST);
return false;
}
const auto & lora = slot.lora[enabled_ids[0]].ptr;

// get the pointer and count for the invocation tokens
const uint64_t n_invocation_tokens = llama_adapter_get_alora_n_invocation_tokens(lora);
const llama_token * invocation_tokens = llama_adapter_get_alora_invocation_tokens (lora);

// scan backwards through the prompt tokens to find the last
// occurrence of the invocation sequence
int match_idx = static_cast<int>(n_invocation_tokens) - 1;
for (int i = slot.prompt_tokens.size() - 1; i >= 0; --i) {
// the token in this position matches the next token to find in
// the invocation sequence
if (slot.prompt_tokens[i] == invocation_tokens[match_idx]) {
// if it's a full match, we've found the start
if (match_idx == 0) {
alora_invocation_start = i;
break;
}
// otherwise, check the next token in the sequence
--match_idx;
} else {
// no match in this position, so start looking over again
match_idx = static_cast<int>(n_invocation_tokens) - 1;
}
}

// if the activation string is not found, disable the alora
if (alora_invocation_start == slot.prompt_tokens.size()) {
SLT_DBG(slot, "alora %zu requested, but not found. deactivating\n", enabled_ids[0]);
slot.lora[enabled_ids[0]].scale = 0.0f;
} else {
SLT_DBG(slot, "alora %zu activated starting at %zu\n", enabled_ids[0], alora_invocation_start);
slot.alora_invocation_start = alora_invocation_start;
}
}

if (!slot.prompt_tokens.validate(ctx)) {
send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST);
return false;
Expand Down Expand Up @@ -3155,6 +3213,8 @@ struct server_context {
int32_t n_ubatch = llama_n_ubatch(ctx);

// next, batch any pending prompts without exceeding n_batch
float alora_scale = -1.0f;
size_t alora_disabled_id = 0;
if (params_base.cont_batching || batch.n_tokens == 0) {
for (auto & slot : slots) {
// check if we can batch this slot with the previous one
Expand Down Expand Up @@ -3275,6 +3335,12 @@ struct server_context {
// reuse any previously computed tokens that are common with the new prompt
slot.n_past = slot.cache_tokens.get_common_prefix(prompt_tokens);

// if there is an alora invoked, don't cache after the invocation start
if (slot.alora_invocation_start >= 0) {
SLT_DBG(slot, "only caching to alora invocation start (n_past=%d, alora_invocation_start=%d)\n", slot.n_past, slot.alora_invocation_start);
slot.n_past = std::min(slot.n_past, slot.alora_invocation_start - 1);
}

// reuse chunks from the cached prompt by shifting their KV cache in the new position
if (params_base.n_cache_reuse > 0) {
size_t head_c = slot.n_past; // cache
Expand Down Expand Up @@ -3447,6 +3513,20 @@ struct server_context {
slot.n_prompt_tokens_processed += n_pos;
}

// If using an alora, there may be uncached tokens that come
// before the invocation sequence. When this happens, the
// tokens before the invocation sequence need to be
// processed without the adpter in a separate batch, then
// the adapter needs to be enabled for the remaining tokens.
if (lora_all_alora(slot.lora) && slot.alora_invocation_start - 1 > slot.n_past) {
SLT_DBG(slot, "processing pre-alora tokens without the adapter (n_past = %d, alora_invocation_start = %d)\n", slot.n_past, slot.alora_invocation_start);
const auto & enabled_loras = lora_get_enabled_ids(slot.lora);
GGML_ASSERT(enabled_loras.size() == 1);
alora_scale = slot.lora[enabled_loras[0]].scale;
slot.lora[enabled_loras[0]].scale = 0.0f;
alora_disabled_id = enabled_loras[0];
}

// add prompt tokens for processing in the current batch
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
// get next token to process
Expand All @@ -3455,6 +3535,14 @@ struct server_context {
break; // end of text chunk
}

// if this is an alora request with pre-invocation
// tokens that are not cached, we need to stop filling
// this batch at those pre-invocation tokens.
if (alora_scale > 0 && slot.n_past == slot.alora_invocation_start - 1) {
SLT_DBG(slot, "stop prompt batch filling at (n_past = %d, alora_invocation_start = %d)\n", slot.n_past, slot.alora_invocation_start);
break;
}

// embedding requires all tokens in the batch to be output
const bool need_embd = server_task_type_need_embd(slot.task_type);

Expand Down Expand Up @@ -3513,6 +3601,13 @@ struct server_context {
// apply lora, only need to do it once per batch
common_set_adapter_lora(ctx, slot_batched->lora);

// if the lora is temporarily disabled for an alora, re-enable it
// for next time
if (alora_scale > 0.0f) {
SRV_DBG("re-enabling alora with scale %f\n", alora_scale);
slot_batched->lora[alora_disabled_id].scale = alora_scale;
}

llama_set_embeddings(ctx, slot_batched->need_embd());
}

Expand Down
41 changes: 41 additions & 0 deletions tools/server/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -992,6 +992,47 @@ static bool are_lora_equal(
return true;
}

// get the ids of all enabled loras
static std::vector<size_t> lora_get_enabled_ids(const std::vector<common_adapter_lora_info> & loras) {
std::vector<size_t> enabled_ids;
for (size_t i = 0; i < loras.size(); ++i) {
if (loras[i].scale > 0) {
enabled_ids.push_back(i);
}
}
return enabled_ids;
}

// check whether the given lora set has only aloras activated (empty => false)
static bool lora_all_alora(const std::vector<common_adapter_lora_info> & loras) {
bool found_alora = false;
for (const auto & lora : loras) {
if (lora.scale != 0) {
if (llama_adapter_get_alora_n_invocation_tokens(lora.ptr) == 0) {
return false;
}
found_alora = true;
}
}
return found_alora;
}

// if the two sets of loras are different, they require a cache clear unless the
// change is only from aloras to aloras.
static bool lora_should_clear_cache(
const std::vector<common_adapter_lora_info> & current,
const std::vector<common_adapter_lora_info> & next) {

// This should always be called after determining that the two sets are
// _not_ equal. This assert is therefore some slightly wasted work and
// should be safe to remove as long as this method is called correctly.
GGML_ASSERT(!are_lora_equal(current, next));

return (
!(lora_get_enabled_ids(current).empty() || lora_all_alora(current)) ||
!lora_all_alora(next));
}

// parse lora config from JSON request, returned a copy of lora_base with updated scale
static std::vector<common_adapter_lora_info> parse_lora_request(
const std::vector<common_adapter_lora_info> & lora_base,
Expand Down
Loading