33
44use  std:: collections:: HashSet ; 
55
6- use  fancy_regex:: Regex ; 
76use  pyo3:: exceptions; 
87use  pyo3:: prelude:: * ; 
98use  pyo3:: PyResult ; 
109use  pyo3:: types:: { PyBytes ,  PyList ,  PyTuple } ; 
1110use  rustc_hash:: FxHashMap  as  HashMap ; 
1211
13- use  crate :: tiktoken :: { byte_pair_encode,  CoreBPE ,   MAX_NUM_THREADS } ; 
12+ use  tiktoken :: core :: { byte_pair_encode,  CoreBPE } ; 
1413
1514#[ pyclass]  
1615pub  struct  PyCoreBPE  { 
@@ -26,47 +25,10 @@ impl PyCoreBPE {
2625        special_tokens_encoder :  HashMap < String ,  usize > , 
2726        pattern :  & str , 
2827    )  -> PyResult < Self >  { 
29-         let  regex = Regex :: new ( pattern) 
30-             . map_err ( |e| PyErr :: new :: < exceptions:: PyValueError ,  _ > ( e. to_string ( ) ) ) ?; 
31- 
32-         let  special_regex = { 
33-             let  _parts = special_tokens_encoder
34-                 . keys ( ) 
35-                 . map ( |s| fancy_regex:: escape ( s) ) 
36-                 . collect :: < Vec < _ > > ( ) ; 
37-             Regex :: new ( & _parts. join ( "|" ) ) 
38-                 . map_err ( |e| PyErr :: new :: < exceptions:: PyValueError ,  _ > ( e. to_string ( ) ) ) ?
39-         } ; 
40- 
41-         let  decoder:  HashMap < usize ,  Vec < u8 > >  =
42-             encoder. iter ( ) . map ( |( k,  v) | ( * v,  k. clone ( ) ) ) . collect ( ) ; 
43- 
44-         assert ! ( 
45-             encoder. len( )  == decoder. len( ) , 
46-             "Encoder and decoder must be of equal length; maybe you had duplicate token indices in your encoder?" 
47-         ) ; 
48- 
49-         let  special_tokens_decoder:  HashMap < usize ,  Vec < u8 > >  = special_tokens_encoder
50-             . iter ( ) 
51-             . map ( |( k,  v) | ( * v,  k. as_bytes ( ) . to_vec ( ) ) ) 
52-             . collect ( ) ; 
53- 
54-         // Clone because I don't know how to tell Rust I'm not going to change the map 
55-         let  mut  sorted_token_bytes:  Vec < Vec < u8 > >  = encoder. keys ( ) . cloned ( ) . collect ( ) ; 
56-         sorted_token_bytes. sort ( ) ; 
57- 
58-         let  core_bpe = CoreBPE  { 
59-             encoder, 
60-             special_tokens_encoder, 
61-             decoder, 
62-             special_tokens_decoder, 
63-             regex_tls :  ( 0 ..MAX_NUM_THREADS ) . map ( |_| regex. clone ( ) ) . collect ( ) , 
64-             special_regex_tls :  ( 0 ..MAX_NUM_THREADS ) 
65-                 . map ( |_| special_regex. clone ( ) ) 
66-                 . collect ( ) , 
67-             sorted_token_bytes, 
68-         } ; 
69-         Ok ( PyCoreBPE  {  core_bpe } ) 
28+         println ! ( "encoder: {:?}" ,  encoder) ; 
29+         CoreBPE :: new ( encoder,  special_tokens_encoder,  pattern) 
30+             . map ( |core_bpe| PyCoreBPE  {  core_bpe } ) 
31+             . map_err ( |e| PyErr :: new :: < exceptions:: PyValueError ,  _ > ( e. to_string ( ) ) ) 
7032    } 
7133
7234    // ==================== 
@@ -82,30 +44,7 @@ impl PyCoreBPE {
8244    } 
8345
8446    fn  _encode_bytes ( & self ,  py :  Python ,  bytes :  & [ u8 ] )  -> Vec < usize >  { 
85-         py. allow_threads ( || { 
86-             match  std:: str:: from_utf8 ( bytes)  { 
87-                 Ok ( text)  => self . core_bpe . _encode_ordinary_native ( text) , 
88-                 Err ( e)  => { 
89-                     let  text = unsafe  {  std:: str:: from_utf8_unchecked ( & bytes[ ..e. valid_up_to ( ) ] )  } ; 
90-                     let  ( tokens,  last_piece_token_len)  = self . core_bpe . _encode_native ( text,  & HashSet :: new ( ) ) ; 
91-                     let  ( mut  tokens,  last_piece_token_len)  =
92-                         self . core_bpe . _increase_last_piece_token_len ( tokens,  last_piece_token_len) ; 
93-                     if  !tokens. is_empty ( )  && last_piece_token_len > 0  { 
94-                         // Lop off the tokens from the last piece and run BPE on the remaining bytes 
95-                         // Somewhat niche, but this may not be correct if we'd have had a regex 
96-                         // split between the valid UTF-8 and the invalid bytes, which is why this 
97-                         // method is private 
98-                         let  mut  unstable_bytes =
99-                             self . core_bpe . _decode_native ( & tokens[ tokens. len ( )  - last_piece_token_len..] ) ; 
100-                         unstable_bytes. extend_from_slice ( & bytes[ e. valid_up_to ( ) ..] ) ; 
101- 
102-                         tokens. truncate ( tokens. len ( )  - last_piece_token_len) ; 
103-                         tokens. extend ( byte_pair_encode ( & unstable_bytes,  & self . core_bpe . encoder ) ) ; 
104-                     } 
105-                     tokens
106-                 } 
107-             } 
108-         } ) 
47+         py. allow_threads ( || self . core_bpe . _encode_bytes ( bytes) ) 
10948    } 
11049
11150    fn  encode_with_unstable ( 
@@ -181,7 +120,7 @@ pub fn _tiktoken(_py: Python, m: &PyModule) -> PyResult<()> {
181120mod  tests { 
182121    use  rustc_hash:: FxHashMap  as  HashMap ; 
183122
184-     use  crate :: tiktoken :: byte_pair_split; 
123+     use  tiktoken :: core :: byte_pair_split; 
185124
186125    #[ test]  
187126    fn  very_simple_test ( )  { 
0 commit comments