@@ -15,85 +15,61 @@ use rustc_hash::FxHashMap as HashMap;
1515
1616type  Rank  = u32 ; 
1717
18- fn  _byte_pair_merge ( 
19-     ranks :  & HashMap < Vec < u8 > ,  Rank > , 
20-     piece :  & [ u8 ] , 
21- )  -> Vec < ( usize ,  Rank ) >  { 
18+ fn  _byte_pair_merge ( ranks :  & HashMap < Vec < u8 > ,  Rank > ,  piece :  & [ u8 ] )  -> Vec < ( usize ,  Rank ) >  { 
2219    // This is a vector of (start, rank). 
23-     // The rank is of the byte pair starting at position start. 
24-     // The rank of the last item in the vector is not a valid value. 
25-     let  mut  parts:  Vec < ( usize ,  Rank ) >  = ( 0 ..piece. len ( )  + 1 ) . map ( |i| ( i,  Rank :: MAX ) ) . collect ( ) ; 
20+     // The rank is of the pair starting at position start. 
21+     let  mut  parts = Vec :: with_capacity ( piece. len ( )  + 1 ) ; 
22+ 
23+     // Note that we hash bytes when indexing into `ranks`, not token pairs. As long as we train BPE 
24+     // the way we currently do, this is equivalent. An easy way to break this would be to decouple 
25+     // merge priority from token index or to prevent specific token merges. 
26+     let  mut  min_rank:  ( Rank ,  usize )  = ( Rank :: MAX ,  usize:: MAX ) ; 
27+     for  i in  0 ..piece. len ( )  - 1  { 
28+         let  rank = * ranks. get ( & piece[ i..i + 2 ] ) . unwrap_or ( & Rank :: MAX ) ; 
29+         if  rank < min_rank. 0  { 
30+             min_rank = ( rank,  i) ; 
31+         } 
32+         parts. push ( ( i,  rank) ) ; 
33+     } 
34+     parts. push ( ( piece. len ( )  - 1 ,  Rank :: MAX ) ) ; 
35+     parts. push ( ( piece. len ( ) ,  Rank :: MAX ) ) ; 
2636
2737    let  get_rank = { 
2838        #[ inline( always) ]  
29-         |parts :  & Vec < ( usize ,  Rank ) > ,  start_idx :  usize ,  skip :  usize | { 
30-             if  ( start_idx + skip + 2 )  < parts. len ( )  { 
31-                 ranks
32-                     . get ( & piece[ parts[ start_idx] . 0 ..parts[ start_idx + skip + 2 ] . 0 ] ) 
33-                     . copied ( ) 
39+         |parts :  & Vec < ( usize ,  Rank ) > ,  i :  usize | { 
40+             if  ( i + 3 )  < parts. len ( )  { 
41+                 // Similar to `piece[i..i + 2]` above. The +3 is because we haven't yet deleted 
42+                 // parts[i + 1], see comment in the main loop. 
43+                 * ranks
44+                     . get ( & piece[ parts[ i] . 0 ..parts[ i + 3 ] . 0 ] ) 
45+                     . unwrap_or ( & Rank :: MAX ) 
3446            }  else  { 
35-                 None 
47+                 Rank :: MAX 
3648            } 
3749        } 
3850    } ; 
3951
40-     // We look up the ranks once in the beginning and iteratively update 
41-     // them during each merge, which reduces the number of rank lookups. 
42-     for  i in  0 ..parts. len ( )  - 2  { 
43-         match  get_rank ( & parts,  i,  0 )  { 
44-             Some ( rank)  => { 
45-                 // Rank::MAX is a sentinel value and cannot be a valid rank 
46-                 debug_assert ! ( rank != Rank :: MAX ) ; 
47-                 parts[ i] . 1  = rank; 
48-             } 
49-             None  => { 
50-                 continue ; 
51-             } 
52-         } ; 
53-     } 
54- 
5552    // If you have n parts and m merges, this does O(mn) work. 
5653    // We could do something with a heap and do O(m log n) work. 
57-     // It is important to consider that n is often small (<100), and as such 
58-     // the cache-locality benefits outweigh the algorithmic complexity downsides 
59-     // of the `parts` vector data structure above. 
60- 
61-     // Note that we hash bytes, not token pairs. As long as we train BPE the way we 
62-     // currently do, this is equivalent. An easy way to break this would be to decouple 
63-     // merge priority from token index or to prevent specific token merges. 
64-     loop  { 
65-         if  parts. len ( )  == 1  { 
66-             break ; 
54+     // n is often very small so considerations like cache-locality outweigh the algorithmic 
55+     // complexity downsides of the `parts` vector. 
56+     while  min_rank. 0  != Rank :: MAX  { 
57+         let  i = min_rank. 1 ; 
58+         // Update parts[i] and parts[i - 1] before removing parts[i + 1], since 
59+         // `parts.remove(i + 1)` will thrash the cache. 
60+         if  i > 0  { 
61+             parts[ i - 1 ] . 1  = get_rank ( & parts,  i - 1 ) ; 
6762        } 
63+         parts[ i] . 1  = get_rank ( & parts,  i) ; 
64+         parts. remove ( i + 1 ) ; 
6865
69-         // Rank::MAX is a sentinel rank value allowing us to 
70-         // take the min more quickly 
71-         let  mut  min_rank:  ( Rank ,  usize )  = ( Rank :: MAX ,  0 ) ; 
66+         min_rank = ( Rank :: MAX ,  usize:: MAX ) ; 
7267        for  ( i,  & ( _,  rank) )  in  parts[ ..parts. len ( )  - 1 ] . iter ( ) . enumerate ( )  { 
7368            if  rank < min_rank. 0  { 
7469                min_rank = ( rank,  i) ; 
7570            } 
7671        } 
77- 
78-         if  min_rank. 0  != Rank :: MAX  { 
79-             let  i = min_rank. 1 ; 
80- 
81-             // NOTE: We are about to remove parts[i + 1]. We do not do it 
82-             // yet because there are cache-locality benefits to updating 
83-             // parts[i] and parts[i-1] before removing, which could thrash 
84-             // the cache. Thus, we update the rank calculation by skipping over 
85-             // parts[i + 1], by invoking `get_rank!` with `skip = 1`. 
86-             parts[ i] . 1  = get_rank ( & parts,  i,  1 ) . unwrap_or ( Rank :: MAX ) ; 
87-             if  i > 0  { 
88-                 parts[ i - 1 ] . 1  = get_rank ( & parts,  i - 1 ,  1 ) . unwrap_or ( Rank :: MAX ) ; 
89-             } 
90- 
91-             parts. remove ( i + 1 ) ; 
92-         }  else  { 
93-             break ; 
94-         } 
9572    } 
96- 
9773    parts
9874} 
9975
0 commit comments