Skip to content

Commit ab34d55

Browse files
author
Vijay Vasudevan
committed
TensorFlow: more features, performance improvements, and doc fixes.
Changes: - Add Split/Concat() methods to TensorUtil (meant for convenience, not speed) by Chris. - Changes to linear algebra ops interface by Rasmus - Tests for tensorboard by Daniel - Fix bug in histogram calculation by Cassandra - Added tool for backwards compatibility of OpDefs. Tool Checks in history of opdefs and their changes, checks for backwards-incompatible changes. All done by @josh11b - Fix some protobuf example proto docs by Oliver - Add derivative of MatrixDeterminant by @yaroslavvb - Add a priority queue queue by @ebrevdo - Doc and typo fixes by Aurelien and @dave-andersen - Speed improvements to ConvBackwardFilter by @AndyDavis - Improve speed of Alexnet on TitanX by @zheng-xq - Add some host memory annotations to some GPU kernels by Yuan. - Add support for doubles in histogram summary by @jmchen-g Base CL: 108158338
1 parent 9eb88d5 commit ab34d55

File tree

111 files changed

+11229
-2753
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

111 files changed

+11229
-2753
lines changed

tensorflow/core/example/example.proto

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,39 +11,39 @@ package tensorflow;
1111
// features {
1212
// feature {
1313
// key: "age"
14-
// float_list {
14+
// value { float_list {
1515
// value: 29.0
16-
// }
16+
// }}
1717
// }
1818
// feature {
1919
// key: "movie"
20-
// bytes_list {
20+
// value { bytes_list {
2121
// value: "The Shawshank Redemption"
2222
// value: "Fight Club"
23-
// }
23+
// }}
2424
// }
2525
// feature {
2626
// key: "movie_ratings"
27-
// float_list {
27+
// value { float_list {
2828
// value: 9.0
2929
// value: 9.7
30-
// }
30+
// }}
3131
// }
3232
// feature {
3333
// key: "suggestion"
34-
// bytes_list {
34+
// value { bytes_list {
3535
// value: "Inception"
36-
// }
36+
// }}
3737
// }
3838
// # Note that this feature exists to be used as a label in training.
3939
// # E.g., if training a logistic regression model to predict purchase
4040
// # probability in our learning tool we would set the label feature to
4141
// # "suggestion_purchased".
4242
// feature {
4343
// key: "suggestion_purchased"
44-
// float_list {
44+
// value { float_list {
4545
// value: 1.0
46-
// }
46+
// }}
4747
// }
4848
// # Similar to "suggestion_purchased" above this feature exists to be used
4949
// # as a label in training.
@@ -52,9 +52,9 @@ package tensorflow;
5252
// # "purchase_price".
5353
// feature {
5454
// key: "purchase_price"
55-
// float_list {
55+
// value { float_list {
5656
// value: 9.99
57-
// }
57+
// }}
5858
// }
5959
// }
6060
//

tensorflow/core/example/feature.proto

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,41 +14,41 @@
1414
// Example Features for a movie recommendation application:
1515
// feature {
1616
// key: "age"
17-
// float_list {
17+
// value { float_list {
1818
// value: 29.0
19-
// }
19+
// }}
2020
// }
2121
// feature {
2222
// key: "movie"
23-
// bytes_list {
23+
// value { bytes_list {
2424
// value: "The Shawshank Redemption"
2525
// value: "Fight Club"
26-
// }
26+
// }}
2727
// }
2828
// feature {
2929
// key: "movie_ratings"
30-
// float_list {
30+
// value { float_list {
3131
// value: 9.0
3232
// value: 9.7
33-
// }
33+
// }}
3434
// }
3535
// feature {
3636
// key: "suggestion"
37-
// bytes_list {
37+
// value { bytes_list {
3838
// value: "Inception"
39-
// }
39+
// }}
4040
// }
4141
// feature {
4242
// key: "suggestion_purchased"
43-
// int64_list {
43+
// value { int64_list {
4444
// value: 1
45-
// }
45+
// }}
4646
// }
4747
// feature {
4848
// key: "purchase_price"
49-
// float_list {
49+
// value { float_list {
5050
// value: 9.99
51-
// }
51+
// }}
5252
// }
5353

