Skip to content

Commit 37dd2c6

Browse files
committed
stubgen.py: do not import NumPy dtypes individually (clashes with builtins)
1 parent 10cc510 commit 37dd2c6

File tree

2 files changed

+61
-71
lines changed

2 files changed

+61
-71
lines changed

src/stubgen.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -638,14 +638,14 @@ def simplify_types(self, s: str) -> str:
638638
639639
- "local_module.X" -> "X"
640640
641-
- "other_module.X" -> "other_module.XX"
641+
- "other_module.X" -> "other_module.X"
642642
(with "import other_module" added at top)
643643
644644
- "builtins.X" -> "X"
645645
646646
- "NoneType" -> "None"
647647
648-
- "ndarray[...]" -> "Annotated[NDArray, dict(...)]"
648+
- "ndarray[...]" -> "Annotated[NDArray[dtype], dict(..extras..)]"
649649
650650
- "collections.abc.X" -> "X"
651651
(with "from collections.abc import X" added at top)
@@ -710,19 +710,18 @@ def is_valid_module(module_name: str) -> bool:
710710

711711
def _format_ndarray(self, annotation: str) -> str:
712712
"""Improve NumPy type annotations for static type checking"""
713-
ndarray = self.import_object("numpy.typing", "NDArray")
714-
715-
# Extract and remove dtype if present
716713
dtype = None
717714
m = re.search(r"dtype=(\w+)", annotation)
715+
718716
if m:
719-
dtype = self.import_object("numpy", m.group(1))
717+
dtype = "numpy."+ m.group(1)
720718
annotation = re.sub(r"dtype=\w+,?\s*", "", annotation).rstrip(", ")
721719

722720
# Turn shape notation into a valid Python type expression
723721
annotation = annotation.replace("*", "None").replace("(None)", "(None,)")
724722

725723
# Build type while potentially preserving extra information as an annotation
724+
ndarray = self.import_object("numpy.typing", "NDArray")
726725
result = f"{ndarray}[{dtype}]" if dtype else ndarray
727726
if annotation:
728727
annotated = self.import_object("typing", "Annotated")

tests/test_ndarray_ext.pyi.ref

Lines changed: 56 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,6 @@
11
from typing import Annotated, overload
22

3-
from numpy import (
4-
bool,
5-
complex128,
6-
complex64,
7-
float16,
8-
float32,
9-
float64,
10-
uint32,
11-
uint8
12-
)
3+
import numpy
134
from numpy.typing import NDArray
145

156

@@ -33,45 +24,45 @@ def check_float(arg: NDArray, /) -> bool: ...
3324

3425
def check_bool(arg: NDArray, /) -> bool: ...
3526

36-
def pass_float32(array: NDArray[float32]) -> None: ...
27+
def pass_float32(array: NDArray[numpy.float32]) -> None: ...
3728

38-
def pass_float32_const(array: Annotated[NDArray[float32], dict(writable=False)]) -> None: ...
29+
def pass_float32_const(array: Annotated[NDArray[numpy.float32], dict(writable=False)]) -> None: ...
3930

40-
def pass_complex64(array: NDArray[complex64]) -> None: ...
31+
def pass_complex64(array: NDArray[numpy.complex64]) -> None: ...
4132

42-
def pass_complex64_const(array: Annotated[NDArray[complex64], dict(writable=False)]) -> None: ...
33+
def pass_complex64_const(array: Annotated[NDArray[numpy.complex64], dict(writable=False)]) -> None: ...
4334

44-
def pass_uint32(array: NDArray[uint32]) -> None: ...
35+
def pass_uint32(array: NDArray[numpy.uint32]) -> None: ...
4536

46-
def pass_bool(array: NDArray[bool]) -> None: ...
37+
def pass_bool(array: NDArray[numpy.bool]) -> None: ...
4738

48-
def pass_float32_shaped(array: Annotated[NDArray[float32], dict(shape=(3, None, 4))]) -> None: ...
39+
def pass_float32_shaped(array: Annotated[NDArray[numpy.float32], dict(shape=(3, None, 4))]) -> None: ...
4940

50-
def pass_float32_shaped_ordered(array: Annotated[NDArray[float32], dict(shape=(None, None, 4), order='C')]) -> None: ...
41+
def pass_float32_shaped_ordered(array: Annotated[NDArray[numpy.float32], dict(shape=(None, None, 4), order='C')]) -> None: ...
5142

5243
def check_rw_by_value(arg: NDArray, /) -> bool: ...
5344

5445
def check_ro_by_value_ro(arg: Annotated[NDArray, dict(writable=False)], /) -> bool: ...
5546

