Skip to content
Draft
1 change: 1 addition & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,7 @@ extern "C" {
LLAMA_API bool llama_supports_rpc (void);

LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_ctx_seq (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
Expand Down
30 changes: 20 additions & 10 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,17 @@ llama_context::llama_context(
}
}

const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
cparams.n_ctx_seq = cparams.kv_unified ? cparams.n_ctx : cparams.n_ctx / cparams.n_seq_max;

if (cparams.n_ctx_seq > hparams.n_ctx_train) {
LLAMA_LOG_WARN("%s: capping n_ctx_seq (%u) to n_ctx_train (%u)\n", __func__, cparams.n_ctx_seq, hparams.n_ctx_train);

cparams.n_ctx_seq = hparams.n_ctx_train;
}

LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max);
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
LLAMA_LOG_INFO("%s: n_ctx_per_seq = %u\n", __func__, n_ctx_per_seq);
LLAMA_LOG_INFO("%s: n_ctx_seq = %u\n", __func__, cparams.n_ctx_seq);
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
Expand All @@ -125,14 +131,14 @@ llama_context::llama_context(
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);

if (n_ctx_per_seq < hparams.n_ctx_train) {
LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
__func__, n_ctx_per_seq, hparams.n_ctx_train);
if (cparams.n_ctx_seq < hparams.n_ctx_train) {
LLAMA_LOG_WARN("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
__func__, cparams.n_ctx_seq, hparams.n_ctx_train);
}

if (n_ctx_per_seq > hparams.n_ctx_train) {
LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
__func__, n_ctx_per_seq, hparams.n_ctx_train);
if (cparams.n_ctx_seq > hparams.n_ctx_train) {
LLAMA_LOG_WARN("%s: n_ctx_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
__func__, cparams.n_ctx_seq, hparams.n_ctx_train);
}

