-
Notifications
You must be signed in to change notification settings - Fork 27
Add vocabulary extraction option to tokenization tools #681
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -64,6 +64,7 @@ def parse_arguments(): | |||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Additional options | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| parser.add_argument("-T", "--track_token_counts", action="store_true", help="Track how often each token appears and store in meta.pkl") | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| parser.add_argument("-E", "--extract-vocab", action="store_true", help="Export the tokenizer vocabulary to a JSON file") | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| return parser.parse_args() | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -152,6 +153,20 @@ def main(): | |||||||||||||||||||||||||||||||||||||||||||||||||||
| if val_ids is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| save_tokens(val_ids, args.val_output, dtype) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| if args.extract_vocab: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not hasattr(tokenizer, "get_vocabulary"): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError(f"Vocabulary extraction is not supported for tokenizer method '{args.method}'.") | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| vocab = tokenizer.get_vocabulary() | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+156
to
+160
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not hasattr(tokenizer, "get_vocabulary"): | |
| raise ValueError(f"Vocabulary extraction is not supported for tokenizer method '{args.method}'.") | |
| vocab = tokenizer.get_vocabulary() | |
| try: | |
| vocab = tokenizer.get_vocabulary() | |
| except NotImplementedError: | |
| raise ValueError(f"Vocabulary extraction is not supported for tokenizer method '{args.method}'.") |
Copilot
AI
Nov 15, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The sorting strategy sort(key=lambda token: (-len(token), token)) sorts by descending length first, then alphabetically. While this may be intentional for some use cases, the choice of sorting order is not documented.
For typical vocabulary files, users might expect alphabetical sorting or sorting by token ID order. Consider documenting why this particular sorting order was chosen, or making it configurable via a command-line option.
Copilot
AI
Nov 15, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The SineWaveTokenizer class doesn't inherit from the Tokenizer base class and therefore doesn't have a get_vocabulary method. When --extract-vocab is used with --method sinewave, the code will fail at line 160 when calling tokenizer.get_vocabulary() with an AttributeError.
Either SineWaveTokenizer should inherit from Tokenizer and implement get_vocabulary, or the vocabulary extraction logic in prepare.py should explicitly handle the sinewave case.
| if not hasattr(tokenizer, "get_vocabulary"): | |
| raise ValueError(f"Vocabulary extraction is not supported for tokenizer method '{args.method}'.") | |
| vocab = tokenizer.get_vocabulary() | |
| # Ensure string representations for all tokens | |
| vocab_strings = [token if isinstance(token, str) else str(token) for token in vocab] | |
| vocab_strings.sort(key=lambda token: (-len(token), token)) | |
| vocab_filename = f"{args.method}_vocab.json" | |
| with open(vocab_filename, 'w', encoding='utf-8') as f: | |
| json.dump(vocab_strings, f, ensure_ascii=False, indent=2) | |
| if args.method == "sinewave": | |
| # SineWaveTokenizer does not have get_vocabulary; generate vocabulary as 0-255 | |
| vocab = [str(i) for i in range(256)] | |
| else: | |
| if not hasattr(tokenizer, "get_vocabulary"): | |
| raise ValueError(f"Vocabulary extraction is not supported for tokenizer method '{args.method}'.") | |
| vocab = tokenizer.get_vocabulary() | |
| # Ensure string representations for all tokens | |
| vocab = [token if isinstance(token, str) else str(token) for token in vocab] | |
| vocab.sort(key=lambda token: (-len(token), token)) | |
| vocab_filename = f"{args.method}_vocab.json" | |
| with open(vocab_filename, 'w', encoding='utf-8') as f: | |
| json.dump(vocab, f, ensure_ascii=False, indent=2) |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -35,6 +35,10 @@ def finalize_meta(self, meta): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| meta["token_counts"] = dict(self.token_counts) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.save_meta(meta) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def get_vocabulary(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Return the list of string representations that make up the tokenizer's vocabulary.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise NotImplementedError("Vocabulary extraction is not implemented for this tokenizer.") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+38
to
+41
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @staticmethod | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def get_key_from_meta(keyname): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| meta_path = 'meta.pkl' | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -117,6 +121,11 @@ def detokenize(self, ids): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError("SentencePiece model is not loaded.") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return self.sp.decode_ids(ids) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def get_vocabulary(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not self.sp: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError("SentencePiece model is not loaded.") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return [self.sp.id_to_piece(i) for i in range(self.sp.GetPieceSize())] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class TiktokenTokenizer(Tokenizer): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def __init__(self, args): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| super().__init__(args) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -218,6 +227,29 @@ def detokenize(self, token_ids): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return ''.join(result) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def get_vocabulary(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| vocab = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| seen = set() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for token_id in range(self.enc.n_vocab): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| token_bytes = self.enc.decode_single_token_bytes(token_id) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| token_str = token_bytes.decode('utf-8', errors='replace') | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if token_str not in seen: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| seen.add(token_str) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| vocab.append(token_str) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Include known special tokens (base and additional) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| special_tokens = {} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if hasattr(self.enc, "_special_tokens") and isinstance(self.enc._special_tokens, dict): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| special_tokens.update(self.enc._special_tokens) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| special_tokens.update(self.special_tokens) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for token in special_tokens.keys(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if token not in seen: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| seen.add(token) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| vocab.append(token) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+231
to
+250
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| vocab = [] | |
| seen = set() | |
| for token_id in range(self.enc.n_vocab): | |
| token_bytes = self.enc.decode_single_token_bytes(token_id) | |
| token_str = token_bytes.decode('utf-8', errors='replace') | |
| if token_str not in seen: | |
| seen.add(token_str) | |
| vocab.append(token_str) | |
| # Include known special tokens (base and additional) | |
| special_tokens = {} | |
| if hasattr(self.enc, "_special_tokens") and isinstance(self.enc._special_tokens, dict): | |
| special_tokens.update(self.enc._special_tokens) | |
| special_tokens.update(self.special_tokens) | |
| for token in special_tokens.keys(): | |
| if token not in seen: | |
| seen.add(token) | |
| vocab.append(token) | |
| """ | |
| Returns the vocabulary as a list of strings, where each entry corresponds to the decoded string for each token ID. | |
| The length of the returned list matches self.enc.n_vocab, so vocab[i] is the string for token ID i. | |
| Special tokens are not included in this list. | |
| """ | |
| vocab = [] | |
| for token_id in range(self.enc.n_vocab): | |
| token_bytes = self.enc.decode_single_token_bytes(token_id) | |
| token_str = token_bytes.decode('utf-8', errors='replace') | |
| vocab.append(token_str) |
Copilot
AI
Nov 15, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The ByteTokenizer.get_vocabulary() method returns characters using chr(i) for byte values 0-255. However, many byte values in this range (0-31, 127-159) are control characters that don't have printable representations. Some values like 0-31 and 127 are non-printable ASCII control characters, and attempting to use chr() on values 128-255 may produce unexpected Unicode characters.
This could lead to issues when the vocabulary is serialized to JSON. Consider using a more appropriate representation for non-printable bytes, such as their hexadecimal or escaped form:
def get_vocabulary(self):
vocab = []
for i in range(256):
if 32 <= i < 127: # Printable ASCII range
vocab.append(chr(i))
else:
vocab.append(f"\\x{i:02x}")
return vocab| return [chr(i) for i in range(256)] | |
| vocab = [] | |
| for i in range(256): | |
| if 32 <= i < 127: # Printable ASCII range | |
| vocab.append(chr(i)) | |
| else: | |
| vocab.append(f"\\x{i:02x}") | |
| return vocab |
Copilot
AI
Nov 15, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The vocabulary is built by iterating over self.itos.values(), which returns dictionary values in arbitrary order (though insertion order is preserved in Python 3.7+). Since itos is a mapping from token IDs to tokens, consider iterating over sorted token IDs to ensure a consistent and predictable vocabulary order:
def get_vocabulary(self):
vocab = []
for token_id in sorted(self.itos.keys()):
token = self.itos[token_id]
if isinstance(token, bytes):
vocab.append(token.decode('utf-8', errors='replace'))
else:
vocab.append(token)
return vocabThis ensures that the vocabulary reflects the actual token ID ordering, which is more useful for debugging and understanding the tokenizer.
Copilot
AI
Nov 15, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The vocabulary is built by iterating over self.itos.values(), which returns dictionary values in arbitrary order (though insertion order is preserved in Python 3.7+). Since itos is a mapping from token IDs to tokens, consider iterating over sorted token IDs to ensure a consistent and predictable vocabulary order:
def get_vocabulary(self):
vocab = []
for token_id in sorted(self.itos.keys()):
token = self.itos[token_id]
if isinstance(token, bytes):
vocab.append(token.decode('utf-8', errors='replace'))
else:
vocab.append(token)
return vocabThis ensures that the vocabulary reflects the actual token ID ordering, which is more useful for debugging and understanding the tokenizer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The command-line argument uses a hyphen (
--extract-vocab) which is inconsistent with Python conventions. The argument will be accessible asargs.extract_vocab(with underscore) due to argparse's automatic conversion, but for consistency with other arguments in the file (e.g.,--track_token_counts), consider using underscores in the flag name:--extract_vocab.