Skip to content

Commit

Permalink
Add NDArray protocol class for nd-array annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
yosh-matsuda committed Aug 17, 2024
1 parent 8e5fd14 commit 2651557
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 87 deletions.
109 changes: 94 additions & 15 deletions src/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,11 +250,28 @@ def __init__(
+ sep_after
)

# Precompile RE to extract nanobind nd-arrays
self.ndarray_re = re.compile(
sep_before + r"(numpy\.ndarray|ndarray|torch\.Tensor)\[([^\]]*)\]"
# Precompile RE to extract known nd-arrays
self.known_ndarray_re = re.compile(
sep_before
+ "("
+ "|".join(
[
r"numpy\.ndarray",
r"torch\.Tensor",
r"tensorflow\.python\.framework\.ops\.EagerTensor",
r"jaxlib\.xla_extension\.DeviceArray",
]
)
+ ")"
+ r"\[([^\]]*)\]"
)

# Precompile RE to extract nanobind nd-arrays
self.nb_ndarray_re = re.compile(sep_before + "(ndarray)" + r"\[([^\]]*)\]")

# Insert ndarray class
self.ndarray_class = False

# Types which moved from typing.* to collections.abc in Python 3.9
self.abc_re = re.compile(
'typing.(AsyncGenerator|AsyncIterable|AsyncIterator|Awaitable|Callable|'
Expand Down Expand Up @@ -606,7 +623,10 @@ def simplify_types(self, s: str) -> str:
- "NoneType" -> "None"
- "ndarray[...]" -> "Annotated[ArrayLike, dict(...)]"
- "<numpy|torch|tensorflow|jax array>[...]" -> "Annotated[<array>, dict(...)]"
- "ndarray[...]" -> "Annotated[NDArray, dict(...)]"
(with array protocol class added at top)
- "collections.abc.X" -> "X"
(with "from collections.abc import X" added at top)
Expand All @@ -616,22 +636,62 @@ def simplify_types(self, s: str) -> str:
changed to 'collections.abc' on newer Python versions)
"""

# Process nd-array type annotations so that MyPy accepts them
def process_ndarray(m: Match[str]) -> str:
s = m.group(2)
# Process nd-array type annotations with metadata
def process_known_ndarray(m: re.Match[str]) -> str:
ndarray_type = m.group(1)
meta = m.group(2)

ndarray = self.import_object("numpy.typing", "ArrayLike")
assert ndarray
s = re.sub(r"dtype=([\w]*)\b", r"dtype='\g<1>'", s)
s = s.replace("*", "None")
if not meta:
return ndarray_type

if s:
if ndarray_type == "numpy.ndarray":
dm = re.search(r"dtype=([\w]*)\b", meta)
if dm and dm.group(1):
dtype = dm.group(1).replace("bool", "bool_")
ndarray_type = f"numpy.typing.NDArray[numpy.{dtype}]"

meta = re.sub(r"dtype=([\w]*)\b", r"dtype='\g<1>'", meta)
meta = meta.replace("*", "None")

if sys.version_info >= (3, 9, 0):
annotated = self.import_object("typing", "Annotated")
return f"{annotated}[{ndarray}, dict({s})]"
else:
return ndarray
annotated = self.import_object("typing_extensions", "Annotated")
return f"{annotated}[{ndarray_type}, dict({meta})]"

s = self.known_ndarray_re.sub(process_known_ndarray, s)

# Process nb-ndarray type annotations with metadata
def process_nb_ndarray(m: re.Match[str]) -> str:
ndarray_type = "NDArray"
meta = m.group(2)

s = self.ndarray_re.sub(process_ndarray, s)
self.ndarray_class = True

self.import_object("typing", "Protocol")
if sys.version_info >= (3, 12, 0):
self.import_object("collections.abc", "Buffer")
else:
self.import_object("typing_extensions", "Buffer")
if sys.version_info >= (3, 10, 0):
self.import_object("typing", "TypeAlias")
else:
self.import_object("typing", "Union")
self.import_object("typing_extensions", "TypeAlias")

if not meta:
return ndarray_type

meta = re.sub(r"dtype=([\w]*)\b", r"dtype='\g<1>'", meta)
meta = meta.replace("*", "None")

if sys.version_info >= (3, 9, 0):
annotated = self.import_object("typing", "Annotated")
else:
annotated = self.import_object("typing_extensions", "Annotated")
return f"{annotated}[{ndarray_type}, dict({meta})]"

s = self.nb_ndarray_re.sub(process_nb_ndarray, s)

if sys.version_info >= (3, 9, 0):
s = self.abc_re.sub(r'collections.abc.\1', s)
Expand Down Expand Up @@ -1143,12 +1203,31 @@ def get(self) -> str:
s += items_v0 if len(items_v0) <= 70 else items_v1

s += "\n\n"
s += self.put_ndarray_class()

# Append the main generated stub
s += self.output

return s.rstrip() + "\n"

def put_ndarray_class(self) -> str:
s = ""
if not self.ndarray_class:
return s

s += "class DLPackBuffer(Protocol):\n"
s += " def __dlpack__(self) -> object: ...\n"
s += "\n"
if sys.version_info >= (3, 12, 0):
s += "type NDArray = Buffer | DLPackBuffer\n"
elif sys.version_info >= (3, 10, 0):
s += "NDArray: TypeAlias = Buffer | DLPackBuffer\n"
else:
s += "NDArray: TypeAlias = Union[Buffer, DLPackBuffer]\n"
s += "\n"

return s

def parse_options(args: List[str]) -> argparse.Namespace:
parser = argparse.ArgumentParser(
prog="python -m nanobind.stubgen",
Expand Down
152 changes: 80 additions & 72 deletions tests/test_ndarray_ext.pyi.ref
Original file line number Diff line number Diff line change
@@ -1,160 +1,168 @@
from typing import Annotated, overload
from collections.abc import Buffer

from numpy.typing import ArrayLike
import numpy
import numpy.typing
import torch

from typing import Annotated, Protocol, overload

class DLPackBuffer(Protocol):
def __dlpack__(self) -> object: ...

type NDArray = Buffer | DLPackBuffer

class Cls:
def __init__(self) -> None: ...

def f1(self) -> Annotated[ArrayLike, dict(dtype='float32')]: ...
def f1(self) -> Annotated[numpy.typing.NDArray[numpy.float32], dict(dtype='float32')]: ...

def f2(self) -> Annotated[ArrayLike, dict(dtype='float32')]: ...
def f2(self) -> Annotated[numpy.typing.NDArray[numpy.float32], dict(dtype='float32')]: ...

def f1_ri(self) -> Annotated[ArrayLike, dict(dtype='float32')]: ...
def f1_ri(self) -> Annotated[numpy.typing.NDArray[numpy.float32], dict(dtype='float32')]: ...

def f2_ri(self) -> Annotated[ArrayLike, dict(dtype='float32')]: ...
def f2_ri(self) -> Annotated[numpy.typing.NDArray[numpy.float32], dict(dtype='float32')]: ...

def f3_ri(self, arg: object, /) -> Annotated[ArrayLike, dict(dtype='float32')]: ...
def f3_ri(self, arg: object, /) -> Annotated[numpy.typing.NDArray[numpy.float32], dict(dtype='float32')]: ...

def accept_ro(arg: Annotated[ArrayLike, dict(dtype='float32', writable=False, shape=(2))], /) -> float: ...
def accept_ro(arg: Annotated[NDArray, dict(dtype='float32', writable=False, shape=(2))], /) -> float: ...

def accept_rw(arg: Annotated[ArrayLike, dict(dtype='float32', shape=(2))], /) -> float: ...
def accept_rw(arg: Annotated[NDArray, dict(dtype='float32', shape=(2))], /) -> float: ...

def cast(arg: bool, /) -> ArrayLike: ...
def cast(arg: bool, /) -> numpy.ndarray: ...

def check(arg: object, /) -> bool: ...

def check_bool(arg: ArrayLike, /) -> bool: ...
def check_bool(arg: NDArray, /) -> bool: ...

@overload
def check_device(arg: Annotated[ArrayLike, dict(device='cpu')], /) -> str: ...
def check_device(arg: Annotated[NDArray, dict(device='cpu')], /) -> str: ...

@overload
def check_device(arg: Annotated[ArrayLike, dict(device='cuda')], /) -> str: ...
def check_device(arg: Annotated[NDArray, dict(device='cuda')], /) -> str: ...

def check_float(arg: ArrayLike, /) -> bool: ...
def check_float(arg: NDArray, /) -> bool: ...

@overload
def check_order(arg: Annotated[ArrayLike, dict(order='C')], /) -> str: ...
def check_order(arg: Annotated[NDArray, dict(order='C')], /) -> str: ...

@overload
def check_order(arg: Annotated[ArrayLike, dict(order='F')], /) -> str: ...
def check_order(arg: Annotated[NDArray, dict(order='F')], /) -> str: ...

@overload
def check_order(arg: ArrayLike, /) -> str: ...
def check_order(arg: NDArray, /) -> str: ...

def check_ro_by_const_ref_const_float64(arg: Annotated[ArrayLike, dict(dtype='float64', writable=False, shape=(None))], /) -> bool: ...
def check_ro_by_const_ref_const_float64(arg: Annotated[NDArray, dict(dtype='float64', writable=False, shape=(None))], /) -> bool: ...

def check_ro_by_const_ref_ro(arg: Annotated[ArrayLike, dict(writable=False)], /) -> bool: ...
def check_ro_by_const_ref_ro(arg: Annotated[NDArray, dict(writable=False)], /) -> bool: ...

def check_ro_by_rvalue_ref_const_float64(arg: Annotated[ArrayLike, dict(dtype='float64', writable=False, shape=(None))], /) -> bool: ...
def check_ro_by_rvalue_ref_const_float64(arg: Annotated[NDArray, dict(dtype='float64', writable=False, shape=(None))], /) -> bool: ...

def check_ro_by_rvalue_ref_ro(arg: Annotated[ArrayLike, dict(writable=False)], /) -> bool: ...
def check_ro_by_rvalue_ref_ro(arg: Annotated[NDArray, dict(writable=False)], /) -> bool: ...

def check_ro_by_value_const_float64(arg: Annotated[ArrayLike, dict(dtype='float64', writable=False, shape=(None))], /) -> bool: ...
def check_ro_by_value_const_float64(arg: Annotated[NDArray, dict(dtype='float64', writable=False, shape=(None))], /) -> bool: ...

def check_ro_by_value_ro(arg: Annotated[ArrayLike, dict(writable=False)], /) -> bool: ...
def check_ro_by_value_ro(arg: Annotated[NDArray, dict(writable=False)], /) -> bool: ...

def check_rw_by_const_ref(arg: ArrayLike, /) -> bool: ...
def check_rw_by_const_ref(arg: NDArray, /) -> bool: ...

def check_rw_by_const_ref_float64(arg: Annotated[ArrayLike, dict(dtype='float64', shape=(None))], /) -> bool: ...
def check_rw_by_const_ref_float64(arg: Annotated[NDArray, dict(dtype='float64', shape=(None))], /) -> bool: ...

def check_rw_by_rvalue_ref(arg: ArrayLike, /) -> bool: ...
def check_rw_by_rvalue_ref(arg: NDArray, /) -> bool: ...

def check_rw_by_rvalue_ref_float64(arg: Annotated[ArrayLike, dict(dtype='float64', shape=(None))], /) -> bool: ...
def check_rw_by_rvalue_ref_float64(arg: Annotated[NDArray, dict(dtype='float64', shape=(None))], /) -> bool: ...

def check_rw_by_value(arg: ArrayLike, /) -> bool: ...
def check_rw_by_value(arg: NDArray, /) -> bool: ...

def check_rw_by_value_float64(arg: Annotated[ArrayLike, dict(dtype='float64', shape=(None))], /) -> bool: ...
def check_rw_by_value_float64(arg: Annotated[NDArray, dict(dtype='float64', shape=(None))], /) -> bool: ...

def check_shape_ptr(arg: ArrayLike, /) -> bool: ...
def check_shape_ptr(arg: NDArray, /) -> bool: ...

def check_stride_ptr(arg: ArrayLike, /) -> bool: ...
def check_stride_ptr(arg: NDArray, /) -> bool: ...

def destruct_count() -> int: ...

def fill_view_1(x: ArrayLike) -> None: ...
def fill_view_1(x: NDArray) -> None: ...

def fill_view_2(x: Annotated[ArrayLike, dict(dtype='float32', shape=(None, None), device='cpu')]) -> None: ...
def fill_view_2(x: Annotated[NDArray, dict(dtype='float32', shape=(None, None), device='cpu')]) -> None: ...

def fill_view_3(x: Annotated[ArrayLike, dict(dtype='float32', shape=(3, 4), order='C', device='cpu')]) -> None: ...
def fill_view_3(x: Annotated[NDArray, dict(dtype='float32', shape=(3, 4), order='C', device='cpu')]) -> None: ...

def fill_view_4(x: Annotated[ArrayLike, dict(dtype='float32', shape=(3, 4), order='F', device='cpu')]) -> None: ...
def fill_view_4(x: Annotated[NDArray, dict(dtype='float32', shape=(3, 4), order='F', device='cpu')]) -> None: ...

def fill_view_5(x: Annotated[ArrayLike, dict(dtype='complex64', shape=(2, 2), order='C', device='cpu')]) -> None: ...
def fill_view_5(x: Annotated[NDArray, dict(dtype='complex64', shape=(2, 2), order='C', device='cpu')]) -> None: ...

def fill_view_6(x: Annotated[ArrayLike, dict(dtype='complex64', shape=(2, 2), order='C', device='cpu')]) -> None: ...
def fill_view_6(x: Annotated[NDArray, dict(dtype='complex64', shape=(2, 2), order='C', device='cpu')]) -> None: ...

def get_is_valid(array: Annotated[ArrayLike, dict(writable=False)] | None) -> bool: ...
def get_is_valid(array: Annotated[NDArray, dict(writable=False)] | None) -> bool: ...

def get_itemsize(array: ArrayLike | None) -> int: ...
def get_itemsize(array: NDArray | None) -> int: ...

def get_nbytes(array: ArrayLike | None) -> int: ...
def get_nbytes(array: NDArray | None) -> int: ...

def get_shape(array: Annotated[ArrayLike, dict(writable=False)]) -> list: ...
def get_shape(array: Annotated[NDArray, dict(writable=False)]) -> list: ...

def get_size(array: ArrayLike | None) -> int: ...
def get_size(array: NDArray | None) -> int: ...

def get_stride(array: ArrayLike, i: int) -> int: ...
def get_stride(array: NDArray, i: int) -> int: ...

def implicit(array: Annotated[ArrayLike, dict(dtype='float32', order='C', shape=(2, 2))]) -> int: ...
def implicit(array: Annotated[NDArray, dict(dtype='float32', order='C', shape=(2, 2))]) -> int: ...

@overload
def initialize(arg: Annotated[ArrayLike, dict(dtype='float32', shape=(10), device='cpu')], /) -> None: ...
def initialize(arg: Annotated[NDArray, dict(dtype='float32', shape=(10), device='cpu')], /) -> None: ...

@overload
def initialize(arg: Annotated[ArrayLike, dict(dtype='float32', shape=(10, None), device='cpu')], /) -> None: ...
def initialize(arg: Annotated[NDArray, dict(dtype='float32', shape=(10, None), device='cpu')], /) -> None: ...

def inspect_ndarray(arg: ArrayLike, /) -> None: ...
def inspect_ndarray(arg: NDArray, /) -> None: ...

def make_contig(arg: Annotated[ArrayLike, dict(order='C')], /) -> Annotated[ArrayLike, dict(order='C')]: ...
def make_contig(arg: Annotated[NDArray, dict(order='C')], /) -> Annotated[NDArray, dict(order='C')]: ...

def noimplicit(array: Annotated[ArrayLike, dict(dtype='float32', order='C', shape=(2, 2))]) -> int: ...
def noimplicit(array: Annotated[NDArray, dict(dtype='float32', order='C', shape=(2, 2))]) -> int: ...

def noop_2d_f_contig(arg: Annotated[ArrayLike, dict(dtype='float32', shape=(None, None), order='F')], /) -> None: ...
def noop_2d_f_contig(arg: Annotated[NDArray, dict(dtype='float32', shape=(None, None), order='F')], /) -> None: ...

def noop_3d_c_contig(arg: Annotated[ArrayLike, dict(dtype='float32', shape=(None, None, None), order='C')], /) -> None: ...
def noop_3d_c_contig(arg: Annotated[NDArray, dict(dtype='float32', shape=(None, None, None), order='C')], /) -> None: ...

def pass_bool(array: Annotated[ArrayLike, dict(dtype='bool')]) -> None: ...
def pass_bool(array: Annotated[NDArray, dict(dtype='bool')]) -> None: ...

def pass_complex64(array: Annotated[ArrayLike, dict(dtype='complex64')]) -> None: ...
def pass_complex64(array: Annotated[NDArray, dict(dtype='complex64')]) -> None: ...

def pass_complex64_const(array: Annotated[ArrayLike, dict(dtype='complex64', writable=False)]) -> None: ...
def pass_complex64_const(array: Annotated[NDArray, dict(dtype='complex64', writable=False)]) -> None: ...

def pass_float32(array: Annotated[ArrayLike, dict(dtype='float32')]) -> None: ...
def pass_float32(array: Annotated[NDArray, dict(dtype='float32')]) -> None: ...

def pass_float32_const(array: Annotated[ArrayLike, dict(dtype='float32', writable=False)]) -> None: ...
def pass_float32_const(array: Annotated[NDArray, dict(dtype='float32', writable=False)]) -> None: ...

def pass_float32_shaped(array: Annotated[ArrayLike, dict(dtype='float32', shape=(3, None, 4))]) -> None: ...
def pass_float32_shaped(array: Annotated[NDArray, dict(dtype='float32', shape=(3, None, 4))]) -> None: ...

def pass_float32_shaped_ordered(array: Annotated[ArrayLike, dict(dtype='float32', order='C', shape=(None, None, 4))]) -> None: ...
def pass_float32_shaped_ordered(array: Annotated[NDArray, dict(dtype='float32', order='C', shape=(None, None, 4))]) -> None: ...

def pass_uint32(array: Annotated[ArrayLike, dict(dtype='uint32')]) -> None: ...
def pass_uint32(array: Annotated[NDArray, dict(dtype='uint32')]) -> None: ...

def passthrough(arg: ArrayLike, /) -> ArrayLike: ...
def passthrough(arg: NDArray, /) -> NDArray: ...

def passthrough_arg_none(arg: ArrayLike | None) -> ArrayLike: ...
def passthrough_arg_none(arg: NDArray | None) -> NDArray: ...

def passthrough_copy(arg: ArrayLike, /) -> ArrayLike: ...
def passthrough_copy(arg: NDArray, /) -> NDArray: ...

def process(arg: Annotated[ArrayLike, dict(dtype='uint8', shape=(None, None, 3), order='C', device='cpu')], /) -> None: ...
def process(arg: Annotated[NDArray, dict(dtype='uint8', shape=(None, None, 3), order='C', device='cpu')], /) -> None: ...

def ret_array_scalar() -> Annotated[ArrayLike, dict(dtype='float32')]: ...
def ret_array_scalar() -> Annotated[numpy.typing.NDArray[numpy.float32], dict(dtype='float32')]: ...

def ret_numpy() -> Annotated[ArrayLike, dict(dtype='float32', shape=(2, 4))]: ...
def ret_numpy() -> Annotated[numpy.typing.NDArray[numpy.float32], dict(dtype='float32', shape=(2, 4))]: ...

def ret_numpy_const() -> Annotated[ArrayLike, dict(dtype='float32', writable=False, shape=(2, 4))]: ...
def ret_numpy_const() -> Annotated[numpy.typing.NDArray[numpy.float32], dict(dtype='float32', writable=False, shape=(2, 4))]: ...

def ret_numpy_const_ref() -> Annotated[ArrayLike, dict(dtype='float32', writable=False, shape=(2, 4))]: ...
def ret_numpy_const_ref() -> Annotated[numpy.typing.NDArray[numpy.float32], dict(dtype='float32', writable=False, shape=(2, 4))]: ...

def ret_numpy_half() -> Annotated[ArrayLike, dict(dtype='float16', shape=(2, 4))]: ...
def ret_numpy_half() -> Annotated[numpy.typing.NDArray[numpy.float16], dict(dtype='float16', shape=(2, 4))]: ...

def ret_pytorch() -> Annotated[ArrayLike, dict(dtype='float32', shape=(2, 4))]: ...
def ret_pytorch() -> Annotated[torch.Tensor, dict(dtype='float32', shape=(2, 4))]: ...

def return_dlpack() -> Annotated[ArrayLike, dict(dtype='float32', shape=(2, 4))]: ...
def return_dlpack() -> Annotated[NDArray, dict(dtype='float32', shape=(2, 4))]: ...

@overload
def set_item(arg0: Annotated[ArrayLike, dict(dtype='float64', shape=(None), order='C')], arg1: int, /) -> None: ...
def set_item(arg0: Annotated[NDArray, dict(dtype='float64', shape=(None), order='C')], arg1: int, /) -> None: ...

@overload
def set_item(arg0: Annotated[ArrayLike, dict(dtype='complex128', shape=(None), order='C')], arg1: int, /) -> None: ...
def set_item(arg0: Annotated[NDArray, dict(dtype='complex128', shape=(None), order='C')], arg1: int, /) -> None: ...
Loading

0 comments on commit 2651557

Please sign in to comment.