Skip to content

Commit 145a88e

Browse files
committed
✨ feat(rust): Convert project to a multi-crate workspace
This commit restructures the project from a single-crate workspace into a multi-crate workspace, dividing it into 'rs-tiktoken' and 'py-tiktoken'. This is done to improve the clarity of the organization of the codebase and make the Rust and Python modules separate for easier code maintenance. The setup.py is also updated to reflect these changes in the directory structure. Refs: openai#24
1 parent f28ce4c commit 145a88e

File tree

14 files changed

+50608
-93
lines changed

14 files changed

+50608
-93
lines changed

.github/workflows/build_wheels.yml

+34-2
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ jobs:
3030
name: dist
3131
path: ./wheelhouse/*.whl
3232

33-
build_wheels_aarch64:
34-
name: py${{ matrix.python-version }} on ${{ matrix.os }} (aarch64)
33+
build_wheels_aarch64_glibc:
34+
name: py${{ matrix.python-version }} on ${{ matrix.os }} (aarch64/glibc)
3535
runs-on: ${{ matrix.os }}
3636
strategy:
3737
fail-fast: false
@@ -52,6 +52,38 @@ jobs:
5252
env:
5353
CIBW_BUILD: "cp${{ matrix.python-version}}-*"
5454
CIBW_ARCHS: aarch64
55+
CIBW_SKIP: "*musllinux*"
56+
CIBW_BUILD_VERBOSITY: 3
57+
# https://github.com/rust-lang/cargo/issues/10583
58+
CIBW_ENVIRONMENT_LINUX: PATH="$PATH:$HOME/.cargo/bin" CARGO_NET_GIT_FETCH_WITH_CLI=true
59+
- uses: actions/upload-artifact@v3
60+
with:
61+
name: dist
62+
path: ./wheelhouse/*.whl
63+
64+
build_wheels_aarch64_musl:
65+
name: py${{ matrix.python-version }} on ${{ matrix.os }} (aarch64/musl)
66+
runs-on: ${{ matrix.os }}
67+
strategy:
68+
fail-fast: false
69+
matrix:
70+
os: [ubuntu-latest]
71+
python-version: [38, 39, 310, 311]
72+
73+
steps:
74+
- uses: actions/checkout@v3
75+
76+
- name: Setup up QEMU
77+
uses: docker/setup-qemu-action@v2
78+
with:
79+
platforms: arm64
80+
81+
- name: Build wheels
82+
uses: pypa/[email protected]
83+
env:
84+
CIBW_BUILD: "cp${{ matrix.python-version}}-*"
85+
CIBW_ARCHS: aarch64
86+
CIBW_SKIP: "*manylinux*"
5587
CIBW_BUILD_VERBOSITY: 3
5688
# https://github.com/rust-lang/cargo/issues/10583
5789
CIBW_ENVIRONMENT_LINUX: PATH="$PATH:$HOME/.cargo/bin" CARGO_NET_GIT_FETCH_WITH_CLI=true

Cargo.toml

+8-17
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,12 @@
1-
[package]
2-
name = "tiktoken"
1+
[workspace]
2+
resolver = "2"
3+
members = [
4+
"rs-tiktoken",
5+
"py-tiktoken",
6+
]
7+
8+
[workspace.package]
39
version = "0.4.0"
4-
edition = "2021"
5-
rust-version = "1.57.0"
6-
7-
[lib]
8-
name = "_tiktoken"
9-
crate-type = ["cdylib"]
10-
11-
[dependencies]
12-
pyo3 = { version = "0.19.0", features = ["extension-module"] }
13-
14-
# tiktoken dependencies
15-
fancy-regex = "0.11.0"
16-
regex = "1.8.3"
17-
rustc-hash = "1.1.0"
18-
bstr = "1.5.0"
1910

2011
[profile.release]
2112
incremental = true

MANIFEST.in

+5-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@ include *.svg
22
include *.toml
33
include *.md
44
include Makefile
5+
include py-tiktoken/*.toml
6+
include rs-tiktoken/*.toml
7+
include rs-tiktoken/tests/gpt2_encoder
58
global-include py.typed
69
recursive-include scripts *.py
710
recursive-include tests *.py
8-
recursive-include src *.rs
11+
recursive-include py-tiktoken *.rs
12+
recursive-include rs-tiktoken *.rs

py-tiktoken/Cargo.toml

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
[package]
2+
name = "py-tiktoken"
3+
version.workspace = true
4+
edition = "2021"
5+
rust-version = "1.57.0"
6+
7+
[lib]
8+
name = "_tiktoken"
9+
crate-type = ["cdylib"]
10+
11+
[dependencies]
12+
tiktoken = { path = "../rs-tiktoken" }
13+
pyo3 = { version = "0.19.0", features = ["extension-module"] }
14+
once_cell = "1.18.0"
15+
16+
# tiktoken dependencies
17+
fancy-regex = "0.11.0"
18+
regex = "1.8.3"
19+
rustc-hash = "1.1.0"
20+
bstr = "1.5.0"

py-tiktoken/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
pub mod tiktoken_py;

src/tiktoken_py.rs renamed to py-tiktoken/src/tiktoken_py.rs

+7-68
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,13 @@
33

44
use std::collections::HashSet;
55

6-
use fancy_regex::Regex;
76
use pyo3::exceptions;
87
use pyo3::prelude::*;
98
use pyo3::PyResult;
109
use pyo3::types::{PyBytes, PyList, PyTuple};
1110
use rustc_hash::FxHashMap as HashMap;
1211

13-
use crate::tiktoken::{byte_pair_encode, CoreBPE, MAX_NUM_THREADS};
12+
use tiktoken::core::{byte_pair_encode, CoreBPE};
1413

1514
#[pyclass]
1615
pub struct PyCoreBPE {
@@ -26,47 +25,10 @@ impl PyCoreBPE {
2625
special_tokens_encoder: HashMap<String, usize>,
2726
pattern: &str,
2827
) -> PyResult<Self> {
29-
let regex = Regex::new(pattern)
30-
.map_err(|e| PyErr::new::<exceptions::PyValueError, _>(e.to_string()))?;
31-
32-
let special_regex = {
33-
let _parts = special_tokens_encoder
34-
.keys()
35-
.map(|s| fancy_regex::escape(s))
36-
.collect::<Vec<_>>();
37-
Regex::new(&_parts.join("|"))
38-
.map_err(|e| PyErr::new::<exceptions::PyValueError, _>(e.to_string()))?
39-
};
40-
41-
let decoder: HashMap<usize, Vec<u8>> =
42-
encoder.iter().map(|(k, v)| (*v, k.clone())).collect();
43-
44-
assert!(
45-
encoder.len() == decoder.len(),
46-
"Encoder and decoder must be of equal length; maybe you had duplicate token indices in your encoder?"
47-
);
48-
49-
let special_tokens_decoder: HashMap<usize, Vec<u8>> = special_tokens_encoder
50-
.iter()
51-
.map(|(k, v)| (*v, k.as_bytes().to_vec()))
52-
.collect();
53-
54-
// Clone because I don't know how to tell Rust I'm not going to change the map
55-
let mut sorted_token_bytes: Vec<Vec<u8>> = encoder.keys().cloned().collect();
56-
sorted_token_bytes.sort();
57-
58-
let core_bpe = CoreBPE {
59-
encoder,
60-
special_tokens_encoder,
61-
decoder,
62-
special_tokens_decoder,
63-
regex_tls: (0..MAX_NUM_THREADS).map(|_| regex.clone()).collect(),
64-
special_regex_tls: (0..MAX_NUM_THREADS)
65-
.map(|_| special_regex.clone())
66-
.collect(),
67-
sorted_token_bytes,
68-
};
69-
Ok(PyCoreBPE { core_bpe })
28+
println!("encoder: {:?}", encoder);
29+
CoreBPE::new(encoder, special_tokens_encoder, pattern)
30+
.map(|core_bpe| PyCoreBPE { core_bpe })
31+
.map_err(|e| PyErr::new::<exceptions::PyValueError, _>(e.to_string()))
7032
}
7133

7234
// ====================
@@ -82,30 +44,7 @@ impl PyCoreBPE {
8244
}
8345

8446
fn _encode_bytes(&self, py: Python, bytes: &[u8]) -> Vec<usize> {
85-
py.allow_threads(|| {
86-
match std::str::from_utf8(bytes) {
87-
Ok(text) => self.core_bpe._encode_ordinary_native(text),
88-
Err(e) => {
89-
let text = unsafe { std::str::from_utf8_unchecked(&bytes[..e.valid_up_to()]) };
90-
let (tokens, last_piece_token_len) = self.core_bpe._encode_native(text, &HashSet::new());
91-
let (mut tokens, last_piece_token_len) =
92-
self.core_bpe._increase_last_piece_token_len(tokens, last_piece_token_len);
93-
if !tokens.is_empty() && last_piece_token_len > 0 {
94-
// Lop off the tokens from the last piece and run BPE on the remaining bytes
95-
// Somewhat niche, but this may not be correct if we'd have had a regex
96-
// split between the valid UTF-8 and the invalid bytes, which is why this
97-
// method is private
98-
let mut unstable_bytes =
99-
self.core_bpe._decode_native(&tokens[tokens.len() - last_piece_token_len..]);
100-
unstable_bytes.extend_from_slice(&bytes[e.valid_up_to()..]);
101-
102-
tokens.truncate(tokens.len() - last_piece_token_len);
103-
tokens.extend(byte_pair_encode(&unstable_bytes, &self.core_bpe.encoder));
104-
}
105-
tokens
106-
}
107-
}
108-
})
47+
py.allow_threads(|| self.core_bpe._encode_bytes(bytes))
10948
}
11049

11150
fn encode_with_unstable(
@@ -181,7 +120,7 @@ pub fn _tiktoken(_py: Python, m: &PyModule) -> PyResult<()> {
181120
mod tests {
182121
use rustc_hash::FxHashMap as HashMap;
183122

184-
use crate::tiktoken::byte_pair_split;
123+
use tiktoken::core::byte_pair_split;
185124

186125
#[test]
187126
fn very_simple_test() {

rs-tiktoken/Cargo.toml

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
[package]
2+
name = "tiktoken"
3+
version.workspace = true
4+
edition = "2021"
5+
rust-version = "1.57.0"
6+
7+
[dependencies]
8+
fancy-regex = "0.11.0"
9+
regex = "1.8.3"
10+
rustc-hash = "1.1.0"
11+
bstr = "1.5.0"
12+
once_cell = "1.18.0"
13+
parse-display = "0.8.2"

0 commit comments

Comments
 (0)