11use bumpalo:: { collections:: Vec as BumpVec , Bump } ;
2- use hashbrown :: { HashMap , HashSet } ;
2+ use std :: collections :: { HashMap , HashSet , BTreeMap } ;
33use itertools:: Itertools ;
44use pyo3:: prelude:: * ;
55use rayon:: iter:: { IntoParallelRefIterator , ParallelIterator } ;
@@ -33,14 +33,14 @@ impl CodebookConfig {
3333 max_codebook_size : usize ,
3434 max_subtokens : usize ,
3535 pad_token_id : usize ,
36- disabled_ids : HashSet < usize > ,
36+ disabled_ids : Option < HashSet < usize > > ,
3737 ) -> Self {
3838 Self {
3939 initial_vocab_size,
4040 max_codebook_size,
4141 max_subtokens,
4242 pad_token_id,
43- disabled_ids,
43+ disabled_ids : disabled_ids . unwrap_or_default ( ) ,
4444 }
4545 }
4646}
@@ -196,8 +196,8 @@ impl Codebook {
196196 result
197197 }
198198
199- pub fn to_decoding_dict ( & self ) -> HashMap < usize , Vec < usize > > {
200- let mut result = HashMap :: with_capacity ( self . base_ids2hyper_id_map . len ( ) ) ;
199+ pub fn to_dict ( & self ) -> BTreeMap < usize , Vec < usize > > {
200+ let mut result = BTreeMap :: new ( ) ;
201201 for ( ids, id) in self . base_ids2hyper_id_map . iter ( ) {
202202 result. insert ( * id, ids. clone ( ) ) ;
203203 }
@@ -291,6 +291,8 @@ impl LZWCompressor {
291291
292292 let id = ids[ i] ;
293293 i += 1 ;
294+ log:: debug!( "coming id: {}" , id) ;
295+ log:: debug!( "buffer_ids_to_merge: {:?}" , buffer_ids_to_merge) ;
294296
295297 if self . config . disabled_ids . contains ( & id) {
296298 if buffer_ids_to_merge. len ( ) > 0 {
@@ -478,7 +480,7 @@ impl LZWCompressor {
478480 if self . config . disabled_ids . contains ( & id) {
479481 previous_ids. clear ( ) ;
480482 output_ids. push ( id) ;
481- log:: debug!( "emitting id: {}" , id) ;
483+ log:: debug!( "emitting disabled id: {}" , id) ;
482484 continue ;
483485 }
484486
@@ -495,7 +497,7 @@ impl LZWCompressor {
495497 // 1. the id is a new hyper id because of force merge
496498 } else if previous_ids. len ( ) == self . config . max_subtokens {
497499 decoded_ids = previous_ids. clone ( ) ;
498- // 2. the id is a new hyper id because of cSc pattern merge
500+ // 2. the id is a new hyper id because of cScSc pattern merge
499501 } else {
500502 decoded_ids = previous_ids. clone ( ) ;
501503 decoded_ids. push ( previous_ids[ 0 ] ) ;
@@ -507,34 +509,53 @@ impl LZWCompressor {
507509 previous_ids = decoded_ids. clone ( ) ;
508510 output_ids. extend_from_slice ( & decoded_ids) ;
509511 log:: debug!( "emitting id: {:?}" , decoded_ids) ;
512+ continue ;
510513 }
511514 // we have decoded the id, we can add it to the output
512515 output_ids. extend_from_slice ( & decoded_ids) ;
516+ log:: debug!( "emitting id: {:?}" , decoded_ids) ;
513517
514518 // the remaining part is to update the codebook if needed
519+
520+ if next_id == self . config . initial_vocab_size + self . config . max_codebook_size {
521+ log:: debug!( "max codebook size reached, clearing buffer_ids_to_merge" ) ;
522+ previous_ids. clear ( ) ;
523+ continue ;
524+ }
525+
526+ // starting case
515527 if previous_ids. len ( ) == 0 {
516528 previous_ids = decoded_ids;
517529 continue ;
518530 }
519531
520- while next_id < self . config . initial_vocab_size + self . config . max_codebook_size
521- && previous_ids. len ( ) < self . config . max_subtokens && decoded_ids. len ( ) > 0
522- {
523- previous_ids. push ( decoded_ids[ 0 ] ) ;
532+ // the buffer is max size and the buffer is the previous ID
533+ // so it must not be a new hyper id but an existing one
534+ // we just clear the buffer and continue
535+ if previous_ids. len ( ) == self . config . max_subtokens {
536+ assert ! ( existing_codes. contains( & previous_ids) , "previous_ids: {:?} not in existing_codes: {:?}" , previous_ids, existing_codes) ;
537+ previous_ids = decoded_ids. clone ( ) ;
538+ continue ;
539+ } else {
524540
525- if !existing_codes. contains ( & previous_ids) {
526- codebook. insert ( next_id, previous_ids. clone ( ) ) ;
527- log:: debug!( "inserting: {:?} -> {:?}" , previous_ids, next_id) ;
528- next_id += 1 ;
529- existing_codes. insert ( previous_ids. clone ( ) ) ;
530- previous_ids = decoded_ids. clone ( ) ;
531- break ;
532- } else if previous_ids. len ( ) == self . config . max_subtokens {
533- previous_ids = decoded_ids. clone ( ) ;
534- break ;
535- }
541+ while decoded_ids. len ( ) > 0
542+ {
543+ previous_ids. push ( decoded_ids[ 0 ] ) ;
544+
545+ if !existing_codes. contains ( & previous_ids) {
546+ codebook. insert ( next_id, previous_ids. clone ( ) ) ;
547+ log:: debug!( "inserting: {:?} -> {:?}" , previous_ids, next_id) ;
548+ next_id += 1 ;
549+ existing_codes. insert ( previous_ids. clone ( ) ) ;
550+ previous_ids = decoded_ids. clone ( ) ;
551+ break ;
552+ } else if previous_ids. len ( ) == self . config . max_subtokens {
553+ previous_ids = decoded_ids. clone ( ) ;
554+ break ;
555+ }
536556
537- decoded_ids. remove ( 0 ) ;
557+ decoded_ids. remove ( 0 ) ;
558+ }
538559 }
539560 }
540561
@@ -611,7 +632,7 @@ impl LZWCompressor {
611632 max_codebook_size,
612633 max_subtokens,
613634 pad_token_id,
614- disabled_ids,
635+ Some ( disabled_ids)
615636 ) ,
616637 }
617638 }
@@ -875,8 +896,8 @@ impl CodebookManager {
875896 let mut current_ids: Vec < usize > ;
876897 if maybe_hid < config. initial_vocab_size {
877898 current_ids = vec ! [ maybe_hid] ;
878- } else if let Some ( entry ) = codebook. get_base_ids ( maybe_hid) {
879- current_ids = entry . clone ( ) ;
899+ } else if let Some ( base_ids ) = codebook. get_base_ids ( maybe_hid) {
900+ current_ids = base_ids . clone ( ) ;
880901 // the following are cases when the maybe_hid is an unknown hyper-id
881902 // (1) the buffer was full and it was inserted in the codebook
882903 // TODO, I think the following is not needed, because
@@ -898,33 +919,38 @@ impl CodebookManager {
898919
899920 }
900921
922+ if state. next_id == config. initial_vocab_size + config. max_codebook_size {
923+ log:: debug!( "max_codebook_size reached, clearing buffer_ids_to_merge" ) ;
924+ state. buffer_ids_to_merge . clear ( ) ;
925+ continue ;
926+ }
927+
901928 // Starting time
902929 if state. buffer_ids_to_merge . len ( ) == 0 {
903930 state. buffer_ids_to_merge = current_ids. clone ( ) ;
904931 continue ;
905932 }
906933
907- while state. next_id < config. initial_vocab_size + config. max_codebook_size
908- && state. buffer_ids_to_merge . len ( ) < config. max_subtokens && current_ids. len ( ) > 0
909- {
910- state. buffer_ids_to_merge . push ( current_ids[ 0 ] ) ;
911-
912- // check if it's already in the codebook
913- if !codebook. contains_key ( & state. buffer_ids_to_merge ) {
914- codebook. insert ( state. buffer_ids_to_merge . clone ( ) , state. next_id ) ;
915- log:: debug!( "inserting: {:?} -> {:?}" , state. buffer_ids_to_merge, state. next_id) ;
916- state. next_id += 1 ;
917- state. buffer_ids_to_merge = current_ids. clone ( ) ;
918- break ;
919- } // reaching max_subtokens, we need to insert the current_ids into the codebook
920- else if state. buffer_ids_to_merge . len ( ) == config. max_subtokens {
921- state. buffer_ids_to_merge = current_ids. clone ( ) ;
922- break ;
923- }
924- current_ids. remove ( 0 ) ;
934+ if state. buffer_ids_to_merge . len ( ) == config. max_subtokens {
935+ state. buffer_ids_to_merge = current_ids. clone ( ) ;
936+ continue ;
937+ } else {
938+ while current_ids. len ( ) > 0 {
939+ state. buffer_ids_to_merge . push ( current_ids[ 0 ] ) ;
925940
941+ if !codebook. contains_key ( & state. buffer_ids_to_merge ) {
942+ codebook. insert ( state. buffer_ids_to_merge . clone ( ) , state. next_id ) ;
943+ log:: debug!( "inserting: {:?} -> {:?}" , state. buffer_ids_to_merge, state. next_id) ;
944+ state. next_id += 1 ;
945+ state. buffer_ids_to_merge = current_ids. clone ( ) ;
946+ break ;
947+ } else if state. buffer_ids_to_merge . len ( ) == config. max_subtokens {
948+ state. buffer_ids_to_merge = current_ids. clone ( ) ;
949+ break ;
950+ }
951+ current_ids. remove ( 0 ) ;
952+ }
926953 }
927-
928954 }
929955 }
930956}
@@ -1031,7 +1057,7 @@ impl CodebookManager {
10311057 _ => panic ! ( "Invalid algorithm: {}" , self . algorithm) ,
10321058 }
10331059
1034- // collect buffered updates from this state’ s codebook
1060+ // collect buffered updates from this state' s codebook
10351061 state
10361062 . codebook
10371063 . borrow_mut ( py)
0 commit comments