Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 62 additions & 38 deletions llama_cpp/_internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,9 @@ def get_embeddings_seq(self, seq_id: int):
# Sampling functions - deprecated, use LlamaSampler instead

def set_rng_seed(self, seed: int):
raise NotImplementedError("set_rng_seed is deprecated, use LlamaSampler instead")
raise NotImplementedError(
"set_rng_seed is deprecated, use LlamaSampler instead"
)

def sample_repetition_penalties(
self,
Expand All @@ -366,30 +368,44 @@ def sample_repetition_penalties(
penalty_freq: float,
penalty_present: float,
):
raise NotImplementedError("sample_repetition_penalties is deprecated, use LlamaSampler instead")
raise NotImplementedError(
"sample_repetition_penalties is deprecated, use LlamaSampler instead"
)

def sample_softmax(self, candidates: "_LlamaTokenDataArray"):
raise NotImplementedError("sample_softmax is deprecated, use LlamaSampler instead")
raise NotImplementedError(
"sample_softmax is deprecated, use LlamaSampler instead"
)

def sample_top_k(self, candidates: "_LlamaTokenDataArray", k: int, min_keep: int):
raise NotImplementedError("sample_top_k is deprecated, use LlamaSampler instead")
raise NotImplementedError(
"sample_top_k is deprecated, use LlamaSampler instead"
)

def sample_top_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int):
raise NotImplementedError("sample_top_p is deprecated, use LlamaSampler instead")
raise NotImplementedError(
"sample_top_p is deprecated, use LlamaSampler instead"
)

def sample_min_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int):
raise NotImplementedError("sample_min_p is deprecated, use LlamaSampler instead")
raise NotImplementedError(
"sample_min_p is deprecated, use LlamaSampler instead"
)

def sample_typical(
self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int
):
raise NotImplementedError("sample_typical is deprecated, use LlamaSampler instead")
raise NotImplementedError(
"sample_typical is deprecated, use LlamaSampler instead"
)

def sample_temp(self, candidates: "_LlamaTokenDataArray", temp: float):
raise NotImplementedError("sample_temp is deprecated, use LlamaSampler instead")

def sample_grammar(self, candidates: "_LlamaTokenDataArray", grammar: LlamaGrammar):
raise NotImplementedError("sample_grammar is deprecated, use LlamaSampler instead")
raise NotImplementedError(
"sample_grammar is deprecated, use LlamaSampler instead"
)

