diff --git a/dpctl/tensor/_manipulation_functions.py b/dpctl/tensor/_manipulation_functions.py index 65b027e2ba..a1ab047b16 100644 --- a/dpctl/tensor/_manipulation_functions.py +++ b/dpctl/tensor/_manipulation_functions.py @@ -829,7 +829,12 @@ def repeat(x, repeats, /, *, axis=None): if repeats.size == 1: scalar = True # bring the single element to the host - repeats = int(repeats) + if repeats.ndim == 0: + repeats = int(repeats) + else: + # Get the single element explicitly + # since non-0D arrays can not be converted to scalars + repeats = int(repeats[0]) if repeats < 0: raise ValueError("`repeats` elements must be positive") else: diff --git a/dpctl/tests/test_usm_ndarray_manipulation.py b/dpctl/tests/test_usm_ndarray_manipulation.py index 09c10340ef..b278761811 100644 --- a/dpctl/tests/test_usm_ndarray_manipulation.py +++ b/dpctl/tests/test_usm_ndarray_manipulation.py @@ -1342,6 +1342,21 @@ def test_repeat_strided_repeats(): assert dpt.all(res == x) +def test_repeat_size1_repeats(): + get_queue_or_skip() + + x = dpt.arange(5, dtype="i4") + expected_res = dpt.repeat(x, 2) + # 0D repeats + reps_0d = dpt.asarray(2, dtype="i8") + res = dpt.repeat(x, reps_0d) + assert dpt.all(res == expected_res) + # 1D repeats + reps_1d = dpt.asarray([2], dtype="i8") + res = dpt.repeat(x, reps_1d) + assert dpt.all(res == expected_res) + + def test_repeat_arg_validation(): get_queue_or_skip()