Skip to content

Commit 65c50f0

Browse files
authored
Make it possible to use encoder separately (tensorflow#150)
1 parent 80a223b commit 65c50f0

File tree

3 files changed

+33
-7
lines changed

3 files changed

+33
-7
lines changed

universal-sentence-encoder/README.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,12 @@ use.load().then(model => {
4949
embeddings.print(true /* verbose */);
5050
});
5151
});
52-
```
52+
```
53+
54+
To use the Tokenizer separately:
55+
56+
```js
57+
use.loadTokenizer().then(tokenizer => {
58+
tokenizer.encode('Hello, how are you?'); // [341, 4125, 8, 140, 31, 19, 54]
59+
});
60+
```
Loading

universal-sentence-encoder/src/index.ts

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,29 @@ export async function load() {
2828
return use;
2929
}
3030

31+
32+
/**
33+
* Load the Tokenizer for use independently from the UniversalSentenceEncoder.
34+
*
35+
* @param pathToVocabulary (optional) Provide a path to the vocabulary file.
36+
*/
37+
export async function loadTokenizer(pathToVocabulary?: string) {
38+
const vocabulary = await loadVocabulary(pathToVocabulary);
39+
const tokenizer = new Tokenizer(vocabulary);
40+
return tokenizer;
41+
}
42+
43+
/**
44+
* Load a vocabulary for the Tokenizer.
45+
*
46+
* @param pathToVocabulary Defaults to the path to the 8k vocabulary used by the
47+
* UniversalSentenceEncoder.
48+
*/
49+
async function loadVocabulary(pathToVocabulary = `${BASE_PATH}vocab.json`) {
50+
const vocabulary = await fetch(pathToVocabulary);
51+
return vocabulary.json();
52+
}
53+
3154
export class UniversalSentenceEncoder {
3255
private model: tf.FrozenModel;
3356
private tokenizer: Tokenizer;
@@ -38,14 +61,9 @@ export class UniversalSentenceEncoder {
3861
`${BASE_PATH}weights_manifest.json`);
3962
}
4063

41-
async loadVocabulary() {
42-
const vocabulary = await fetch(`${BASE_PATH}vocab.json`);
43-
return vocabulary.json();
44-
}
45-
4664
async load() {
4765
const [model, vocabulary] =
48-
await Promise.all([this.loadModel(), this.loadVocabulary()]);
66+
await Promise.all([this.loadModel(), loadVocabulary()]);
4967

5068
this.model = model;
5169
this.tokenizer = new Tokenizer(vocabulary);

0 commit comments

Comments
 (0)