@@ -73,7 +73,11 @@ def assert_gcxs_slicing(s, x):
73
73
74
74
75
75
def assert_nnz (s , x ):
76
- fill_value = s .fill_value if hasattr (s , "fill_value" ) else _zero_of_dtype (s .dtype , s .device )
76
+ from ._settings import NUMPY_DEVICE
77
+
78
+ fill_value = (
79
+ s .fill_value if hasattr (s , "fill_value" ) else _zero_of_dtype (s .dtype , getattr (s , "device" , NUMPY_DEVICE ))
80
+ )
77
81
assert np .sum (~ equivalent (x , fill_value )) == s .nnz
78
82
79
83
@@ -442,7 +446,7 @@ def equivalent(x, y, /, loose=False):
442
446
443
447
from ._common import _coerce_to_supported_dense
444
448
445
- namespace = array_api_compat .array_namespace (x , y )
449
+ xp = array_api_compat .array_namespace (x , y )
446
450
x = _coerce_to_supported_dense (x )
447
451
y = _coerce_to_supported_dense (y )
448
452
# Can't contain NaNs
@@ -458,9 +462,9 @@ def equivalent(x, y, /, loose=False):
458
462
return (x == y ) | ((x != x ) & (y != y ))
459
463
460
464
if x .size == 0 or y .size == 0 :
461
- shape = namespace .broadcast_shapes (x .shape , y .shape )
462
- return namespace .empty (shape , dtype = np .bool_ )
463
- x , y = namespace .broadcast_arrays (x [..., None ], y [..., None ])
465
+ shape = xp .broadcast_shapes (x .shape , y .shape )
466
+ return xp .empty (shape , dtype = np .bool_ )
467
+ x , y = xp .broadcast_arrays (x [..., None ], y [..., None ])
464
468
return (x .astype (dt ).view (np .uint8 ) == y .astype (dt ).view (np .uint8 )).all (axis = - 1 )
465
469
466
470
0 commit comments