Skip to content

Commit 777ce73

Browse files
committed
Partial sync of codebase
1 parent bb5805d commit 777ce73

File tree

3 files changed

+41
-2
lines changed

3 files changed

+41
-2
lines changed

tiktoken/core.py

+35-1
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,16 @@
22

33
import functools
44
from concurrent.futures import ThreadPoolExecutor
5-
from typing import AbstractSet, Collection, Literal, NoReturn, Sequence
5+
from typing import TYPE_CHECKING, AbstractSet, Collection, Literal, NoReturn, Sequence
66

77
import regex
88

99
from tiktoken import _tiktoken
1010

11+
if TYPE_CHECKING:
12+
import numpy as np
13+
import numpy.typing as npt
14+
1115

1216
class Encoding:
1317
def __init__(
@@ -128,6 +132,32 @@ def encode(
128132
text = text.encode("utf-16", "surrogatepass").decode("utf-16", "replace")
129133
return self._core_bpe.encode(text, allowed_special)
130134

135+
def encode_to_numpy(
136+
self,
137+
text: str,
138+
*,
139+
allowed_special: Literal["all"] | AbstractSet[str] = set(), # noqa: B006
140+
disallowed_special: Literal["all"] | Collection[str] = "all",
141+
) -> npt.NDArray[np.uint32]:
142+
"""Encodes a string into tokens, returning a numpy array.
143+
144+
Avoids the overhead of copying the token buffer into a Python list.
145+
"""
146+
if allowed_special == "all":
147+
allowed_special = self.special_tokens_set
148+
if disallowed_special == "all":
149+
disallowed_special = self.special_tokens_set - allowed_special
150+
if disallowed_special:
151+
if not isinstance(disallowed_special, frozenset):
152+
disallowed_special = frozenset(disallowed_special)
153+
if match := _special_token_regex(disallowed_special).search(text):
154+
raise_disallowed_special_token(match.group())
155+
156+
import numpy as np
157+
158+
buffer = self._core_bpe.encode_to_tiktoken_buffer(text, self.special_tokens_set)
159+
return np.frombuffer(buffer, dtype=np.uint32)
160+
131161
def encode_ordinary_batch(self, text: list[str], *, num_threads: int = 8) -> list[list[int]]:
132162
"""Encodes a list of strings into tokens, in parallel, ignoring special tokens.
133163
@@ -332,6 +362,10 @@ def eot_token(self) -> int:
332362
def special_tokens_set(self) -> set[str]:
333363
return set(self._special_tokens.keys())
334364

365+
def is_special_token(self, token: int) -> bool:
366+
assert isinstance(token, int)
367+
return token in self._special_token_values
368+
335369
@property
336370
def n_vocab(self) -> int:
337371
"""For backwards compatibility. Prefer to use `enc.max_token_value + 1`."""

tiktoken/load.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -154,5 +154,5 @@ def load_tiktoken_bpe(tiktoken_bpe_file: str, expected_hash: str | None = None)
154154
token, rank = line.split()
155155
ret[base64.b64decode(token)] = int(rank)
156156
except Exception as e:
157-
raise ValueError(f"Error parsing line {line} in {tiktoken_bpe_file}") from e
157+
raise ValueError(f"Error parsing line {line!r} in {tiktoken_bpe_file}") from e
158158
return ret

tiktoken/model.py

+5
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,25 @@
66
# TODO: these will likely be replaced by an API endpoint
77
MODEL_PREFIX_TO_ENCODING: dict[str, str] = {
88
"o1-": "o200k_base",
9+
"o3-": "o200k_base",
910
# chat
1011
"chatgpt-4o-": "o200k_base",
1112
"gpt-4o-": "o200k_base", # e.g., gpt-4o-2024-05-13
1213
"gpt-4-": "cl100k_base", # e.g., gpt-4-0314, etc., plus gpt-4-32k
1314
"gpt-3.5-turbo-": "cl100k_base", # e.g, gpt-3.5-turbo-0301, -0401, etc.
1415
"gpt-35-turbo-": "cl100k_base", # Azure deployment name
1516
# fine-tuned
17+
"ft:gpt-4o": "o200k_base",
1618
"ft:gpt-4": "cl100k_base",
1719
"ft:gpt-3.5-turbo": "cl100k_base",
1820
"ft:davinci-002": "cl100k_base",
1921
"ft:babbage-002": "cl100k_base",
2022
}
2123

2224
MODEL_TO_ENCODING: dict[str, str] = {
25+
# reasoning
26+
"o1": "o200k_base",
27+
"o3": "o200k_base",
2328
# chat
2429
"gpt-4o": "o200k_base",
2530
"gpt-4": "cl100k_base",

0 commit comments

Comments
 (0)