Skip to content

Commit 5596e67

Browse files
authored
Update real_time_dtln_audio.py
1 parent eb1d8e9 commit 5596e67

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

real_time_dtln_audio.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ def int_or_str(text):
7777
input_details_2 = interpreter_2.get_input_details()
7878
output_details_2 = interpreter_2.get_output_details()
7979
# create states for the lstms
80-
states_1 = np.zeros(input_details_1[0]['shape']).astype('float32')
81-
states_2 = np.zeros(input_details_1[0]['shape']).astype('float32')
80+
states_1 = np.zeros(input_details_1[1]['shape']).astype('float32')
81+
states_2 = np.zeros(input_details_2[1]['shape']).astype('float32')
8282
# calculate shift and length
8383
block_shift = int(np.round(fs_target * (block_shift_ms / 1000)))
8484
block_len = int(np.round(fs_target * (block_len_ms / 1000)))
@@ -102,8 +102,8 @@ def callback(indata, outdata, frames, time, status):
102102
# reshape magnitude to input dimensions
103103
in_mag = np.reshape(in_mag, (1,1,-1)).astype('float32')
104104
# set tensors to the first model
105-
interpreter_1.set_tensor(input_details_1[0]['index'], states_1)
106-
interpreter_1.set_tensor(input_details_1[1]['index'], in_mag)
105+
interpreter_1.set_tensor(input_details_1[1]['index'], states_1)
106+
interpreter_1.set_tensor(input_details_1[0]['index'], in_mag)
107107
# run calculation
108108
interpreter_1.invoke()
109109
# get the output of the first block
@@ -115,8 +115,8 @@ def callback(indata, outdata, frames, time, status):
115115
# reshape the time domain block
116116
estimated_block = np.reshape(estimated_block, (1,1,-1)).astype('float32')
117117
# set tensors to the second block
118-
interpreter_2.set_tensor(input_details_1[0]['index'], states_2)
119-
interpreter_2.set_tensor(input_details_1[1]['index'], estimated_block)
118+
interpreter_2.set_tensor(input_details_2[1]['index'], states_2)
119+
interpreter_2.set_tensor(input_details_2[0]['index'], estimated_block)
120120
# run calculation
121121
interpreter_2.invoke()
122122
# get output tensors
@@ -144,4 +144,4 @@ def callback(indata, outdata, frames, time, status):
144144
parser.exit('')
145145
except Exception as e:
146146
parser.exit(type(e).__name__ + ': ' + str(e))
147-
147+

0 commit comments

Comments
 (0)