Skip to content

Commit 917312e

Browse files
committed
ENH: eye default dtype
1 parent 6d72df1 commit 917312e

2 files changed

Lines changed: 11 additions & 4 deletions

File tree

array_api_strict/_creation_functions.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,8 @@ def eye(
193193

194194
_check_device(device)
195195
_check_valid_dtype(dtype, device)
196+
if dtype is None:
197+
dtype = get_default_dtypes(device)["real floating"]
196198

197199
return Array._new(
198200
np.eye(n_rows, M=n_cols, k=k, dtype=_np_dtype(dtype)), device=device

array_api_strict/tests/test_creation_functions.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -273,15 +273,20 @@ def test_ones_like_etc_incorrect(self, func):
273273
with pytest.raises((TypeError, ValueError)):
274274
func(a, device=Device('F32_device'), dtype=float64)
275275

276+
def test_eye(self):
277+
device = Device('F32_device')
278+
a = eye(3, device=device)
279+
assert a.dtype == self.info.default_dtypes(device=device)["real floating"]
280+
281+
with pytest.raises((TypeError, ValueError)):
282+
eye(3, device=device, dtype=float64)
283+
276284

277285
# TODO:
278286
# def asarray(
279287
# def arange(
280-
# def eye(
281288
# def linspace(
282-
# def meshgrid(*arrays: Array, indexing: Literal["xy", "ij"] = "xy") -> tuple[Array, ...]:
283-
# def tril(x: Array, /, *, k: int = 0) -> Array:
284-
# def triu(x: Array, /, *, k: int = 0) -> Array:
289+
285290

286291

287292
@pytest.mark.parametrize("api_version", ['2021.12', '2022.12', '2023.12'])

0 commit comments

Comments
 (0)