-
Notifications
You must be signed in to change notification settings - Fork 198
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add NDArray protocol class for nd-array annotations
- Loading branch information
1 parent
8e5fd14
commit 2651557
Showing
3 changed files
with
180 additions
and
87 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,160 +1,168 @@ | ||
from typing import Annotated, overload | ||
from collections.abc import Buffer | ||
|
||
from numpy.typing import ArrayLike | ||
import numpy | ||
import numpy.typing | ||
import torch | ||
|
||
from typing import Annotated, Protocol, overload | ||
|
||
class DLPackBuffer(Protocol): | ||
def __dlpack__(self) -> object: ... | ||
|
||
type NDArray = Buffer | DLPackBuffer | ||
|
||
class Cls: | ||
def __init__(self) -> None: ... | ||
|
||
def f1(self) -> Annotated[ArrayLike, dict(dtype='float32')]: ... | ||
def f1(self) -> Annotated[numpy.typing.NDArray[numpy.float32], dict(dtype='float32')]: ... | ||
|
||
def f2(self) -> Annotated[ArrayLike, dict(dtype='float32')]: ... | ||
def f2(self) -> Annotated[numpy.typing.NDArray[numpy.float32], dict(dtype='float32')]: ... | ||
|
||
def f1_ri(self) -> Annotated[ArrayLike, dict(dtype='float32')]: ... | ||
def f1_ri(self) -> Annotated[numpy.typing.NDArray[numpy.float32], dict(dtype='float32')]: ... | ||
|
||
def f2_ri(self) -> Annotated[ArrayLike, dict(dtype='float32')]: ... | ||
def f2_ri(self) -> Annotated[numpy.typing.NDArray[numpy.float32], dict(dtype='float32')]: ... | ||
|
||
def f3_ri(self, arg: object, /) -> Annotated[ArrayLike, dict(dtype='float32')]: ... | ||
def f3_ri(self, arg: object, /) -> Annotated[numpy.typing.NDArray[numpy.float32], dict(dtype='float32')]: ... | ||
|
||
def accept_ro(arg: Annotated[ArrayLike, dict(dtype='float32', writable=False, shape=(2))], /) -> float: ... | ||
def accept_ro(arg: Annotated[NDArray, dict(dtype='float32', writable=False, shape=(2))], /) -> float: ... | ||
|
||
def accept_rw(arg: Annotated[ArrayLike, dict(dtype='float32', shape=(2))], /) -> float: ... | ||
def accept_rw(arg: Annotated[NDArray, dict(dtype='float32', shape=(2))], /) -> float: ... | ||
|
||
def cast(arg: bool, /) -> ArrayLike: ... | ||
def cast(arg: bool, /) -> numpy.ndarray: ... | ||
|
||
def check(arg: object, /) -> bool: ... | ||
|
||
def check_bool(arg: ArrayLike, /) -> bool: ... | ||
def check_bool(arg: NDArray, /) -> bool: ... | ||
|
||
@overload | ||
def check_device(arg: Annotated[ArrayLike, dict(device='cpu')], /) -> str: ... | ||
def check_device(arg: Annotated[NDArray, dict(device='cpu')], /) -> str: ... | ||
|
||
@overload | ||
def check_device(arg: Annotated[ArrayLike, dict(device='cuda')], /) -> str: ... | ||
def check_device(arg: Annotated[NDArray, dict(device='cuda')], /) -> str: ... | ||
|
||
def check_float(arg: ArrayLike, /) -> bool: ... | ||
def check_float(arg: NDArray, /) -> bool: ... | ||
|
||
@overload | ||
def check_order(arg: Annotated[ArrayLike, dict(order='C')], /) -> str: ... | ||
def check_order(arg: Annotated[NDArray, dict(order='C')], /) -> str: ... | ||
|
||
@overload | ||
def check_order(arg: Annotated[ArrayLike, dict(order='F')], /) -> str: ... | ||
def check_order(arg: Annotated[NDArray, dict(order='F')], /) -> str: ... | ||
|
||
@overload | ||
def check_order(arg: ArrayLike, /) -> str: ... | ||
def check_order(arg: NDArray, /) -> str: ... | ||
|
||
def check_ro_by_const_ref_const_float64(arg: Annotated[ArrayLike, dict(dtype='float64', writable=False, shape=(None))], /) -> bool: ... | ||
def check_ro_by_const_ref_const_float64(arg: Annotated[NDArray, dict(dtype='float64', writable=False, shape=(None))], /) -> bool: ... | ||
|
||
def check_ro_by_const_ref_ro(arg: Annotated[ArrayLike, dict(writable=False)], /) -> bool: ... | ||
def check_ro_by_const_ref_ro(arg: Annotated[NDArray, dict(writable=False)], /) -> bool: ... | ||
|
||
def check_ro_by_rvalue_ref_const_float64(arg: Annotated[ArrayLike, dict(dtype='float64', writable=False, shape=(None))], /) -> bool: ... | ||
def check_ro_by_rvalue_ref_const_float64(arg: Annotated[NDArray, dict(dtype='float64', writable=False, shape=(None))], /) -> bool: ... | ||
|
||
def check_ro_by_rvalue_ref_ro(arg: Annotated[ArrayLike, dict(writable=False)], /) -> bool: ... | ||
def check_ro_by_rvalue_ref_ro(arg: Annotated[NDArray, dict(writable=False)], /) -> bool: ... | ||
|
||
def check_ro_by_value_const_float64(arg: Annotated[ArrayLike, dict(dtype='float64', writable=False, shape=(None))], /) -> bool: ... | ||
def check_ro_by_value_const_float64(arg: Annotated[NDArray, dict(dtype='float64', writable=False, shape=(None))], /) -> bool: ... | ||
|
||
def check_ro_by_value_ro(arg: Annotated[ArrayLike, dict(writable=False)], /) -> bool: ... | ||
def check_ro_by_value_ro(arg: Annotated[NDArray, dict(writable=False)], /) -> bool: ... | ||
|
||
def check_rw_by_const_ref(arg: ArrayLike, /) -> bool: ... | ||
def check_rw_by_const_ref(arg: NDArray, /) -> bool: ... | ||
|
||
def check_rw_by_const_ref_float64(arg: Annotated[ArrayLike, dict(dtype='float64', shape=(None))], /) -> bool: ... | ||
def check_rw_by_const_ref_float64(arg: Annotated[NDArray, dict(dtype='float64', shape=(None))], /) -> bool: ... | ||
|
||
def check_rw_by_rvalue_ref(arg: ArrayLike, /) -> bool: ... | ||
def check_rw_by_rvalue_ref(arg: NDArray, /) -> bool: ... | ||
|
||
def check_rw_by_rvalue_ref_float64(arg: Annotated[ArrayLike, dict(dtype='float64', shape=(None))], /) -> bool: ... | ||
def check_rw_by_rvalue_ref_float64(arg: Annotated[NDArray, dict(dtype='float64', shape=(None))], /) -> bool: ... | ||
|
||
def check_rw_by_value(arg: ArrayLike, /) -> bool: ... | ||
def check_rw_by_value(arg: NDArray, /) -> bool: ... | ||
|
||
def check_rw_by_value_float64(arg: Annotated[ArrayLike, dict(dtype='float64', shape=(None))], /) -> bool: ... | ||
def check_rw_by_value_float64(arg: Annotated[NDArray, dict(dtype='float64', shape=(None))], /) -> bool: ... | ||
|
||
def check_shape_ptr(arg: ArrayLike, /) -> bool: ... | ||
def check_shape_ptr(arg: NDArray, /) -> bool: ... | ||
|
||
def check_stride_ptr(arg: ArrayLike, /) -> bool: ... | ||
def check_stride_ptr(arg: NDArray, /) -> bool: ... | ||
|
||
def destruct_count() -> int: ... | ||
|
||
def fill_view_1(x: ArrayLike) -> None: ... | ||
def fill_view_1(x: NDArray) -> None: ... | ||
|
||
def fill_view_2(x: Annotated[ArrayLike, dict(dtype='float32', shape=(None, None), device='cpu')]) -> None: ... | ||
def fill_view_2(x: Annotated[NDArray, dict(dtype='float32', shape=(None, None), device='cpu')]) -> None: ... | ||
|
||
def fill_view_3(x: Annotated[ArrayLike, dict(dtype='float32', shape=(3, 4), order='C', device='cpu')]) -> None: ... | ||
def fill_view_3(x: Annotated[NDArray, dict(dtype='float32', shape=(3, 4), order='C', device='cpu')]) -> None: ... | ||
|
||
def fill_view_4(x: Annotated[ArrayLike, dict(dtype='float32', shape=(3, 4), order='F', device='cpu')]) -> None: ... | ||
def fill_view_4(x: Annotated[NDArray, dict(dtype='float32', shape=(3, 4), order='F', device='cpu')]) -> None: ... | ||
|
||
def fill_view_5(x: Annotated[ArrayLike, dict(dtype='complex64', shape=(2, 2), order='C', device='cpu')]) -> None: ... | ||
def fill_view_5(x: Annotated[NDArray, dict(dtype='complex64', shape=(2, 2), order='C', device='cpu')]) -> None: ... | ||
|
||
def fill_view_6(x: Annotated[ArrayLike, dict(dtype='complex64', shape=(2, 2), order='C', device='cpu')]) -> None: ... | ||
def fill_view_6(x: Annotated[NDArray, dict(dtype='complex64', shape=(2, 2), order='C', device='cpu')]) -> None: ... | ||
|
||
def get_is_valid(array: Annotated[ArrayLike, dict(writable=False)] | None) -> bool: ... | ||
def get_is_valid(array: Annotated[NDArray, dict(writable=False)] | None) -> bool: ... | ||
|
||
def get_itemsize(array: ArrayLike | None) -> int: ... | ||
def get_itemsize(array: NDArray | None) -> int: ... | ||
|
||
def get_nbytes(array: ArrayLike | None) -> int: ... | ||
def get_nbytes(array: NDArray | None) -> int: ... | ||
|
||
def get_shape(array: Annotated[ArrayLike, dict(writable=False)]) -> list: ... | ||
def get_shape(array: Annotated[NDArray, dict(writable=False)]) -> list: ... | ||
|
||
def get_size(array: ArrayLike | None) -> int: ... | ||
def get_size(array: NDArray | None) -> int: ... | ||
|
||
def get_stride(array: ArrayLike, i: int) -> int: ... | ||
def get_stride(array: NDArray, i: int) -> int: ... | ||
|
||
def implicit(array: Annotated[ArrayLike, dict(dtype='float32', order='C', shape=(2, 2))]) -> int: ... | ||
def implicit(array: Annotated[NDArray, dict(dtype='float32', order='C', shape=(2, 2))]) -> int: ... | ||
|
||
@overload | ||
def initialize(arg: Annotated[ArrayLike, dict(dtype='float32', shape=(10), device='cpu')], /) -> None: ... | ||
def initialize(arg: Annotated[NDArray, dict(dtype='float32', shape=(10), device='cpu')], /) -> None: ... | ||
|
||
@overload | ||
def initialize(arg: Annotated[ArrayLike, dict(dtype='float32', shape=(10, None), device='cpu')], /) -> None: ... | ||
def initialize(arg: Annotated[NDArray, dict(dtype='float32', shape=(10, None), device='cpu')], /) -> None: ... | ||
|
||
def inspect_ndarray(arg: ArrayLike, /) -> None: ... | ||
def inspect_ndarray(arg: NDArray, /) -> None: ... | ||
|
||
def make_contig(arg: Annotated[ArrayLike, dict(order='C')], /) -> Annotated[ArrayLike, dict(order='C')]: ... | ||
def make_contig(arg: Annotated[NDArray, dict(order='C')], /) -> Annotated[NDArray, dict(order='C')]: ... | ||
|
||
def noimplicit(array: Annotated[ArrayLike, dict(dtype='float32', order='C', shape=(2, 2))]) -> int: ... | ||
def noimplicit(array: Annotated[NDArray, dict(dtype='float32', order='C', shape=(2, 2))]) -> int: ... | ||
|
||
def noop_2d_f_contig(arg: Annotated[ArrayLike, dict(dtype='float32', shape=(None, None), order='F')], /) -> None: ... | ||
def noop_2d_f_contig(arg: Annotated[NDArray, dict(dtype='float32', shape=(None, None), order='F')], /) -> None: ... | ||
|
||
def noop_3d_c_contig(arg: Annotated[ArrayLike, dict(dtype='float32', shape=(None, None, None), order='C')], /) -> None: ... | ||
def noop_3d_c_contig(arg: Annotated[NDArray, dict(dtype='float32', shape=(None, None, None), order='C')], /) -> None: ... | ||
|
||
def pass_bool(array: Annotated[ArrayLike, dict(dtype='bool')]) -> None: ... | ||
def pass_bool(array: Annotated[NDArray, dict(dtype='bool')]) -> None: ... | ||
|
||
def pass_complex64(array: Annotated[ArrayLike, dict(dtype='complex64')]) -> None: ... | ||
def pass_complex64(array: Annotated[NDArray, dict(dtype='complex64')]) -> None: ... | ||
|
||
def pass_complex64_const(array: Annotated[ArrayLike, dict(dtype='complex64', writable=False)]) -> None: ... | ||
def pass_complex64_const(array: Annotated[NDArray, dict(dtype='complex64', writable=False)]) -> None: ... | ||
|
||
def pass_float32(array: Annotated[ArrayLike, dict(dtype='float32')]) -> None: ... | ||
def pass_float32(array: Annotated[NDArray, dict(dtype='float32')]) -> None: ... | ||
|
||
def pass_float32_const(array: Annotated[ArrayLike, dict(dtype='float32', writable=False)]) -> None: ... | ||
def pass_float32_const(array: Annotated[NDArray, dict(dtype='float32', writable=False)]) -> None: ... | ||
|
||
def pass_float32_shaped(array: Annotated[ArrayLike, dict(dtype='float32', shape=(3, None, 4))]) -> None: ... | ||
def pass_float32_shaped(array: Annotated[NDArray, dict(dtype='float32', shape=(3, None, 4))]) -> None: ... | ||
|
||
def pass_float32_shaped_ordered(array: Annotated[ArrayLike, dict(dtype='float32', order='C', shape=(None, None, 4))]) -> None: ... | ||
def pass_float32_shaped_ordered(array: Annotated[NDArray, dict(dtype='float32', order='C', shape=(None, None, 4))]) -> None: ... | ||
|
||
def pass_uint32(array: Annotated[ArrayLike, dict(dtype='uint32')]) -> None: ... | ||
def pass_uint32(array: Annotated[NDArray, dict(dtype='uint32')]) -> None: ... | ||
|
||
def passthrough(arg: ArrayLike, /) -> ArrayLike: ... | ||
def passthrough(arg: NDArray, /) -> NDArray: ... | ||
|
||
def passthrough_arg_none(arg: ArrayLike | None) -> ArrayLike: ... | ||
def passthrough_arg_none(arg: NDArray | None) -> NDArray: ... | ||
|
||
def passthrough_copy(arg: ArrayLike, /) -> ArrayLike: ... | ||
def passthrough_copy(arg: NDArray, /) -> NDArray: ... | ||
|
||
def process(arg: Annotated[ArrayLike, dict(dtype='uint8', shape=(None, None, 3), order='C', device='cpu')], /) -> None: ... | ||
def process(arg: Annotated[NDArray, dict(dtype='uint8', shape=(None, None, 3), order='C', device='cpu')], /) -> None: ... | ||
|
||
def ret_array_scalar() -> Annotated[ArrayLike, dict(dtype='float32')]: ... | ||
def ret_array_scalar() -> Annotated[numpy.typing.NDArray[numpy.float32], dict(dtype='float32')]: ... | ||
|
||
def ret_numpy() -> Annotated[ArrayLike, dict(dtype='float32', shape=(2, 4))]: ... | ||
def ret_numpy() -> Annotated[numpy.typing.NDArray[numpy.float32], dict(dtype='float32', shape=(2, 4))]: ... | ||
|
||
def ret_numpy_const() -> Annotated[ArrayLike, dict(dtype='float32', writable=False, shape=(2, 4))]: ... | ||
def ret_numpy_const() -> Annotated[numpy.typing.NDArray[numpy.float32], dict(dtype='float32', writable=False, shape=(2, 4))]: ... | ||
|
||
def ret_numpy_const_ref() -> Annotated[ArrayLike, dict(dtype='float32', writable=False, shape=(2, 4))]: ... | ||
def ret_numpy_const_ref() -> Annotated[numpy.typing.NDArray[numpy.float32], dict(dtype='float32', writable=False, shape=(2, 4))]: ... | ||
|
||
def ret_numpy_half() -> Annotated[ArrayLike, dict(dtype='float16', shape=(2, 4))]: ... | ||
def ret_numpy_half() -> Annotated[numpy.typing.NDArray[numpy.float16], dict(dtype='float16', shape=(2, 4))]: ... | ||
|
||
def ret_pytorch() -> Annotated[ArrayLike, dict(dtype='float32', shape=(2, 4))]: ... | ||
def ret_pytorch() -> Annotated[torch.Tensor, dict(dtype='float32', shape=(2, 4))]: ... | ||
|
||
def return_dlpack() -> Annotated[ArrayLike, dict(dtype='float32', shape=(2, 4))]: ... | ||
def return_dlpack() -> Annotated[NDArray, dict(dtype='float32', shape=(2, 4))]: ... | ||
|
||
@overload | ||
def set_item(arg0: Annotated[ArrayLike, dict(dtype='float64', shape=(None), order='C')], arg1: int, /) -> None: ... | ||
def set_item(arg0: Annotated[NDArray, dict(dtype='float64', shape=(None), order='C')], arg1: int, /) -> None: ... | ||
|
||
@overload | ||
def set_item(arg0: Annotated[ArrayLike, dict(dtype='complex128', shape=(None), order='C')], arg1: int, /) -> None: ... | ||
def set_item(arg0: Annotated[NDArray, dict(dtype='complex128', shape=(None), order='C')], arg1: int, /) -> None: ... |
Oops, something went wrong.