Skip to content

Commit ce847d5

Browse files
authored
Merge pull request #77 from ljn917/cppflow2-refactor-tensor
Add tensor::get_tensor() and tensor::get_eager_handle()
2 parents d1aad32 + f9bd7a1 commit ce847d5

File tree

7 files changed

+1099
-81
lines changed

7 files changed

+1099
-81
lines changed

examples/tensor/CMakeLists.txt

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
cmake_minimum_required(VERSION 3.10)
2+
project(example)
3+
4+
find_library(TENSORFLOW_LIB tensorflow HINT $ENV{HOME}/libtensorflow2/lib)
5+
6+
set(CMAKE_CXX_STANDARD 17)
7+
8+
add_executable(example main.cpp)
9+
target_include_directories(example PRIVATE ../../include $ENV{HOME}/libtensorflow2/include)
10+
target_link_libraries(example "${TENSORFLOW_LIB}")

examples/tensor/main.cpp

+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
#include <cmath>
2+
#include <vector>
3+
#include <iostream>
4+
#include <stdexcept>
5+
6+
#include "cppflow/cppflow.h"
7+
8+
bool float_equal(const float f1, const float f2) {
9+
return std::abs(f1/f2-1.0f) < 1e-6;
10+
}
11+
12+
void test1(const bool is_cpu) {
13+
std::cout << "test1 starts: is_cpu=" << is_cpu << std::endl;
14+
float target = 1.0;
15+
int64_t ndim = 2;
16+
cppflow::tensor t1;
17+
18+
if(is_cpu) {
19+
std::vector<float> _data(ndim, target);
20+
t1 = cppflow::tensor(_data, {ndim});
21+
} else {
22+
t1 = cppflow::fill({ndim}, target);
23+
}
24+
25+
std::cout << "tensor::device(true) : " << t1.device(true) << std::endl;
26+
std::cout << "tensor::device(false) : " << t1.device(false) << std::endl;
27+
28+
auto t1_tensor = t1.get_tensor();
29+
auto raw_data = static_cast<float*>(TF_TensorData(t1_tensor.get()));
30+
float result_value = raw_data[0];
31+
if(float_equal(result_value, target)) {
32+
std::cout << "tensor::get_tensor() test1-1: pass" << std::endl;
33+
} else {
34+
std::cout << "tensor::get_tensor() test1-1: result_value=" << result_value << ", target=" << target << std::endl;
35+
throw std::runtime_error("tensor::get_tensor() test1-1: failed");
36+
}
37+
38+
// IMPORTANT NOTE: CANNOT modify the returned cache
39+
float target2 = target + 10.0;
40+
raw_data[1] = target2;
41+
result_value = t1.get_data<float>()[0];
42+
float result_value2 = t1.get_data<float>()[1];
43+
if(float_equal(result_value, target)) {
44+
std::cout << "tensor::get_tensor() test1-2: pass" << std::endl;
45+
} else {
46+
std::cout << "tensor::get_tensor() test1-2: failed, result_value=" << result_value << ", target=" << target << std::endl;
47+
throw std::runtime_error("tensor::get_tensor() test1-2: failed");
48+
}
49+
if(float_equal(result_value2, target2)) {
50+
std::cout << "tensor::get_tensor() test1-3: pass" << std::endl;
51+
} else {
52+
std::cout << "The failure of test1-3 is not considered as a bug." << std::endl;
53+
std::cout << "tensor::get_tensor() test1-3: failed, result_value=" << result_value2 << ", target2=" << target2 << std::endl;
54+
}
55+
56+
auto t2 = t1 + cppflow::tensor(0.f);
57+
std::cout << "Can NOT modify the cache!" << std::endl;
58+
std::cout << "t2: " << t2 << std::endl;
59+
60+
auto dt = cppflow::to_string(t1.dtype());
61+
std::string expected_dtype{"TF_FLOAT"};
62+
if(dt == expected_dtype) {
63+
std::cout << "tensor::get_tensor() test1-4: pass" << std::endl;
64+
} else {
65+
std::cout << "tensor::get_tensor() test1-4: dtype=" << dt << ", expected_dtype=" << expected_dtype << std::endl;
66+
throw std::runtime_error("tensor::get_tensor() test1-4: failed");
67+
}
68+
69+
auto shape_tensor = t1.shape();
70+
auto shape = shape_tensor.get_data<int32_t>()[0];
71+
if(shape == ndim) {
72+
std::cout << "tensor::get_tensor() test1-5: pass" << std::endl;
73+
} else {
74+
std::cout << "tensor::get_tensor() test1-5: shape_tensor.dtype()=" << cppflow::to_string(shape_tensor.dtype()) << std::endl;
75+
std::cout << "tensor::get_tensor() test1-5: shape_tensor=" << shape_tensor << std::endl;
76+
std::cout << "tensor::get_tensor() test1-5: shape()=" << shape << ", ndim=" << ndim << std::endl;
77+
throw std::runtime_error("tensor::get_tensor() test1-5: failed");
78+
}
79+
80+
std::cout << std::endl;
81+
}
82+
83+
int main() {
84+
test1(true);
85+
test1(false);
86+
87+
return 0;
88+
}

