diff --git a/common/arg.cpp b/common/arg.cpp index 40af7e574830f..5833466510785 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1205,6 +1205,15 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e exit(1); // for other exceptions, we exit with status code 1 } + float &pafter = params.sampling.start_eog_after; + float &premain = params.sampling.start_eog_at_remain; + float const premain0 = premain; + float remain = params.n_predict - pafter; + if (premain < remain) + premain = remain; + if (params.sampling.eog_bias_per_tok) + LOG_INF("%s: n_predict=%d (first of start_eog_at_remain=%0.3g start_eog_after=%0.3g) => (remain=%0.3g) eog-bias-per-tok=%0.3g\n", __func__, (int) params.n_predict, + (double) premain0, (double) pafter, (double)premain, (double) params.sampling.eog_bias_per_tok); return true; } @@ -1937,6 +1946,27 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } } ).set_sparam()); + add_opt(common_arg( + {"-eog", "--eog-bias-per-tok"}, "N", + string_format("when fewer than -start-eog-at-remain tokens are left to generate after -n, add this bias eog for each subsequent token (default: %.1f)", (double)params.sampling.eog_bias_per_tok), + [](common_params & params, const std::string & value) { + params.sampling.eog_bias_per_tok = std::stof(value); + } + ).set_sparam()); + add_opt(common_arg( + {"-remain", "--start-eog-at-remain"}, "N", + string_format("start applying -eog bias when this many tokens remain of the -n max (default: %.1f)", (double)params.sampling.start_eog_at_remain), + [](common_params & params, const std::string & value) { + params.sampling.start_eog_at_remain = std::stof(value); + } + ).set_sparam()); + add_opt(common_arg( + {"-after", "--start-eog-after"}, "N", + string_format("start applying -eog bias after this many tokens generated (default: %.1f); whichever happens first between -remain and -after applies", (double)params.sampling.start_eog_after), + [](common_params & params, const std::string & value) { + params.sampling.start_eog_after = std::stof(value); + } + ).set_sparam()); add_opt(common_arg( {"--grammar"}, "GRAMMAR", string_format("BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", params.sampling.grammar.c_str()), diff --git a/common/common.h b/common/common.h index 8922090e7b10d..e506801f10e09 100644 --- a/common/common.h +++ b/common/common.h @@ -179,6 +179,13 @@ struct common_params_sampling { std::vector logit_bias; // logit biases to apply + float eog_bias_per_tok = 0; // escalating bias added to eog per token after: + /// this many remaining tokens (before applying eog_bias_per_tok) ... + float start_eog_at_remain = 0; + // or (whichever is first) after start_eog_after many generated: + /// (i.e. EOG logit bias = max(0,start_eog_after = max(start_eog_after, n_remain - start_eog_at_remain)) * eog_bias_per_tok) + float start_eog_after = 1e9; + // print the parameters into a string std::string print() const; }; diff --git a/common/sampling.cpp b/common/sampling.cpp index 9c04d35fd00a2..6e56e1bb1582d 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -226,7 +226,10 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co llama_sampler_init_logit_bias( llama_vocab_n_tokens(vocab), params.logit_bias.size(), - params.logit_bias.data())); + params.logit_bias.data(), + params.eog_bias_per_tok, + params.start_eog_at_remain, + vocab)); if (params.mirostat == 0) { for (const auto & cnstr : params.samplers) { @@ -335,7 +338,7 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam } } -llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) { +llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first, float n_remain) { gsmpl->set_logits(ctx, idx); auto & grmr = gsmpl->grmr; @@ -343,10 +346,10 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co auto & cur_p = gsmpl->cur_p; // initialized by set_logits if (grammar_first) { - llama_sampler_apply(grmr, &cur_p); + llama_sampler_apply(grmr, &cur_p, n_remain); } - llama_sampler_apply(chain, &cur_p); + llama_sampler_apply(chain, &cur_p, n_remain); GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration"); @@ -361,7 +364,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co llama_token_data single_token_data = { id, 1.0f, 0.0f }; llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false }; - llama_sampler_apply(grmr, &single_token_data_array); + llama_sampler_apply(grmr, &single_token_data_array, n_remain); const bool is_valid = single_token_data_array.data[0].logit != -INFINITY; if (is_valid) { @@ -373,15 +376,15 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co // if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain gsmpl->set_logits(ctx, idx); - llama_sampler_apply(grmr, &cur_p); - llama_sampler_apply(chain, &cur_p); + llama_sampler_apply(grmr, &cur_p, n_remain); + llama_sampler_apply(chain, &cur_p, n_remain); GGML_ASSERT(cur_p.selected != -1 && "no selected token during re-sampling - check your sampling configuration"); return cur_p.data[cur_p.selected].id; } -std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft, bool grammar_first) { +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft, bool grammar_first, float n_remain) { GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1"); std::vector result; @@ -389,7 +392,7 @@ std::vector common_sampler_sample_and_accept_n(struct common_sample size_t i = 0; for (; i < draft.size(); i++) { - const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first); + const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first, n_remain); common_sampler_accept(gsmpl, id, true); @@ -401,7 +404,7 @@ std::vector common_sampler_sample_and_accept_n(struct common_sample } if (i == draft.size()) { - const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first); + const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first, n_remain); common_sampler_accept(gsmpl, id, true); @@ -411,13 +414,13 @@ std::vector common_sampler_sample_and_accept_n(struct common_sample return result; } -std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) { +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first, float n_remain) { std::vector idxs(draft.size() + 1); for (size_t i = 0; i < idxs.size(); ++i) { idxs[i] = i; } - return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first); + return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first, n_remain); } uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) { diff --git a/common/sampling.h b/common/sampling.h index 2064421db4e80..e9bbdbcfd0884 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -58,7 +58,7 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam // if grammar_first is true, the grammar is applied before the samplers (slower) // useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar // -llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false); +llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false, float n_remain = 0); // generalized version of common_sampler_sample // @@ -76,10 +76,10 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co // // returns at least 1 token, up to idxs.size() // -std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft, bool grammar_first = false); +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft, bool grammar_first = false, float n_remain = 0); // assume idxs == [ 0, 1, 2, ..., draft.size() ] -std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false); +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false, float n_remain = 0); uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl); diff --git a/common/speculative.cpp b/common/speculative.cpp index 843bd1ddbdbd7..2bc053b79533f 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -238,12 +238,12 @@ llama_tokens common_speculative_gen_draft( llama_decode(ctx, batch); common_sampler_reset(smpl); - + int n_remain = params.n_draft; // sample n_draft tokens from the draft model for (int i = 0; i < params.n_draft; ++i) { common_batch_clear(batch); - common_sampler_sample(smpl, ctx, 0, true); + common_sampler_sample(smpl, ctx, 0, true, --n_remain); const auto * cur_p = common_sampler_get_candidates(smpl); diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 1a5de5928a526..f90f79bb2b453 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -162,7 +162,9 @@ int main(int argc, char ** argv) { const auto t_main_start = ggml_time_us(); + int n_remain = n_predict; while (n_cur <= n_predict) { + --n_remain; // prepare the next batch common_batch_clear(batch); @@ -173,7 +175,7 @@ int main(int argc, char ** argv) { continue; } - const llama_token new_token_id = llama_sampler_sample(smpl, ctx, i_batch[i]); + const llama_token new_token_id = llama_sampler_sample(smpl, ctx, i_batch[i], n_remain); // is it an end of generation? -> mark the stream as finished if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_predict) { diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index bdab052c3390f..c6aaa303ae347 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -108,7 +108,7 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std std::vector inputs = common_tokenize(vocab, prompt, false, true); int32_t i_current_token = 0; - + int n_remain = 32; while (true) { common_batch_clear(bat); { @@ -122,7 +122,7 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std llama_decode(ctx, bat); - llama_token token = llama_sampler_sample(smpl, ctx, bat.n_tokens - 1); + llama_token token = llama_sampler_sample(smpl, ctx, bat.n_tokens - 1, --n_remain); if (token == eos_token) { break; diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index 1e26d8221b86b..e2d802dce6739 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -253,6 +253,7 @@ int main(int argc, char ** argv) { int seq_id_best = 0; + int n_remain = N; for (int v = 0; v < N; ++v) { int i_batch = 0; @@ -274,8 +275,9 @@ int main(int argc, char ** argv) { } } + --n_remain; // sample the next token - id = common_sampler_sample(smpl, ctx, i_batch); + id = common_sampler_sample(smpl, ctx, i_batch, n_remain); common_sampler_accept(smpl, id, true); @@ -349,10 +351,11 @@ int main(int argc, char ** argv) { tokens_j[j] = tokens_j[j + 1]; } + unsigned constexpr NA = (unsigned)-1; if (v == 0) { // sample from the last level for (int i = 0; i < W; i++) { - tokens_j[N - 2][i] = common_sampler_sample(smpl, ctx, ngrams_cur.size()*(N-1) + W*(N - 2) + i); + tokens_j[N - 2][i] = common_sampler_sample(smpl, ctx, ngrams_cur.size()*(N-1) + W*(N - 2) + i, NA); } } else { for (int i = 0; i < W; i++) { diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index 2bfa26b55f0a6..af318f8060ff5 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -117,7 +117,8 @@ int main(int argc, char ** argv){ int i_dft = 0; while (true) { // sample from the target model - llama_token id = common_sampler_sample(smpl, ctx, i_dft); + unsigned const n_remain = params.n_predict - n_predict; + llama_token id = common_sampler_sample(smpl, ctx, i_dft, n_remain); common_sampler_accept(smpl, id, true); diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index 8a4faa383bf32..f2b4b251afb19 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -217,10 +217,12 @@ int main(int argc, char ** argv) { const auto t_main_start = ggml_time_us(); + int n_remain = n_len - n_cur; while (n_cur <= n_len) { + --n_remain; // sample the next token { - const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1); + const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1, n_remain); // is it an end of generation? if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len) { diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index db79588f1a5a4..5a3d95399473a 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -76,8 +76,10 @@ int main(int argc, char ** argv) { // first run printf("\nfirst run: %s", params.prompt.c_str()); + int n_remain = params.n_predict; for (auto i = 0; i < params.n_predict; i++) { - auto next_token = llama_sampler_sample(smpl, ctx, -1); + --n_remain; + auto next_token = llama_sampler_sample(smpl, ctx, -1, n_remain); auto next_token_str = common_token_to_piece(ctx, next_token); printf("%s", next_token_str.c_str()); @@ -128,8 +130,10 @@ int main(int argc, char ** argv) { n_past = n_past_saved; // second run + n_remain = params.n_predict; for (auto i = 0; i < params.n_predict; i++) { - auto next_token = llama_sampler_sample(smpl2, ctx2, -1); + --n_remain; + auto next_token = llama_sampler_sample(smpl2, ctx2, -1, n_remain); auto next_token_str = common_token_to_piece(ctx2, next_token); printf("%s", next_token_str.c_str()); @@ -209,8 +213,10 @@ int main(int argc, char ** argv) { } // third run with seq 1 instead of 0 + n_remain = params.n_predict; for (auto i = 0; i < params.n_predict; i++) { - auto next_token = llama_sampler_sample(smpl3, ctx3, -1); + --n_remain; + auto next_token = llama_sampler_sample(smpl3, ctx3, -1, n_remain); auto next_token_str = common_token_to_piece(ctx3, next_token); printf("%s", next_token_str.c_str()); diff --git a/examples/simple-chat/simple-chat.cpp b/examples/simple-chat/simple-chat.cpp index 57195df331628..7741730aab1e2 100644 --- a/examples/simple-chat/simple-chat.cpp +++ b/examples/simple-chat/simple-chat.cpp @@ -110,7 +110,9 @@ int main(int argc, char ** argv) { // prepare a batch for the prompt llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size()); llama_token new_token_id; + int n_remain = batch.n_tokens; while (true) { + --n_remain; // check if we have enough space in the context to evaluate this batch int n_ctx = llama_n_ctx(ctx); int n_ctx_used = llama_memory_seq_pos_max(llama_get_memory(ctx), 0) + 1; @@ -126,7 +128,7 @@ int main(int argc, char ** argv) { } // sample the next token - new_token_id = llama_sampler_sample(smpl, ctx, -1); + new_token_id = llama_sampler_sample(smpl, ctx, -1, n_remain); // is it an end of generation? if (llama_vocab_is_eog(vocab, new_token_id)) { diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index 633b87e58406e..3f5e723141512 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -151,6 +151,8 @@ int main(int argc, char ** argv) { int n_decode = 0; llama_token new_token_id; + int n_remain = n_predict; + for (int n_pos = 0; n_pos + batch.n_tokens < n_prompt + n_predict; ) { // evaluate the current batch with the transformer model if (llama_decode(ctx, batch)) { @@ -162,7 +164,7 @@ int main(int argc, char ** argv) { // sample the next token { - new_token_id = llama_sampler_sample(smpl, ctx, -1); + new_token_id = llama_sampler_sample(smpl, ctx, -1, --n_remain); // is it an end of generation? if (llama_vocab_is_eog(vocab, new_token_id)) { diff --git a/include/llama.h b/include/llama.h index 3eda9bc68608c..c77ec29369b99 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1202,7 +1202,7 @@ extern "C" { struct llama_sampler_i { const char * (*name) (const struct llama_sampler * smpl); // can be NULL void (*accept)( struct llama_sampler * smpl, llama_token token); // can be NULL - void (*apply) ( struct llama_sampler * smpl, llama_token_data_array * cur_p); // required + void (*apply) ( struct llama_sampler * smpl, llama_token_data_array * cur_p, float n_remain); // required void (*reset) ( struct llama_sampler * smpl); // can be NULL struct llama_sampler * (*clone) (const struct llama_sampler * smpl); // can be NULL if ctx is NULL void (*free) ( struct llama_sampler * smpl); // can be NULL if ctx is NULL @@ -1220,7 +1220,7 @@ extern "C" { LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_sampler_i * iface, llama_sampler_context_t ctx); LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl); LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token); - LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p); + LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p, float n_remain); LLAMA_API void llama_sampler_reset ( struct llama_sampler * smpl); LLAMA_API struct llama_sampler * llama_sampler_clone (const struct llama_sampler * smpl); // important: do not free if the sampler has been added to a llama_sampler_chain (via llama_sampler_chain_add) @@ -1351,7 +1351,10 @@ extern "C" { LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias( int32_t n_vocab, int32_t n_logit_bias, - const llama_logit_bias * logit_bias); + const llama_logit_bias * logit_bias, + float eog_bias_per_tok, + float start_eog_at_remain, + const struct llama_vocab *vocab); // this sampler is meant to be used for fill-in-the-middle infilling // it's supposed to be used after top_k + top_p sampling @@ -1389,7 +1392,7 @@ extern "C" { // llama_sampler_accept(smpl, token); // return token; // Returns the sampled token - LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx); + LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx, float n_remain); // TODO: extend in the future //LLAMA_API void llama_decode_with_sampler(struct llama_context * ctx, struct llama_sampler * smpl, struct llama_batch batch, ...); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index bfbf5fa230112..ad0ea86a63f8b 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -338,9 +338,9 @@ void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) { } } -void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) { +void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p, float n_remain) { GGML_ASSERT(smpl->iface->apply); - smpl->iface->apply(smpl, cur_p); + smpl->iface->apply(smpl, cur_p, n_remain); } void llama_sampler_reset(struct llama_sampler * smpl) { @@ -376,7 +376,7 @@ void llama_sampler_free(struct llama_sampler * smpl) { delete smpl; } -llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) { +llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx, float n_remain) { const auto * logits = llama_get_logits_ith(ctx, idx); const llama_model * model = llama_get_model(ctx); @@ -398,7 +398,7 @@ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_conte /* .sorted = */ false, }; - llama_sampler_apply(smpl, &cur_p); + llama_sampler_apply(smpl, &cur_p, n_remain); GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size); @@ -427,13 +427,13 @@ static void llama_sampler_chain_accept(struct llama_sampler * smpl, llama_token chain->n_sample++; } -static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { +static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p, float n_remain) { auto * chain = (llama_sampler_chain *) smpl->ctx; time_meas tm(chain->t_sample_us, chain->params.no_perf); for (auto * smpl : chain->samplers) { - llama_sampler_apply(smpl, cur_p); + llama_sampler_apply(smpl, cur_p, n_remain); } } @@ -535,7 +535,7 @@ static const char * llama_sampler_greedy_name(const struct llama_sampler * /*smp return "greedy"; } -static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) { +static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p, float) { cur_p->selected = 0; for (size_t i = 1; i < cur_p->size; ++i) { if (cur_p->data[i].logit > cur_p->data[cur_p->selected].logit) { @@ -573,7 +573,7 @@ static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl* return "dist"; } -static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { +static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p, float) { auto * ctx = (llama_sampler_dist *) smpl->ctx; llama_sampler_softmax_impl(cur_p); @@ -632,7 +632,7 @@ static const char * llama_sampler_softmax_name(const struct llama_sampler * /*sm return "softmax"; } -static void llama_sampler_softmax_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) { +static void llama_sampler_softmax_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p, float) { llama_sampler_softmax_impl(cur_p); } @@ -662,7 +662,7 @@ static const char * llama_sampler_top_k_name(const struct llama_sampler * /*smpl return "top-k"; } -static void llama_sampler_top_k_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { +static void llama_sampler_top_k_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p, float) { const auto * ctx = (llama_sampler_top_k *) smpl->ctx; llama_sampler_top_k_impl(cur_p, ctx->k); } @@ -705,7 +705,7 @@ static const char * llama_sampler_top_p_name(const struct llama_sampler * /*smpl return "top-p"; } -static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { +static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p, float) { const auto * ctx = (llama_sampler_top_p *) smpl->ctx; if (ctx->p >= 1.0f) { @@ -772,7 +772,7 @@ static const char * llama_sampler_min_p_name(const struct llama_sampler * /*smpl return "min-p"; } -static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { +static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p, float) { const auto * ctx = (llama_sampler_min_p *) smpl->ctx; if (ctx->p <= 0.0f || !cur_p->size) { @@ -868,7 +868,7 @@ static const char * llama_sampler_typical_name(const struct llama_sampler * /*sm return "typical"; } -static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { +static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p, float) { const auto * ctx = (llama_sampler_typical *) smpl->ctx; // Reference implementation: @@ -966,7 +966,7 @@ static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl* return "temp"; } -static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { +static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p, float) { const auto * ctx = (llama_sampler_temp *) smpl->ctx; llama_sampler_temp_impl(cur_p, ctx->temp); @@ -1011,7 +1011,7 @@ static const char * llama_sampler_temp_ext_name(const struct llama_sampler * /*s return "temp-ext"; } -static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { +static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p, float) { const auto * ctx = (llama_sampler_temp_ext *) smpl->ctx; if (ctx->delta > 0) { const float min_temp = std::max(0.0f, ctx->temp - ctx->delta); @@ -1128,7 +1128,7 @@ static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/ return "xtc"; } -static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { +static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p, float) { auto * ctx = (llama_sampler_xtc *) smpl->ctx; if (ctx->probability <= 0.0f @@ -1228,7 +1228,7 @@ static const char * llama_sampler_mirostat_name(const struct llama_sampler * /*s return "mirostat"; } -static void llama_sampler_mirostat_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { +static void llama_sampler_mirostat_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p, float) { auto * ctx = (llama_sampler_mirostat *) smpl->ctx; llama_sampler_softmax_impl(cur_p); @@ -1333,7 +1333,7 @@ static const char * llama_sampler_mirostat_v2_name(const struct llama_sampler * return "mirostat-v2"; } -static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { +static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p, float) { auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx; llama_sampler_softmax_impl(cur_p); @@ -1434,7 +1434,7 @@ static void llama_sampler_grammar_accept_impl(struct llama_sampler * smpl, llama } } -static void llama_sampler_grammar_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { +static void llama_sampler_grammar_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p, float) { auto * ctx = (llama_sampler_grammar *) smpl->ctx; if (ctx->grammar) { llama_grammar_apply_impl(*ctx->grammar, cur_p); @@ -1647,7 +1647,7 @@ static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_to #endif } -static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { +static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p, float) { auto * ctx = (llama_sampler_penalties *) smpl->ctx; if ((ctx->penalty_last_n == 0) || @@ -1747,7 +1747,7 @@ static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * return "top-n-sigma"; } -static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { +static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p, float) { const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx; if (ctx->n <= 0.0f || cur_p->size <= 1) { @@ -1889,7 +1889,7 @@ static void llama_sampler_dry_accept(struct llama_sampler * smpl, llama_token to } // Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am) -static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { +static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p, float) { auto * ctx = (llama_sampler_dry *) smpl->ctx; if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) { @@ -2213,31 +2213,57 @@ struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, floa // logit-bias +struct length_eog_bias { + /// after remain is <= start_eog_at_remain, each token adds this much more bias to eog: + float eog_bias_per_tok = 0; + /// n_remain at which bias starts + float start_eog_at_remain = 0; + const llama_vocab * vocab = nullptr; + float effective_bias(float n_remain) const { + float d = start_eog_at_remain - n_remain; + return d > 0 ? d * eog_bias_per_tok : 0; + } +}; + struct llama_sampler_logit_bias { const int32_t n_vocab; const std::vector logit_bias; std::vector to_search; + + struct length_eog_bias eog; + int32_t eog_token = -1; }; static const char * llama_sampler_logit_bias_name(const struct llama_sampler * /*smpl*/) { return "logit-bias"; } -static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { +static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p, float n_remain) { auto * ctx = (llama_sampler_logit_bias *) smpl->ctx; + float eog_bias = ctx->eog.effective_bias(n_remain); - if (ctx->logit_bias.empty()) { + if (!eog_bias && ctx->logit_bias.empty()) { return; } ctx->to_search.clear(); + llama_token_data *tok = cur_p->data; + size_t const ntok = cur_p->size; + if (eog_bias) { + const int32_t eog = ctx->eog_token; + if ((size_t)eog < ntok && tok[eog].id == eog) { + tok[eog].logit += eog_bias; + } else { + ctx->to_search.push_back(llama_logit_bias{ eog, eog_bias }); + } + } // update the candidates that have not been shuffled in the vocabulary (i.e. idx == id) for (const auto & lb : ctx->logit_bias) { - if (lb.token >= 0 && cur_p->size > (size_t) lb.token && cur_p->data[lb.token].id == lb.token) { - cur_p->data[lb.token].logit += lb.bias; + if (lb.token >= 0 && ntok > (size_t) lb.token && tok[lb.token].id == lb.token) { + tok[lb.token].logit += lb.bias; } else { ctx->to_search.push_back(lb); } @@ -2247,20 +2273,25 @@ static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_to return; } + const llama_vocab * vocab = eog_bias ? ctx->eog.vocab : nullptr; // search for the remaining candidates that were not found in the previous step - for (size_t i = 0; i < cur_p->size; ++i) { + for (size_t i = 0; i < ntok; ++i) { for (const auto & lb : ctx->to_search) { - if (cur_p->data[i].id == lb.token) { - cur_p->data[i].logit += lb.bias; - break; + if (tok[i].id == lb.token) { + tok[i].logit += lb.bias; + goto next; } } + if (vocab && llama_vocab_is_eog(vocab, tok[i].id)) { + tok[i].logit += eog_bias; + } + next:; } } static struct llama_sampler * llama_sampler_logit_bias_clone(const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_logit_bias *) smpl->ctx; - return llama_sampler_init_logit_bias(ctx->n_vocab, ctx->logit_bias.size(), ctx->logit_bias.data()); + return llama_sampler_init_logit_bias(ctx->n_vocab, ctx->logit_bias.size(), ctx->logit_bias.data(), ctx->eog.eog_bias_per_tok, ctx->eog.start_eog_at_remain, ctx->eog.vocab); } static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) { @@ -2279,13 +2310,18 @@ static struct llama_sampler_i llama_sampler_logit_bias_i = { struct llama_sampler * llama_sampler_init_logit_bias( int32_t n_vocab, int32_t n_logit_bias, - const llama_logit_bias * logit_bias) { + const llama_logit_bias * logit_bias, + float eog_bias_per_tok, + float start_eog_at_remain, + const struct llama_vocab *vocab) { return llama_sampler_init( /* .iface = */ &llama_sampler_logit_bias_i, /* .ctx = */ new llama_sampler_logit_bias { /* .n_vocab = */ n_vocab, /* .logit_bias = */ std::vector(logit_bias, logit_bias + n_logit_bias), /* .to_search = */ {}, + /* .eog = */ length_eog_bias{eog_bias_per_tok, start_eog_at_remain, vocab}, + /* .eog_token = */ -1, } ); } @@ -2305,7 +2341,7 @@ static const char * llama_sampler_infill_name(const struct llama_sampler * /*smp return "infill"; } -static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { +static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p, float) { auto * ctx = (llama_sampler_infill *) smpl->ctx; llama_sampler_softmax_impl(cur_p); diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index 6300f25caebe3..a9e0c73b86f3e 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -42,7 +42,7 @@ struct sampler_tester { } void apply(llama_sampler * sampler) { - llama_sampler_apply(sampler, &cur_p); + llama_sampler_apply(sampler, &cur_p, 0); llama_sampler_free(sampler); } @@ -271,13 +271,13 @@ static void bench(llama_sampler * cnstr, const char * cnstr_name, const std::vec std::vector cur(data.size()); std::copy(data.begin(), data.end(), cur.begin()); llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; - llama_sampler_apply(cnstr, &cur_p); + llama_sampler_apply(cnstr, &cur_p, 0); llama_sampler_reset(cnstr); const int64_t t_start = ggml_time_us(); for (int i = 0; i < n_iter; i++) { std::copy(data.begin(), data.end(), cur.begin()); llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; - llama_sampler_apply(cnstr, &cur_p); + llama_sampler_apply(cnstr, &cur_p, 0); llama_sampler_reset(cnstr); } const int64_t t_end = ggml_time_us(); diff --git a/tools/main/main.cpp b/tools/main/main.cpp index 516bf09652484..ed5578d333e82 100644 --- a/tools/main/main.cpp +++ b/tools/main/main.cpp @@ -699,7 +699,7 @@ int main(int argc, char ** argv) { LOG_DBG("saved session to %s\n", path_session.c_str()); } - const llama_token id = common_sampler_sample(smpl, ctx, -1); + const llama_token id = common_sampler_sample(smpl, ctx, -1, false, n_remain); common_sampler_accept(smpl, id, /* accept_grammar= */ true); diff --git a/tools/run/run.cpp b/tools/run/run.cpp index 6fe728c685358..fa4624e7b5334 100644 --- a/tools/run/run.cpp +++ b/tools/run/run.cpp @@ -1012,7 +1012,9 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str // prepare a batch for the prompt llama_batch batch = llama_batch_get_one(tokens.data(), tokens.size()); llama_token new_token_id; + int n_remain = batch.n_tokens; while (true) { + --n_remain; check_context_size(llama_data.context, batch); if (llama_decode(llama_data.context.get(), batch)) { printe("failed to decode\n"); @@ -1020,7 +1022,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str } // sample the next token, check is it an end of generation? - new_token_id = llama_sampler_sample(llama_data.sampler.get(), llama_data.context.get(), -1); + new_token_id = llama_sampler_sample(llama_data.sampler.get(), llama_data.context.get(), -1, n_remain); if (llama_vocab_is_eog(vocab, new_token_id)) { break; } diff --git a/tools/server/server.cpp b/tools/server/server.cpp index d3f6271931f62..fe286e6c392e7 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -3582,7 +3582,7 @@ struct server_context { llama_decode(ctx, slot.batch_spec); // the accepted tokens from the speculation - const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); + const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft, false, slot.n_remaining); slot.n_past += ids.size(); slot.n_decoded += ids.size();