Skip to content

Commit 4d197ed

Browse files
committed
context : fix n_ctx_per_seq computation
1 parent 02d1011 commit 4d197ed

File tree

2 files changed

+7
-9
lines changed

2 files changed

+7
-9
lines changed

src/llama-context.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,9 @@ llama_context::llama_context(
112112
}
113113
}
114114

115-
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
116-
117115
LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max);
118116
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
119-
LLAMA_LOG_INFO("%s: n_ctx_per_seq = %u\n", __func__, n_ctx_per_seq);
117+
LLAMA_LOG_INFO("%s: n_ctx_per_seq = %u\n", __func__, n_ctx_per_seq());
120118
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
121119
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
122120
LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
@@ -125,14 +123,14 @@ llama_context::llama_context(
125123
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
126124
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
127125

128-
if (n_ctx_per_seq < hparams.n_ctx_train) {
126+
if (n_ctx_per_seq() < hparams.n_ctx_train) {
129127
LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
130-
__func__, n_ctx_per_seq, hparams.n_ctx_train);
128+
__func__, n_ctx_per_seq(), hparams.n_ctx_train);
131129
}
132130

133-
if (n_ctx_per_seq > hparams.n_ctx_train) {
131+
if (n_ctx_per_seq() > hparams.n_ctx_train) {
134132
LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
135-
__func__, n_ctx_per_seq, hparams.n_ctx_train);
133+
__func__, n_ctx_per_seq(), hparams.n_ctx_train);
136134
}
137135

138136
if (!hparams.vocab_only) {
@@ -449,7 +447,7 @@ uint32_t llama_context::n_ctx() const {
449447
}
450448

451449
uint32_t llama_context::n_ctx_per_seq() const {
452-
return cparams.n_ctx / cparams.n_seq_max;
450+
return cparams.kv_unified ? cparams.n_ctx : cparams.n_ctx / cparams.n_seq_max;
453451
}
454452

455453
uint32_t llama_context::n_batch() const {

src/llama-model.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6575,7 +6575,7 @@ float llama_model::get_rope_freq_scale(const llama_cparams & cparams, int il) co
65756575
}
65766576

65776577
ggml_tensor * llama_model::get_rope_factors(const llama_cparams & cparams, int il) const {
6578-
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
6578+
const uint32_t n_ctx_per_seq = cparams.kv_unified ? cparams.n_ctx : cparams.n_ctx / cparams.n_seq_max;
65796579

65806580
// choose long/short freq factors based on the context size
65816581
if (layers[il].rope_freqs != nullptr) {

0 commit comments

Comments
 (0)