diff --git a/keras/src/backend/tensorflow/numpy.py b/keras/src/backend/tensorflow/numpy.py index 90994c200fa..6d300db50f3 100644 --- a/keras/src/backend/tensorflow/numpy.py +++ b/keras/src/backend/tensorflow/numpy.py @@ -1010,7 +1010,11 @@ def view(x, dtype=None): old_itemsize = old_dtype.size new_itemsize = new_dtype.size - if list(x.shape)[-1] * old_itemsize % new_itemsize != 0: + old_shape = list(shape_op(x)) + last_dim_size = old_shape[-1] if len(old_shape) > 0 else -1 + if (last_dim_size == -1 and old_itemsize != new_itemsize) or ( + last_dim_size * old_itemsize % new_itemsize != 0 + ): raise ValueError( f"Cannot view array of shape {x.shape} and dtype {old_dtype} " f"as dtype {new_dtype} because the total number of bytes " @@ -1027,8 +1031,6 @@ def view(x, dtype=None): cast_tensor = tf.bitcast(flat_tensor, type=new_dtype) return tf.reshape(cast_tensor, new_shape) else: - old_shape = list(shape_op(x)) - last_dim_size = old_shape[-1] ratio = new_itemsize // old_itemsize if isinstance(last_dim_size, int) and last_dim_size % ratio != 0: raise ValueError( diff --git a/keras/src/ops/numpy_test.py b/keras/src/ops/numpy_test.py index 3f76446725e..759333eb967 100644 --- a/keras/src/ops/numpy_test.py +++ b/keras/src/ops/numpy_test.py @@ -4136,6 +4136,16 @@ def test_concatenate(self): ) def test_view(self): + x = np.array(1, dtype="int16") + result = knp.view(x, dtype="float16") + assert backend.standardize_dtype(result.dtype) == "float16" + + with self.assertRaises(Exception): + result = knp.view(x, dtype="int8") + + with self.assertRaises(Exception): + result = knp.view(x, dtype="int32") + x = np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype="int16") result = knp.view(x, dtype="int16") assert backend.standardize_dtype(result.dtype) == "int16"