Skip to content

Commit 3ceb47b

Browse files
committed
Fix mirastat requiring c_float
1 parent 9797394 commit 3ceb47b

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

examples/low_level_api/low_level_api_chat_cpp.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ def generate(self):
357357

358358
# Apply params.logit_bias map
359359
for key, value in self.params.logit_bias.items():
360-
logits[key] += value
360+
logits[key] += llama_cpp.c_float(value)
361361

362362
_arr = (llama_cpp.llama_token_data * n_vocab)(*[
363363
llama_cpp.llama_token_data(token_id, logits[token_id], 0.0)
@@ -372,34 +372,34 @@ def generate(self):
372372
_arr = (llama_cpp.llama_token * last_n_repeat)(*self.last_n_tokens[len(self.last_n_tokens) - last_n_repeat:])
373373
llama_cpp.llama_sample_repetition_penalty(self.ctx, candidates_p,
374374
_arr,
375-
last_n_repeat, self.params.repeat_penalty)
375+
last_n_repeat, llama_cpp.c_float(self.params.repeat_penalty))
376376
llama_cpp.llama_sample_frequency_and_presence_penalties(self.ctx, candidates_p,
377377
_arr,
378-
last_n_repeat, self.params.frequency_penalty, self.params.presence_penalty)
378+
last_n_repeat, llama_cpp.c_float(self.params.frequency_penalty), llama_cpp.c_float(self.params.presence_penalty))
379379

380380
if not self.params.penalize_nl:
381381
logits[llama_cpp.llama_token_nl()] = nl_logit
382-
382+
383383
if self.params.temp <= 0:
384384
# Greedy sampling
385385
id = llama_cpp.llama_sample_token_greedy(self.ctx, candidates_p)
386386
else:
387387
if self.params.mirostat == 1:
388388
mirostat_mu = 2.0 * self.params.mirostat_tau
389389
mirostat_m = 100
390-
llama_cpp.llama_sample_temperature(self.ctx, candidates_p, self.params.temp)
391-
id = llama_cpp.llama_sample_token_mirostat(self.ctx, candidates_p, self.params.mirostat_tau, self.params.mirostat_eta, mirostat_m, mirostat_mu)
390+
llama_cpp.llama_sample_temperature(self.ctx, candidates_p, llama_cpp.c_float(self.params.temp))
391+
id = llama_cpp.llama_sample_token_mirostat(self.ctx, candidates_p, llama_cpp.c_float(self.params.mirostat_tau), llama_cpp.c_float(self.params.mirostat_eta), llama_cpp.c_int(mirostat_m), llama_cpp.c_float(mirostat_mu))
392392
elif self.params.mirostat == 2:
393393
mirostat_mu = 2.0 * self.params.mirostat_tau
394-
llama_cpp.llama_sample_temperature(self.ctx, candidates_p, self.params.temp)
395-
id = llama_cpp.llama_sample_token_mirostat_v2(self.ctx, candidates_p, self.params.mirostat_tau, self.params.mirostat_eta, mirostat_mu)
394+
llama_cpp.llama_sample_temperature(self.ctx, candidates_p, llama_cpp.c_float(self.params.temp))
395+
id = llama_cpp.llama_sample_token_mirostat_v2(self.ctx, candidates_p, llama_cpp.c_float(self.params.mirostat_tau), llama_cpp.c_float(self.params.mirostat_eta), llama_cpp.c_float(mirostat_mu))
396396
else:
397397
# Temperature sampling
398398
llama_cpp.llama_sample_top_k(self.ctx, candidates_p, top_k)
399-
llama_cpp.llama_sample_tail_free(self.ctx, candidates_p, self.params.tfs_z)
400-
llama_cpp.llama_sample_typical(self.ctx, candidates_p, self.params.typical_p)
401-
llama_cpp.llama_sample_top_p(self.ctx, candidates_p, self.params.top_p)
402-
llama_cpp.llama_sample_temperature(self.ctx, candidates_p, self.params.temp)
399+
llama_cpp.llama_sample_tail_free(self.ctx, candidates_p, llama_cpp.c_float(self.params.tfs_z))
400+
llama_cpp.llama_sample_typical(self.ctx, candidates_p, llama_cpp.c_float(self.params.typical_p))
401+
llama_cpp.llama_sample_top_p(self.ctx, candidates_p, llama_cpp.c_float(self.params.top_p))
402+
llama_cpp.llama_sample_temperature(self.ctx, candidates_p, llama_cpp.c_float(self.params.temp))
403403
id = llama_cpp.llama_sample_token(self.ctx, candidates_p)
404404
# print("`{}`".format(candidates_p.size))
405405

0 commit comments

Comments
 (0)