5454
syntax = "proto3";

tensorflow/core/framework/tensor_util.cc

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,107 @@ Tensor DeepCopy(const Tensor& other) {
2424
return tmp;
2525
}
2626

27+
Tensor Concat(const gtl::ArraySlice<Tensor>& tensors) {
28+
CHECK_GT(tensors.size(), 0);
29+
int64 total_dim0_size = 0;
30+
for (const Tensor& tensor : tensors) {
31+
CHECK_GT(tensor.dims(), 0);
32+
total_dim0_size += tensor.dim_size(0);
33+
}
34+
TensorShape shape = tensors[0].shape();
35+
shape.set_dim(0, total_dim0_size);
36+
Tensor result = Tensor(tensors[0].dtype(), shape);
37+
38+
// We use StringPiece as a convenient map over the tensor buffer,
39+
// but we cast the type to get to the underlying buffer to do the
40+
// copy.
41+
StringPiece to_data = result.tensor_data();
42+
43+
if (DataTypeCanUseMemcpy(result.dtype())) {
44+
int64 offset = 0;
45+
for (const Tensor& tensor : tensors) {
46+
StringPiece from_data = tensor.tensor_data();
47+
CHECK_LE(offset + from_data.size(), to_data.size());
48+
memcpy(const_cast<char*>(to_data.data()) + offset, from_data.data(),
49+
from_data.size());
50+
51+
offset += from_data.size();
52+
}
53+
} else {
54+
CHECK_EQ(DT_STRING, result.dtype());
55+
string* to_strings =
56+
reinterpret_cast<string*>(const_cast<char*>(to_data.data()));
57+
58+
int64 offset = 0;
59+
for (const Tensor& tensor : tensors) {
60+
auto from_strings = tensor.flat<string>();
61+
CHECK_LE(offset + tensor.NumElements(), result.NumElements());
62+
for (int i = 0; i < tensor.NumElements(); ++i) {
63+
to_strings[offset + i] = from_strings(i);
64+
}
65+
66+
offset += tensor.NumElements();
67+
}
68+
}
69+
70+
return result;
71+
}
72+
73+
std::vector<Tensor> Split(const Tensor& tensor,
74+
const gtl::ArraySlice<int64>& sizes) {
75+
CHECK_GT(tensor.dims(), 0);
76+
int64 total_size = 0;
77+
for (int64 size : sizes) {
78+
total_size += size;
79+
}
80+
CHECK_EQ(total_size, tensor.dim_size(0));
81+
82+
std::vector<Tensor> result;
83+
84+
StringPiece from_data = tensor.tensor_data();
85+
86+
if (DataTypeCanUseMemcpy(tensor.dtype())) {
87+
int64 offset = 0;
88+
for (int64 size : sizes) {
89+
TensorShape shape = tensor.shape();
90+
shape.set_dim(0, size);
91+
result.emplace_back(tensor.dtype(), shape);
92+
Tensor* split = &result[result.size() - 1];
93+
94+
// We use StringPiece as a convenient map over the tensor buffer,
95+
// but we cast the type to get to the underlying buffer to do the
96+
// copy.
97+
StringPiece to_data = split->tensor_data();
98+
CHECK_LE(offset + to_data.size(), from_data.size());
99+
memcpy(const_cast<char*>(to_data.data()), from_data.data() + offset,
100+
to_data.size());
101+
102+
offset += to_data.size();
103+
}
104+
} else {
105+
CHECK_EQ(DT_STRING, tensor.dtype());
106+
auto from_strings = tensor.flat<string>();
107+
108+
int64 offset = 0;
109+
for (int64 size : sizes) {
110+
TensorShape shape = tensor.shape();
111+
shape.set_dim(0, size);
112+
result.emplace_back(tensor.dtype(), shape);
113+
Tensor& split = result[result.size() - 1];
114+
string* to_strings = reinterpret_cast<string*>(
115+
const_cast<char*>(split.tensor_data().data()));
116+
117+
CHECK_LE(offset + split.NumElements(), tensor.NumElements());
118+
for (int i = 0; i < split.NumElements(); ++i) {
119+
to_strings[i] = from_strings(offset + i);
120+
}
121+
122+
offset += split.NumElements();
123+
}
124+
}
125+
126+
return result;
127+
}
128+
27129
} // namespace tensor
28130
} // namespace tensorflow

