Skip to content

Commit 25412da

Browse files
committed
✨: Add stubs for major protocols
1 parent 8641f7f commit 25412da

File tree

6 files changed

+156
-5
lines changed

6 files changed

+156
-5
lines changed

src/array_api_typing/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
11
"""Static typing support for the array API standard."""
22

33
__all__ = (
4+
"Array",
5+
"ArrayNamespace",
6+
"DType",
7+
"Device",
48
"HasArrayNamespace",
59
"__version__",
610
"__version_tuple__",
11+
"signature_types",
712
)
813

9-
from ._namespace import HasArrayNamespace
14+
from . import signature_types
15+
from ._array import Array
16+
from ._misc_objects import Device, DType
17+
from ._namespace import ArrayNamespace, HasArrayNamespace
1018
from ._version import version as __version__, version_tuple as __version_tuple__

src/array_api_typing/_array.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
"""Static typing support for array API arrays."""
2+
3+
from typing import Protocol
4+
5+
from ._namespace import HasArrayNamespace
6+
7+
8+
class Array(HasArrayNamespace, Protocol):
9+
"""An Array API array of homogenously-typed numbers."""
10+
11+
# TODO(https://github.com/data-apis/array-api-typing/issues/23): Populate this
12+
# protocol with methods defined by the Array API specification.

src/array_api_typing/_misc_objects.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
"""Static typing support for miscellaneous objects in the array API."""
2+
3+
from typing import TypeAlias
4+
5+
Device: TypeAlias = object # The device on which an Array API array is stored.
6+
DType: TypeAlias = object # The type of the numbers contained in an Array API array."""

src/array_api_typing/_namespace.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,54 @@
11
__all__ = ("HasArrayNamespace",)
22

3-
from types import ModuleType
4-
from typing import Protocol, final
3+
from __future__ import annotations
4+
5+
from typing import TYPE_CHECKING, Protocol
56
from typing_extensions import TypeVar
67

7-
T = TypeVar("T", bound=object, default=ModuleType) # PEP 696 default
8+
if TYPE_CHECKING:
9+
# This condition exists to prevent a circular import: _array imports _namespace for
10+
# HasArrayNamespace. Therefore, _namespace cannot import _array except when
11+
# type-checking. The type variable depends on Array, so we create a dummy type
12+
# variable without the same bounds and default for this case. In Python 3.13, this
13+
# is no longer be necessary.
14+
from typing_extensions import Buffer
15+
16+
from ._array import Array
17+
from ._misc_objects import Device, DType
18+
from .signature_types import NestedSequence
19+
20+
A = TypeVar("A", bound=Array, default=Array) # PEP 696 default
21+
else:
22+
A = TypeVar("A")
23+
24+
25+
class ArrayNamespace(Protocol[A]):
26+
"""An Array API namespace."""
27+
28+
def asarray(
29+
self,
30+
obj: Array | complex | NestedSequence[complex] | Buffer,
31+
/,
32+
*,
33+
dtype: DType | None = None,
34+
device: Device | None = None,
35+
copy: bool | None = None,
36+
) -> A: ...
37+
38+
def astype(
39+
self,
40+
x: A,
41+
dtype: DType,
42+
/,
43+
*,
44+
copy: bool = True,
45+
device: Device | None = None,
46+
) -> A: ...
47+
48+
49+
T = TypeVar("T", bound=ArrayNamespace, default=ArrayNamespace) # PEP 696 default
850

951

10-
@final
1152
class HasArrayNamespace(Protocol[T]): # type: ignore[misc] # see python/mypy#17288
1253
"""Protocol for classes that have an `__array_namespace__` method.
1354
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""Types that appear in public function signatures."""
2+
3+
__all__ = [
4+
"NestedSequence",
5+
]
6+
7+
from ._signature_types import NestedSequence
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Protocol, TypeVar, runtime_checkable
4+
5+
if TYPE_CHECKING:
6+
from collections.abc import Iterator
7+
8+
_T_co = TypeVar("_T_co", covariant=True)
9+
10+
11+
@runtime_checkable
12+
class NestedSequence(Protocol[_T_co]):
13+
"""A protocol for representing nested sequences.
14+
15+
Warning:
16+
-------
17+
`NestedSequence` currently does not work in combination with type variables,
18+
*e.g.* ``def func(a: NestedSequnce[T]) -> T: ...``.
19+
20+
See Also:
21+
--------
22+
collections.abc.Sequence:
23+
ABCs for read-only and mutable :term:`sequences`.
24+
25+
Examples:
26+
--------
27+
.. code-block:: python
28+
29+
>>> from typing import TYPE_CHECKING
30+
>>> import numpy as np
31+
>>> import array_api_typing as xpt
32+
33+
>>> def get_dtype(seq: xpt.NestedSequence[float]) -> np.dtype[np.float64]:
34+
... return np.asarray(seq).dtype
35+
36+
>>> a = get_dtype([1.0])
37+
>>> b = get_dtype([[1.0]])
38+
>>> c = get_dtype([[[1.0]]])
39+
>>> d = get_dtype([[[[1.0]]]])
40+
41+
>>> if TYPE_CHECKING:
42+
... reveal_locals()
43+
... # note: Revealed local types are:
44+
... # note: a: numpy.dtype[numpy.floating[numpy._typing._64Bit]]
45+
... # note: b: numpy.dtype[numpy.floating[numpy._typing._64Bit]]
46+
... # note: c: numpy.dtype[numpy.floating[numpy._typing._64Bit]]
47+
... # note: d: numpy.dtype[numpy.floating[numpy._typing._64Bit]]
48+
49+
"""
50+
51+
def __len__(self, /) -> int:
52+
"""Implement ``len(self)``."""
53+
raise NotImplementedError
54+
55+
def __getitem__(self, index: int, /) -> _T_co | NestedSequence[_T_co]:
56+
"""Implement ``self[x]``."""
57+
raise NotImplementedError
58+
59+
def __contains__(self, x: object, /) -> bool:
60+
"""Implement ``x in self``."""
61+
raise NotImplementedError
62+
63+
def __iter__(self, /) -> Iterator[_T_co | NestedSequence[_T_co]]:
64+
"""Implement ``iter(self)``."""
65+
raise NotImplementedError
66+
67+
def __reversed__(self, /) -> Iterator[_T_co | NestedSequence[_T_co]]:
68+
"""Implement ``reversed(self)``."""
69+
raise NotImplementedError
70+
71+
def count(self, value: object, /) -> int:
72+
"""Return the number of occurrences of `value`."""
73+
raise NotImplementedError
74+
75+
def index(self, value: object, /) -> int:
76+
"""Return the first index of `value`."""
77+
raise NotImplementedError

0 commit comments

Comments
 (0)