3
3
4
4
use std:: collections:: HashSet ;
5
5
6
- use fancy_regex:: Regex ;
7
6
use pyo3:: exceptions;
8
7
use pyo3:: prelude:: * ;
9
8
use pyo3:: PyResult ;
10
9
use pyo3:: types:: { PyBytes , PyList , PyTuple } ;
11
10
use rustc_hash:: FxHashMap as HashMap ;
12
11
13
- use crate :: tiktoken :: { byte_pair_encode, CoreBPE , MAX_NUM_THREADS } ;
12
+ use tiktoken :: core :: { byte_pair_encode, CoreBPE } ;
14
13
15
14
#[ pyclass]
16
15
pub struct PyCoreBPE {
@@ -26,47 +25,10 @@ impl PyCoreBPE {
26
25
special_tokens_encoder : HashMap < String , usize > ,
27
26
pattern : & str ,
28
27
) -> PyResult < Self > {
29
- let regex = Regex :: new ( pattern)
30
- . map_err ( |e| PyErr :: new :: < exceptions:: PyValueError , _ > ( e. to_string ( ) ) ) ?;
31
-
32
- let special_regex = {
33
- let _parts = special_tokens_encoder
34
- . keys ( )
35
- . map ( |s| fancy_regex:: escape ( s) )
36
- . collect :: < Vec < _ > > ( ) ;
37
- Regex :: new ( & _parts. join ( "|" ) )
38
- . map_err ( |e| PyErr :: new :: < exceptions:: PyValueError , _ > ( e. to_string ( ) ) ) ?
39
- } ;
40
-
41
- let decoder: HashMap < usize , Vec < u8 > > =
42
- encoder. iter ( ) . map ( |( k, v) | ( * v, k. clone ( ) ) ) . collect ( ) ;
43
-
44
- assert ! (
45
- encoder. len( ) == decoder. len( ) ,
46
- "Encoder and decoder must be of equal length; maybe you had duplicate token indices in your encoder?"
47
- ) ;
48
-
49
- let special_tokens_decoder: HashMap < usize , Vec < u8 > > = special_tokens_encoder
50
- . iter ( )
51
- . map ( |( k, v) | ( * v, k. as_bytes ( ) . to_vec ( ) ) )
52
- . collect ( ) ;
53
-
54
- // Clone because I don't know how to tell Rust I'm not going to change the map
55
- let mut sorted_token_bytes: Vec < Vec < u8 > > = encoder. keys ( ) . cloned ( ) . collect ( ) ;
56
- sorted_token_bytes. sort ( ) ;
57
-
58
- let core_bpe = CoreBPE {
59
- encoder,
60
- special_tokens_encoder,
61
- decoder,
62
- special_tokens_decoder,
63
- regex_tls : ( 0 ..MAX_NUM_THREADS ) . map ( |_| regex. clone ( ) ) . collect ( ) ,
64
- special_regex_tls : ( 0 ..MAX_NUM_THREADS )
65
- . map ( |_| special_regex. clone ( ) )
66
- . collect ( ) ,
67
- sorted_token_bytes,
68
- } ;
69
- Ok ( PyCoreBPE { core_bpe } )
28
+ println ! ( "encoder: {:?}" , encoder) ;
29
+ CoreBPE :: new ( encoder, special_tokens_encoder, pattern)
30
+ . map ( |core_bpe| PyCoreBPE { core_bpe } )
31
+ . map_err ( |e| PyErr :: new :: < exceptions:: PyValueError , _ > ( e. to_string ( ) ) )
70
32
}
71
33
72
34
// ====================
@@ -82,30 +44,7 @@ impl PyCoreBPE {
82
44
}
83
45
84
46
fn _encode_bytes ( & self , py : Python , bytes : & [ u8 ] ) -> Vec < usize > {
85
- py. allow_threads ( || {
86
- match std:: str:: from_utf8 ( bytes) {
87
- Ok ( text) => self . core_bpe . _encode_ordinary_native ( text) ,
88
- Err ( e) => {
89
- let text = unsafe { std:: str:: from_utf8_unchecked ( & bytes[ ..e. valid_up_to ( ) ] ) } ;
90
- let ( tokens, last_piece_token_len) = self . core_bpe . _encode_native ( text, & HashSet :: new ( ) ) ;
91
- let ( mut tokens, last_piece_token_len) =
92
- self . core_bpe . _increase_last_piece_token_len ( tokens, last_piece_token_len) ;
93
- if !tokens. is_empty ( ) && last_piece_token_len > 0 {
94
- // Lop off the tokens from the last piece and run BPE on the remaining bytes
95
- // Somewhat niche, but this may not be correct if we'd have had a regex
96
- // split between the valid UTF-8 and the invalid bytes, which is why this
97
- // method is private
98
- let mut unstable_bytes =
99
- self . core_bpe . _decode_native ( & tokens[ tokens. len ( ) - last_piece_token_len..] ) ;
100
- unstable_bytes. extend_from_slice ( & bytes[ e. valid_up_to ( ) ..] ) ;
101
-
102
- tokens. truncate ( tokens. len ( ) - last_piece_token_len) ;
103
- tokens. extend ( byte_pair_encode ( & unstable_bytes, & self . core_bpe . encoder ) ) ;
104
- }
105
- tokens
106
- }
107
- }
108
- } )
47
+ py. allow_threads ( || self . core_bpe . _encode_bytes ( bytes) )
109
48
}
110
49
111
50
fn encode_with_unstable (
@@ -181,7 +120,7 @@ pub fn _tiktoken(_py: Python, m: &PyModule) -> PyResult<()> {
181
120
mod tests {
182
121
use rustc_hash:: FxHashMap as HashMap ;
183
122
184
- use crate :: tiktoken :: byte_pair_split;
123
+ use tiktoken :: core :: byte_pair_split;
185
124
186
125
#[ test]
187
126
fn very_simple_test ( ) {
0 commit comments