tensorflow/core/framework/tensor_util.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,28 @@ namespace tensor {
1515
// 'other' is not appropriately memory-aligned.
1616
Tensor DeepCopy(const Tensor& other);
1717

18+
// Concatenates 'tensors' into a single tensor, along their 0th dimension.
19+
//
20+
// REQUIRES: All members of 'tensors' must have the same data type parameter.
21+
// REQUIRES: Each member of 'tensors' must have at least one dimension.
22+
// REQUIRES: Each member of 'tensors' must point to data stored in CPU memory.
23+
// REQUIRES: Each member of 'tensors' must be a Tensor of a copy-able type if it
24+
// is not appropriately memory-aligned.
25+
Tensor Concat(const gtl::ArraySlice<Tensor>& tensors);
26+
27+
// Splits 'tensor' into 'sizes.size()' individual tensors, along the 0th
28+
// dimension. The ith output tensor has 0th-dimension size 'sizes[i]'.
29+
//
30+
// REQUIRES: 'tensor' must have at least one dimension.
31+
// REQUIRES: 'tensor.dim_size(0)' must equal the sum of the elements of 'sizes'.
32+
// REQUIRES: 'tensor' must point to data stored in CPU memory.
33+
// REQUIRES: 'tensor' must be a Tensor of a copy-able type if it is not
34+
// appropriately memory-aligned.
35+
//
36+
// Split() and Concat() are inverse operations.
37+
std::vector<Tensor> Split(const Tensor& tensor,
38+
const gtl::ArraySlice<int64>& sizes);
39+
1840
} // namespace tensor
1941
} // namespace tensorflow
2042

tensorflow/core/framework/tensor_util_test.cc

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,5 +120,81 @@ TEST(TensorUtil, DeepCopySlice) {
120120
}
121121
}
122122

123+
TEST(TensorUtil, Concat) {
124+
std::vector<int64> sizes = {1, 4, 5};
125+
std::vector<Tensor> to_concat;
126+
int64 total_size = 0;
127+
int offset = 0;
128+
for (int entry = 0; entry < sizes.size(); ++entry) {
129+
const int64 size = sizes[entry];
130+
Tensor tensor(DT_INT32, TensorShape({size, 2}));
131+
for (int i = offset; i < offset + size; ++i) {
132+
for (int j = 0; j < 2; ++j) {
133+
tensor.matrix<int32>()(i - offset, j) = 2 * i + j;
134+
}
135+
}
136+
to_concat.push_back(tensor);
137+
total_size += size;
138+
offset += size;
139+
}
140+
141+
Tensor concated = tensor::Concat(to_concat);
142+
ASSERT_EQ(TensorShape({total_size, 2}), concated.shape());
143+
for (int i = 0; i < total_size; ++i) {
144+
for (int j = 0; j < 2; ++j) {
145+
EXPECT_EQ(2 * i + j, concated.matrix<int32>()(i, j));
146+
}
147+
}
148+
}
149+
150+
TEST(TensorUtil, Split) {
151+
Tensor to_split(DT_INT64, TensorShape({10, 2}));
152+
for (int i = 0; i < 10; ++i) {
153+
for (int j = 0; j < 2; ++j) {
154+
to_split.matrix<int64>()(i, j) = 2 * i + j;
155+
}
156+
}
157+
158+
std::vector<int64> sizes = {1, 4, 5};
159+
std::vector<Tensor> splits = tensor::Split(to_split, sizes);
160+
ASSERT_EQ(sizes.size(), splits.size());
161+
162+
int offset = 0;
163+
for (int entry = 0; entry < splits.size(); ++entry) {
164+
const int64 size = sizes[entry];
165+
const Tensor& split = splits[entry];
166+
167+
ASSERT_EQ(TensorShape({size, 2}), split.shape());
168+
for (int i = offset; i < offset + size; ++i) {
169+
for (int j = 0; j < 2; ++j) {
170+
EXPECT_EQ(2 * i + j, split.matrix<int64>()(i - offset, j));
171+
}
172+
}
173+
174+
offset += size;
175+
}
176+
}
177+
178+
TEST(TensorUtil, ConcatSplitStrings) {
179+
Tensor x(DT_STRING, TensorShape({4, 3}));
180+
for (int i = 0; i < 4 * 3; ++i) {
181+
x.flat<string>()(i) = strings::StrCat("foo_", i);
182+
}
183+
184+
Tensor x_round_tripped = tensor::Concat(tensor::Split(x, {2, 1, 1}));
185+
ASSERT_EQ(x.shape(), x_round_tripped.shape());
186+
for (int i = 0; i < 4 * 3; ++i) {
187+
EXPECT_EQ(x.flat<string>()(i), x_round_tripped.flat<string>()(i));
188+
}
189+
190+
// Ensure that no memory is being shared between 'x' and 'x_round_tripped'.
191+
for (int i = 0; i < 4 * 3; ++i) {
192+
x_round_tripped.flat<string>()(i) = strings::StrCat("bar_", i);
193+
}
194+
for (int i = 0; i < 4 * 3; ++i) {
195+
EXPECT_NE(x.flat<string>()(i), x_round_tripped.flat<string>()(i));
196+
}
197+
}
198+
123199
} // namespace
124200
} // namespace tensorflow