56-
def check_rw_by_value_float64(arg: Annotated[NDArray[float64], dict(shape=(None,))], /) -> bool: ...
47+
def check_rw_by_value_float64(arg: Annotated[NDArray[numpy.float64], dict(shape=(None,))], /) -> bool: ...
5748

58-
def check_ro_by_value_const_float64(arg: Annotated[NDArray[float64], dict(shape=(None,), writable=False)], /) -> bool: ...
49+
def check_ro_by_value_const_float64(arg: Annotated[NDArray[numpy.float64], dict(shape=(None,), writable=False)], /) -> bool: ...
5950

6051
def check_rw_by_const_ref(arg: NDArray, /) -> bool: ...
6152

6253
def check_ro_by_const_ref_ro(arg: Annotated[NDArray, dict(writable=False)], /) -> bool: ...
6354

64-
def check_rw_by_const_ref_float64(arg: Annotated[NDArray[float64], dict(shape=(None,))], /) -> bool: ...
55+
def check_rw_by_const_ref_float64(arg: Annotated[NDArray[numpy.float64], dict(shape=(None,))], /) -> bool: ...
6556

66-
def check_ro_by_const_ref_const_float64(arg: Annotated[NDArray[float64], dict(shape=(None,), writable=False)], /) -> bool: ...
57+
def check_ro_by_const_ref_const_float64(arg: Annotated[NDArray[numpy.float64], dict(shape=(None,), writable=False)], /) -> bool: ...
6758

6859
def check_rw_by_rvalue_ref(arg: NDArray, /) -> bool: ...
6960

7061
def check_ro_by_rvalue_ref_ro(arg: Annotated[NDArray, dict(writable=False)], /) -> bool: ...
7162

72-
def check_rw_by_rvalue_ref_float64(arg: Annotated[NDArray[float64], dict(shape=(None,))], /) -> bool: ...
63+
def check_rw_by_rvalue_ref_float64(arg: Annotated[NDArray[numpy.float64], dict(shape=(None,))], /) -> bool: ...
7364

74-
def check_ro_by_rvalue_ref_const_float64(arg: Annotated[NDArray[float64], dict(shape=(None,), writable=False)], /) -> bool: ...
65+
def check_ro_by_rvalue_ref_const_float64(arg: Annotated[NDArray[numpy.float64], dict(shape=(None,), writable=False)], /) -> bool: ...
7566

7667
@overload
7768
def check_order(arg: Annotated[NDArray, dict(order='C')], /) -> str: ...
@@ -91,116 +82,116 @@ def check_device(arg: Annotated[NDArray, dict(device='cpu')], /) -> str: ...
9182
def check_device(arg: Annotated[NDArray, dict(device='cuda')], /) -> str: ...
9283

9384
@overload
94-
def initialize(arg: Annotated[NDArray[float32], dict(shape=(10), device='cpu')], /) -> None: ...
85+
def initialize(arg: Annotated[NDArray[numpy.float32], dict(shape=(10), device='cpu')], /) -> None: ...
9586

9687
@overload
97-
def initialize(arg: Annotated[NDArray[float32], dict(shape=(10, None), device='cpu')], /) -> None: ...
88+
def initialize(arg: Annotated[NDArray[numpy.float32], dict(shape=(10, None), device='cpu')], /) -> None: ...
9889

99-
def noimplicit(array: Annotated[NDArray[float32], dict(shape=(2, 2), order='C')]) -> int: ...
90+
def noimplicit(array: Annotated[NDArray[numpy.float32], dict(shape=(2, 2), order='C')]) -> int: ...
10091

101-
def implicit(array: Annotated[NDArray[float32], dict(shape=(2, 2), order='C')]) -> int: ...
92+
def implicit(array: Annotated[NDArray[numpy.float32], dict(shape=(2, 2), order='C')]) -> int: ...
10293

10394
def inspect_ndarray(arg: NDArray, /) -> None: ...
10495

105-
def process(arg: Annotated[NDArray[uint8], dict(shape=(None, None, 3), order='C', device='cpu')], /) -> None: ...
96+
def process(arg: Annotated[NDArray[numpy.uint8], dict(shape=(None, None, 3), order='C', device='cpu')], /) -> None: ...
10697

10798
def destruct_count() -> int: ...
10899

109-
def return_dlpack() -> Annotated[NDArray[float32], dict(shape=(2, 4))]: ...
100+
def return_dlpack() -> Annotated[NDArray[numpy.float32], dict(shape=(2, 4))]: ...
110101

111102
def passthrough(arg: NDArray, /) -> NDArray: ...
112103

113104
def passthrough_copy(arg: NDArray, /) -> NDArray: ...
114105

115106
def passthrough_arg_none(arg: NDArray | None) -> NDArray: ...
116107

