Skip to content

Commit a977a77

Browse files
Add bidirectional sequence RNN to TFLite Ops.
PiperOrigin-RevId: 183465032
1 parent e9dc418 commit a977a77

File tree

7 files changed

+1205
-0
lines changed

7 files changed

+1205
-0
lines changed

tensorflow/contrib/lite/kernels/BUILD

+13
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ cc_library(
104104
"add.cc",
105105
"basic_rnn.cc",
106106
"batch_to_space_nd.cc",
107+
"bidirectional_sequence_rnn.cc",
107108
"concatenation.cc",
108109
"conv.cc",
109110
"depthwise_conv.cc",
@@ -288,6 +289,18 @@ tf_cc_test(
288289
],
289290
)
290291

292+
tf_cc_test(
293+
name = "bidirectional_sequence_rnn_test",
294+
size = "small",
295+
srcs = ["bidirectional_sequence_rnn_test.cc"],
296+
deps = [
297+
":builtin_ops",
298+
"//tensorflow/contrib/lite:framework",
299+
"//tensorflow/contrib/lite/kernels:test_util",
300+
"@com_google_googletest//:gtest",
301+
],
302+
)
303+
291304
tf_cc_test(
292305
name = "unidirectional_sequence_rnn_test",
293306
size = "small",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
#include <unistd.h>
16+
#include <cassert>
17+
#include <cmath>
18+
#include <cstdlib>
19+
#include <cstdio>
20+
#include <iostream>
21+
#include <limits>
22+
23+
#include "tensorflow/contrib/lite/builtin_op_data.h"
24+
#include "tensorflow/contrib/lite/context.h"
25+
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
26+
#include "tensorflow/contrib/lite/kernels/op_macros.h"
27+
28+
namespace tflite {
29+
namespace ops {
30+
namespace builtin {
31+
namespace bidirectional_sequence_rnn {
32+
33+
constexpr int kInputTensor = 0;
34+
// Forward and backward cell tensors.
35+
constexpr int kFwWeightsTensor = 1;
36+
constexpr int kFwRecurrentWeightsTensor = 2;
37+
constexpr int kFwBiasTensor = 3;
38+
constexpr int kBwWeightsTensor = 4;
39+
constexpr int kBwRecurrentWeightsTensor = 5;
40+
constexpr int kBwBiasTensor = 6;
41+
// State and output tensors.
42+
constexpr int kFwHiddenStateTensor = 0;
43+
constexpr int kFwOutputTensor = 1;
44+
constexpr int kBwHiddenStateTensor = 2;
45+
constexpr int kBwOutputTensor = 3;
46+
47+
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
48+
// Check we have all the inputs and outputs we need.
49+
TF_LITE_ENSURE_EQ(context, node->inputs->size, 7);
50+
TF_LITE_ENSURE_EQ(context, node->outputs->size, 4);
51+
52+
TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]];
53+
TfLiteTensor* fw_input_weights =
54+
&context->tensors[node->inputs->data[kFwWeightsTensor]];
55+
TfLiteTensor* fw_recurrent_weights =
56+
&context->tensors[node->inputs->data[kFwRecurrentWeightsTensor]];
57+
TfLiteTensor* fw_bias = &context->tensors[node->inputs->data[kFwBiasTensor]];
58+
TfLiteTensor* bw_input_weights =
59+
&context->tensors[node->inputs->data[kBwWeightsTensor]];
60+
TfLiteTensor* bw_recurrent_weights =
61+
&context->tensors[node->inputs->data[kBwRecurrentWeightsTensor]];
62+
TfLiteTensor* bw_bias = &context->tensors[node->inputs->data[kBwBiasTensor]];
63+
64+
// Check all the parameters of tensor match within themselves and match the
65+
// input configuration.
66+
const int batch_size = input->dims->data[0];
67+
const int max_time = input->dims->data[1];
68+
const int fw_num_units = fw_input_weights->dims->data[0];
69+
const int bw_num_units = bw_input_weights->dims->data[0];
70+
TF_LITE_ASSERT_EQ(input->dims->data[2], fw_input_weights->dims->data[1]);
71+
TF_LITE_ASSERT_EQ(input->dims->data[2], bw_input_weights->dims->data[1]);
72+
TF_LITE_ASSERT_EQ(fw_input_weights->dims->data[0], fw_bias->dims->data[0]);
73+
TF_LITE_ASSERT_EQ(bw_input_weights->dims->data[0], bw_bias->dims->data[0]);
74+
TF_LITE_ASSERT_EQ(fw_recurrent_weights->dims->data[0],
75+
fw_bias->dims->data[0]);
76+
TF_LITE_ASSERT_EQ(bw_recurrent_weights->dims->data[1],
77+
bw_bias->dims->data[0]);
78+
79+
TfLiteTensor* fw_output =
80+
&context->tensors[node->outputs->data[kFwOutputTensor]];
81+
TfLiteTensor* bw_output =
82+
&context->tensors[node->outputs->data[kBwOutputTensor]];
83+
84+
// Resize hidden states.
85+
TfLiteIntArray* fw_hidden_state_size_array = TfLiteIntArrayCreate(2);
86+
fw_hidden_state_size_array->data[0] = batch_size;
87+
fw_hidden_state_size_array->data[1] = fw_num_units;
88+
TfLiteTensor* fw_hidden_state =
89+
&context->tensors[node->outputs->data[kFwHiddenStateTensor]];
90+
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_hidden_state,
91+
fw_hidden_state_size_array));
92+
93+
TfLiteIntArray* bw_hidden_state_size_array = TfLiteIntArrayCreate(2);
94+
bw_hidden_state_size_array->data[0] = batch_size;
95+
bw_hidden_state_size_array->data[1] = fw_num_units;
96+
TfLiteTensor* bw_hidden_state =
97+
&context->tensors[node->outputs->data[kBwHiddenStateTensor]];
98+
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_hidden_state,
99+
bw_hidden_state_size_array));
100+
101+
// Mark hidden states as a persistent tensor.
102+
fw_hidden_state->allocation_type = kTfLiteArenaRwPersistent;
103+
bw_hidden_state->allocation_type = kTfLiteArenaRwPersistent;
104+
105+
// Resize outputs.
106+
TfLiteIntArray* fw_output_size_array = TfLiteIntArrayCreate(3);
107+
fw_output_size_array->data[0] = batch_size;
108+
fw_output_size_array->data[1] = max_time;
109+
fw_output_size_array->data[2] = fw_num_units;
110+
TF_LITE_ENSURE_OK(
111+
context, context->ResizeTensor(context, fw_output, fw_output_size_array));
112+
TfLiteIntArray* bw_output_size_array = TfLiteIntArrayCreate(3);
113+
bw_output_size_array->data[0] = batch_size;
114+
bw_output_size_array->data[1] = max_time;
115+
bw_output_size_array->data[2] = bw_num_units;
116+
TF_LITE_ENSURE_OK(
117+
context, context->ResizeTensor(context, bw_output, bw_output_size_array));
118+
119+
return kTfLiteOk;
120+
}
121+
122+
namespace {
123+
// Performs one RNN computation step for the input specified by input_ptr_batch.
124+
// The RNN cell is specified by the pointers to its weights and biases, along
125+
// with the input size, number of units, strides, activation.
126+
// The pointers to the hidden state and the output are updated as a result.
127+
// TODO(mirkov): factor out this function to a shared library.
128+
void RnnStep(const float* input_ptr_batch, const float* input_weights_ptr,
129+
const float* recurrent_weights_ptr, const float* bias_ptr,
130+
int input_size, int num_units, int input_weights_stride,
131+
int recurrent_weights_stride, TfLiteFusedActivation activation,
132+
float* hidden_state_ptr_batch, float* output_ptr_batch) {
133+
// Output = bias
134+
for (int o = 0; o < num_units; o++) {
135+
output_ptr_batch[o] = bias_ptr[o];
136+
}
137+
138+
// Output += input * input_weights
139+
for (int o = 0; o < num_units; o++) {
140+
for (int i = 0; i < input_size; i++) {
141+
output_ptr_batch[o] += input_ptr_batch[i] * input_weights_ptr[i];
142+
}
143+
input_weights_ptr += input_weights_stride;
144+
}
145+
146+
// Output += recurrent_weights * hidden_state
147+
for (int o = 0; o < num_units; o++) {
148+
for (int h = 0; h < num_units; h++) {
149+
output_ptr_batch[o] +=
150+
hidden_state_ptr_batch[h] * recurrent_weights_ptr[h];
151+
}
152+
recurrent_weights_ptr += recurrent_weights_stride;
153+
}
154+
155+
// Output = activation(Output) and update hidden_state
156+
for (int o = 0; o < num_units; o++) {
157+
output_ptr_batch[o] = (ActivationFunctor(activation))(output_ptr_batch[o]);
158+
hidden_state_ptr_batch[o] = output_ptr_batch[o];
159+
}
160+
}
161+
} // namespace
162+
163+
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
164+
auto* params = reinterpret_cast<TfLiteSequenceRNNParams*>(node->builtin_data);
165+
166+
TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]];
167+
TfLiteTensor* fw_input_weights =
168+
&context->tensors[node->inputs->data[kFwWeightsTensor]];
169+
TfLiteTensor* fw_recurrent_weights =
170+
&context->tensors[node->inputs->data[kFwRecurrentWeightsTensor]];
171+
TfLiteTensor* fw_bias = &context->tensors[node->inputs->data[kFwBiasTensor]];
172+
TfLiteTensor* fw_hidden_state =
173+
&context->tensors[node->outputs->data[kFwHiddenStateTensor]];
174+
TfLiteTensor* fw_output =
175+
&context->tensors[node->outputs->data[kFwOutputTensor]];
176+
177+
TfLiteTensor* bw_input_weights =
178+
&context->tensors[node->inputs->data[kBwWeightsTensor]];
179+
TfLiteTensor* bw_recurrent_weights =
180+
&context->tensors[node->inputs->data[kBwRecurrentWeightsTensor]];
181+
TfLiteTensor* bw_bias = &context->tensors[node->inputs->data[kBwBiasTensor]];
182+
TfLiteTensor* bw_hidden_state =
183+
&context->tensors[node->outputs->data[kBwHiddenStateTensor]];
184+
TfLiteTensor* bw_output =
185+
&context->tensors[node->outputs->data[kBwOutputTensor]];
186+
187+
const int batch_size = input->dims->data[0];
188+
const int max_time = input->dims->data[1];
189+
const int input_size = input->dims->data[2];
190+
191+
const int fw_num_units = fw_input_weights->dims->data[0];
192+
const int fw_input_weights_stride = fw_input_weights->dims->data[1];
193+
const int fw_recurrent_weights_stride = fw_recurrent_weights->dims->data[1];
194+
const float* fw_bias_ptr = fw_bias->data.f;
195+
const float* fw_input_weights_ptr = fw_input_weights->data.f;
196+
const float* fw_recurrent_weights_ptr = fw_recurrent_weights->data.f;
197+
198+
const int bw_num_units = bw_input_weights->dims->data[0];
199+
const int bw_input_weights_stride = bw_input_weights->dims->data[1];
200+
const int bw_recurrent_weights_stride = bw_recurrent_weights->dims->data[1];
201+
const float* bw_bias_ptr = bw_bias->data.f;
202+
const float* bw_input_weights_ptr = bw_input_weights->data.f;
203+
const float* bw_recurrent_weights_ptr = bw_recurrent_weights->data.f;
204+
205+
for (int b = 0; b < batch_size; b++) {
206+
// Forward cell.
207+
float* fw_hidden_state_ptr_batch =
208+
fw_hidden_state->data.f + b * fw_num_units;
209+
for (int s = 0; s < max_time; s++) {
210+
const float* input_ptr_batch =
211+
input->data.f + b * input_size * max_time + s * input_size;
212+
float* output_ptr_batch =
213+
fw_output->data.f + b * fw_num_units * max_time + s * fw_num_units;
214+
215+
RnnStep(input_ptr_batch, fw_input_weights_ptr, fw_recurrent_weights_ptr,
216+
fw_bias_ptr, input_size, fw_num_units, fw_input_weights_stride,
217+
fw_recurrent_weights_stride, params->activation,
218+
fw_hidden_state_ptr_batch, output_ptr_batch);
219+
}
220+
// Backward cell.
221+
float* bw_hidden_state_ptr_batch =
222+
bw_hidden_state->data.f + b * bw_num_units;
223+
for (int s = max_time - 1; s >= 0; s--) {
224+
const float* input_ptr_batch =
225+
input->data.f + b * input_size * max_time + s * input_size;
226+
float* output_ptr_batch =
227+
bw_output->data.f + b * bw_num_units * max_time + s * bw_num_units;
228+
229+
RnnStep(input_ptr_batch, bw_input_weights_ptr, bw_recurrent_weights_ptr,
230+
bw_bias_ptr, input_size, bw_num_units, bw_input_weights_stride,
231+
bw_recurrent_weights_stride, params->activation,
232+
bw_hidden_state_ptr_batch, output_ptr_batch);
233+
}
234+
}
235+
return kTfLiteOk;
236+
}
237+
238+
} // namespace bidirectional_sequence_rnn
239+
240+
TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_RNN() {
241+
static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
242+
bidirectional_sequence_rnn::Prepare,
243+
bidirectional_sequence_rnn::Eval};
244+
return &r;
245+
}
246+
247+
} // namespace builtin
248+
} // namespace ops
249+
} // namespace tflite

0 commit comments

Comments
 (0)