diff --git a/data/template/prepare.py b/data/template/prepare.py index e4b5e10155..e6946107b2 100644 --- a/data/template/prepare.py +++ b/data/template/prepare.py @@ -64,6 +64,7 @@ def parse_arguments(): # Additional options parser.add_argument("-T", "--track_token_counts", action="store_true", help="Track how often each token appears and store in meta.pkl") + parser.add_argument("-E", "--extract-vocab", action="store_true", help="Export the tokenizer vocabulary to a JSON file") return parser.parse_args() @@ -152,6 +153,20 @@ def main(): if val_ids is not None: save_tokens(val_ids, args.val_output, dtype) + if args.extract_vocab: + if not hasattr(tokenizer, "get_vocabulary"): + raise ValueError(f"Vocabulary extraction is not supported for tokenizer method '{args.method}'.") + + vocab = tokenizer.get_vocabulary() + # Ensure string representations for all tokens + vocab_strings = [token if isinstance(token, str) else str(token) for token in vocab] + vocab_strings.sort(key=lambda token: (-len(token), token)) + + vocab_filename = f"{args.method}_vocab.json" + with open(vocab_filename, 'w', encoding='utf-8') as f: + json.dump(vocab_strings, f, ensure_ascii=False, indent=2) + print(f"Saved vocabulary to {vocab_filename}") + if args.method == "sinewave": meta = { "tokenizer": "sinewave", diff --git a/data/template/tokenizers.py b/data/template/tokenizers.py index d80ad10e70..d7b629d091 100644 --- a/data/template/tokenizers.py +++ b/data/template/tokenizers.py @@ -35,6 +35,10 @@ def finalize_meta(self, meta): meta["token_counts"] = dict(self.token_counts) self.save_meta(meta) + def get_vocabulary(self): + """Return the list of string representations that make up the tokenizer's vocabulary.""" + raise NotImplementedError("Vocabulary extraction is not implemented for this tokenizer.") + @staticmethod def get_key_from_meta(keyname): meta_path = 'meta.pkl' @@ -117,6 +121,11 @@ def detokenize(self, ids): raise ValueError("SentencePiece model is not loaded.") return self.sp.decode_ids(ids) + def get_vocabulary(self): + if not self.sp: + raise ValueError("SentencePiece model is not loaded.") + return [self.sp.id_to_piece(i) for i in range(self.sp.GetPieceSize())] + class TiktokenTokenizer(Tokenizer): def __init__(self, args): super().__init__(args) @@ -218,6 +227,29 @@ def detokenize(self, token_ids): return ''.join(result) + def get_vocabulary(self): + vocab = [] + seen = set() + for token_id in range(self.enc.n_vocab): + token_bytes = self.enc.decode_single_token_bytes(token_id) + token_str = token_bytes.decode('utf-8', errors='replace') + if token_str not in seen: + seen.add(token_str) + vocab.append(token_str) + + # Include known special tokens (base and additional) + special_tokens = {} + if hasattr(self.enc, "_special_tokens") and isinstance(self.enc._special_tokens, dict): + special_tokens.update(self.enc._special_tokens) + special_tokens.update(self.special_tokens) + + for token in special_tokens.keys(): + if token not in seen: + seen.add(token) + vocab.append(token) + + return vocab + class CustomTokenizer(Tokenizer): def __init__(self, args): @@ -261,6 +293,9 @@ def tokenize(self, data): def detokenize(self, ids): return ''.join([self.itos[id] for id in ids]) + def get_vocabulary(self): + return list(self.tokens) + class ByteTokenizer(Tokenizer): def __init__(self, args): super().__init__(args) @@ -281,6 +316,9 @@ def tokenize(self, data): def detokenize(self, ids): return bytes(ids).decode('utf-8', errors='replace') + def get_vocabulary(self): + return [chr(i) for i in range(256)] + class CharTokenizer(Tokenizer): def __init__(self, args, train_data, val_data): @@ -315,6 +353,9 @@ def tokenize(self, data): def detokenize(self, ids): return ''.join([self.itos[id] for id in ids]) + def get_vocabulary(self): + return list(self.chars) + class CustomCharTokenizerWithByteFallback(Tokenizer): """ In this version, we assign IDs 0..255 to raw bytes, @@ -445,6 +486,15 @@ def detokenize(self, ids): return ''.join(out_pieces) + def get_vocabulary(self): + vocab = [] + for token in self.itos.values(): + if isinstance(token, bytes): + vocab.append(token.decode('utf-8', errors='replace')) + else: + vocab.append(token) + return vocab + class JsonByteTokenizerWithByteFallback(Tokenizer): """ Similar to CustomCharTokenizerWithByteFallback, but loads tokens from a JSON array. @@ -577,6 +627,15 @@ def detokenize(self, ids): return ''.join(out_pieces) + def get_vocabulary(self): + vocab = [] + for token in self.itos.values(): + if isinstance(token, bytes): + vocab.append(token.decode('utf-8', errors='replace')) + else: + vocab.append(token) + return vocab + class SineWaveTokenizer: """Generate a deterministic sequence of sine wave samples."""