Skip to content

Commit e669565

Browse files
authored
Addressing array api failure for sort() function (#862)
1 parent b019f98 commit e669565

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

ci/Numba-array-api-xfails.txt

-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__]
5555
array_api_tests/test_signatures.py::test_array_method_signature[__dlpack_device__]
5656
array_api_tests/test_signatures.py::test_array_method_signature[__setitem__]
5757
array_api_tests/test_sorting_functions.py::test_argsort
58-
array_api_tests/test_sorting_functions.py::test_sort
5958
array_api_tests/test_special_cases.py::test_nan_propagation[max]
6059
array_api_tests/test_special_cases.py::test_nan_propagation[mean]
6160
array_api_tests/test_special_cases.py::test_nan_propagation[min]

sparse/numba_backend/_coo/common.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -1307,8 +1307,12 @@ def sort(x, /, *, axis=-1, descending=False, stable=False):
13071307

13081308
x = x.reshape(x_shape[:-1] + (x_shape[-1],))
13091309
x = moveaxis(x, source=-1, destination=axis)
1310-
1311-
return x if original_ndim == x.ndim else x.squeeze()
1310+
if original_ndim == x.ndim:
1311+
return x
1312+
x = x.squeeze()
1313+
if x.shape == ():
1314+
return x[None]
1315+
return x
13121316

13131317

13141318
def take(x, indices, /, *, axis=None):

0 commit comments

Comments
 (0)