Skip to content

Commit bb3e497

Browse files
authored
[mypyc] Fix crash on accessing StopAsyncIteration (#21406)
Using `StopAsyncIteration` in `raise` expressions or as the caught type in `except` blocks crashes at runtime because the generated code loads the address of the `PyExc_StopAsyncIteration` variable and casts it to `PyObject *`. The variable itself [is already](https://github.com/python/cpython/blob/main/Include/pyerrors.h#L80) a `PyObject *` so its address being interpreted as `PyObject *` is incorrect. To fix, add a special case to use `LoadGlobal` instead of `LoadAddress` for builtin type variables that are already `PyObject *`.
1 parent 472d034 commit bb3e497

10 files changed

Lines changed: 109 additions & 21 deletions

File tree

mypyc/irbuild/builder.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1603,6 +1603,9 @@ def add_coroutine_setup_call(self, class_name: str, obj: Value) -> Value:
16031603
)
16041604
)
16051605

1606+
def load_builtin(self, name: str, line: int) -> Value | None:
1607+
return self.builder.load_builtin(name, line)
1608+
16061609

16071610
def gen_arg_defaults(builder: IRBuilder) -> None:
16081611
"""Generate blocks for arguments that have default values.

mypyc/irbuild/expression.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,6 @@
136136
from mypyc.primitives.generic_ops import iter_op, name_op
137137
from mypyc.primitives.list_ops import list_append_op, list_extend_op, list_slice_op
138138
from mypyc.primitives.misc_ops import ellipsis_op, get_module_dict_op, new_slice_op, type_op
139-
from mypyc.primitives.registry import builtin_names
140139
from mypyc.primitives.set_ops import set_add_op, set_in_op, set_update_op
141140
from mypyc.primitives.str_ops import str_slice_op
142141
from mypyc.primitives.tuple_ops import list_tuple_op, tuple_slice_op
@@ -157,9 +156,8 @@ def transform_name_expr(builder: IRBuilder, expr: NameExpr) -> Value:
157156
)
158157
return builder.none(expr.line)
159158
fullname = expr.node.fullname
160-
if fullname in builtin_names:
161-
typ, src = builtin_names[fullname]
162-
return builder.add(LoadAddress(typ, src, expr.line))
159+
if builtin := builder.load_builtin(fullname, expr.line):
160+
return builtin
163161
# special cases
164162
if fullname == "builtins.None":
165163
return builder.none(expr.line)

mypyc/irbuild/for_helpers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@
3636
IntOp,
3737
LoadAddress,
3838
LoadErrorValue,
39+
LoadGlobal,
3940
LoadLiteral,
40-
LoadMem,
4141
MethodCall,
4242
RaiseStandardError,
4343
Register,
@@ -63,7 +63,6 @@
6363
is_tuple_rprimitive,
6464
object_pointer_rprimitive,
6565
object_rprimitive,
66-
pointer_rprimitive,
6766
short_int_rprimitive,
6867
)
6968
from mypyc.irbuild.builder import IRBuilder
@@ -828,8 +827,9 @@ def gen_condition(self) -> None:
828827
line = self.line
829828

830829
def except_match() -> Value:
831-
addr = builder.add(LoadAddress(pointer_rprimitive, stop_async_iteration_op.src, line))
832-
return builder.add(LoadMem(stop_async_iteration_op.type, addr, borrow=True))
830+
return builder.add(
831+
LoadGlobal(stop_async_iteration_op.type, stop_async_iteration_op.src, line)
832+
)
833833

834834
def try_body() -> None:
835835
awaitable = builder.call_c(anext_op, [builder.read(self.iter_target, line)], line)

mypyc/irbuild/function.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
ComparisonOp,
4646
GetAttr,
4747
Integer,
48-
LoadAddress,
4948
LoadLiteral,
5049
Register,
5150
Return,
@@ -85,7 +84,6 @@
8584
)
8685
from mypyc.primitives.generic_ops import generic_getattr, generic_setattr, py_setattr_op
8786
from mypyc.primitives.misc_ops import register_function
88-
from mypyc.primitives.registry import builtin_names
8987
from mypyc.sametype import is_same_method_signature, is_same_type
9088

9189
# Top-level transform functions
@@ -935,9 +933,8 @@ def load_type(builder: IRBuilder, typ: TypeInfo, unbounded_type: Type | None, li
935933
if typ in builder.mapper.type_to_ir:
936934
class_ir = builder.mapper.type_to_ir[typ]
937935
class_obj = builder.builder.get_native_type(class_ir)
938-
elif typ.fullname in builtin_names:
939-
builtin_addr_type, src = builtin_names[typ.fullname]
940-
class_obj = builder.add(LoadAddress(builtin_addr_type, src, line))
936+
elif builtin := builder.load_builtin(typ.fullname, line):
937+
class_obj = builtin
941938
elif isinstance(unbounded_type, UnboundType):
942939
path_parts = unbounded_type.name.split(".")
943940
class_obj = builder.load_global_str(path_parts[0], line)
@@ -1013,8 +1010,8 @@ def gen_native_func_call_and_return(fdef: FuncDef) -> None:
10131010
coerced = builder.coerce(ret_val, current_func_decl.sig.ret_type, line)
10141011
builder.add(Return(coerced))
10151012

1016-
typ, src = builtin_names["builtins.int"]
1017-
int_type_obj = builder.add(LoadAddress(typ, src, line))
1013+
int_type_obj = builder.load_builtin("builtins.int", line)
1014+
assert int_type_obj
10181015
is_int = builder.builder.type_is_op(impl_to_use, int_type_obj, line)
10191016

10201017
native_call, non_native_call = BasicBlock(), BasicBlock()

mypyc/irbuild/ll_builder.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,9 @@
185185
ERR_NEG_INT,
186186
CFunctionDescription,
187187
binary_ops,
188+
builtin_names,
188189
function_ops,
190+
global_names,
189191
method_call_ops,
190192
unary_ops,
191193
)
@@ -323,8 +325,18 @@ def set_mem(self, ptr: Value, value_type: RType, value: Value) -> None:
323325
def get_element(self, reg: Value, field: str) -> Value:
324326
return self.add(GetElement(reg, field))
325327

326-
def load_address(self, name: str, rtype: RType) -> Value:
327-
return self.add(LoadAddress(rtype, name))
328+
def load_address(self, name: str, rtype: RType, line: int = -1) -> Value:
329+
return self.add(LoadAddress(rtype, name, line))
330+
331+
def load_global(self, name: str, rtype: RType, line: int) -> Value:
332+
return self.add(LoadGlobal(rtype, name, line))
333+
334+
def load_builtin(self, name: str, line: int) -> Value | None:
335+
if builtin := builtin_names.get(name):
336+
return self.load_address(builtin[1], builtin[0], line)
337+
if glob := global_names.get(name):
338+
return self.load_global(glob[1], glob[0], line)
339+
return None
328340

329341
def load_struct_field(
330342
self, ptr: Value, struct: RStruct, field: str, *, borrow: bool = False

mypyc/irbuild/vec.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@
5454
vec_api_by_item_type,
5555
vec_item_type_tags,
5656
)
57-
from mypyc.primitives.registry import builtin_names
5857

5958
if TYPE_CHECKING:
6059
from mypyc.irbuild.ll_builder import LowLevelIRBuilder
@@ -213,8 +212,7 @@ def vec_item_type_info(
213212
builder: LowLevelIRBuilder, typ: RType, line: int
214213
) -> tuple[Value | None, bool, int]:
215214
if isinstance(typ, RPrimitive) and typ.is_refcounted:
216-
typ, src = builtin_names[typ.name]
217-
return builder.load_address(src, typ), False, 0
215+
return builder.load_builtin(typ.name, line), False, 0
218216
elif isinstance(typ, RInstance):
219217
return builder.load_native_type_object(typ.name), False, 0
220218
elif typ in vec_item_type_tags:

mypyc/primitives/misc_ops.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
custom_primitive_op,
3232
function_op,
3333
load_address_op,
34+
load_global_op,
3435
method_op,
3536
)
3637

@@ -52,7 +53,7 @@
5253
)
5354

5455
# Get the boxed StopAsyncIteration object
55-
stop_async_iteration_op = load_address_op(
56+
stop_async_iteration_op = load_global_op(
5657
name="builtins.StopAsyncIteration", type=object_rprimitive, src="PyExc_StopAsyncIteration"
5758
)
5859

mypyc/primitives/registry.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,12 @@ class LoadAddressDescription(NamedTuple):
8585
# Primitive ops for unary ops
8686
unary_ops: dict[str, list[PrimitiveDescription]] = {}
8787

88+
# Mapping of type name to (type, C value variable name).
8889
builtin_names: dict[str, tuple[RType, str]] = {}
8990

91+
# Mapping of type name to (type, C pointer variable name).
92+
global_names: dict[str, tuple[RType, str]] = {}
93+
9094

9195
def method_op(
9296
name: str,
@@ -387,6 +391,12 @@ def load_address_op(name: str, type: RType, src: str) -> LoadAddressDescription:
387391
return LoadAddressDescription(name, type, src)
388392

389393

394+
def load_global_op(name: str, type: RType, src: str) -> LoadAddressDescription:
395+
assert name not in global_names, "already defined: %s" % name
396+
global_names[name] = (type, src)
397+
return LoadAddressDescription(name, type, src)
398+
399+
390400
# Import various modules that set up global state.
391401
import mypyc.primitives.bytearray_ops
392402
import mypyc.primitives.bytes_ops

mypyc/test-data/fixtures/ir.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,9 @@ class ReferenceError(Exception): pass
374374
class StopIteration(Exception):
375375
value: Any
376376

377+
class StopAsyncIteration(Exception):
378+
value: Any
379+
377380
class ArithmeticError(Exception): pass
378381
class ZeroDivisionError(ArithmeticError): pass
379382
class OverflowError(ArithmeticError): pass

mypyc/test-data/run-async.test

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1874,3 +1874,69 @@ def test_nested_coroutine_calls_another_nested_function():
18741874
from typing import Any, Generator
18751875

18761876
def run(x: object) -> object: ...
1877+
1878+
[case testRaiseStopAsyncIteration]
1879+
from async_iter import async_iter
1880+
from testutil import assertRaises
1881+
1882+
class AsyncIter:
1883+
def __init__(self, vals: list[str]) -> None:
1884+
self._iter = iter(vals)
1885+
1886+
def __aiter__(self) -> AsyncIter:
1887+
return self
1888+
1889+
async def __anext__(self) -> str:
1890+
try:
1891+
return next(self._iter)
1892+
except StopIteration:
1893+
raise StopAsyncIteration
1894+
1895+
async def test_iterator() -> None:
1896+
new_list: list[int] = []
1897+
async for v in async_iter([1, 2, 3]):
1898+
new_list.append(v)
1899+
assert new_list == [1, 2, 3]
1900+
1901+
new_list = []
1902+
iter = async_iter([1, 2, 3])
1903+
while True:
1904+
try:
1905+
v = await iter.__anext__()
1906+
new_list.append(v)
1907+
except StopAsyncIteration:
1908+
new_list.append(4)
1909+
break
1910+
assert new_list == [1, 2, 3, 4]
1911+
1912+
with assertRaises(StopAsyncIteration):
1913+
await async_iter([]).__anext__()
1914+
1915+
async def test_wrapper() -> None:
1916+
new_list: list[str] = []
1917+
async for v in AsyncIter(['a', 'b', 'c']):
1918+
new_list.append(v)
1919+
assert new_list == ['a', 'b', 'c']
1920+
1921+
new_list = []
1922+
iter = AsyncIter(['a', 'b', 'c'])
1923+
while True:
1924+
try:
1925+
v = await iter.__anext__()
1926+
new_list.append(v)
1927+
except StopAsyncIteration:
1928+
new_list.append('d')
1929+
break
1930+
assert new_list == ['a', 'b', 'c', 'd']
1931+
1932+
with assertRaises(StopAsyncIteration):
1933+
await AsyncIter([]).__anext__()
1934+
1935+
[file async_iter.py]
1936+
from typing import AsyncIterator
1937+
1938+
async def async_iter(vals: list[int]) -> AsyncIterator[int]:
1939+
for v in vals:
1940+
yield v
1941+
1942+
[typing fixtures/typing-full.pyi]

0 commit comments

Comments
 (0)