tensorflow/core/framework/types.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,12 @@
55
#include <set>
66
#include <string>
77

8+
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
9+
// Disable clang-format to prevent 'FixedPoint' header from being included
10+
// before 'Tensor' header on which it depends.
11+
// clang-format off
812
#include "third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint"
13+
// clang-format on
914
#include "tensorflow/core/framework/bfloat16.h"
1015
#include "tensorflow/core/framework/numeric_types.h"
1116
#include "tensorflow/core/framework/types.pb.h"

tensorflow/core/kernels/bias_op_gpu.cu.cc

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,49 @@
22

33
#define EIGEN_USE_GPU
44

5+
#include <algorithm>
6+
57
#include "tensorflow/core/framework/register_types.h"
68
#include "tensorflow/core/kernels/bias_op.h"
9+
#include "tensorflow/core/util/cuda_kernel_helper.h"
710

811
namespace tensorflow {
912

1013
typedef Eigen::GpuDevice GPUDevice;
1114

1215
// Definition of the GPU implementations declared in bias_op.cc.
16+
17+
namespace functor {
18+
19+
template <typename T>
20+
__global__ void BiasOpCustomKernel(int nthreads, const T* input, const T* bias,
21+
int bias_size, int replicate_count,
22+
T* output) {
23+
CUDA_1D_KERNEL_LOOP(index, nthreads) {
24+
int bias_offset = index % bias_size;
25+
output[index] = input[index] + bias[bias_offset];
26+
}
27+
}
28+
29+
template <typename T, int Dims>
30+
struct Bias<GPUDevice, T, Dims> {
31+
typedef GPUDevice Device;
32+
// Add "bias" to "input", broadcasting it on all dimensions but the last one.
33+
void operator()(const Device& d, typename TTypes<T, Dims>::ConstTensor input,
34+
typename TTypes<T>::ConstVec bias,
35+
typename TTypes<T, Dims>::Tensor output) {
36+
const int bias_size = bias.dimension(0);
37+
const int rest_size = input.size() / bias_size;
38+
CudaLaunchConfig config = GetCudaLaunchConfig(output.size(), d);
39+
BiasOpCustomKernel<<<config.block_count, config.thread_per_block, 0,
40+
d.stream()>>>(config.virtual_thread_count,
41+
input.data(), bias.data(), bias_size,
42+
rest_size, output.data());
43+
}
44+
};
45+
46+
} // namespace functor
47+
1348
#define DEFINE_GPU_SPECS(T) \
1449
template struct functor::Bias<GPUDevice, T, 2>; \
1550
template struct functor::Bias<GPUDevice, T, 3>; \

0 commit comments

Comments
 (0)