Skip to content

Commit 053bd2d

Browse files
authored
Add the Universal Sentence Encoder lite. (#139)
1 parent 9bc9643 commit 053bd2d

File tree

15 files changed

+2010
-0
lines changed

15 files changed

+2010
-0
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
node_modules/
2+
.cache/
3+
dist/
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Universal Sentence Encoder lite
2+
3+
The Universal Sentence Encoder ([Cer et al., 2018](https://arxiv.org/pdf/1803.11175.pdf)) is a model that encodes text into 512-dimensional embeddings. These embeddings can then be used as inputs to natural language processing tasks such as [sentiment classification](https://en.wikipedia.org/wiki/Sentiment_analysis) and [textual similarity](https://en.wikipedia.org/wiki/Semantic_similarity) analysis.
4+
5+
This module is a TensorFlow.js [`FrozenModel`](https://js.tensorflow.org/api/latest/#loadFrozenModel) converted from the Universal Sentence Encoder lite ([module on TFHub](https://tfhub.dev/google/universal-sentence-encoder-lite/2)), a lightweight version of the original. The lite model is based on the Transformer ([Vaswani et al, 2017](https://arxiv.org/pdf/1706.03762.pdf)) architecture, and uses an 8k word piece [vocabulary](https://storage.googleapis.com/tfjs-models/savedmodel/universal_sentence_encoder/vocab.json).
6+
7+
## Usage
8+
9+
```js
10+
11+
import * as use from '@tensorflow-models/universal-sentence-encoder';
12+
13+
// Load the model.
14+
const model = await use.load();
15+
16+
// Embed an array of sentences.
17+
const sentences = [
18+
'Hello.',
19+
'How are you?'
20+
];
21+
22+
const embeddings = await model.embed(sentences);
23+
24+
// `embeddings` is a 2D tensor consisting of the 512-dimensional embeddings for each sentence.
25+
// So in this example `embeddings` has the shape [2, 512].
26+
const verbose = true;
27+
embeddings.print(verbose);
28+
29+
```
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
{
2+
"name": "@tensorflow-models/universal-sentence-encoder",
3+
"version": "0.0.1",
4+
"description": "Universal Sentence Encoder lite in TensorFlow.js",
5+
"main": "dist/index.js",
6+
"jsnext:main": "dist/universal-sentence-encoder.esm.js",
7+
"module": "dist/universal-sentence-encoder.esm.js",
8+
"unpkg": "dist/universal-sentence-encoder.min.js",
9+
"jsdelivr": "dist/universal-sentence-encoder.min.js",
10+
"types": "dist/index.d.ts",
11+
"repository": {
12+
"type": "git",
13+
"url": "https://github.com/tensorflow/tfjs-models.git"
14+
},
15+
"peerDependencies": {
16+
"@tensorflow/tfjs": "^0.14.2"
17+
},
18+
"devDependencies": {
19+
"@tensorflow/tfjs": "^0.14.2",
20+
"@types/jasmine": "~2.5.53",
21+
"jasmine": "^3.3.1",
22+
"jasmine-core": "^3.3.0",
23+
"rimraf": "~2.6.2",
24+
"rollup": "~0.58.2",
25+
"rollup-plugin-node-resolve": "~3.3.0",
26+
"rollup-plugin-typescript2": "~0.13.0",
27+
"rollup-plugin-uglify": "~3.0.0",
28+
"ts-node": "~5.0.0",
29+
"tslint": "~5.8.0",
30+
"typescript": "2.9.2"
31+
},
32+
"scripts": {
33+
"build": "rimraf dist && tsc",
34+
"publish-local": "yarn build && yalc push",
35+
"test": "ts-node run_tests.ts",
36+
"publish-npm": "yarn build && rollup -c && npm publish",
37+
"lint": "tslint -p . -t verbose"
38+
},
39+
"license": "Apache-2.0"
40+
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import node from 'rollup-plugin-node-resolve';
19+
import typescript from 'rollup-plugin-typescript2';
20+
import uglify from 'rollup-plugin-uglify';
21+
22+
const PREAMBLE =
23+
`// @tensorflow/tfjs-models Copyright ${(new Date).getFullYear()} Google`;
24+
25+
function minify() {
26+
return uglify({output: {preamble: PREAMBLE}});
27+
}
28+
29+
function config({plugins = [], output = {}}) {
30+
return {
31+
input: 'src/index.ts',
32+
plugins: [
33+
typescript({tsconfigOverride: {compilerOptions: {module: 'ES2015'}}}),
34+
node(), ...plugins
35+
],
36+
output: {banner: PREAMBLE, globals: {'@tensorflow/tfjs': 'tf'}, ...output},
37+
external: ['@tensorflow/tfjs']
38+
};
39+
}
40+
41+
export default [
42+
config(
43+
{output: {format: 'umd', name: 'universal-sentence-encoder', file: 'dist/universal-sentence-encoder.js'}}),
44+
config({
45+
plugins: [minify()],
46+
output: {format: 'umd', name: 'universal-sentence-encoder', file: 'dist/universal-sentence-encoder.min.js'}
47+
}),
48+
config({
49+
plugins: [minify()],
50+
output: {format: 'es', file: 'dist/universal-sentence-encoder.esm.js'}
51+
})
52+
];
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import * as jasmine_util from '@tensorflow/tfjs-core/dist/jasmine_util';
19+
import {runTests} from '../test_util';
20+
21+
runTests(jasmine_util);
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import * as tf from '@tensorflow/tfjs';
19+
20+
import {Tokenizer} from './tokenizer';
21+
22+
const BASE_PATH =
23+
'https://storage.googleapis.com/tfjs-models/savedmodel/universal_sentence_encoder/';
24+
25+
export async function load() {
26+
const use = new UniversalSentenceEncoder();
27+
await use.load();
28+
return use;
29+
}
30+
31+
export class UniversalSentenceEncoder {
32+
private model: tf.FrozenModel;
33+
private tokenizer: Tokenizer;
34+
35+
async loadModel() {
36+
return tf.loadFrozenModel(
37+
`${BASE_PATH}tensorflowjs_model.pb`,
38+
`${BASE_PATH}weights_manifest.json`);
39+
}
40+
41+
async loadVocabulary() {
42+
const vocabulary = await fetch(`${BASE_PATH}vocab.json`);
43+
return vocabulary.json();
44+
}
45+
46+
async load() {
47+
const [model, vocabulary] =
48+
await Promise.all([this.loadModel(), this.loadVocabulary()]);
49+
50+
this.model = model;
51+
this.tokenizer = new Tokenizer(vocabulary);
52+
}
53+
54+
/**
55+
*
56+
* Returns a 2D Tensor of shape [input.length, 512] that contains the
57+
* Universal Sentence Encoder embeddings for each input.
58+
*
59+
* @param inputs A string or an array of strings to embed.
60+
*/
61+
async embed(inputs: string[]|string): Promise<tf.Tensor2D> {
62+
if (typeof inputs === 'string') {
63+
inputs = [inputs];
64+
}
65+
66+
const encodings = inputs.map(d => this.tokenizer.encode(d));
67+
68+
const indicesArr =
69+
encodings.map((arr, i) => arr.map((d, index) => [i, index]));
70+
71+
let flattenedIndicesArr: Array<[number, number]> = [];
72+
for (let i = 0; i < indicesArr.length; i++) {
73+
flattenedIndicesArr =
74+
flattenedIndicesArr.concat(indicesArr[i] as Array<[number, number]>);
75+
}
76+
77+
const indices = tf.tensor2d(
78+
flattenedIndicesArr, [flattenedIndicesArr.length, 2], 'int32');
79+
const values = tf.tensor1d(tf.util.flatten(encodings) as number[], 'int32');
80+
81+
const embeddings = await this.model.executeAsync({indices, values});
82+
indices.dispose();
83+
values.dispose();
84+
85+
return embeddings as tf.Tensor2D;
86+
}
87+
}
88+
89+
export {Tokenizer};
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
export const stubbedTokenizerVocab = [
19+
['�', 0],
20+
['<s>', 0],
21+
['</s>', 0],
22+
['extra_token_id_1', 0],
23+
['extra_token_id_2', 0],
24+
['extra_token_id_3', 0],
25+
['▁', -2],
26+
['▁a', -1],
27+
['▁ç', -2],
28+
['a', -3],
29+
['.', -1],
30+
['▁I', -1],
31+
['▁like', -1],
32+
['▁it', -1],
33+
['I', -2],
34+
['like', -2],
35+
['it', -2],
36+
['l', -3],
37+
['i', -3],
38+
['k', -3],
39+
['e', -3],
40+
['i', -3],
41+
['t', -3]
42+
];

0 commit comments

Comments
 (0)