Skip to content

Commit 358cebb

Browse files
[tflite] Add task library wrapper client (tensorflow#4979)
* exports stuff from tflite web api client better * save * save * update * update * update * address comments * update * fix * update version
1 parent b785953 commit 358cebb

26 files changed

+679
-40
lines changed
-2.92 MB
Binary file not shown.

tfjs-tflite/demo/package.json

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
},
2121
"staticFiles": {
2222
"staticPath": [
23-
"models",
2423
"../dist"
2524
]
2625
},

tfjs-tflite/demo/src/script.ts

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,10 @@ const PHYSICAL_CORES = navigator.hardwareConcurrency / 2;
2929
async function start() {
3030
// Load model runner with the cartoonizer tflite model.
3131
const start = Date.now();
32-
const tfliteModel = await loadTFLiteModel('cartoongan_fp16.tflite', {
33-
numThreads: PHYSICAL_CORES,
34-
});
32+
const tfliteModel = await loadTFLiteModel(
33+
'https://tfhub.dev/sayakpaul/lite-model/cartoongan/fp16/1', {
34+
numThreads: PHYSICAL_CORES,
35+
});
3536
ele('.loading-msg').innerHTML = `Loaded WASM module and <a href='${
3637
CARTOONIZER_LINK}' target='blank'>TFLite model</a> in ${
3738
Date.now() - start}ms`;

tfjs-tflite/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "@tensorflow/tfjs-tflite",
3-
"version": "0.0.1-alpha.0",
3+
"version": "0.0.1-alpha.2",
44
"description": "TFLite support for TensorFlow.js",
55
"main": "dist/tf-tflite.node.js",
66
"jsnext:main": "dist/index.js",

tfjs-tflite/rollup.config.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ function config({
8585
module.exports = cmdOptions => {
8686
const bundles = [];
8787

88-
const name = 'tf';
88+
const name = 'tflite';
8989
const extend = true;
9090
const fileName = 'tf-tflite';
9191

tfjs-tflite/scripts/download-tflite-web-api.sh

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ set -e
2121
cd "$(dirname "$0")"
2222

2323
# The default version.
24-
CURRENT_VERSION=0.0.2
24+
CURRENT_VERSION=0.0.3
2525

2626
# Get the version from the first parameter.
2727
# Default to the value in CURRENT_VERSION.
@@ -37,3 +37,7 @@ fi
3737
mkdir -p ../deps
3838
GCP_DIR="gs://tfweb/${VERSION}/dist"
3939
gsutil -m cp "${GCP_DIR}/*" ../deps/
40+
41+
# Append module exports to the JS client to make it a valid CommonJS module.
42+
# This is needed to help bundler correctly initialize the tfweb namespace.
43+
echo "var tfweb = (window && window['tfweb']) || this['tfweb']; exports.tfweb = tfweb;" >> ../deps/tflite_web_api_client.js

tfjs-tflite/src/index.ts

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717

1818
export * from './tflite_model';
1919
export * from './types/tflite_web_model_runner';
20-
export * from './types/nl_classifier';
21-
export * from './types/image_classifier';
22-
export * from './types/tflite_web_api';
20+
export * from './tflite_task_library_client/image_classifier';
21+
export * from './tflite_task_library_client/image_segmenter';
22+
export * from './tflite_task_library_client/object_detector';
23+
export * from './tflite_task_library_client/nl_classifier';
24+
export * from './tflite_task_library_client/bert_nl_classifier';
25+
export * from './tflite_task_library_client/bert_qa';

tfjs-tflite/src/tflite_model.ts

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -269,24 +269,24 @@ export class TFLiteModel implements InferenceModel {
269269
/**
270270
* Loads a TFLiteModel from the given model url.
271271
*
272-
* @param modelUrl The path to the model.
272+
* @param model The path to the model (string), or the model content in memory
273+
* (ArrayBuffer).
273274
* @param options Options related to model inference.
274275
*
275276
* @doc {heading: 'Models', subheading: 'TFLiteModel'}
276277
*/
277278
export async function loadTFLiteModel(
278-
modelUrl: string,
279+
model: string|ArrayBuffer,
279280
options: TFLiteWebModelRunnerOptions =
280281
DEFAULT_TFLITE_MODEL_RUNNER_OPTIONS): Promise<TFLiteModel> {
281282
// Handle tfhub links.
282-
if (modelUrl.includes('tfhub.dev')) {
283-
if (!modelUrl.endsWith(TFHUB_SEARCH_PARAM)) {
284-
modelUrl = `${modelUrl}${TFHUB_SEARCH_PARAM}`;
285-
}
283+
if (typeof model === 'string' && model.includes('tfhub.dev') &&
284+
model.includes('lite-model') && !model.endsWith(TFHUB_SEARCH_PARAM)) {
285+
model = `${model}${TFHUB_SEARCH_PARAM}`;
286286
}
287287
const tfliteModelRunner =
288288
await tfliteWebAPIClient.tfweb.TFLiteWebModelRunner.create(
289-
modelUrl, options);
289+
model, options);
290290
return new TFLiteModel(tfliteModelRunner);
291291
}
292292

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/**
2+
* @license
3+
* Copyright 2021 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import * as tfliteWebAPIClient from '../tflite_web_api_client';
19+
import {BertNLClassifier as TaskLibraryBertNLClassifier} from '../types/bert_nl_classifier';
20+
import {Class} from './common';
21+
22+
export interface BertNLClassifierOptions {
23+
/**
24+
* Max number of tokens to pass to the model.
25+
*
26+
* Default to 128.
27+
*/
28+
maxSeqLen?: number;
29+
}
30+
31+
/**
32+
* Client for BertNLClassifier TFLite Task Library.
33+
*
34+
* It is a wrapper around the underlying javascript API to make it more
35+
* convenient to use. See comments in the corresponding type declaration file in
36+
* src/types for more info.
37+
*/
38+
export class BertNLClassifier {
39+
constructor(private instance: TaskLibraryBertNLClassifier) {}
40+
41+
static async create(
42+
model: string|ArrayBuffer,
43+
options?: BertNLClassifierOptions): Promise<BertNLClassifier> {
44+
const protoOptions = new tfliteWebAPIClient.tfweb.BertNLClassifierOptions();
45+
if (options) {
46+
if (options.maxSeqLen) {
47+
protoOptions.setMaxSeqLen(options.maxSeqLen);
48+
}
49+
}
50+
const instance = await tfliteWebAPIClient.tfweb.BertNLClassifier.create(
51+
model, protoOptions);
52+
return new BertNLClassifier(instance);
53+
}
54+
55+
classify(input: string): Class[] {
56+
return this.instance.classify(input).map(category => {
57+
return {
58+
probability: category.score,
59+
className: category.className,
60+
};
61+
});
62+
}
63+
64+
cleanUp() {
65+
this.instance.cleanUp();
66+
}
67+
}
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
/**
2+
* @license
3+
* Copyright 2021 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import * as tfliteWebAPIClient from '../tflite_web_api_client';
19+
import {BertQuestionAnswerer as TaskLibraryBertQuestionAnswerer} from '../types/bert_qa';
20+
21+
/** A single answer. */
22+
export interface QaAnswer {
23+
/** The text of the answer. */
24+
text: string;
25+
/** The position and logit of the answer. */
26+
pos: Pos;
27+
}
28+
29+
/** Answer position. */
30+
export interface Pos {
31+
/** The start position. */
32+
start: number;
33+
/** The end position. */
34+
end: number;
35+
/** The logit. */
36+
logit: number;
37+
}
38+
39+
/**
40+
* Client for BertQA TFLite Task Library.
41+
*
42+
* It is a wrapper around the underlying javascript API to make it more
43+
* convenient to use. See comments in the corresponding type declaration file in
44+
* src/types for more info.
45+
*/
46+
export class BertQuestionAnswerer {
47+
constructor(private instance: TaskLibraryBertQuestionAnswerer) {}
48+
49+
static async create(model: string|
50+
ArrayBuffer): Promise<BertQuestionAnswerer> {
51+
const instance =
52+
await tfliteWebAPIClient.tfweb.BertQuestionAnswerer.create(model);
53+
return new BertQuestionAnswerer(instance);
54+
}
55+
56+
answer(context: string, question: string): QaAnswer[] {
57+
const result = this.instance.answer(context, question);
58+
if (!result) {
59+
return [];
60+
}
61+
62+
return result.map(answer => {
63+
return {
64+
text: answer.text,
65+
pos: answer.pos,
66+
};
67+
});
68+
}
69+
70+
cleanUp() {
71+
this.instance.cleanUp();
72+
}
73+
}
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/**
2+
* @license
3+
* Copyright 2021 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {Class as ProtoClass} from '../types/common';
19+
20+
/** Common options for all task library tasks. */
21+
export interface CommonTaskLibraryOptions {
22+
/**
23+
* The number of threads to be used for TFLite ops that support
24+
* multi-threading when running inference with CPU. num_threads should be
25+
* greater than 0 or equal to -1. Setting num_threads to -1 has the effect to
26+
* let TFLite runtime set the value.
27+
*
28+
* Default to -1.
29+
*/
30+
numThreads?: number;
31+
}
32+
33+
/** A single class in the classification result. */
34+
export interface Class {
35+
/** The name of the class. */
36+
className: string;
37+
38+
/** The probability/score of the class. */
39+
probability: number;
40+
}
41+
42+
/** Convert proto Class array to our own Class array. */
43+
export function convertProtoClassesToClasses(protoClasses: ProtoClass[]):
44+
Class[] {
45+
const classes: Class[] = [];
46+
protoClasses.forEach(cls => {
47+
classes.push({
48+
className: cls.getDisplayName() || cls.getClassName(),
49+
probability: cls.getScore(),
50+
});
51+
});
52+
return classes;
53+
}
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/**
2+
* @license
3+
* Copyright 2021 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import * as tfliteWebAPIClient from '../tflite_web_api_client';
19+
import {ImageClassifier as TaskLibraryImageClassifier} from '../types/image_classifier';
20+
21+
import {Class, CommonTaskLibraryOptions, convertProtoClassesToClasses} from './common';
22+
23+
/** ImageClassifier options. */
24+
export interface ImageClassifierOptions extends CommonTaskLibraryOptions {
25+
/**
26+
* Maximum number of top scored results to return. If < 0, all results will
27+
* be returned. If 0, an invalid argument error is returned.
28+
*/
29+
maxResults?: number;
30+
31+
/**
32+
* Score threshold in [0,1), overrides the ones provided in the model metadata
33+
* (if any). Results below this value are rejected.
34+
*/
35+
scoreThreshold?: number;
36+
}
37+
38+
/**
39+
* Client for ImageClassifier TFLite Task Library.
40+
*
41+
* It is a wrapper around the underlying javascript API to make it more
42+
* convenient to use. See comments in the corresponding type declaration file in
43+
* src/types for more info.
44+
*/
45+
export class ImageClassifier {
46+
constructor(private instance: TaskLibraryImageClassifier) {}
47+
48+
static async create(
49+
model: string|ArrayBuffer,
50+
options?: ImageClassifierOptions): Promise<ImageClassifier> {
51+
const optionsProto = new tfliteWebAPIClient.tfweb.ImageClassifierOptions();
52+
if (options) {
53+
if (options.maxResults !== undefined) {
54+
optionsProto.setMaxResults(options.maxResults);
55+
}
56+
if (options.scoreThreshold !== undefined) {
57+
optionsProto.setScoreThreshold(options.scoreThreshold);
58+
}
59+
if (options.numThreads !== undefined) {
60+
optionsProto.setNumThreads(options.numThreads);
61+
}
62+
}
63+
const instance = await tfliteWebAPIClient.tfweb.ImageClassifier.create(
64+
model, optionsProto);
65+
return new ImageClassifier(instance);
66+
}
67+
68+
classify(input: ImageData|HTMLImageElement|HTMLCanvasElement|
69+
HTMLVideoElement): Class[] {
70+
const result = this.instance.classify(input);
71+
if (!result) {
72+
return [];
73+
}
74+
75+
let classes: Class[] = [];
76+
if (result.getClassificationsList().length > 0) {
77+
classes = convertProtoClassesToClasses(
78+
result.getClassificationsList()[0].getClassesList());
79+
}
80+
return classes;
81+
}
82+
83+
cleanUp() {
84+
this.instance.cleanUp();
85+
}
86+
}

0 commit comments

Comments
 (0)