Skip to content

Commit e5c96fe

Browse files
tensorflow.convert_model
1 parent 462ecff commit e5c96fe

File tree

3 files changed

+47
-1
lines changed

3 files changed

+47
-1
lines changed

everywhereml/code_generators/tensorflow.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,41 @@
11
import hexdump
2+
import numpy as np
23
import tensorflow as tf
34
from everywhereml.code_generators import GeneratesCode
5+
from everywhereml.code_generators.jinja.Jinja import Jinja
6+
7+
8+
def convert_model(model, X: np.ndarray, y: np.ndarray, model_name: str = 'tfData') -> str:
9+
"""
10+
Convert model to C++ header
11+
:param model_name:
12+
:param model:
13+
:param X:
14+
:param y:
15+
:return:
16+
"""
17+
assert y.dtype != int or len(y.shape) == 2, 'y must be of dtype=float (regression) or one-hot encoded'
18+
19+
num_inputs = X.shape[1] if len(X.shape) > 1 else 1
20+
num_outputs = 1 if y.dtype != int else y.shape[1]
21+
converter = tf.lite.TFLiteConverter.from_keras_model(model)
22+
bytes = hexdump.dump(converter.convert()).split(' ')
23+
bytes_array = ', '.join(['0x%02x' % int(byte, 16) for byte in bytes])
24+
model_size = len(bytes)
25+
26+
return Jinja(base_folder='', language='cpp', dialect=None).render('convert_tf_model', {
27+
'num_inputs': num_inputs,
28+
'num_outputs': num_outputs,
29+
'bytes_array': bytes_array,
30+
'model_size': model_size,
31+
'model_name': model_name or 'tfData'
32+
})
433

534

635
class TensorFlowPorter(GeneratesCode):
736
"""
837
Convert TF models to C++
38+
@deprecated
939
"""
1040
def __init__(self, model, X, y):
1141
"""
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#pragma once
2+
3+
#ifdef __has_attribute
4+
#define HAVE_ATTRIBUTE(x) __has_attribute(x)
5+
#else
6+
#define HAVE_ATTRIBUTE(x) 0
7+
#endif
8+
#if HAVE_ATTRIBUTE(aligned) || (defined(__GNUC__) && !defined(__clang__))
9+
#define DATA_ALIGN_ATTRIBUTE __attribute__((aligned(4)))
10+
#else
11+
#define DATA_ALIGN_ATTRIBUTE
12+
#endif
13+
14+
/** model size = {{ model_size }} bytes **/
15+
const unsigned char {{ model_name }}[] DATA_ALIGN_ATTRIBUTE = { {{ bytes_array }} };

setup_template.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@
3131
'jinja2_workarounds',
3232
'requests',
3333
'pySerial',
34-
'tqdm'
34+
'tqdm',
35+
'hexdump'
3536
],
3637
extras_require={
3738
'tf': ['tensorflow']

0 commit comments

Comments
 (0)