Skip to content

Commit 0923e47

Browse files
Fix type mismatch error in freq_mask and time_mask (#1170)
* Fix type mismatch error in freq_mask This PR fixes type mismatch error in freq_mask, reported in issue 1158 (credit to BryanWBear) Signed-off-by: Yong Tang <[email protected]> * fix type mismatch in time_mask * Update audio_ops.py Co-authored-by: Vignesh Kothapalli <[email protected]>
1 parent 29f9667 commit 0923e47

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

tensorflow_io/core/python/experimental/audio_ops.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,7 @@ def freq_mask(input, param, name=None):
273273
Returns:
274274
A tensor of spectrogram.
275275
"""
276+
input = tf.convert_to_tensor(input)
276277
# TODO: Support audio with channel > 1.
277278
freq_max = tf.shape(input)[1]
278279
f = tf.random.uniform(shape=(), minval=0, maxval=param, dtype=tf.dtypes.int32)
@@ -283,7 +284,7 @@ def freq_mask(input, param, name=None):
283284
condition = tf.math.logical_and(
284285
tf.math.greater_equal(indices, f0), tf.math.less(indices, f0 + f)
285286
)
286-
return tf.where(condition, 0, input)
287+
return tf.where(condition, tf.cast(0, input.dtype), input)
287288

288289

289290
def time_mask(input, param, name=None):
@@ -297,6 +298,7 @@ def time_mask(input, param, name=None):
297298
Returns:
298299
A tensor of spectrogram.
299300
"""
301+
input = tf.convert_to_tensor(input)
300302
# TODO: Support audio with channel > 1.
301303
time_max = tf.shape(input)[0]
302304
t = tf.random.uniform(shape=(), minval=0, maxval=param, dtype=tf.dtypes.int32)
@@ -307,7 +309,7 @@ def time_mask(input, param, name=None):
307309
condition = tf.math.logical_and(
308310
tf.math.greater_equal(indices, t0), tf.math.less(indices, t0 + t)
309311
)
310-
return tf.where(condition, 0, input)
312+
return tf.where(condition, tf.cast(0, input.dtype), input)
311313

312314

313315
def fade(input, fade_in, fade_out, mode, name=None):

0 commit comments

Comments
 (0)