117-
def ret_numpy() -> Annotated[NDArray[float32], dict(shape=(2, 4))]: ...
108+
def ret_numpy() -> Annotated[NDArray[numpy.float32], dict(shape=(2, 4))]: ...
118109

119-
def ret_numpy_const_ref() -> Annotated[NDArray[float32], dict(shape=(2, 4), order='C', writable=False)]: ...
110+
def ret_numpy_const_ref() -> Annotated[NDArray[numpy.float32], dict(shape=(2, 4), order='C', writable=False)]: ...
120111

121-
def ret_numpy_const_ref_f() -> Annotated[NDArray[float32], dict(shape=(2, 4), order='F', writable=False)]: ...
112+
def ret_numpy_const_ref_f() -> Annotated[NDArray[numpy.float32], dict(shape=(2, 4), order='F', writable=False)]: ...
122113

123-
def ret_numpy_const() -> Annotated[NDArray[float32], dict(shape=(2, 4), writable=False)]: ...
114+
def ret_numpy_const() -> Annotated[NDArray[numpy.float32], dict(shape=(2, 4), writable=False)]: ...
124115

125-
def ret_pytorch() -> Annotated[NDArray[float32], dict(shape=(2, 4))]: ...
116+
def ret_pytorch() -> Annotated[NDArray[numpy.float32], dict(shape=(2, 4))]: ...
126117

127-
def ret_array_scalar() -> NDArray[float32]: ...
118+
def ret_array_scalar() -> NDArray[numpy.float32]: ...
128119

129-
def noop_3d_c_contig(arg: Annotated[NDArray[float32], dict(shape=(None, None, None), order='C')], /) -> None: ...
120+
def noop_3d_c_contig(arg: Annotated[NDArray[numpy.float32], dict(shape=(None, None, None), order='C')], /) -> None: ...
130121

131-
def noop_2d_f_contig(arg: Annotated[NDArray[float32], dict(shape=(None, None), order='F')], /) -> None: ...
122+
def noop_2d_f_contig(arg: Annotated[NDArray[numpy.float32], dict(shape=(None, None), order='F')], /) -> None: ...
132123

133-
def accept_rw(arg: Annotated[NDArray[float32], dict(shape=(2))], /) -> float: ...
124+
def accept_rw(arg: Annotated[NDArray[numpy.float32], dict(shape=(2))], /) -> float: ...
134125

135-
def accept_ro(arg: Annotated[NDArray[float32], dict(shape=(2), writable=False)], /) -> float: ...
126+
def accept_ro(arg: Annotated[NDArray[numpy.float32], dict(shape=(2), writable=False)], /) -> float: ...
136127

137128
def check(arg: object, /) -> bool: ...
138129

139-
def accept_np_both_true_contig_a(arg: Annotated[NDArray[float32], dict(shape=(2, 1), order='A')], /) -> float: ...
130+
def accept_np_both_true_contig_a(arg: Annotated[NDArray[numpy.float32], dict(shape=(2, 1), order='A')], /) -> float: ...
140131

141-
def accept_np_both_true_contig_c(arg: Annotated[NDArray[float32], dict(shape=(2, 1), order='C')], /) -> float: ...
132+
def accept_np_both_true_contig_c(arg: Annotated[NDArray[numpy.float32], dict(shape=(2, 1), order='C')], /) -> float: ...
142133

143-
def accept_np_both_true_contig_f(arg: Annotated[NDArray[float32], dict(shape=(2, 1), order='F')], /) -> float: ...
134+
def accept_np_both_true_contig_f(arg: Annotated[NDArray[numpy.float32], dict(shape=(2, 1), order='F')], /) -> float: ...
144135

145136
class Cls:
146137
def __init__(self) -> None: ...
147138

148-
def f1(self) -> NDArray[float32]: ...
139+
def f1(self) -> NDArray[numpy.float32]: ...
149140

150-
def f2(self) -> NDArray[float32]: ...
141+
def f2(self) -> NDArray[numpy.float32]: ...
151142

152-
def f1_ri(self) -> NDArray[float32]: ...
143+
def f1_ri(self) -> NDArray[numpy.float32]: ...
153144

154-
def f2_ri(self) -> NDArray[float32]: ...
145+
def f2_ri(self) -> NDArray[numpy.float32]: ...
155146

156-
def f3_ri(self, arg: object, /) -> NDArray[float32]: ...
147+
def f3_ri(self, arg: object, /) -> NDArray[numpy.float32]: ...
157148

158149
def fill_view_1(x: NDArray) -> None: ...
159150

160-
def fill_view_2(x: Annotated[NDArray[float32], dict(shape=(None, None), device='cpu')]) -> None: ...
151+
def fill_view_2(x: Annotated[NDArray[numpy.float32], dict(shape=(None, None), device='cpu')]) -> None: ...
161152

