Skip to content

Commit 053c00f

Browse files
Lőrinchauntsaninja
Lőrinc
authored andcommitted
Inline custom mapping function in _byte_pair_merge
1 parent 6e4851a commit 053c00f

File tree

1 file changed

+16
-17
lines changed

1 file changed

+16
-17
lines changed

src/lib.rs

+16-17
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,17 @@ use std::thread;
88
use fancy_regex::Regex;
99
use pyo3::exceptions;
1010
use pyo3::prelude::*;
11+
use pyo3::pyclass;
1112
use pyo3::PyResult;
1213
use pyo3::types::{PyBytes, PyList, PyTuple};
1314
use rustc_hash::FxHashMap as HashMap;
1415

1516
type Rank = u32;
1617

17-
fn _byte_pair_merge<T>(
18-
piece: &[u8],
18+
fn _byte_pair_merge(
1919
ranks: &HashMap<Vec<u8>, Rank>,
20-
f: impl Fn(std::ops::Range<usize>) -> T,
21-
) -> Vec<T> {
20+
piece: &[u8],
21+
) -> Vec<(usize, Rank)> {
2222
// This is a vector of (start, rank).
2323
// The rank is of the byte pair starting at position start.
2424
// The rank of the last item in the vector is not a valid value.
@@ -93,25 +93,24 @@ fn _byte_pair_merge<T>(
9393
break;
9494
}
9595
}
96-
let mut out: Vec<T> = Vec::with_capacity(parts.len() - 1);
97-
for i in 0..parts.len() - 1 {
98-
out.push(f(parts[i].0..parts[i + 1].0));
99-
}
100-
out
96+
97+
parts
10198
}
10299

103100
pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<Rank> {
104-
if piece.len() == 1 {
105-
return vec![ranks[piece]];
106-
}
107-
_byte_pair_merge(piece, ranks, |p| ranks[&piece[p.start..p.end]])
101+
assert!(piece.len() > 1);
102+
_byte_pair_merge(&ranks, &piece)
103+
.windows(2)
104+
.map(|part| ranks[&piece[part[0].0..part[1].0]])
105+
.collect()
108106
}
109107

110108
pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<&'a [u8]> {
111-
if piece.len() == 1 {
112-
return vec![piece];
113-
}
114-
_byte_pair_merge(piece, ranks, |p| &piece[p.start..p.end])
109+
assert!(piece.len() > 1);
110+
_byte_pair_merge(&ranks, &piece)
111+
.windows(2)
112+
.map(|part| &piece[part[0].0..part[1].0])
113+
.collect()
115114
}
116115

117116
// Various performance notes:

0 commit comments

Comments
 (0)