Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistent behaviours between backends on ops.tile #20914

Open
yshao-aim-solutions opened this issue Feb 17, 2025 · 0 comments
Open

Inconsistent behaviours between backends on ops.tile #20914

yshao-aim-solutions opened this issue Feb 17, 2025 · 0 comments
Assignees
Labels

Comments

@yshao-aim-solutions
Copy link

I have the following function to broadcast two arrays to compute the multiplication for all possible mutations. In this function I used both tile and repeat function and I found tile shows inconsistent behaviours between backends.

  • Jax:
(<KerasTensor shape=(None, 6, 2, 2), dtype=float32, sparse=False, name=keras_tensor_2>, <KerasTensor shape=(None, 6, 2, 2), dtype=float32, sparse=False, name=keras_tensor_3>)
(<KerasTensor shape=(None, 6, 2, 2), dtype=float32, sparse=False, name=keras_tensor_4>, <KerasTensor shape=(None, 6, 2, 2), dtype=float32, sparse=False, name=keras_tensor_5>)
  • TensorFlow:
(<KerasTensor shape=(None, None, None, None), dtype=float32, sparse=False, name=keras_tensor_2>, <KerasTensor shape=(None, None, None, None), dtype=float32, sparse=False, name=keras_tensor_3>)
(<KerasTensor shape=(None, 6, 2, 2), dtype=float32, sparse=False, name=keras_tensor_4>, <KerasTensor shape=(None, 6, 2, 2), dtype=float32, sparse=False, name=keras_tensor_5>)

It seems that TensorFlow could not properly infer the shape of the resulting symbolic tensor.

Another issue is that when using repeats for tile based on the shape of symbolic tensor, TensorFlow still works (although with shapes all None), but Jax raises an error: "'str' object has no attribute '_error_repr'". This issue can be reproduced by replacing repeats with the commented command.

Reproduction Code

import os
os.environ["KERAS_BACKEND"] = "jax"

from keras import ops, layers
from keras import Input

# %%
def broadcast(x1, x2):
    """
    Broadcast the shapes of x1 and x2 to allow the computation of cross-interation. 
    
    - repeating input1: (a, b) -> (a, b, a, b)
    - repeating input2: (c, d) -> (c, c, d, d)
    - result: (a, b) * (c, d) = (a * c, b * c, a * d, b * d)
    
    Args:
        x1: nD array in shape (..., n1, ny1, nx1) to be broadcasted
        x2: nD array in shape (..., n2, ny2, nx2) to be broadcasted

    Returns:
        Broadcasted nD arrays in shape (..., n1 * n2, ...)

    Examples:
        >>> import numpy as np
        >>> x1 = np.array([[[[0., 1., 2.]],
        ...                 [[3., 4., 5.]]]])
        >>> x2 = np.array([[[[0., 1., 2.]],
        ...                 [[3., 4., 5.]]]])
        >>> x1, x2 = broadcast((x1, np.zeros(np.shape(x1))), (x2, np.zeros(np.shape(x2))))
       
        >>> np.array(x1[0]) + 1j * np.array(x1[1])
        array([[[[0.+0.j, 1.+0.j, 2.+0.j]],
                [[3.+0.j, 4.+0.j, 5.+0.j]],
                [[0.+0.j, 1.+0.j, 2.+0.j]],
                [[3.+0.j, 4.+0.j, 5.+0.j]]]], dtype=complex64)
        >>> np.array(x2[0]) + 1j * np.array(x2[1])
        array([[[[0.+0.j, 1.+0.j, 2.+0.j]],
                [[0.+0.j, 1.+0.j, 2.+0.j]],
                [[3.+0.j, 4.+0.j, 5.+0.j]],
                [[3.+0.j, 4.+0.j, 5.+0.j]]]], dtype=complex64)
    """

    x1real, x1imag = x1
    x2real, x2imag = x2
    
    x1shape = ops.shape(x1real)[-3] # spatial mode dimension
    x2shape = ops.shape(x2real)[-3] # spatial mode dimension
    
    x1dims = len(ops.shape(x1real))
    # repeats = ops.scatter_update(ops.cast(ops.ones(x1dims), dtype="int32"), [[-3 + x1dims]], [x2shape])
    repeats = [1, 2, 1, 1]
    x1real = ops.tile(x1real, repeats)
    x1imag = ops.tile(x1imag, repeats)

    x2real = ops.repeat(x2real, x1shape, axis=-3)
    x2imag = ops.repeat(x2imag, x1shape, axis=-3)

    return ((x1real, x1imag), (x2real, x2imag))

# %%

class Test(layers.Layer):
    
    def call(self, inputs1, inputs2):

        return broadcast(inputs1, inputs2)

test = Test()

# %%

x1 = Input(shape=(3, 2, 2))
x2 = Input(shape=(2, 2, 2))

y1, y2 = test((x1, x1), (x2, x2))

print(y1)
print(y2)

Environment
jax 0.5.0
jaxlib 0.5.0
keras 3.8.0
tensorboard 2.18.0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants