Skip to content

In-code documentation update #383

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 105 additions & 15 deletions tiktoken/_educational.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@

class SimpleBytePairEncoding:
def __init__(self, *, pat_str: str, mergeable_ranks: dict[bytes, int]) -> None:
"""Creates an Encoding object."""
"""Creates an Encoding object.

Args:
pat_str (str): A regex pattern string that is used to split the input text.
mergeable_ranks (dict[bytes, int]): A dictionary mapping token bytes to their ranks.
The ranks correspond to merge priority.
"""
# A regex pattern string that is used to split the input text
self.pat_str = pat_str
# A dictionary mapping token bytes to their ranks. The ranks correspond to merge priority
Expand All @@ -23,8 +29,17 @@ def __init__(self, *, pat_str: str, mergeable_ranks: dict[bytes, int]) -> None:
def encode(self, text: str, visualise: str | None = "colour") -> list[int]:
"""Encodes a string into tokens.

>>> enc.encode("hello world")
[388, 372]
Args:
text (str): The text to encode.
visualise (str | None, optional): Visualization mode. Can be 'colour', 'color',
'simple', or None. Defaults to 'colour'.

Returns:
list[int]: The encoded tokens.

Examples:
>>> enc.encode("hello world")
[388, 372]
"""
# Use the regex to split the text into (approximately) words
words = self._pat.findall(text)
Expand All @@ -39,35 +54,70 @@ def encode(self, text: str, visualise: str | None = "colour") -> list[int]:
def decode_bytes(self, tokens: list[int]) -> bytes:
"""Decodes a list of tokens into bytes.

>>> enc.decode_bytes([388, 372])
b'hello world'
Args:
tokens (list[int]): The list of tokens to decode.

Returns:
bytes: The decoded bytes.

Examples:
>>> enc.decode_bytes([388, 372])
b'hello world'
"""
return b"".join(self._decoder[token] for token in tokens)

def decode(self, tokens: list[int]) -> str:
"""Decodes a list of tokens into a string.

Decoded bytes are not guaranteed to be valid UTF-8. In that case, we replace
the invalid bytes with the replacement character "�".
Args:
tokens (list[int]): The list of tokens to decode.

>>> enc.decode([388, 372])
'hello world'
Returns:
str: The decoded string.

Note:
Decoded bytes are not guaranteed to be valid UTF-8. In that case, we replace
the invalid bytes with the replacement character "�".

Examples:
>>> enc.decode([388, 372])
'hello world'
"""
return self.decode_bytes(tokens).decode("utf-8", errors="replace")

def decode_tokens_bytes(self, tokens: list[int]) -> list[bytes]:
"""Decodes a list of tokens into a list of bytes.

Useful for visualising how a string is tokenised.
Args:
tokens (list[int]): The list of tokens to decode.

Returns:
list[bytes]: A list of decoded bytes.

>>> enc.decode_tokens_bytes([388, 372])
[b'hello', b' world']
Note:
Useful for visualising how a string is tokenised.

Examples:
>>> enc.decode_tokens_bytes([388, 372])
[b'hello', b' world']
"""
return [self._decoder[token] for token in tokens]

@staticmethod
def train(training_data: str, vocab_size: int, pat_str: str):
"""Train a BPE tokeniser on some data!"""
"""Train a BPE tokeniser on some data.

Args:
training_data (str): The text data to train on.
vocab_size (int): The desired size of the vocabulary.
pat_str (str): The regex pattern string used for tokenization.

Returns:
SimpleBytePairEncoding: A new tokenizer trained on the data.

Note:
This is an educational implementation of BPE training.
"""
mergeable_ranks = bpe_train(data=training_data, vocab_size=vocab_size, pat_str=pat_str)
return SimpleBytePairEncoding(pat_str=pat_str, mergeable_ranks=mergeable_ranks)

Expand All @@ -81,8 +131,21 @@ def from_tiktoken(encoding):


def bpe_encode(
mergeable_ranks: dict[bytes, int], input: bytes, visualise: str | None = "colour"
mergeable_ranks: dict[bytes, int],
input: bytes,
visualise: str | None = "colour"
) -> list[int]:
"""Encodes input bytes using byte pair encoding.

Args:
mergeable_ranks (dict[bytes, int]): Dictionary mapping token bytes to their ranks.
input (bytes): The input bytes to encode.
visualise (str | None, optional): Visualization mode. Can be 'colour', 'color',
'simple', or None. Defaults to 'colour'.

Returns:
list[int]: The encoded tokens.
"""
parts = [bytes([b]) for b in input]
while True:
# See the intermediate merges play out!
Expand Down Expand Up @@ -117,8 +180,26 @@ def bpe_encode(


def bpe_train(
data: str, vocab_size: int, pat_str: str, visualise: str | None = "colour"
data: str,
vocab_size: int,
pat_str: str,
visualise: str | None = "colour"
) -> dict[bytes, int]:
"""Trains a byte pair encoding model on the given data.

Args:
data (str): The text data to train on.
vocab_size (int): The desired size of the vocabulary.
pat_str (str): The regex pattern string used for tokenization.
visualise (str | None, optional): Visualization mode. Can be 'colour', 'color',
'simple', or None. Defaults to 'colour'.

Returns:
dict[bytes, int]: A dictionary mapping token bytes to their ranks.

Raises:
ValueError: If vocab_size is less than 256.
"""
# First, add tokens for each individual byte value
if vocab_size < 2**8:
raise ValueError("vocab_size must be at least 256, so we can encode all bytes")
Expand Down Expand Up @@ -186,6 +267,15 @@ def bpe_train(


def visualise_tokens(token_values: list[bytes]) -> None:
"""Visualizes tokens by printing them with different background colors.

Args:
token_values (list[bytes]): List of token bytes to visualize.

Note:
If token boundaries do not occur at unicode character boundaries, it uses the
unicode replacement character to represent some fraction of a character.
"""
background = [f"\u001b[48;5;{i}m" for i in [167, 179, 185, 77, 80, 68, 134]]
# If token boundaries do not occur at unicode character boundaries, it's unclear how best to
# visualise the token. Here, we'll just use the unicode replacement character to represent some
Expand Down
Loading