Skip to content

Commit f3f0028

Browse files
authored
up stream the changes from g3 (#1128)
1 parent c7914da commit f3f0028

File tree

3 files changed

+28
-17
lines changed

3 files changed

+28
-17
lines changed

universal-sentence-encoder/src/index.ts

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import * as tfconv from '@tensorflow/tfjs-converter';
1919
import * as tf from '@tensorflow/tfjs-core';
2020

21-
import {loadTokenizer, loadVocabulary, Tokenizer} from './tokenizer';
21+
import {loadTokenizer as loadTokenizerInternal, loadVocabulary, Tokenizer} from './tokenizer';
2222
import {loadQnA} from './use_qna';
2323

2424
export {version} from './version';
@@ -47,12 +47,11 @@ export class UniversalSentenceEncoder {
4747
private tokenizer: Tokenizer;
4848

4949
async loadModel(modelUrl?: string) {
50-
return modelUrl
51-
? tfconv.loadGraphModel(modelUrl)
52-
: tfconv.loadGraphModel(
53-
'https://tfhub.dev/tensorflow/tfjs-model/universal-sentence-encoder-lite/1/default/1',
54-
{fromTFHub: true}
55-
);
50+
return modelUrl ?
51+
tfconv.loadGraphModel(modelUrl) :
52+
tfconv.loadGraphModel(
53+
'https://tfhub.dev/tensorflow/tfjs-model/universal-sentence-encoder-lite/1/default/1',
54+
{fromTFHub: true});
5655
}
5756

5857
async load(config: LoadConfig = {}) {
@@ -102,6 +101,14 @@ export class UniversalSentenceEncoder {
102101
}
103102
}
104103

104+
/**
105+
* Load the Tokenizer for use independently from the UniversalSentenceEncoder.
106+
*
107+
* @param pathToVocabulary (optional) Provide a path to the vocabulary file.
108+
*/
109+
export async function loadTokenizer(pathToVocabulary?: string) {
110+
return loadTokenizerInternal(pathToVocabulary || BASE_PATH + '/vocab.json');
111+
}
112+
105113
export {Tokenizer};
106-
export {loadTokenizer};
107114
export {loadQnA};

universal-sentence-encoder/src/tokenizer/index.ts

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ export class Tokenizer {
5454
trie: Trie;
5555

5656
constructor(
57-
private vocabulary: Vocabulary,
58-
private reservedSymbolsCount = RESERVED_SYMBOLS_COUNT) {
57+
private readonly vocabulary: Vocabulary,
58+
private readonly reservedSymbolsCount = RESERVED_SYMBOLS_COUNT) {
5959
this.trie = new Trie();
6060

6161
for (let i = this.reservedSymbolsCount; i < this.vocabulary.length; i++) {
@@ -121,7 +121,7 @@ export class Tokenizer {
121121
}
122122

123123
// Merge consecutive unks.
124-
const merged = [];
124+
const merged: number[] = [];
125125
let isPreviousUnk = false;
126126
for (let i = 0; i < results.length; i++) {
127127
const id = results[i];
@@ -139,9 +139,9 @@ export class Tokenizer {
139139
/**
140140
* Load the Tokenizer for use independently from the UniversalSentenceEncoder.
141141
*
142-
* @param pathToVocabulary (optional) Provide a path to the vocabulary file.
142+
* @param pathToVocabulary Provide a path to the vocabulary file.
143143
*/
144-
export async function loadTokenizer(pathToVocabulary?: string) {
144+
export async function loadTokenizer(pathToVocabulary: string) {
145145
const vocabulary = await loadVocabulary(pathToVocabulary);
146146
const tokenizer = new Tokenizer(vocabulary);
147147
return tokenizer;

universal-sentence-encoder/src/tokenizer/trie.ts

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import {stringToChars} from '../util';
2121
type OutputNode = [string[], number, number];
2222

2323
class TrieNode {
24-
public parent: TrieNode;
24+
public parent: TrieNode|null;
2525
public end: boolean;
2626
public children: {[firstSymbol: string]: TrieNode};
2727
public word: OutputNode;
@@ -74,12 +74,16 @@ export class Trie {
7474
const output: OutputNode[] = [];
7575
let node = this.root.children[ss[0]];
7676

77-
for (let i = 0; i < ss.length && node; i++){
78-
if (node.end){ output.push(node.word); }
77+
for (let i = 0; i < ss.length && node; i++) {
78+
if (node.end) {
79+
output.push(node.word);
80+
}
7981
node = node.children[ss[i + 1]];
8082
}
8183

82-
if (!output.length){ output.push([[ss[0]], 0, 0]); }
84+
if (!output.length) {
85+
output.push([[ss[0]], 0, 0]);
86+
}
8387

8488
return output;
8589
}

0 commit comments

Comments
 (0)