@@ -15,8 +15,12 @@ limitations under the License.
1515#include " tensorflow/compiler/mlir/lite/schema/schema_utils.h"
1616
1717#include < algorithm>
18+ #include < complex>
19+ #include < cstddef>
20+ #include < cstdint>
1821
1922#include " tensorflow/compiler/mlir/lite/kernels/internal/compatibility_macros.h"
23+ #include " tensorflow/compiler/mlir/lite/schema/schema_generated.h"
2024
2125namespace tflite {
2226
@@ -59,4 +63,51 @@ BuiltinOperator GetBuiltinCode(const OperatorCodeT* op_code) {
5963 op_code->deprecated_builtin_code ));
6064}
6165
66+ size_t TensorTypeGetSize (::tflite::TensorType data_type) {
67+ switch (data_type) {
68+ case ::tflite::TensorType_FLOAT32:
69+ static_assert (sizeof (float ) == 4 , " " );
70+ return 4 ;
71+ case ::tflite::TensorType_FLOAT16:
72+ static_assert (sizeof (int16_t ) == 2 , " " );
73+ return 2 ;
74+ case ::tflite::TensorType_INT32:
75+ static_assert (sizeof (int32_t ) == 4 , " " );
76+ return 4 ;
77+ case ::tflite::TensorType_UINT8:
78+ static_assert (sizeof (uint8_t ) == 1 , " " );
79+ return 1 ;
80+ case ::tflite::TensorType_INT64:
81+ static_assert (sizeof (int64_t ) == 8 , " " );
82+ return 8 ;
83+ case ::tflite::TensorType_BOOL:
84+ return sizeof (bool );
85+ case ::tflite::TensorType_INT16:
86+ static_assert (sizeof (int16_t ) == 2 , " " );
87+ return 2 ;
88+ case ::tflite::TensorType_COMPLEX64:
89+ static_assert (sizeof (std::complex <float >) == 8 , " " );
90+ return 8 ;
91+ case ::tflite::TensorType_INT8:
92+ static_assert (sizeof (int8_t ) == 1 , " " );
93+ return 1 ;
94+ case ::tflite::TensorType_FLOAT64:
95+ static_assert (sizeof (double ) == 8 , " " );
96+ return 8 ;
97+ case ::tflite::TensorType_COMPLEX128:
98+ static_assert (sizeof (std::complex <double >) == 16 , " " );
99+ return 16 ;
100+ case ::tflite::TensorType_UINT64:
101+ static_assert (sizeof (uint64_t ) == 8 , " " );
102+ return 8 ;
103+ case ::tflite::TensorType_UINT32:
104+ static_assert (sizeof (uint32_t ) == 4 , " " );
105+ return 4 ;
106+ case ::tflite::TensorType_UINT16:
107+ static_assert (sizeof (uint16_t ) == 2 , " " );
108+ return 2 ;
109+ default :
110+ return 0 ;
111+ }
112+ }
62113} // namespace tflite
0 commit comments