@@ -77,8 +77,8 @@ def int_or_str(text):
77
77
input_details_2 = interpreter_2 .get_input_details ()
78
78
output_details_2 = interpreter_2 .get_output_details ()
79
79
# create states for the lstms
80
- states_1 = np .zeros (input_details_1 [0 ]['shape' ]).astype ('float32' )
81
- states_2 = np .zeros (input_details_1 [ 0 ]['shape' ]).astype ('float32' )
80
+ states_1 = np .zeros (input_details_1 [1 ]['shape' ]).astype ('float32' )
81
+ states_2 = np .zeros (input_details_2 [ 1 ]['shape' ]).astype ('float32' )
82
82
# calculate shift and length
83
83
block_shift = int (np .round (fs_target * (block_shift_ms / 1000 )))
84
84
block_len = int (np .round (fs_target * (block_len_ms / 1000 )))
@@ -102,8 +102,8 @@ def callback(indata, outdata, frames, time, status):
102
102
# reshape magnitude to input dimensions
103
103
in_mag = np .reshape (in_mag , (1 ,1 ,- 1 )).astype ('float32' )
104
104
# set tensors to the first model
105
- interpreter_1 .set_tensor (input_details_1 [0 ]['index' ], states_1 )
106
- interpreter_1 .set_tensor (input_details_1 [1 ]['index' ], in_mag )
105
+ interpreter_1 .set_tensor (input_details_1 [1 ]['index' ], states_1 )
106
+ interpreter_1 .set_tensor (input_details_1 [0 ]['index' ], in_mag )
107
107
# run calculation
108
108
interpreter_1 .invoke ()
109
109
# get the output of the first block
@@ -115,8 +115,8 @@ def callback(indata, outdata, frames, time, status):
115
115
# reshape the time domain block
116
116
estimated_block = np .reshape (estimated_block , (1 ,1 ,- 1 )).astype ('float32' )
117
117
# set tensors to the second block
118
- interpreter_2 .set_tensor (input_details_1 [ 0 ]['index' ], states_2 )
119
- interpreter_2 .set_tensor (input_details_1 [ 1 ]['index' ], estimated_block )
118
+ interpreter_2 .set_tensor (input_details_2 [ 1 ]['index' ], states_2 )
119
+ interpreter_2 .set_tensor (input_details_2 [ 0 ]['index' ], estimated_block )
120
120
# run calculation
121
121
interpreter_2 .invoke ()
122
122
# get output tensors
@@ -144,4 +144,4 @@ def callback(indata, outdata, frames, time, status):
144
144
parser .exit ('' )
145
145
except Exception as e :
146
146
parser .exit (type (e ).__name__ + ': ' + str (e ))
147
-
147
+
0 commit comments