include/cppflow/model.h

+1-8
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,6 @@ namespace cppflow {
8686
std::vector<TF_Output> inp_ops(inputs.size());
8787
std::vector<TF_Tensor*> inp_val(inputs.size(), nullptr);
8888

89-
defer d([&inp_val]{
90-
for (auto* tf_tensor : inp_val) {
91-
TF_DeleteTensor(tf_tensor);
92-
}
93-
});
9489
for (int i=0; i<inputs.size(); i++) {
9590

9691
// Operations
@@ -102,9 +97,7 @@ namespace cppflow {
10297
throw std::runtime_error("No operation named \"" + op_name + "\" exists");
10398

10499
// Values
105-
auto inp_tensor = TFE_TensorHandleResolve(std::get<1>(inputs[i]).tfe_handle.get(), context::get_status());
106-
status_check(context::get_status());
107-
inp_val[i] = inp_tensor;
100+
inp_val[i] = std::get<1>(inputs[i]).get_tensor().get();
108101
}
109102

110103
std::vector<TF_Output> out_ops(outputs.size());

include/cppflow/ops.h

+4-7
Original file line numberDiff line numberDiff line change
@@ -80,24 +80,21 @@ namespace cppflow {
8080

8181
std::string to_string(const tensor &t) {
8282
auto res_tensor = string_format({t.shape(), t}, "(tensor: shape=%s, data=\n%s)");
83-
auto res_tensor_h = TFE_TensorHandleResolve(res_tensor.tfe_handle.get(), context::get_status());
84-
status_check(context::get_status());
83+
auto res_tensor_h = res_tensor.get_tensor();
8584

8685
#ifdef TENSORFLOW_C_TF_TSTRING_H_
8786
// For future version TensorFlow 2.4
88-
//auto *t_str = reinterpret_cast<TF_TString *>(TF_TensorData(res_tensor_h));
89-
auto *t_str = (TF_TString *)(TF_TensorData(res_tensor_h));
87+
//auto *t_str = reinterpret_cast<TF_TString *>(TF_TensorData(res_tensor_h.get()));
88+
auto *t_str = (TF_TString *)(TF_TensorData(res_tensor_h.get()));
9089
auto result = std::string(TF_TString_GetDataPointer(t_str), TF_TString_GetSize(t_str));
9190
#else
9291
const char* dst[1] = {nullptr};
9392
size_t dst_len[1] = {3};
94-
TF_StringDecode(static_cast<char*>(TF_TensorData(res_tensor_h)) + 8, TF_TensorByteSize(res_tensor_h), dst, dst_len, context::get_status());
93+
TF_StringDecode(static_cast<char*>(TF_TensorData(res_tensor_h.get())) + 8, TF_TensorByteSize(res_tensor_h.get()), dst, dst_len, context::get_status());
9594
status_check(context::get_status());
9695
auto result = std::string(dst[0], *dst_len);
9796
#endif // TENSORFLOW_C_TF_TSTRING_H_
9897

99-
TF_DeleteTensor(res_tensor_h);
100-
10198
return result;
10299
}
103100

include/cppflow/ops_generator/generator.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def code(self):
106106
'type' : 'TFE_OpSetAttrType(op.get(), "{orig:}", {0});',
107107
'bool' : 'TFE_OpSetAttrBool(op.get(), "{orig:}", (unsigned char){0});',
108108
'tensor': '''
109-
TFE_OpSetAttrTensor(op.get(), "{orig:}", {0}.tf_tensor.get(), context::get_status());
109+
TFE_OpSetAttrTensor(op.get(), "{orig:}", {0}.get_tensor().get(), context::get_status());
110110
status_check(context::get_status());
111111
''',
112112
'n_attr': 'TFE_OpSetAttrInt(op.get(), "{orig:}", {n_attr:}.size());'
@@ -268,4 +268,4 @@ def code(self):
268268

269269

270270
with open('../raw_ops.h', 'w') as f:
271-
f.write(ops_file.format(ops_code))
271+
f.write(ops_file.format(ops_code))

0 commit comments

Comments
 (0)