diff --git a/keras/api/ops/core/quantile_test.py b/keras/api/ops/core/quantile_test.py new file mode 100644 index 00000000000..45f0fb11a23 --- /dev/null +++ b/keras/api/ops/core/quantile_test.py @@ -0,0 +1,14 @@ +import numpy as np +import tensorflow as tf +from keras import ops + +def test_quantile_graph_mode(): + @tf.function + def run_quantile(): + x = np.array([[1, 2, 3], [4, 5, 6]]) + q = [0.5] + return ops.quantile(x, q, axis=1) + + result = run_quantile() + expected = np.array([[2, 5]]) + np.testing.assert_allclose(result, expected) diff --git a/keras/src/backend/tensorflow/numpy.py b/keras/src/backend/tensorflow/numpy.py index d6f54719bfc..119696fd4f4 100644 --- a/keras/src/backend/tensorflow/numpy.py +++ b/keras/src/backend/tensorflow/numpy.py @@ -2301,7 +2301,7 @@ def _get_indices(method): return gathered_y perm = collections.deque(range(ndims)) perm.rotate(shift_value_static) - return tf.transpose(a=gathered_y, perm=perm) + return tf.transpose(a=gathered_y, perm=list(perm)) def quantile(x, q, axis=None, method="linear", keepdims=False):