Skip to content

Commit 1470505

Browse files
committed
ENH: linspace default dtype
1 parent 917312e commit 1470505

2 files changed

Lines changed: 19 additions & 2 deletions

File tree

array_api_strict/_creation_functions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,11 @@ def linspace(
318318

319319
_check_device(device)
320320
_check_valid_dtype(dtype, device)
321+
if dtype is None:
322+
if isinstance(start, complex) or isinstance(stop, complex):
323+
dtype = get_default_dtypes(device)["complex floating"]
324+
else:
325+
dtype = get_default_dtypes(device)["real floating"]
321326

322327
return Array._new(
323328
np.linspace(start, stop, num, dtype=_np_dtype(dtype), endpoint=endpoint),

array_api_strict/tests/test_creation_functions.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,11 +282,23 @@ def test_eye(self):
282282
eye(3, device=device, dtype=float64)
283283

284284

285+
def test_linspace(self):
286+
device = Device('F32_device')
287+
288+
a = linspace(1, 10, 11, device=device)
289+
assert a.dtype == self.info.default_dtypes(device=device)["real floating"]
290+
291+
a = linspace(1+0j, 10, 11, device=device)
292+
assert a.dtype == self.info.default_dtypes(device=device)["complex floating"]
293+
294+
with pytest.raises((TypeError, ValueError)):
295+
linspace(1, 10, 11, device=device, dtype=float64)
296+
297+
298+
285299
# TODO:
286300
# def asarray(
287301
# def arange(
288-
# def linspace(
289-
290302

291303

292304
@pytest.mark.parametrize("api_version", ['2021.12', '2022.12', '2023.12'])

0 commit comments

Comments
 (0)