def sample_token_mirostat(
self,
Expand All @@ -399,7 +415,9 @@ def sample_token_mirostat(
m: int,
mu: llama_cpp.CtypesPointerOrRef[ctypes.c_float],
) -> int:
raise NotImplementedError("sample_token_mirostat is deprecated, use LlamaSampler instead")
raise NotImplementedError(
"sample_token_mirostat is deprecated, use LlamaSampler instead"
)

def sample_token_mirostat_v2(
self,
Expand All @@ -408,17 +426,25 @@ def sample_token_mirostat_v2(
eta: float,
mu: llama_cpp.CtypesPointerOrRef[ctypes.c_float],
) -> int:
raise NotImplementedError("sample_token_mirostat_v2 is deprecated, use LlamaSampler instead")
raise NotImplementedError(
"sample_token_mirostat_v2 is deprecated, use LlamaSampler instead"
)

def sample_token_greedy(self, candidates: "_LlamaTokenDataArray") -> int:
raise NotImplementedError("sample_token_greedy is deprecated, use LlamaSampler instead")
raise NotImplementedError(
"sample_token_greedy is deprecated, use LlamaSampler instead"
)

def sample_token(self, candidates: "_LlamaTokenDataArray") -> int:
raise NotImplementedError("sample_token is deprecated, use LlamaSampler instead")
raise NotImplementedError(
"sample_token is deprecated, use LlamaSampler instead"
)

# Grammar
def grammar_accept_token(self, grammar: LlamaGrammar, token: int):
raise NotImplementedError("grammar_accept_token is deprecated, use LlamaSampler instead")
raise NotImplementedError(
"grammar_accept_token is deprecated, use LlamaSampler instead"
)

def reset_timings(self):
llama_cpp.llama_perf_context_reset(self.ctx)
Expand Down Expand Up @@ -529,6 +555,8 @@ def normalize_embedding(embedding):
norm = float(np.linalg.norm(embedding))
if norm == 0.0:
return embedding
if isinstance(embedding, np.ndarray):
return embedding / norm
return [v / norm for v in embedding]


Expand Down Expand Up @@ -602,16 +630,16 @@ def sample(
logits_array: Optional[npt.NDArray[np.single]] = None,
):
# This method is deprecated in favor of using LlamaSampler directly
raise NotImplementedError("LlamaSamplingContext.sample is deprecated, use LlamaSampler instead")
raise NotImplementedError(
"LlamaSamplingContext.sample is deprecated, use LlamaSampler instead"
)

def accept(self, ctx_main: LlamaContext, id: int, apply_grammar: bool):
self.prev.append(id)


class CustomSampler:
def __init__(
self, apply_func: Callable[[llama_cpp.llama_token_data_array], None]
):
def __init__(self, apply_func: Callable[[llama_cpp.llama_token_data_array], None]):
self.apply_func = apply_func

def apply_wrapper(
Expand Down Expand Up @@ -723,28 +751,28 @@ def add_grammar(self, model: LlamaModel, grammar: LlamaGrammar):
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)

def add_grammar_lazy_patterns(
self,
model: LlamaModel,
self,
model: LlamaModel,
grammar: LlamaGrammar,
trigger_patterns: List[str],
trigger_tokens: List[int]
trigger_tokens: List[int],
):
# Convert patterns to C array
pattern_ptrs = (ctypes.c_char_p * len(trigger_patterns))()
for i, pattern in enumerate(trigger_patterns):
pattern_ptrs[i] = pattern.encode("utf-8")

# Convert tokens to C array
token_array = (llama_cpp.llama_token * len(trigger_tokens))(*trigger_tokens)

sampler = llama_cpp.llama_sampler_init_grammar_lazy_patterns(
model.vocab,
grammar._grammar.encode("utf-8"),
grammar._root.encode("utf-8"),
pattern_ptrs,
len(trigger_patterns),
token_array,
len(trigger_tokens)
len(trigger_tokens),
)
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)

Expand All @@ -771,13 +799,13 @@ def add_dry(
dry_base: float,
dry_allowed_length: int,
dry_penalty_last_n: int,
seq_breakers: List[str]
seq_breakers: List[str],
):
# Convert seq_breakers to C array
breaker_ptrs = (ctypes.c_char_p * len(seq_breakers))()
for i, breaker in enumerate(seq_breakers):
breaker_ptrs[i] = breaker.encode("utf-8")

sampler = llama_cpp.llama_sampler_init_dry(
model.vocab,
n_ctx_train,
Expand All @@ -786,25 +814,19 @@ def add_dry(
dry_allowed_length,
dry_penalty_last_n,
breaker_ptrs,
len(seq_breakers)
len(seq_breakers),
)
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)

def add_logit_bias(
self,
n_vocab: int,
logit_bias: Dict[int, float]
):
def add_logit_bias(self, n_vocab: int, logit_bias: Dict[int, float]):
# Convert logit_bias dict to C array
bias_array = (llama_cpp.llama_logit_bias * len(logit_bias))()
for i, (token, bias) in enumerate(logit_bias.items()):
bias_array[i].token = token
bias_array[i].bias = bias

sampler = llama_cpp.llama_sampler_init_logit_bias(
n_vocab,
len(logit_bias),
bias_array
n_vocab, len(logit_bias), bias_array
)
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)

Expand Down Expand Up @@ -838,15 +860,17 @@ def reset(self):
def clone(self):
# NOTE: Custom samplers cannot be cloned due to Python callback limitations
if self.custom_samplers:
raise NotImplementedError("Cannot clone LlamaSampler that contains custom samplers")

raise NotImplementedError(
"Cannot clone LlamaSampler that contains custom samplers"
)

cloned_sampler = llama_cpp.llama_sampler_clone(self.sampler)
# Create a new wrapper around the cloned sampler
new_sampler = LlamaSampler.__new__(LlamaSampler)
new_sampler.sampler = cloned_sampler
new_sampler.custom_samplers = []
new_sampler._exit_stack = ExitStack()

def free_sampler():
if new_sampler.sampler is not None:
llama_cpp.llama_sampler_free(new_sampler.sampler)
Expand Down
Loading