if (!hparams.vocab_only) {
Expand Down Expand Up @@ -453,8 +459,8 @@ uint32_t llama_context::n_ctx() const {
return cparams.n_ctx;
}

uint32_t llama_context::n_ctx_per_seq() const {
return cparams.n_ctx / cparams.n_seq_max;
uint32_t llama_context::n_ctx_seq() const {
return cparams.n_ctx_seq;
}

uint32_t llama_context::n_batch() const {
Expand Down Expand Up @@ -2383,6 +2389,10 @@ uint32_t llama_n_ctx(const llama_context * ctx) {
return ctx->n_ctx();
}

uint32_t llama_n_ctx_seq(const llama_context * ctx) {
return ctx->n_ctx_seq();
}

uint32_t llama_n_batch(const llama_context * ctx) {
return ctx->n_batch();
}
Expand Down
10 changes: 5 additions & 5 deletions src/llama-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@ struct llama_context {

ggml_backend_sched_t get_sched() const;

uint32_t n_ctx() const;
uint32_t n_ctx_per_seq() const;
uint32_t n_batch() const;
uint32_t n_ubatch() const;
uint32_t n_seq_max() const;
uint32_t n_ctx() const;
uint32_t n_ctx_seq() const;
uint32_t n_batch() const;
uint32_t n_ubatch() const;
uint32_t n_seq_max() const;

uint32_t n_threads() const;
uint32_t n_threads_batch() const;
Expand Down
1 change: 1 addition & 0 deletions src/llama-cparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

struct llama_cparams {
uint32_t n_ctx; // context size used during inference
uint32_t n_ctx_seq; // context for a single sequence
uint32_t n_batch;
uint32_t n_ubatch;
uint32_t n_seq_max;
Expand Down
11 changes: 5 additions & 6 deletions src/llama-kv-cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -957,10 +957,14 @@ bool llama_kv_cache::get_has_shift() const {
uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const {
uint32_t result = 0;

// pad the n_kv value so that the graph remains constant across batches and can be reused
// note: this also helps some backends with performance (f.ex https://github.com/ggml-org/llama.cpp/pull/16812#issuecomment-3455112220)
const uint32_t n_pad_cur = std::max(n_pad, 256u);

for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
const auto & cells = v_cells[sinfo.strm[s]];

result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result);
result = std::max(std::min(cells.size(), std::max(n_pad_cur, GGML_PAD(cells.used_max_p1(), n_pad_cur))), result);
}

return result;
Expand Down Expand Up @@ -2010,8 +2014,3 @@ void llama_kv_cache_context::set_input_kq_mask(ggml_tensor * dst, const llama_ub
void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
kv->set_input_pos_bucket(dst, ubatch);
}

uint32_t llama_kv_cache::get_padding(const llama_cparams & cparams) {
// the FA kernels require padding to avoid extra runtime boundary checks
return cparams.flash_attn ? 256u : 32u;
}
2 changes: 0 additions & 2 deletions src/llama-kv-cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ struct llama_context;

class llama_kv_cache : public llama_memory_i {
public:
static uint32_t get_padding(const llama_cparams & cparams);

struct stream_copy_info {
bool empty() const {
assert(ssrc.size() == sdst.size());
Expand Down
37 changes: 8 additions & 29 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6581,14 +6581,14 @@ float llama_model::get_rope_freq_scale(const llama_cparams & cparams, int il) co
}

ggml_tensor * llama_model::get_rope_factors(const llama_cparams & cparams, int il) const {
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
const uint32_t n_ctx_seq = cparams.n_ctx_seq;

// choose long/short freq factors based on the context size
if (layers[il].rope_freqs != nullptr) {
return layers[il].rope_freqs;
}

if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) {
if (n_ctx_seq > hparams.n_ctx_orig_yarn) {
return layers[il].rope_long;
}

Expand Down Expand Up @@ -19641,7 +19641,7 @@ struct llm_build_apertus : public llm_graph_context {
}
};

llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
llama_memory_i * llama_model::create_memory(const llama_memory_params & params, const llama_cparams & cparams) const {
llama_memory_i * res;

switch (arch) {
Expand Down Expand Up @@ -19692,17 +19692,13 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
};
}

const auto padding = llama_kv_cache::get_padding(cparams);

cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);

res = new llama_memory_hybrid(
/* model */ *this,
/* attn_type_k */ params.type_k,
/* attn_type_v */ params.type_v,
/* attn_v_trans */ !cparams.flash_attn,
/* attn_kv_size */ cparams.n_ctx,
/* attn_n_pad */ padding,
/* attn_n_pad */ 1,
/* attn_n_swa */ hparams.n_swa,
/* attn_swa_type */ hparams.swa_type,
/* recurrent_type_k */ GGML_TYPE_F32,
Expand All @@ -19714,23 +19710,6 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
/* filter_attn */ std::move(filter_attn),
/* filter_recr */ std::move(filter_recr));
} else {
const auto padding = llama_kv_cache::get_padding(cparams);

uint32_t n_ctx_per_stream = cparams.n_ctx;

if (!cparams.kv_unified) {
n_ctx_per_stream = (cparams.n_ctx + cparams.n_seq_max - 1)/cparams.n_seq_max;
n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding);

cparams.n_ctx = n_ctx_per_stream*cparams.n_seq_max;
} else {
n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding);

cparams.n_ctx = n_ctx_per_stream;
}

LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);

llama_memory_i::layer_reuse_cb reuse = nullptr;

if (arch == LLM_ARCH_GEMMA3N) {
Expand All @@ -19754,10 +19733,10 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
cparams.offload_kqv,
params.swa_full,
cparams.kv_unified,
n_ctx_per_stream,
cparams.n_ctx_seq,
cparams.n_seq_max,
cparams.n_ubatch,
padding,
1,
nullptr,
reuse);
} else {
Expand All @@ -19770,9 +19749,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
!cparams.flash_attn,
cparams.offload_kqv,
cparams.kv_unified,
n_ctx_per_stream,
cparams.n_ctx_seq,
cparams.n_seq_max,
padding,
1,
hparams.n_swa,
hparams.swa_type,
nullptr,
Expand Down
3 changes: 1 addition & 2 deletions src/llama-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -500,9 +500,8 @@ struct llama_model {

ggml_tensor * get_rope_factors(const llama_cparams & cparams, int il) const;

// note: can mutate `cparams`
// TODO: move this to new llm_arch_model_i interface
llama_memory_i * create_memory(const llama_memory_params & params, llama_cparams & cparams) const;
llama_memory_i * create_memory(const llama_memory_params & params, const llama_cparams & cparams) const;

// TODO: move this to new llm_arch_model_i interface
ggml_cgraph * build_graph(const llm_graph_params & params) const;
Expand Down
9 changes: 8 additions & 1 deletion tests/test-thread-safety.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,14 @@ int main(int argc, char ** argv) {
}

batch = llama_batch_get_one(&token, 1);
if (llama_decode(ctx.get(), batch)) {

int ret = llama_decode(ctx.get(), batch);
if (ret == 1 && i > 0) {
LOG_INF("Context full, stopping generation.\n");
break;
}

if (ret != 0) {
LOG_ERR("Model %d/%d, Context %d/%d: failed to decode\n", m + 1, num_models, c + 1, num_contexts);
failed.store(true);
return;
Expand Down
Loading
Loading