Skip to content

Commit 0cffe93

Browse files
committed
logit_bias: apply configurable escalating EOG bias at low n_remain
give eog an increasing (with length - per token, could be per codepoint in future) bias, only after a configured amount generated add to `sample_apply` an `n_remain` param, which is safer than having logit_bias maintain state for how many times it's called (which would lead to wrong assumptions e.g. when calling multiple times per token). see new command line options (incl a request 'after' instead of 'remain'): -eog, --eog-bias-per-tok N when fewer than -start-eog-at-remain tokens are left to generate after -n, add this bias eog for each subsequent token (default: 0.0) -remain, --start-eog-at-remain N start applying -eog bias when this many tokens remain of the -n max (default: 0.0) -after, --start-eog-after N start applying -eog bias after this many tokens generated (default: 1000000000.0); whichever happens first between -remain and -after applies Verified that eog bias was effective at avoiding overgeneration and is a reasonable supplement or alternative to editing the prompt; a *constant* eog bias, already supported in samplers, is likely to allow pathologically short outputs.
1 parent e434e69 commit 0cffe93

File tree

19 files changed

+171
-72
lines changed

19 files changed

+171
-72
lines changed

common/arg.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1205,6 +1205,15 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e
12051205
exit(1); // for other exceptions, we exit with status code 1
12061206
}
12071207

1208+
float &pafter = params.sampling.start_eog_after;
1209+
float &premain = params.sampling.start_eog_at_remain;
1210+
float const premain0 = premain;
1211+
float remain = params.n_predict - pafter;
1212+
if (premain < remain)
1213+
premain = remain;
1214+
if (params.sampling.eog_bias_per_tok)
1215+
LOG_INF("%s: n_predict=%d (first of start_eog_at_remain=%0.3g start_eog_after=%0.3g) => (remain=%0.3g) eog-bias-per-tok=%0.3g\n", __func__, (int) params.n_predict,
1216+
(double) premain0, (double) pafter, (double)premain, (double) params.sampling.eog_bias_per_tok);
12081217
return true;
12091218
}
12101219

@@ -1937,6 +1946,27 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
19371946
}
19381947
}
19391948
).set_sparam());
1949+
add_opt(common_arg(
1950+
{"-eog", "--eog-bias-per-tok"}, "N",
1951+
string_format("when fewer than -start-eog-at-remain tokens are left to generate after -n, add this bias eog for each subsequent token (default: %.1f)", (double)params.sampling.eog_bias_per_tok),
1952+
[](common_params & params, const std::string & value) {
1953+
params.sampling.eog_bias_per_tok = std::stof(value);
1954+
}
1955+
).set_sparam());
1956+
add_opt(common_arg(
1957+
{"-remain", "--start-eog-at-remain"}, "N",
1958+
string_format("start applying -eog bias when this many tokens remain of the -n max (default: %.1f)", (double)params.sampling.start_eog_at_remain),
1959+
[](common_params & params, const std::string & value) {
1960+
params.sampling.start_eog_at_remain = std::stof(value);
1961+
}
1962+
).set_sparam());
1963+
add_opt(common_arg(
1964+
{"-after", "--start-eog-after"}, "N",
1965+
string_format("start applying -eog bias after this many tokens generated (default: %.1f); whichever happens first between -remain and -after applies", (double)params.sampling.start_eog_after),
1966+
[](common_params & params, const std::string & value) {
1967+
params.sampling.start_eog_after = std::stof(value);
1968+
}
1969+
).set_sparam());
19401970
add_opt(common_arg(
19411971
{"--grammar"}, "GRAMMAR",
19421972
string_format("BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", params.sampling.grammar.c_str()),

common/common.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,13 @@ struct common_params_sampling {
178178

179179
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
180180

181+
float eog_bias_per_tok = 0; // escalating bias added to eog per token after:
182+
/// this many remaining tokens (before applying eog_bias_per_tok) ...
183+
float start_eog_at_remain = 0;
184+
// or (whichever is first) after start_eog_after many generated:
185+
/// (i.e. EOG logit bias = max(0,start_eog_after = max(start_eog_after, n_remain - start_eog_at_remain)) * eog_bias_per_tok)
186+
float start_eog_after = 1e9;
187+
181188
// print the parameters into a string
182189
std::string print() const;
183190
};

common/sampling.cpp

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,10 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
226226
llama_sampler_init_logit_bias(
227227
llama_vocab_n_tokens(vocab),
228228
params.logit_bias.size(),
229-
params.logit_bias.data()));
229+
params.logit_bias.data(),
230+
params.eog_bias_per_tok,
231+
params.start_eog_at_remain,
232+
vocab));
230233

