Skip to content

Commit dc43e1d

Browse files
committed
✨: AsType functions
Signed-off-by: nstarman <[email protected]>
1 parent 466b4f7 commit dc43e1d

File tree

3 files changed

+66
-5
lines changed

3 files changed

+66
-5
lines changed

src/array_api_typing/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# ==================
1515
# Namespace
1616
"ArrayNamespace",
17+
"DoesAsType",
18+
"HasAsType",
1719
# ==================
1820
"__version__",
1921
"__version_tuple__",
@@ -29,5 +31,5 @@
2931
HasSize,
3032
HasTranspose,
3133
)
32-
from ._namespace import ArrayNamespace
34+
from ._namespace import ArrayNamespace, DoesAsType, HasAsType
3335
from ._version import version as __version__, version_tuple as __version_tuple__

src/array_api_typing/_namespace.py

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,16 @@
1-
from typing import Protocol
1+
from typing import Protocol, TypeVar
22

3-
__all__ = ("ArrayNamespace",)
3+
from ._array import HasDType
4+
5+
__all__ = (
6+
"ArrayNamespace",
7+
# Data Type Functions
8+
"DoesAsType",
9+
"HasAsType",
10+
)
11+
12+
DTypeT = TypeVar("DTypeT")
13+
ToDTypeT = TypeVar("ToDTypeT")
414

515
# ===================================================================
616
# Creation Functions
@@ -9,9 +19,48 @@
919

1020
# ===================================================================
1121
# Data Type Functions
12-
# TODO: astype, broadcast_arrays, broadcast_to, can_cast, finfo, iinfo,
22+
# TODO: broadcast_arrays, broadcast_to, can_cast, finfo, iinfo,
1323
# result_type
1424

25+
26+
class DoesAsType(Protocol):
27+
"""Copies an array to a specified data type irrespective of Type Promotion Rules rules.
28+
29+
Note:
30+
Casting floating-point ``NaN`` and ``infinity`` values to integral data
31+
types is not specified and is implementation-dependent.
32+
33+
Note:
34+
When casting a boolean input array to a numeric data type, a value of
35+
`True` must cast to a numeric value equal to ``1``, and a value of
36+
`False` must cast to a numeric value equal to ``0``.
37+
38+
When casting a numeric input array to bool, a value of ``0`` must cast
39+
to `False`, and a non-zero value must cast to `True`.
40+
41+
Args:
42+
x: The array to cast.
43+
dtype: desired data type.
44+
copy: specifies whether to copy an array when the specified `dtype`
45+
matches the data type of the input array `x`. If `True`, a newly
46+
allocated array must always be returned. If `False` and the
47+
specified `dtype` matches the data type of the input array, the
48+
input array must be returned; otherwise, a newly allocated must be
49+
returned. Default: `True`.
50+
51+
""" # noqa: E501
52+
53+
def __call__(
54+
self, x: HasDType[DTypeT], dtype: ToDTypeT, /, *, copy: bool = True
55+
) -> HasDType[ToDTypeT]: ...
56+
57+
58+
class HasAsType(Protocol):
59+
"""Protocol for namespaces that have an ``astype`` function."""
60+
61+
astype: DoesAsType
62+
63+
1564
# ===================================================================
1665
# Element-wise Functions
1766
# TODO: abs, acos, acosh, add, asin, asinh, atan, atan2, atanh, bitwise_and,
@@ -55,5 +104,9 @@
55104
# Full Namespace
56105

57106

58-
class ArrayNamespace(Protocol):
107+
class ArrayNamespace(
108+
# Data Type Functions
109+
HasAsType,
110+
Protocol,
111+
):
59112
"""Protocol for an Array API-compatible namespace."""

tests/integration/test_numpy2p0.pyi

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,9 @@ assert_type(x_b.size, int | None)
9191
assert_type(x_f32.T, xpt.Array[np.dtype[F32]])
9292
assert_type(x_i32.T, xpt.Array[np.dtype[I32]])
9393
assert_type(x_b.T, xpt.Array[np.dtype[B]])
94+
95+
##############################################################################
96+
# Tests on Namespace Functions
97+
98+
assert_type(np.astype, xpt.DoesAsType)
99+
_: xpt.HasAsType = np

0 commit comments

Comments
 (0)