Skip to content

Commit f8e2648

Browse files
committed
Partial sync of codebase
1 parent bb5805d commit f8e2648

File tree

1 file changed

+35
-1
lines changed

1 file changed

+35
-1
lines changed

tiktoken/core.py

+35-1
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,16 @@
22

33
import functools
44
from concurrent.futures import ThreadPoolExecutor
5-
from typing import AbstractSet, Collection, Literal, NoReturn, Sequence
5+
from typing import TYPE_CHECKING, AbstractSet, Collection, Literal, NoReturn, Sequence
66

77
import regex
88

99
from tiktoken import _tiktoken
1010

11+
if TYPE_CHECKING:
12+
import numpy as np
13+
import numpy.typing as npt
14+
1115

1216
class Encoding:
1317
def __init__(
@@ -128,6 +132,32 @@ def encode(
128132
text = text.encode("utf-16", "surrogatepass").decode("utf-16", "replace")
129133
return self._core_bpe.encode(text, allowed_special)
130134

135+
def encode_to_numpy(
136+
self,
137+
text: str,
138+
*,
139+
allowed_special: Literal["all"] | AbstractSet[str] = set(), # noqa: B006
140+
disallowed_special: Literal["all"] | Collection[str] = "all",
141+
) -> npt.NDArray[np.uint32]:
142+
"""Encodes a string into tokens, returning a numpy array.
143+
144+
Avoids the overhead of copying the token buffer into a Python list.
145+
"""
146+
if allowed_special == "all":
147+
allowed_special = self.special_tokens_set
148+
if disallowed_special == "all":
149+
disallowed_special = self.special_tokens_set - allowed_special
150+
if disallowed_special:
151+
if not isinstance(disallowed_special, frozenset):
152+
disallowed_special = frozenset(disallowed_special)
153+
if match := _special_token_regex(disallowed_special).search(text):
154+
raise_disallowed_special_token(match.group())
155+
156+
import numpy as np
157+
158+
buffer = self._core_bpe.encode_to_tiktoken_buffer(text, self.special_tokens_set)
159+
return np.frombuffer(buffer, dtype=np.uint32)
160+
131161
def encode_ordinary_batch(self, text: list[str], *, num_threads: int = 8) -> list[list[int]]:
132162
"""Encodes a list of strings into tokens, in parallel, ignoring special tokens.
133163
@@ -332,6 +362,10 @@ def eot_token(self) -> int:
332362
def special_tokens_set(self) -> set[str]:
333363
return set(self._special_tokens.keys())
334364

365+
def is_special_token(self, token: int) -> bool:
366+
assert isinstance(token, int)
367+
return token in self._special_token_values
368+
335369
@property
336370
def n_vocab(self) -> int:
337371
"""For backwards compatibility. Prefer to use `enc.max_token_value + 1`."""

0 commit comments

Comments
 (0)