Skip to content

Commit addf2e9

Browse files
authored
[use-qna] use tf.matmul for dot product computation (#886)
* use tf.matMul to do the dot product instead of download to cpu and compute * fix typo on readme
1 parent 5f53a4b commit addf2e9

File tree

4 files changed

+515
-56
lines changed

4 files changed

+515
-56
lines changed

universal-sentence-encoder/README.md

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -125,30 +125,8 @@ use.loadQnA().then(model => {
125125
* And embed_responses[0] is the embedding for the answer
126126
* 'I\'m not feeling very well.'
127127
*/
128-
const embed_query = embeddings['queryEmbedding'].arraySync();
129-
const embed_responses = embeddings['responseEmbedding'].arraySync();
130-
// compute the dotProduct of each query and response pair.
131-
for (let i = 0; i < input['queries'].length; i++) {
132-
for (let j = 0; j < input['responses'].length; j++) {
133-
scores.push(dotProduct(embed_query[i], embed_responses[j]));
134-
}
135-
}
128+
const scores = tf.matMul(embeddings['queryEmbedding'],
129+
embeddings['responseEmbedding'], false, true).dataSync();
136130
});
137131

138-
// Calculate the dot product of two vector arrays.
139-
const dotProduct = (xs, ys) => {
140-
const sum = xs => xs ? xs.reduce((a, b) => a + b, 0) : undefined;
141-
142-
return xs.length === ys.length ?
143-
sum(zipWith((a, b) => a * b, xs, ys))
144-
: undefined;
145-
}
146-
147-
// zipWith :: (a -> b -> c) -> [a] -> [b] -> [c]
148-
const zipWith =
149-
(f, xs, ys) => {
150-
const ny = ys.length;
151-
return (xs.length <= ny ? xs : xs.slice(0, ny))
152-
.map((x, i) => f(x, ys[i]));
153-
}
154132
```

universal-sentence-encoder/demo/index.js

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -83,31 +83,15 @@ const initQnA = async () => {
8383
const model = await use.loadQnA();
8484
document.querySelector('#loadingQnA').style.display = 'none';
8585
let result = model.embed(input);
86-
const query = result['queryEmbedding'].arraySync();
87-
const answers = result['responseEmbedding'].arraySync();
88-
for (let i = 0; i < answers.length; i++) {
86+
const dp = tf.matMul(result['queryEmbedding'], result['responseEmbedding'],
87+
false, true).dataSync();
88+
for (let i = 0; i < dp.length; i++) {
8989
document.getElementById(`answer_${i + 1}`).textContent =
90-
`${dotProduct(query[0], answers[i])}`
90+
`${dp[i]}`
9191
}
9292
};
9393
init();
9494
initQnA();
95-
// zipWith :: (a -> b -> c) -> [a] -> [b] -> [c]
96-
const zipWith =
97-
(f, xs, ys) => {
98-
const ny = ys.length;
99-
return (xs.length <= ny ? xs : xs.slice(0, ny))
100-
.map((x, i) => f(x, ys[i]));
101-
}
102-
103-
// dotProduct :: [Int] -> [Int] -> Int
104-
const dotProduct =
105-
(xs, ys) => {
106-
const sum = xs => xs ? xs.reduce((a, b) => a + b, 0) : undefined;
107-
108-
return xs.length === ys.length ? (sum(zipWith((a, b) => a * b, xs, ys))) :
109-
undefined;
110-
}
11195

11296
const renderSentences = () => {
11397
sentences.forEach((sentence, i) => {

universal-sentence-encoder/demo/package.json

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99
"node": ">=8.9.0"
1010
},
1111
"dependencies": {
12+
"@tensorflow-models/universal-sentence-encoder": "file:../dist",
1213
"@tensorflow/tfjs-backend-cpu": "^3.3.0",
1314
"@tensorflow/tfjs-backend-webgl": "^3.3.0",
1415
"@tensorflow/tfjs-converter": "^3.3.0",
1516
"@tensorflow/tfjs-core": "^3.3.0",
16-
"d3-scale-chromatic": "^1.3.3",
17-
"@tensorflow-models/universal-sentence-encoder": "file:../dist"
17+
"d3-scale-chromatic": "^1.3.3"
1818
},
1919
"scripts": {
2020
"watch": "cross-env NODE_ENV=development parcel index.html --no-hmr --open --target browser",
@@ -32,6 +32,7 @@
3232
"@babel/plugin-transform-runtime": "^7.7.6",
3333
"@babel/polyfill": "^7.10.4",
3434
"@babel/preset-env": "^7.7.6",
35+
"babel-preset-env": "^1.7.0",
3536
"clang-format": "~1.2.2",
3637
"cross-env": "^5.2.0",
3738
"dat.gui": "^0.7.2",

0 commit comments

Comments
 (0)