Skip to content

Commit 4e274bf

Browse files
author
Nikhil Thorat
authored
[WASM] Add output_min, output_max to cache key for conv2d. (#2507)
BUG Previously we weren't using this in the cache key, which could give incorrect results at runtime.
1 parent f956373 commit 4e274bf

File tree

5 files changed

+46
-21
lines changed

5 files changed

+46
-21
lines changed

tfjs-backend-wasm/src/cc/conv2d_impl.cc

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
#include "src/cc/conv2d_impl.h"
2020

2121
#include <xnnpack.h>
22-
#include <array>
22+
#include <tuple>
2323
#include <cmath>
2424
#include <limits>
2525
#include <map>
@@ -34,10 +34,11 @@
3434
#include "src/cc/util.h"
3535

3636
namespace {
37-
// These integer values are keys to creating the conv2d operator. We use
38-
// std::array instead of a vanilla array as it implements the compare operator
37+
// We use std::tuple as the cache key as it implements the compare operator
3938
// needed for std::map.
40-
typedef std::array<int, 19> OperatorCacheKey;
39+
typedef std::tuple<int, int, int, int, int, int, int, int, int, int, int, int,
40+
int, int, int, int, int, int, int, float, float>
41+
OperatorCacheKey;
4142

4243
struct CachedInfo {
4344
xnn_operator_t op;
@@ -162,6 +163,16 @@ void conv2d(const int x_id, const int batch_size, const int input_height,
162163
clamp_method = tfjs::wasm::FusableActivation::LINEAR;
163164
}
164165

166+
float output_min = -std::numeric_limits<float>::infinity();
167+
float output_max = std::numeric_limits<float>::infinity();
168+
169+
if (activation == FusableActivation::RELU) {
170+
output_min = 0;
171+
} else if (activation == FusableActivation::RELU6) {
172+
output_min = 0;
173+
output_max = 6;
174+
}
175+
165176
OperatorCacheKey cache_key = {pad_top,
166177
pad_right,
167178
pad_bottom,
@@ -180,20 +191,12 @@ void conv2d(const int x_id, const int batch_size, const int input_height,
180191
clamp_method,
181192
filter_id,
182193
bias_id,
183-
flags};
194+
flags,
195+
output_min,
196+
output_max};
184197

185198
auto operator_cache_idx = operator_cache.find(cache_key);
186199
if (operator_cache_idx == operator_cache.end()) {
187-
float output_min = -std::numeric_limits<float>::infinity();
188-
float output_max = std::numeric_limits<float>::infinity();
189-
190-
if (activation == FusableActivation::RELU) {
191-
output_min = 0;
192-
} else if (activation == FusableActivation::RELU6) {
193-
output_min = 0;
194-
output_max = 6;
195-
}
196-
197200
// This lives outside the if statement so the data survives the scope.
198201
std::vector<float> transposed_filter;
199202

tfjs-backend-wasm/src/cc/kernels/AvgPool.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ void AvgPool(const int x_id, const int batch_size, const int input_height,
6565
auto operator_cache_idx = operator_cache.find(cache_key);
6666

6767
if (operator_cache_idx == operator_cache.end()) {
68-
float output_min = -std::numeric_limits<float>::infinity();
69-
float output_max = std::numeric_limits<float>::infinity();
68+
const float output_min = -std::numeric_limits<float>::infinity();
69+
const float output_max = std::numeric_limits<float>::infinity();
7070

7171
xnn_status status = xnn_create_average_pooling2d_nhwc_f32(
7272
pad_top, pad_right, pad_bottom, pad_left, filter_height, filter_width,

tfjs-backend-wasm/src/cc/kernels/FusedConv2D_test.cc

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,20 @@ TEST(FUSEDCONV2D, xnn_operator_lifetime) {
168168
prelu_weights_id, out_id);
169169
ASSERT_EQ(6, tfjs::backend::xnn_operator_count);
170170

171+
// One new XNN operator should be created for the next call to conv2d with a
172+
// different activation.
173+
const int activation2 = tfjs::wasm::FusableActivation::RELU6;
174+
tfjs::wasm::FusedConv2D(
175+
x1_id, batch_size, input_height, input_width, weights1_id, filter_height,
176+
filter_width, bias1_id, pad_top1, pad_right, pad_bottom1, pad_left,
177+
is_same_pad1, dilation_height, dilation_width, stride_height,
178+
stride_width, input_channels, output_channels, activation2,
179+
prelu_weights_id, out_id);
180+
ASSERT_EQ(7, tfjs::backend::xnn_operator_count);
181+
171182
// Disposing the first weights should remove 2 operators.
172183
tfjs::wasm::dispose_data(weights0_id);
173-
ASSERT_EQ(4, tfjs::backend::xnn_operator_count);
184+
ASSERT_EQ(5, tfjs::backend::xnn_operator_count);
174185

175186
// Disposing the second bias should remove 2 operators it's associated with.
176187
tfjs::wasm::dispose_data(bias1_id);

tfjs-backend-wasm/src/cc/kernels/FusedDepthwiseConv2D_test.cc

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,20 @@ TEST(FUSEDDEPTHWISECONV2D, xnn_operator_lifetime) {
184184
prelu_weights_id, out_id);
185185
ASSERT_EQ(7, tfjs::backend::xnn_operator_count);
186186

187+
// One new XNN operator should be created for the next call to conv2d with a
188+
// different activation.
189+
const int activation2 = tfjs::wasm::FusableActivation::RELU6;
190+
tfjs::wasm::FusedDepthwiseConv2D(
191+
x1_id, batch_size, input_height, input_width, weights1_id, filter_height,
192+
filter_width, bias1_id, pad_top1, pad_right, pad_bottom1, pad_left,
193+
is_same_pad1, dilation_height, dilation_width, stride_height,
194+
stride_width, input_channels, output_channels, activation2,
195+
prelu_weights_id, out_id);
196+
ASSERT_EQ(8, tfjs::backend::xnn_operator_count);
197+
187198
// Disposing the first weights should remove 2 operators.
188199
tfjs::wasm::dispose_data(weights0_id);
189-
ASSERT_EQ(5, tfjs::backend::xnn_operator_count);
200+
ASSERT_EQ(6, tfjs::backend::xnn_operator_count);
190201

191202
// Disposing the second bias should remove 2 operators it's associated with.
192203
tfjs::wasm::dispose_data(bias1_id);

tfjs-backend-wasm/src/cc/kernels/MaxPool.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ void MaxPool(const int x_id, const int batch_size, const int input_height,
6767
auto operator_cache_idx = operator_cache.find(cache_key);
6868

6969
if (operator_cache_idx == operator_cache.end()) {
70-
float output_min = -std::numeric_limits<float>::infinity();
71-
float output_max = std::numeric_limits<float>::infinity();
70+
const float output_min = -std::numeric_limits<float>::infinity();
71+
const float output_max = std::numeric_limits<float>::infinity();
7272

7373
xnn_status status = xnn_create_max_pooling2d_nhwc_f32(
7474
pad_top, pad_right, pad_bottom, pad_left, filter_height, filter_width,

0 commit comments

Comments
 (0)