77import inspect
88from typing import TYPE_CHECKING , Any , NamedTuple , Optional , Sequence , cast
99
10- from ._helpers import _check_device , array_namespace
10+ from ._helpers import _device_ctx , array_namespace
1111from ._helpers import device as _get_device
1212from ._helpers import is_cupy_namespace as _is_cupy_namespace
1313from ._typing import Array , Device , DType , Namespace
@@ -32,8 +32,8 @@ def arange(
3232 device : Device | None = None ,
3333 ** kwargs : object ,
3434) -> Array :
35- _check_device (xp , device )
36- return xp .arange (start , stop = stop , step = step , dtype = dtype , ** kwargs )
35+ with _device_ctx (xp , device ):
36+ return xp .arange (start , stop = stop , step = step , dtype = dtype , ** kwargs )
3737
3838
3939def empty (
@@ -44,8 +44,8 @@ def empty(
4444 device : Device | None = None ,
4545 ** kwargs : object ,
4646) -> Array :
47- _check_device (xp , device )
48- return xp .empty (shape , dtype = dtype , ** kwargs )
47+ with _device_ctx (xp , device ):
48+ return xp .empty (shape , dtype = dtype , ** kwargs )
4949
5050
5151def empty_like (
@@ -57,8 +57,8 @@ def empty_like(
5757 device : Device | None = None ,
5858 ** kwargs : object ,
5959) -> Array :
60- _check_device (xp , device )
61- return xp .empty_like (x , dtype = dtype , ** kwargs )
60+ with _device_ctx (xp , device , like = x ):
61+ return xp .empty_like (x , dtype = dtype , ** kwargs )
6262
6363
6464def eye (
@@ -72,8 +72,8 @@ def eye(
7272 device : Device | None = None ,
7373 ** kwargs : object ,
7474) -> Array :
75- _check_device (xp , device )
76- return xp .eye (n_rows , M = n_cols , k = k , dtype = dtype , ** kwargs )
75+ with _device_ctx (xp , device ):
76+ return xp .eye (n_rows , M = n_cols , k = k , dtype = dtype , ** kwargs )
7777
7878
7979def full (
@@ -85,8 +85,8 @@ def full(
8585 device : Device | None = None ,
8686 ** kwargs : object ,
8787) -> Array :
88- _check_device (xp , device )
89- return xp .full (shape , fill_value , dtype = dtype , ** kwargs )
88+ with _device_ctx (xp , device ):
89+ return xp .full (shape , fill_value , dtype = dtype , ** kwargs )
9090
9191
9292def full_like (
@@ -99,8 +99,8 @@ def full_like(
9999 device : Device | None = None ,
100100 ** kwargs : object ,
101101) -> Array :
102- _check_device (xp , device )
103- return xp .full_like (x , fill_value , dtype = dtype , ** kwargs )
102+ with _device_ctx (xp , device , like = x ):
103+ return xp .full_like (x , fill_value , dtype = dtype , ** kwargs )
104104
105105
106106def linspace (
@@ -115,8 +115,8 @@ def linspace(
115115 endpoint : bool = True ,
116116 ** kwargs : object ,
117117) -> Array :
118- _check_device (xp , device )
119- return xp .linspace (start , stop , num , dtype = dtype , endpoint = endpoint , ** kwargs )
118+ with _device_ctx (xp , device ):
119+ return xp .linspace (start , stop , num , dtype = dtype , endpoint = endpoint , ** kwargs )
120120
121121
122122def ones (
@@ -127,8 +127,8 @@ def ones(
127127 device : Device | None = None ,
128128 ** kwargs : object ,
129129) -> Array :
130- _check_device (xp , device )
131- return xp .ones (shape , dtype = dtype , ** kwargs )
130+ with _device_ctx (xp , device ):
131+ return xp .ones (shape , dtype = dtype , ** kwargs )
132132
133133
134134def ones_like (
@@ -140,8 +140,8 @@ def ones_like(
140140 device : Device | None = None ,
141141 ** kwargs : object ,
142142) -> Array :
143- _check_device (xp , device )
144- return xp .ones_like (x , dtype = dtype , ** kwargs )
143+ with _device_ctx (xp , device , like = x ):
144+ return xp .ones_like (x , dtype = dtype , ** kwargs )
145145
146146
147147def zeros (
@@ -152,8 +152,8 @@ def zeros(
152152 device : Device | None = None ,
153153 ** kwargs : object ,
154154) -> Array :
155- _check_device (xp , device )
156- return xp .zeros (shape , dtype = dtype , ** kwargs )
155+ with _device_ctx (xp , device ):
156+ return xp .zeros (shape , dtype = dtype , ** kwargs )
157157
158158
159159def zeros_like (
@@ -165,8 +165,8 @@ def zeros_like(
165165 device : Device | None = None ,
166166 ** kwargs : object ,
167167) -> Array :
168- _check_device (xp , device )
169- return xp .zeros_like (x , dtype = dtype , ** kwargs )
168+ with _device_ctx (xp , device , like = x ):
169+ return xp .zeros_like (x , dtype = dtype , ** kwargs )
170170
171171
172172# np.unique() is split into four functions in the array API:
0 commit comments