162-
def fill_view_3(x: Annotated[NDArray[float32], dict(shape=(3, 4), order='C', device='cpu')]) -> None: ...
153+
def fill_view_3(x: Annotated[NDArray[numpy.float32], dict(shape=(3, 4), order='C', device='cpu')]) -> None: ...
163154

164-
def fill_view_4(x: Annotated[NDArray[float32], dict(shape=(3, 4), order='F', device='cpu')]) -> None: ...
155+
def fill_view_4(x: Annotated[NDArray[numpy.float32], dict(shape=(3, 4), order='F', device='cpu')]) -> None: ...
165156

166-
def fill_view_5(x: Annotated[NDArray[complex64], dict(shape=(2, 2), order='C', device='cpu')]) -> None: ...
157+
def fill_view_5(x: Annotated[NDArray[numpy.complex64], dict(shape=(2, 2), order='C', device='cpu')]) -> None: ...
167158

168-
def fill_view_6(x: Annotated[NDArray[complex64], dict(shape=(2, 2), order='C', device='cpu')]) -> None: ...
159+
def fill_view_6(x: Annotated[NDArray[numpy.complex64], dict(shape=(2, 2), order='C', device='cpu')]) -> None: ...
169160

170-
def ret_numpy_half() -> Annotated[NDArray[float16], dict(shape=(2, 4))]: ...
161+
def ret_numpy_half() -> Annotated[NDArray[numpy.float16], dict(shape=(2, 4))]: ...
171162

172163
def cast(arg: bool, /) -> NDArray: ...
173164

174165
@overload
175-
def set_item(arg0: Annotated[NDArray[float64], dict(shape=(None,), order='C')], arg1: int, /) -> None: ...
166+
def set_item(arg0: Annotated[NDArray[numpy.float64], dict(shape=(None,), order='C')], arg1: int, /) -> None: ...
176167

177168
@overload
178-
def set_item(arg0: Annotated[NDArray[complex128], dict(shape=(None,), order='C')], arg1: int, /) -> None: ...
169+
def set_item(arg0: Annotated[NDArray[numpy.complex128], dict(shape=(None,), order='C')], arg1: int, /) -> None: ...
179170

180171
def test_implicit_conversion(arg: Annotated[NDArray, dict(order='C', device='cpu', writable=False)]) -> Annotated[NDArray, dict(order='C', device='cpu', writable=False)]: ...
181172

182-
def ret_infer_c() -> Annotated[NDArray[float32], dict(shape=(2, 4), order='C')]: ...
173+
def ret_infer_c() -> Annotated[NDArray[numpy.float32], dict(shape=(2, 4), order='C')]: ...
183174

184-
def ret_infer_f() -> Annotated[NDArray[float32], dict(shape=(2, 4), order='F')]: ...
175+
def ret_infer_f() -> Annotated[NDArray[numpy.float32], dict(shape=(2, 4), order='F')]: ...
185176

186177
class Matrix4f:
187178
def __init__(self) -> None: ...
188179

189-
def data(self) -> Annotated[NDArray[float32], dict(shape=(4, 4), order='F')]: ...
180+
def data(self) -> Annotated[NDArray[numpy.float32], dict(shape=(4, 4), order='F')]: ...
190181

191-
def data_ref(self) -> Annotated[NDArray[float32], dict(shape=(4, 4), order='F')]: ...
182+
def data_ref(self) -> Annotated[NDArray[numpy.float32], dict(shape=(4, 4), order='F')]: ...
192183

193-
def data_copy(self) -> Annotated[NDArray[float32], dict(shape=(4, 4), order='F')]: ...
184+
def data_copy(self) -> Annotated[NDArray[numpy.float32], dict(shape=(4, 4), order='F')]: ...
194185

195186
def ret_from_stack_1() -> object: ...
196187

197-
def ret_from_stack_2() -> Annotated[NDArray[float32], dict(shape=(3))]: ...
188+
def ret_from_stack_2() -> Annotated[NDArray[numpy.float32], dict(shape=(3))]: ...
198189

199190
class Wrapper:
200-
def __init__(self, arg: NDArray[float32], /) -> None: ...
191+
def __init__(self, arg: NDArray[numpy.float32], /) -> None: ...
201192

202193
@property
203-
def value(self) -> NDArray[float32]: ...
194+
def value(self) -> NDArray[numpy.float32]: ...
204195

205196
@value.setter
206-
def value(self, arg: NDArray[float32], /) -> None: ...
197+
def value(self, arg: NDArray[numpy.float32], /) -> None: ...

0 commit comments

Comments
 (0)