Skip to content

Commit 5008ae5

Browse files
committed
Sync from upstream TF.
1 parent 6264a69 commit 5008ae5

File tree

3 files changed

+62
-1
lines changed

3 files changed

+62
-1
lines changed

tensorflow/compiler/mlir/lite/schema/schema_utils.cc

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,12 @@ limitations under the License.
1515
#include "tensorflow/compiler/mlir/lite/schema/schema_utils.h"
1616

1717
#include <algorithm>
18+
#include <complex>
19+
#include <cstddef>
20+
#include <cstdint>
1821

1922
#include "tensorflow/compiler/mlir/lite/kernels/internal/compatibility_macros.h"
23+
#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h"
2024

2125
namespace tflite {
2226

@@ -59,4 +63,51 @@ BuiltinOperator GetBuiltinCode(const OperatorCodeT* op_code) {
5963
op_code->deprecated_builtin_code));
6064
}
6165

66+
size_t TensorTypeGetSize(::tflite::TensorType data_type) {
67+
switch (data_type) {
68+
case ::tflite::TensorType_FLOAT32:
69+
static_assert(sizeof(float) == 4, "");
70+
return 4;
71+
case ::tflite::TensorType_FLOAT16:
72+
static_assert(sizeof(int16_t) == 2, "");
73+
return 2;
74+
case ::tflite::TensorType_INT32:
75+
static_assert(sizeof(int32_t) == 4, "");
76+
return 4;
77+
case ::tflite::TensorType_UINT8:
78+
static_assert(sizeof(uint8_t) == 1, "");
79+
return 1;
80+
case ::tflite::TensorType_INT64:
81+
static_assert(sizeof(int64_t) == 8, "");
82+
return 8;
83+
case ::tflite::TensorType_BOOL:
84+
return sizeof(bool);
85+
case ::tflite::TensorType_INT16:
86+
static_assert(sizeof(int16_t) == 2, "");
87+
return 2;
88+
case ::tflite::TensorType_COMPLEX64:
89+
static_assert(sizeof(std::complex<float>) == 8, "");
90+
return 8;
91+
case ::tflite::TensorType_INT8:
92+
static_assert(sizeof(int8_t) == 1, "");
93+
return 1;
94+
case ::tflite::TensorType_FLOAT64:
95+
static_assert(sizeof(double) == 8, "");
96+
return 8;
97+
case ::tflite::TensorType_COMPLEX128:
98+
static_assert(sizeof(std::complex<double>) == 16, "");
99+
return 16;
100+
case ::tflite::TensorType_UINT64:
101+
static_assert(sizeof(uint64_t) == 8, "");
102+
return 8;
103+
case ::tflite::TensorType_UINT32:
104+
static_assert(sizeof(uint32_t) == 4, "");
105+
return 4;
106+
case ::tflite::TensorType_UINT16:
107+
static_assert(sizeof(uint16_t) == 2, "");
108+
return 2;
109+
default:
110+
return 0;
111+
}
112+
}
62113
} // namespace tflite

tensorflow/compiler/mlir/lite/schema/schema_utils.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ limitations under the License.
1515
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_SCHEMA_SCHEMA_UTILS_H_
1616
#define TENSORFLOW_COMPILER_MLIR_LITE_SCHEMA_SCHEMA_UTILS_H_
1717

18+
#include <cstddef>
19+
1820
#include "flatbuffers/flatbuffers.h"
1921
#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h"
2022

@@ -28,6 +30,11 @@ BuiltinOperator GetBuiltinCode(const OperatorCode *op_code);
2830

2931
BuiltinOperator GetBuiltinCode(const OperatorCodeT *op_code);
3032

33+
// Returns the size of the given TensorType in bytes, or 0 if the TensorType is
34+
// not supported, this function should be aligned with TfLiteTypeGetSize in
35+
// lite/kernels/kernel_util.h.
36+
size_t TensorTypeGetSize(::tflite::TensorType data_type);
37+
3138
} // namespace tflite
3239

3340
#endif // TENSORFLOW_COMPILER_MLIR_LITE_SCHEMA_SCHEMA_UTILS_H_

tensorflow/lite/kernels/internal/reference/broadcast_to.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ limitations under the License.
1515
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BROADCAST_TO_H_
1616
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BROADCAST_TO_H_
1717

18+
#include <cstddef>
19+
1820
#include "tensorflow/lite/kernels/internal/common.h"
1921
#include "tensorflow/lite/kernels/kernel_util.h"
2022

@@ -83,7 +85,8 @@ inline void BroadcastTo(const RuntimeShape& unextended_input_shape,
8385
// If non-broadcasting, just copy data from input to output tensor.
8486
if (last_broadcast_dim == -1) {
8587
memcpy(output_data, input_data,
86-
unextended_input_shape.FlatSize() * TfLiteTypeGetSize(data_type));
88+
static_cast<size_t>(unextended_input_shape.FlatSize()) *
89+
static_cast<size_t>(TfLiteTypeGetSize(data_type)));
8790
return;
8891
}
8992

0 commit comments

Comments
 (0)