Skip to content

Commit 664b8b6

Browse files
Fix bug with correlate for tensorflow (#21778)
* fix correlate bug * fix correlate function * add pydoc * Update keras/src/backend/tensorflow/numpy.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * comment, and remove recursion * address comments around shape_op and test * explicit about which axes to squeeze * fix max call * fix `max` call --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 6d06085 commit 664b8b6

File tree

2 files changed

+48
-21
lines changed

2 files changed

+48
-21
lines changed

keras/src/backend/tensorflow/numpy.py

Lines changed: 46 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3186,30 +3186,57 @@ def correlate(x1, x2, mode="valid"):
31863186
x1 = tf.cast(x1, dtype)
31873187
x2 = tf.cast(x2, dtype)
31883188

3189-
x1_len, x2_len = int(x1.shape[0]), int(x2.shape[0])
3189+
def _pack(a, b):
3190+
# a: input [N] -> [1,N,1];
3191+
# b: filter [M] -> [M,1,1]
3192+
return (
3193+
tf.reshape(a, (1, shape_op(a)[0], 1)),
3194+
tf.reshape(b, (shape_op(b)[0], 1, 1)),
3195+
)
31903196

3191-
if mode == "full":
3192-
full_len = x1_len + x2_len - 1
3197+
def _full_corr(x1, x2):
3198+
"""Compute 'full' correlation result (length = n + m - 1)."""
3199+
m = shape_op(x2)[0]
3200+
pad = (
3201+
builtins.max(m - 1, 0)
3202+
if isinstance(m, int)
3203+
else tf.maximum(m - 1, 0)
3204+
)
3205+
x1 = tf.pad(x1, [[pad, pad]]) # pad input with zeros
3206+
x1, x2 = _pack(x1, x2)
3207+
out = tf.nn.conv1d(x1, x2, stride=1, padding="VALID")
3208+
return tf.squeeze(out, axis=[0, 2])
31933209

3194-
x1_pad = (full_len - x1_len) / 2
3195-
x2_pad = (full_len - x2_len) / 2
3210+
n = shape_op(x1)[0]
3211+
m = shape_op(x2)[0]
31963212

3197-
x1 = tf.pad(
3198-
x1, paddings=[[tf.math.floor(x1_pad), tf.math.ceil(x1_pad)]]
3213+
if mode == "full":
3214+
return _full_corr(x1, x2)
3215+
elif mode == "same":
3216+
# unfortunately we can't leverage 'SAME' padding directly like
3217+
# we can with "valid"
3218+
# it works fine for odd-length filters, but for even-length filters
3219+
# the output is off by 1 compared to numpy, due to how
3220+
# tf handles centering
3221+
full_corr = _full_corr(x1, x2)
3222+
full_len = n + m - 1
3223+
out_len = (
3224+
max([n, m])
3225+
if isinstance(n, int) and isinstance(m, int)
3226+
else tf.maximum(n, m)
31993227
)
3200-
x2 = tf.pad(
3201-
x2, paddings=[[tf.math.floor(x2_pad), tf.math.ceil(x2_pad)]]
3228+
start = (full_len - out_len) // 2
3229+
return tf.slice(full_corr, [start], [out_len])
3230+
elif mode == "valid":
3231+
x1, x2 = _pack(x1, x2)
3232+
return tf.squeeze(
3233+
tf.nn.conv1d(x1, x2, stride=1, padding="VALID"), axis=[0, 2]
3234+
)
3235+
else:
3236+
raise ValueError(
3237+
f"Invalid mode: '{mode}'. Mode must be one of:"
3238+
f" 'full', 'same', 'valid'."
32023239
)
3203-
3204-
x1 = tf.reshape(x1, (1, full_len, 1))
3205-
x2 = tf.reshape(x2, (full_len, 1, 1))
3206-
3207-
return tf.squeeze(tf.nn.conv1d(x1, x2, stride=1, padding="SAME"))
3208-
3209-
x1 = tf.reshape(x1, (1, x1_len, 1))
3210-
x2 = tf.reshape(x2, (x2_len, 1, 1))
3211-
3212-
return tf.squeeze(tf.nn.conv1d(x1, x2, stride=1, padding=mode.upper()))
32133240

32143241

32153242
def select(condlist, choicelist, default=0):

keras/src/ops/numpy_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5230,8 +5230,8 @@ def test_correlate(self):
52305230
)
52315231

52325232
def test_correlate_different_size(self):
5233-
x = np.array([1, 2, 3, 4, 5, 6])
5234-
y = np.array([0, 1, 0.5])
5233+
x = np.array([1, 3, 5])
5234+
y = np.array([7, 9])
52355235
self.assertAllClose(knp.correlate(x, y), np.correlate(x, y))
52365236
self.assertAllClose(
52375237
knp.correlate(x, y, mode="same"), np.correlate(x, y, mode="same")

0 commit comments

Comments
 (0)