231234
if (params.mirostat == 0) {
232235
for (const auto & cnstr : params.samplers) {
@@ -335,18 +338,18 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
335338
}
336339
}
337340

338-
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
341+
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first, float n_remain) {
339342
gsmpl->set_logits(ctx, idx);
340343

341344
auto & grmr = gsmpl->grmr;
342345
auto & chain = gsmpl->chain;
343346
auto & cur_p = gsmpl->cur_p; // initialized by set_logits
344347

345348
if (grammar_first) {
346-
llama_sampler_apply(grmr, &cur_p);
349+
llama_sampler_apply(grmr, &cur_p, n_remain);
347350
}
348351

349-
llama_sampler_apply(chain, &cur_p);
352+
llama_sampler_apply(chain, &cur_p, n_remain);
350353

351354
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
352355

@@ -361,7 +364,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
361364
llama_token_data single_token_data = { id, 1.0f, 0.0f };
362365
llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };
363366

364-
llama_sampler_apply(grmr, &single_token_data_array);
367+
llama_sampler_apply(grmr, &single_token_data_array, n_remain);
365368

366369
const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
367370
if (is_valid) {
@@ -373,23 +376,23 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
373376
// if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
374377
gsmpl->set_logits(ctx, idx);
375378

376-
llama_sampler_apply(grmr, &cur_p);
377-
llama_sampler_apply(chain, &cur_p);
379+
llama_sampler_apply(grmr, &cur_p, n_remain);
380+
llama_sampler_apply(chain, &cur_p, n_remain);
378381

379382
GGML_ASSERT(cur_p.selected != -1 && "no selected token during re-sampling - check your sampling configuration");
380383

381384
return cur_p.data[cur_p.selected].id;
382385
}
383386

384-
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
387+
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first, float n_remain) {
385388
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
386389

387390
std::vector<llama_token> result;
388391
result.reserve(idxs.size());
389392

390393
size_t i = 0;
391394
for (; i < draft.size(); i++) {
392-
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
395+
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first, n_remain);
393396

394397
common_sampler_accept(gsmpl, id, true);
395398

@@ -401,7 +404,7 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
401404
}
402405

403406
if (i == draft.size()) {
404-
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
407+
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first, n_remain);
405408

406409
common_sampler_accept(gsmpl, id, true);
407410

@@ -411,13 +414,13 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
411414
return result;
412415
}
413416

414-
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
417+
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first, float n_remain) {
415418
std::vector<int> idxs(draft.size() + 1);
416419
for (size_t i = 0; i < idxs.size(); ++i) {
417420
idxs[i] = i;
418421
}
419422

420-
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
423+
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first, n_remain);
421424
}
422425

423426
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {

common/sampling.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
5858
// if grammar_first is true, the grammar is applied before the samplers (slower)
5959
// useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar
6060
//
61-
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
61+
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false, float n_remain = 0);
6262

6363
// generalized version of common_sampler_sample
6464
//
@@ -76,10 +76,10 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
7676
//
7777
// returns at least 1 token, up to idxs.size()
7878
//
79-
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false);
79+
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false, float n_remain = 0);
8080

8181
// assume idxs == [ 0, 1, 2, ..., draft.size() ]
82-
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);
82+
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false, float n_remain = 0);
8383

8484
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
8585

