Skip to content

Commit f956373

Browse files
authored
[wasm] Use xnnpack for Add/Sub/Mul/Relu/Relu6 (#2506)
Call into xnnpack for Add/Sub/Mul/Relu/Relu6. This provides general axis broadcast support for Add, Sub and Mul. Perf improvements (average of 200 runs on Macbook Pro 15 2018): - Detector (15.8ms --> 13.9ms) (~14%) - Mesh (9.2ms --> 8.2ms) (~12%) - PoseNet (197ms --> 179ms) (~10%) - MobileNet (103ms --> 101ms) (~0%) PERF
1 parent 8267a95 commit f956373

22 files changed

+300
-90
lines changed

.vscode/settings.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@
130130
"forward_list": "cpp",
131131
"typeindex": "cpp",
132132
"*.inc": "cpp",
133-
"hash_map": "cpp"
133+
"hash_map": "cpp",
134+
"__refstring": "cpp"
134135
}
135136
}

tfjs-backend-wasm/scripts/cpplint.js

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,27 +26,25 @@ const ignoreCode = true;
2626
const commandOpts = null;
2727

2828
let pythonVersion = exec('python --version', commandOpts, ignoreCode);
29-
if(pythonVersion['stderr'].includes('Python 2')) {
29+
if (pythonVersion['stderr'].includes('Python 2')) {
3030
python2Cmd = 'python';
3131
} else {
3232
pythonVersion = exec('python2 --version', commandOpts, ignoreCode);
33-
if(pythonVersion.code === 0) {
33+
if (pythonVersion.code === 0) {
3434
python2Cmd = 'python2';
3535
}
3636
}
3737

38-
if(python2Cmd != null) {
38+
if (python2Cmd != null) {
3939
const result = shell.find('src/cc').filter(
40-
fileName => fileName.endsWith('.cc') || fileName.endsWith('.h'));
41-
42-
console.log(`C++ linting files:`);
43-
console.log(result);
40+
fileName => fileName.endsWith('.cc') || fileName.endsWith('.h'));
4441

4542
const cwd = process.cwd() + '/' + CC_FILEPATH;
4643
const filenameArgument = result.join(' ');
4744

4845
exec(`${python2Cmd} tools/cpplint.py --root ${cwd} ${filenameArgument}`);
4946
} else {
50-
console.warn('No python2.x version found - please install python2. ' +
51-
'cpplint.py only works correctly with python 2.x.');
47+
console.warn(
48+
'No python2.x version found - please install python2. ' +
49+
'cpplint.py only works correctly with python 2.x.');
5250
}

tfjs-backend-wasm/src/cc/BUILD

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ tfjs_cc_library(
4949

5050
tfjs_cc_library(
5151
name = "binary",
52-
srcs = ["binary.h"],
52+
hdrs = ["binary.h"],
53+
srcs = ["binary.cc"],
5354
deps = [
5455
":backend",
5556
],
@@ -94,6 +95,16 @@ tfjs_cc_library(
9495
],
9596
)
9697

98+
tfjs_cc_library(
99+
name = "clamp_impl",
100+
hdrs = ["clamp_impl.h"],
101+
srcs = ["clamp_impl.cc"],
102+
deps = [
103+
":backend",
104+
":util"
105+
],
106+
)
107+
97108
tfjs_cc_library(
98109
name = "all_kernels",
99110
deps = [
@@ -360,6 +371,7 @@ tfjs_cc_library(
360371
srcs = ["kernels/Relu.cc"],
361372
deps = [
362373
":backend",
374+
":clamp_impl",
363375
":unary",
364376
],
365377
)
@@ -369,6 +381,7 @@ tfjs_cc_library(
369381
srcs = ["kernels/Relu6.cc"],
370382
deps = [
371383
":backend",
384+
":clamp_impl",
372385
":unary",
373386
],
374387
)

tfjs-backend-wasm/src/cc/binary.cc

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/* Copyright 2019 Google Inc. All Rights Reserved.
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
* ===========================================================================*/
14+
15+
#include "src/cc/binary.h"
16+
17+
#include <xnnpack.h>
18+
#include <limits>
19+
#include <unordered_map>
20+
21+
#include "src/cc/backend.h"
22+
#include "src/cc/util.h"
23+
24+
namespace {
25+
// Maps an `xnn_create_*_nd_f32` function pointer to an instantiated operator.
26+
std::unordered_map<tfjs::wasm::xnn_create_binary_op, xnn_operator_t> op_cache;
27+
} // namespace
28+
29+
namespace tfjs {
30+
namespace wasm {
31+
32+
void binary_xnn_f32(const int a_id, const size_t* a_shape_ptr,
33+
const int a_shape_len, const int b_id,
34+
const size_t* b_shape_ptr, const int b_shape_len,
35+
const int out_id, xnn_create_binary_op create_op,
36+
xnn_setup_binary_op setup_op) {
37+
auto& a_info = backend::get_tensor_info(a_id);
38+
auto& b_info = backend::get_tensor_info(b_id);
39+
auto& out_info = backend::get_tensor_info_out(out_id);
40+
const float* a_buf = a_info.f32();
41+
const float* b_buf = b_info.f32();
42+
float* out_buf = out_info.f32_write();
43+
44+
xnn_operator_t binary_op = nullptr;
45+
46+
auto cache_result = op_cache.find(create_op);
47+
if (cache_result == op_cache.end()) {
48+
const float sum_min = -std::numeric_limits<float>::infinity(),
49+
sum_max = std::numeric_limits<float>::infinity();
50+
const int flags = 0;
51+
xnn_status status = create_op(sum_min, sum_max, flags, &binary_op);
52+
if (status != xnn_status_success) {
53+
util::warn(
54+
"XNN status for xnn_create_*_nd_f32 is not successful. Got "
55+
"status %d. Use -c dbg to see XNN logs.");
56+
return;
57+
}
58+
op_cache.insert({create_op, binary_op});
59+
backend::xnn_operator_count++;
60+
} else {
61+
binary_op = cache_result->second;
62+
}
63+
const int batch_size = out_info.size;
64+
xnn_status status =
65+
setup_op(binary_op, a_shape_len, a_shape_ptr, b_shape_len, b_shape_ptr,
66+
a_buf, b_buf, out_buf, nullptr /* thread pool */);
67+
if (status != xnn_status_success) {
68+
util::warn(
69+
"XNN status for xnn_setup_*_nd_f32 is not successful. Got "
70+
"status %d. Use -c dbg to see XNN logs.",
71+
status);
72+
return;
73+
}
74+
75+
xnn_run_operator(binary_op, nullptr /* thread pool */);
76+
}
77+
78+
} // namespace wasm
79+
} // namespace tfjs

tfjs-backend-wasm/src/cc/binary.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#ifndef BINARY_H_
1616
#define BINARY_H_
1717

18+
#include <xnnpack.h>
1819
#include <algorithm>
1920

2021
#include "src/cc/backend.h"
@@ -58,6 +59,18 @@ inline void binary_bool(const int a_id, const int b_id, const int out_id,
5859
out_info.b_write(), operation);
5960
}
6061

62+
typedef xnn_status (*xnn_create_binary_op)(float, float, uint32_t,
63+
xnn_operator_t*);
64+
typedef xnn_status (*xnn_setup_binary_op)(xnn_operator_t, size_t, const size_t*,
65+
size_t, const size_t*, const float*,
66+
const float*, float*, pthreadpool_t);
67+
68+
void binary_xnn_f32(const int a_id, const size_t* a_shape_ptr,
69+
const int a_shape_len, const int b_id,
70+
const size_t* b_shape_ptr, const int b_shape_len,
71+
const int out_id, xnn_create_binary_op create_op,
72+
xnn_setup_binary_op setup_op);
73+
6174
} // namespace wasm
6275
} // namespace tfjs
6376

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/* Copyright 2019 Google Inc. All Rights Reserved.
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
* ===========================================================================*/
14+
15+
#include "src/cc/clamp_impl.h"
16+
17+
#include <xnnpack.h>
18+
#include <map>
19+
#include <tuple>
20+
21+
#include "src/cc/backend.h"
22+
#include "src/cc/util.h"
23+
24+
namespace {
25+
// These values are keys to creating the xnn clamp operator. We use
26+
// std::tuple since it implements the compare operator needed for std::map.
27+
typedef std::tuple<float, float> CacheKey;
28+
// The operator cache maps the params of xnn_create_clamp_nc_f32 to an operator.
29+
std::map<CacheKey, xnn_operator_t> op_cache;
30+
} // namespace
31+
32+
namespace tfjs {
33+
namespace wasm {
34+
35+
void xnn_clamp(const int x_id, const int out_id, const float min,
36+
const float max) {
37+
auto& x_info = backend::get_tensor_info(x_id);
38+
auto& out_info = backend::get_tensor_info_out(out_id);
39+
const float* x_buf = x_info.f32();
40+
float* out_buf = out_info.f32_write();
41+
42+
xnn_operator_t op = nullptr;
43+
CacheKey cache_key = {min, max};
44+
const auto& cache_result = op_cache.find(cache_key);
45+
if (cache_result == op_cache.end()) {
46+
const size_t channels = 1, input_stride = 1, output_stride = 1, flags = 1;
47+
xnn_status status = xnn_create_clamp_nc_f32(
48+
channels, input_stride, output_stride, min, max, flags, &op);
49+
if (status != xnn_status_success) {
50+
util::warn(
51+
"XNN status for xnn_create_clamp_nc_f32 is not successful. "
52+
"Got status %d. Use -c dbg to see XNN logs.",
53+
status);
54+
return;
55+
}
56+
op_cache.emplace(cache_key, op);
57+
backend::xnn_operator_count++;
58+
} else {
59+
op = cache_result->second;
60+
}
61+
62+
const size_t batch_size = out_info.size;
63+
xnn_status status = xnn_setup_clamp_nc_f32(op, batch_size, x_buf, out_buf,
64+
nullptr /* thread pool */);
65+
if (status != xnn_status_success) {
66+
util::warn(
67+
"XNN status for xnn_setup_clamp_nc_f32 is not successful. "
68+
"Got status %d. Use -c dbg to see XNN logs.",
69+
status);
70+
return;
71+
}
72+
73+
xnn_run_operator(op, nullptr /* thread pool */);
74+
}
75+
76+
} // namespace wasm
77+
} // namespace tfjs

tfjs-backend-wasm/src/cc/clamp_impl.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/* Copyright 2019 Google Inc. All Rights Reserved.
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
* ===========================================================================*/
14+
15+
#ifndef CLAMP_IMPL_H_
16+
#define CLAMP_IMPL_H_
17+
18+
namespace tfjs {
19+
namespace wasm {
20+
21+
void xnn_clamp(const int x_id, const int out_id, const float min,
22+
const float max);
23+
24+
} // namespace wasm
25+
} // namespace tfjs
26+
27+
#endif // CLAMP_IMPL_H_

tfjs-backend-wasm/src/cc/kernels/Add.cc

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
#ifdef __EMSCRIPTEN__
1616
#include <emscripten.h>
1717
#endif
18+
#include <xnnpack.h>
1819

19-
#include "src/cc/backend.h"
2020
#include "src/cc/binary.h"
2121
#include "src/cc/util.h"
2222

@@ -35,11 +35,14 @@ extern "C" {
3535
#ifdef __EMSCRIPTEN__
3636
EMSCRIPTEN_KEEPALIVE
3737
#endif
38-
void Add(const int a_id, const int b_id, const DType dtype, const int out_id) {
39-
auto& a_info = backend::get_tensor_info(a_id);
38+
void Add(const int a_id, const size_t* a_shape_ptr, const int a_shape_len,
39+
const int b_id, const size_t* b_shape_ptr, const int b_shape_len,
40+
const DType dtype, const int out_id) {
4041
switch (dtype) {
4142
case DType::float32:
42-
binary_f32(a_id, b_id, out_id, add<float>);
43+
binary_xnn_f32(a_id, a_shape_ptr, a_shape_len, b_id, b_shape_ptr,
44+
b_shape_len, out_id, xnn_create_add_nd_f32,
45+
xnn_setup_add_nd_f32);
4346
break;
4447
case DType::int32:
4548
binary_i32(a_id, b_id, out_id, add<int>);

tfjs-backend-wasm/src/cc/kernels/Div.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ extern "C" {
3535
#ifdef __EMSCRIPTEN__
3636
EMSCRIPTEN_KEEPALIVE
3737
#endif
38-
void Div(const int a_id, const int b_id, const DType dtype, const int out_id) {
38+
void Div(const int a_id, const size_t* a_shape_ptr, const int a_shape_len,
39+
const int b_id, const size_t* b_shape_ptr, const int b_shape_len,
40+
const DType dtype, const int out_id) {
3941
auto& a_info = backend::get_tensor_info(a_id);
4042
switch (dtype) {
4143
case DType::float32:

tfjs-backend-wasm/src/cc/kernels/FloorDiv.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,9 @@ extern "C" {
3030
#ifdef __EMSCRIPTEN__
3131
EMSCRIPTEN_KEEPALIVE
3232
#endif
33-
void FloorDiv(const int a_id, const int b_id, const DType dtype,
34-
const int out_id) {
33+
void FloorDiv(const int a_id, const size_t* a_shape_ptr, const int a_shape_len,
34+
const int b_id, const size_t* b_shape_ptr, const int b_shape_len,
35+
const DType dtype, const int out_id) {
3536
auto& a_info = backend::get_tensor_info(a_id);
3637
switch (dtype) {
3738
case DType::float32:

tfjs-backend-wasm/src/cc/kernels/Mul.cc

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
#ifdef __EMSCRIPTEN__
1616
#include <emscripten.h>
1717
#endif
18+
#include <xnnpack.h>
1819

19-
#include "src/cc/backend.h"
2020
#include "src/cc/binary.h"
2121
#include "src/cc/util.h"
2222

@@ -35,11 +35,14 @@ extern "C" {
3535
#ifdef __EMSCRIPTEN__
3636
EMSCRIPTEN_KEEPALIVE
3737
#endif
38-
void Mul(const int a_id, const int b_id, const DType dtype, const int out_id) {
39-
auto& a_info = backend::get_tensor_info(a_id);
38+
void Mul(const int a_id, const size_t* a_shape_ptr, const int a_shape_len,
39+
const int b_id, const size_t* b_shape_ptr, const int b_shape_len,
40+
const DType dtype, const int out_id) {
4041
switch (dtype) {
4142
case DType::float32:
42-
binary_f32(a_id, b_id, out_id, mul<float>);
43+
binary_xnn_f32(a_id, a_shape_ptr, a_shape_len, b_id, b_shape_ptr,
44+
b_shape_len, out_id, xnn_create_multiply_nd_f32,
45+
xnn_setup_multiply_nd_f32);
4346
break;
4447
case DType::int32:
4548
binary_i32(a_id, b_id, out_id, mul<int>);

0 commit comments

Comments
 (0)