diff --git a/keras/src/backend/tensorflow/numpy.py b/keras/src/backend/tensorflow/numpy.py index 07b83b5e7bc..6afebae6d23 100644 --- a/keras/src/backend/tensorflow/numpy.py +++ b/keras/src/backend/tensorflow/numpy.py @@ -3130,30 +3130,57 @@ def correlate(x1, x2, mode="valid"): x1 = tf.cast(x1, dtype) x2 = tf.cast(x2, dtype) - x1_len, x2_len = int(x1.shape[0]), int(x2.shape[0]) + def _pack(a, b): + # a: input [N] -> [1,N,1]; + # b: filter [M] -> [M,1,1] + return ( + tf.reshape(a, (1, shape_op(a)[0], 1)), + tf.reshape(b, (shape_op(b)[0], 1, 1)), + ) - if mode == "full": - full_len = x1_len + x2_len - 1 + def _full_corr(x1, x2): + """Compute 'full' correlation result (length = n + m - 1).""" + m = shape_op(x2)[0] + pad = ( + builtins.max(m - 1, 0) + if isinstance(m, int) + else tf.maximum(m - 1, 0) + ) + x1 = tf.pad(x1, [[pad, pad]]) # pad input with zeros + x1, x2 = _pack(x1, x2) + out = tf.nn.conv1d(x1, x2, stride=1, padding="VALID") + return tf.squeeze(out, axis=[0, 2]) - x1_pad = (full_len - x1_len) / 2 - x2_pad = (full_len - x2_len) / 2 + n = shape_op(x1)[0] + m = shape_op(x2)[0] - x1 = tf.pad( - x1, paddings=[[tf.math.floor(x1_pad), tf.math.ceil(x1_pad)]] + if mode == "full": + return _full_corr(x1, x2) + elif mode == "same": + # unfortunately we can't leverage 'SAME' padding directly like + # we can with "valid" + # it works fine for odd-length filters, but for even-length filters + # the output is off by 1 compared to numpy, due to how + # tf handles centering + full_corr = _full_corr(x1, x2) + full_len = n + m - 1 + out_len = ( + max([n, m]) + if isinstance(n, int) and isinstance(m, int) + else tf.maximum(n, m) ) - x2 = tf.pad( - x2, paddings=[[tf.math.floor(x2_pad), tf.math.ceil(x2_pad)]] + start = (full_len - out_len) // 2 + return tf.slice(full_corr, [start], [out_len]) + elif mode == "valid": + x1, x2 = _pack(x1, x2) + return tf.squeeze( + tf.nn.conv1d(x1, x2, stride=1, padding="VALID"), axis=[0, 2] + ) + else: + raise ValueError( + f"Invalid mode: '{mode}'. Mode must be one of:" + f" 'full', 'same', 'valid'." ) - - x1 = tf.reshape(x1, (1, full_len, 1)) - x2 = tf.reshape(x2, (full_len, 1, 1)) - - return tf.squeeze(tf.nn.conv1d(x1, x2, stride=1, padding="SAME")) - - x1 = tf.reshape(x1, (1, x1_len, 1)) - x2 = tf.reshape(x2, (x2_len, 1, 1)) - - return tf.squeeze(tf.nn.conv1d(x1, x2, stride=1, padding=mode.upper())) def select(condlist, choicelist, default=0): diff --git a/keras/src/ops/numpy_test.py b/keras/src/ops/numpy_test.py index 78a47dfd1ee..54f39724511 100644 --- a/keras/src/ops/numpy_test.py +++ b/keras/src/ops/numpy_test.py @@ -5115,8 +5115,8 @@ def test_correlate(self): ) def test_correlate_different_size(self): - x = np.array([1, 2, 3, 4, 5, 6]) - y = np.array([0, 1, 0.5]) + x = np.array([1, 3, 5]) + y = np.array([7, 9]) self.assertAllClose(knp.correlate(x, y), np.correlate(x, y)) self.assertAllClose( knp.correlate(x, y, mode="same"), np.correlate(x, y, mode="same")