common/speculative.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,12 +238,12 @@ llama_tokens common_speculative_gen_draft(
238238
llama_decode(ctx, batch);
239239

240240
common_sampler_reset(smpl);
241-
241+
int n_remain = params.n_draft;
242242
// sample n_draft tokens from the draft model
243243
for (int i = 0; i < params.n_draft; ++i) {
244244
common_batch_clear(batch);
245245

246-
common_sampler_sample(smpl, ctx, 0, true);
246+
common_sampler_sample(smpl, ctx, 0, true, --n_remain);
247247

248248
const auto * cur_p = common_sampler_get_candidates(smpl);
249249

examples/batched/batched.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,9 @@ int main(int argc, char ** argv) {
162162

163163
const auto t_main_start = ggml_time_us();
164164

165+
int n_remain = n_predict;
165166
while (n_cur <= n_predict) {
167+
--n_remain;
166168
// prepare the next batch
167169
common_batch_clear(batch);
168170

@@ -173,7 +175,7 @@ int main(int argc, char ** argv) {
173175
continue;
174176
}
175177

176-
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, i_batch[i]);
178+
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, i_batch[i], n_remain);
177179

178180
// is it an end of generation? -> mark the stream as finished
179181
if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_predict) {

examples/gritlm/gritlm.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
108108

109109
std::vector<llama_token> inputs = common_tokenize(vocab, prompt, false, true);
110110
int32_t i_current_token = 0;
111-
111+
int n_remain = 32;
112112
while (true) {
113113
common_batch_clear(bat);
114114
{
@@ -122,7 +122,7 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
122122

123123
llama_decode(ctx, bat);
124124

125-
llama_token token = llama_sampler_sample(smpl, ctx, bat.n_tokens - 1);
125+
llama_token token = llama_sampler_sample(smpl, ctx, bat.n_tokens - 1, --n_remain);
126126

127127
if (token == eos_token) {
128128
break;

examples/lookahead/lookahead.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ int main(int argc, char ** argv) {
253253

254254
int seq_id_best = 0;
255255

256+
int n_remain = N;
256257
for (int v = 0; v < N; ++v) {
257258
int i_batch = 0;
258259

@@ -274,8 +275,9 @@ int main(int argc, char ** argv) {
274275
}
275276
}
276277

278+
--n_remain;
277279
// sample the next token
278-
id = common_sampler_sample(smpl, ctx, i_batch);
280+
id = common_sampler_sample(smpl, ctx, i_batch, n_remain);
279281

280282
common_sampler_accept(smpl, id, true);
281283

@@ -349,10 +351,11 @@ int main(int argc, char ** argv) {
349351
tokens_j[j] = tokens_j[j + 1];
350352
}
351353

354+
unsigned constexpr NA = (unsigned)-1;
352355
if (v == 0) {
353356
// sample from the last level
354357
for (int i = 0; i < W; i++) {
355-
tokens_j[N - 2][i] = common_sampler_sample(smpl, ctx, ngrams_cur.size()*(N-1) + W*(N - 2) + i);
358+
tokens_j[N - 2][i] = common_sampler_sample(smpl, ctx, ngrams_cur.size()*(N-1) + W*(N - 2) + i, NA);
356359
}
357360
} else {
358361
for (int i = 0; i < W; i++) {

examples/lookup/lookup.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,8 @@ int main(int argc, char ** argv){
117117
int i_dft = 0;
118118
while (true) {
119119
// sample from the target model
120-
llama_token id = common_sampler_sample(smpl, ctx, i_dft);
120+
unsigned const n_remain = params.n_predict - n_predict;
121+
llama_token id = common_sampler_sample(smpl, ctx, i_dft, n_remain);
121122

122123
common_sampler_accept(smpl, id, true);
123124

examples/passkey/passkey.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,10 +217,12 @@ int main(int argc, char ** argv) {
217217

218218
const auto t_main_start = ggml_time_us();
219219

220+
int n_remain = n_len - n_cur;
220221
while (n_cur <= n_len) {
222+
--n_remain;
221223
// sample the next token
222224
{
223-
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1);
225+
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1, n_remain);
224226

225227
// is it an end of generation?
226228
if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len) {

examples/save-load-state/save-load-state.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,10 @@ int main(int argc, char ** argv) {
7676
// first run
7777
printf("\nfirst run: %s", params.prompt.c_str());
7878

79+
int n_remain = params.n_predict;
7980
for (auto i = 0; i < params.n_predict; i++) {
80-
auto next_token = llama_sampler_sample(smpl, ctx, -1);
81+
--n_remain;
82+
auto next_token = llama_sampler_sample(smpl, ctx, -1, n_remain);
8183
auto next_token_str = common_token_to_piece(ctx, next_token);
8284

8385
printf("%s", next_token_str.c_str());
@@ -128,8 +130,10 @@ int main(int argc, char ** argv) {
128130
n_past = n_past_saved;
129131

130132
// second run
133+
n_remain = params.n_predict;
131134
for (auto i = 0; i < params.n_predict; i++) {
132-
auto next_token = llama_sampler_sample(smpl2, ctx2, -1);
135+
--n_remain;
136+
auto next_token = llama_sampler_sample(smpl2, ctx2, -1, n_remain);
133137
auto next_token_str = common_token_to_piece(ctx2, next_token);
134138

135139
printf("%s", next_token_str.c_str());
@@ -209,8 +213,10 @@ int main(int argc, char ** argv) {
209213
}
210214

211215
// third run with seq 1 instead of 0
216+
n_remain = params.n_predict;
212217
for (auto i = 0; i < params.n_predict; i++) {
213-
auto next_token = llama_sampler_sample(smpl3, ctx3, -1);
218+
--n_remain;
219+
auto next_token = llama_sampler_sample(smpl3, ctx3, -1, n_remain);
214220
auto next_token_str = common_token_to_piece(ctx3, next_token);
215221

216222
printf("%s", next_token_str.c_str());

examples/simple-chat/simple-chat.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,9 @@ int main(int argc, char ** argv) {
110110
// prepare a batch for the prompt
111111
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size());
112112
llama_token new_token_id;
113+
int n_remain = batch.n_tokens;
113114
while (true) {
115+
--n_remain;
114116
// check if we have enough space in the context to evaluate this batch
115117
int n_ctx = llama_n_ctx(ctx);
116118
int n_ctx_used = llama_memory_seq_pos_max(llama_get_memory(ctx), 0);
@@ -125,7 +127,7 @@ int main(int argc, char ** argv) {
125127
}
126128

127129
// sample the next token
128-
new_token_id = llama_sampler_sample(smpl, ctx, -1);
130+
new_token_id = llama_sampler_sample(smpl, ctx, -1, n_remain);
129131

130132
// is it an end of generation?
131133
if (llama_vocab_is_eog(vocab, new_token_id)) {

examples/simple/simple.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,8 @@ int main(int argc, char ** argv) {
151151
int n_decode = 0;
152152
llama_token new_token_id;
153153

154+
int n_remain = n_predict;
155+
154156
for (int n_pos = 0; n_pos + batch.n_tokens < n_prompt + n_predict; ) {
155157
// evaluate the current batch with the transformer model
156158
if (llama_decode(ctx, batch)) {
@@ -162,7 +164,7 @@ int main(int argc, char ** argv) {
162164

163165
// sample the next token
164166
{
165-
new_token_id = llama_sampler_sample(smpl, ctx, -1);
167+
new_token_id = llama_sampler_sample(smpl, ctx, -1, --n_remain);
166168

167169
// is it an end of generation?
168170
if (llama_vocab_is_eog(vocab, new_token_id)) {

include/llama.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,7 +1197,7 @@ extern "C" {
11971197
struct llama_sampler_i {
11981198
const char * (*name) (const struct llama_sampler * smpl); // can be NULL
11991199
void (*accept)( struct llama_sampler * smpl, llama_token token); // can be NULL
1200-
void (*apply) ( struct llama_sampler * smpl, llama_token_data_array * cur_p); // required
1200+
void (*apply) ( struct llama_sampler * smpl, llama_token_data_array * cur_p, float n_remain); // required
12011201
void (*reset) ( struct llama_sampler * smpl); // can be NULL
12021202
struct llama_sampler * (*clone) (const struct llama_sampler * smpl); // can be NULL if ctx is NULL
12031203
void (*free) ( struct llama_sampler * smpl); // can be NULL if ctx is NULL
@@ -1215,7 +1215,7 @@ extern "C" {
12151215
LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_sampler_i * iface, llama_sampler_context_t ctx);
12161216
LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl);
12171217
LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token);
1218-
LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p);
1218+
LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p, float n_remain);
12191219
LLAMA_API void llama_sampler_reset ( struct llama_sampler * smpl);
12201220
LLAMA_API struct llama_sampler * llama_sampler_clone (const struct llama_sampler * smpl);
12211221
// important: do not free if the sampler has been added to a llama_sampler_chain (via llama_sampler_chain_add)
@@ -1346,7 +1346,10 @@ extern "C" {
13461346
LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias(
13471347
int32_t n_vocab,
13481348
int32_t n_logit_bias,
1349-
const llama_logit_bias * logit_bias);
1349+
const llama_logit_bias * logit_bias,
1350+
float eog_bias_per_tok,
1351+
float start_eog_at_remain,
1352+
const struct llama_vocab *vocab);
13501353

13511354
// this sampler is meant to be used for fill-in-the-middle infilling
13521355
// it's supposed to be used after top_k + top_p sampling
@@ -1384,7 +1387,7 @@ extern "C" {
13841387
// llama_sampler_accept(smpl, token);
13851388
// return token;
13861389
// Returns the sampled token
1387-
LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx);
1390+
LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx, float n_remain);
13881391

13891392
// TODO: extend in the future
13901393
//LLAMA_API void llama_decode_with_sampler(struct llama_context * ctx, struct llama_sampler * smpl, struct llama_batch batch, ...);

0 commit comments

Comments
 (0)