diff --git a/Cargo.toml b/Cargo.toml index 4dbfef9..60f3d10 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,3 +25,8 @@ fancy-regex = "0.13.0" regex = "1.10.3" rustc-hash = "1.1.0" bstr = "1.5.0" +# cli dependencies +tui-textarea = "0.7.0" +crossterm = "0.28.0" +ratatui = "0.29.0" + diff --git a/MANIFEST.in b/MANIFEST.in index 7f25b27..e74e83d 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,3 +1,4 @@ +include *.png include *.svg include *.toml include *.md diff --git a/README.md b/README.md index 4f36c53..39c9433 100644 --- a/README.md +++ b/README.md @@ -129,3 +129,14 @@ setup( Then simply `pip install ./my_tiktoken_extension` and you should be able to use your custom encodings! Make sure **not** to use an editable install. + +## Tiktoken tokenizer environment + +Test your tokenizer through a terminal-based environment that allows you to visualize tokenized data points. This tool helps you better grasp model information by providing immediate feedback on how input text is being tokenized. You can see token types, and their positions in the input text, making it easier to understand and debug your tokenizer. + +```python +import tiktoken +enc = tiktoken.get_encoding("gpt2") +enc.environment() +``` +![image](/environment.png) diff --git a/environment.png b/environment.png new file mode 100644 index 0000000..899374a Binary files /dev/null and b/environment.png differ diff --git a/src/lib.rs b/src/lib.rs index a64c3de..1d9a04e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,15 +4,36 @@ use std::collections::HashSet; use std::num::NonZeroU64; use std::thread; +use bstr::ByteSlice; use fancy_regex::Regex; #[cfg(feature = "python")] use pyo3::prelude::*; +use pyo3::pyclass; +use pyo3::PyResult; +use pyo3::types::{PyBytes, PyList, PyTuple}; + use rustc_hash::FxHashMap as HashMap; +// so many imports :D +use crossterm::terminal::{ + disable_raw_mode, enable_raw_mode, EnterAlternateScreen, LeaveAlternateScreen, +}; +use ratatui::prelude::{Span, Constraint}; +use ratatui::Terminal; +use ratatui::backend::CrosstermBackend; +use ratatui::layout::{Layout, Direction, Margin}; +use ratatui::text::Line; +use ratatui::widgets::{ + block::Title, Block, Padding, Borders, Paragraph, Scrollbar, ScrollbarOrientation, + ScrollbarState, Wrap, +}; +use ratatui::style::{Color, Style}; +use tui_textarea::TextArea; + #[cfg(feature = "python")] mod py; -pub type Rank = u32; +type Rank = u32; fn _byte_pair_merge(ranks: &HashMap, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> { // This is a vector of (start, rank). @@ -513,12 +534,160 @@ impl CoreBPE { .collect() } - pub fn encode_with_special_tokens(&self, text: &str) -> Vec { - let allowed_special = self.special_tokens(); - self.encode(text, &allowed_special).0 + fn _environment(&self, name : &str, allowed_special: HashSet<&str>) -> PyResult<()> { + let stdout = std::io::stdout(); + let mut stdout = stdout.lock(); + + enable_raw_mode()?; + crossterm::execute!(stdout, EnterAlternateScreen)?; + let backend = CrosstermBackend::new(stdout); + let mut term = Terminal::new(backend)?; + + let mut textarea = TextArea::default(); + textarea.set_block( + Block::default() + .borders(Borders::ALL) + .title(format!("{} Encoder",name)).padding(Padding::new(1, 1, 1, 0)) + ); + + let colours = vec![Color::Red, Color::Green, Color::Blue, Color::Yellow, Color::Magenta, Color::Cyan]; + + let parent_layout = Layout::default() + .constraints([Constraint::Percentage(100), Constraint::Min(1)]); + + let layout = Layout::default() + .direction(Direction::Horizontal) + .constraints([Constraint::Percentage(50), Constraint::Percentage(50)].as_ref()); + + + let mut vertical_scroll = 0; + let mut line_count : usize = 0; + loop { + let mut current_color_index = 0; + term.draw(|f| { + let chunks = parent_layout.split(f.size()); + + let sub_chunk = layout.split(chunks[0]); + + let tokens: Vec> = textarea.lines().iter().map(|line| { + + // Encode the line + let encoding = self.encode(line, &allowed_special).0; + + // Decode the encoded line + let decoding: Vec> = encoding.iter() + .map(|&token| self.decode_bytes(&[token]).unwrap()) + .collect(); + + // Convert decoded tokens to Strings + let tokens: Vec = decoding.iter().map(|bytes| { + bytes.to_str() + .unwrap() + .to_string() + + }).collect(); + + tokens + // Create spans for the line + + }).collect(); + + let mut lines : Vec = Vec::new(); + let mut token_count = 0; + for token in &tokens{ + let span : Vec = token.iter().map(|token| { + let color = colours[current_color_index]; + current_color_index = (current_color_index + 1) % colours.len(); + token_count += 1; // Increment the token count + Span::styled(token, Style::default().bg(color).fg(Color::White)) + }).collect(); + lines.push(Line::from(span)); + + } + + + let scrollbar = Scrollbar::new(ScrollbarOrientation::VerticalRight) + .begin_symbol(Some("↑")) + .end_symbol(Some("↓")); + + + let paragraph = Paragraph::new(lines.clone()) + .block(Block::default().borders(Borders::ALL) + .title("Decoded Tokens") + .title(Title::from(Line::from(vec![ + Span::styled(token_count.to_string(), + Style::new().fg(Color::Green)), + Span::from(" token(s)")])) + .alignment(ratatui::layout::Alignment::Center) + .position(ratatui::widgets::block::Position::Bottom)) + .padding(Padding::new(1, 1, 1, 1))) + .scroll((vertical_scroll as u16, 0)) + .wrap(Wrap { trim: true }); + + let menu: Block<'_> = Block::new() + .title(Title::from("[Esc] Exit").alignment(ratatui::layout::Alignment::Left)) + .title(Title::from("[Ctrl+S] Scroll Down").alignment(ratatui::layout::Alignment::Center)) + .title(Title::from("[Ctrl+A] Scroll Up").alignment(ratatui::layout::Alignment::Center)) + .padding(Padding::horizontal(5u16)) + .border_style(Style::default().fg(Color::White)) + .borders(Borders::TOP); + + line_count = lines.len(); + let mut scrollbar_state = ScrollbarState::new(line_count) + .position(vertical_scroll); + + f.render_widget(menu, chunks[1]); + f.render_widget(textarea.widget(), sub_chunk[0]); + f.render_widget(paragraph, sub_chunk[1]); + + f.render_stateful_widget( + scrollbar, + sub_chunk[1].inner(Margin { + // using an inner vertical margin of 1 unit makes the scrollbar inside the block + vertical: 1, + horizontal: 0, + }), + &mut scrollbar_state, + ); + + })?; + + match crossterm::event::read()? { + crossterm::event::Event::Key(key) => { + match key.code { + crossterm::event::KeyCode::Esc => break, + crossterm::event::KeyCode::Char('s') if key.modifiers.contains(crossterm::event::KeyModifiers::CONTROL) => { + if vertical_scroll < line_count - 1 { + vertical_scroll += 1; + } + }, + crossterm::event::KeyCode::Char('a') if key.modifiers.contains(crossterm::event::KeyModifiers::CONTROL) => { + if vertical_scroll > 0 { + vertical_scroll -= 1; + } + }, + _ => { + textarea.input(key); + } + } + }, + _ => {} + } + } + + disable_raw_mode()?; + crossterm::execute!( + term.backend_mut(), + LeaveAlternateScreen, + )?; + term.show_cursor()?; + + Ok(()) + } } + #[cfg(test)] mod tests { use fancy_regex::Regex; diff --git a/src/py.rs b/src/py.rs index 8485462..e01fc8a 100644 --- a/src/py.rs +++ b/src/py.rs @@ -170,7 +170,21 @@ impl CoreBPE { .map(|x| PyBytes::new_bound(py, x).into()) .collect() } -} + + // ==================== + // TUI Environment + // ==================== + + #[pyo3(name = "_environment")] + fn py_environment(&self, _py : Python, name : &str, allowed_special: HashSet) -> PyResult<()>{ + // Convert PyBackedStr to &str + let allowed_special: HashSet<&str> = + allowed_special.iter().map(|s| s.as_ref()).collect(); + + self._environment(name, allowed_special) + } + +} #[pyclass] struct TiktokenBuffer { diff --git a/tiktoken/core.py b/tiktoken/core.py index 6bc9736..464381f 100644 --- a/tiktoken/core.py +++ b/tiktoken/core.py @@ -119,6 +119,10 @@ def encode( disallowed_special = frozenset(disallowed_special) if match := _special_token_regex(disallowed_special).search(text): raise_disallowed_special_token(match.group()) + + # https://github.com/PyO3/pyo3/pull/3632 + if isinstance(allowed_special, frozenset): + allowed_special = set(allowed_special) try: return self._core_bpe.encode(text, allowed_special) @@ -371,6 +375,18 @@ def n_vocab(self) -> int: """For backwards compatibility. Prefer to use `enc.max_token_value + 1`.""" return self.max_token_value + 1 + def environment(self, + *, + allowed_special: Literal["all"] | AbstractSet[str] = set(), # noqa: B006 + disallowed_special: Literal["all"] | Collection[str] = "all",) -> None: + """Builds a Text User Interface (TUI) environment to test out encoding.""" + + if allowed_special == "all": + allowed_special = self.special_tokens_set + if disallowed_special == "all": + disallowed_special = self.special_tokens_set - allowed_special + + return self._core_bpe._environment(self.name, allowed_special) # ==================== # Private # ====================