Skip to content

Commit 363be14

Browse files
authored
Fix crash on invalid recursive variadic alias (#21572)
Fixes #21125 This one was tricky. This is because the above issue actually exposed _two_ different crash scenarios: * A crash on invalid constructs like `*tuple[Ts]` (must be `*tuple[*Ts]`). * An infinite recursion when trying to detect pathological and divergent aliases. And while working on this I discovered two more cases: * A crash where an invalid type (like a union) appears in unpack in a recursive alias definition. * A crash on non-normalizeable recursive tuple. I fix the first by tightening logic in `typeanal.py` w.r.t. where exactly a `TypeVarTuple` is allowed. I fix the second and third by avoiding `get_proper_type()` calls in `expand_type()` for recursive tuples. The fourth is the most problematic, and is kind of a fundamental thing. This PR only avoids an immediate crash for such aliases. We will still need to update various call sites where we special-case tuples to expect non-normal ones. Couple more related things: * I fix couple issues with `is_recursive` cache invalidation. * I added a fast path to `detect_diverging_alias()` to avoid creating sets unless really needed.
1 parent cbd8c82 commit 363be14

7 files changed

Lines changed: 74 additions & 18 deletions

File tree

mypy/expandtype.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -228,11 +228,11 @@ def visit_instance(self, t: Instance) -> Type:
228228
if t.type.fullname == "builtins.tuple":
229229
# Normalize Tuple[*Tuple[X, ...], ...] -> Tuple[X, ...]
230230
arg = args[0]
231-
if isinstance(arg, UnpackType):
231+
if isinstance(arg, UnpackType) and not (
232+
isinstance(arg.type, TypeAliasType) and arg.type.is_recursive
233+
):
232234
unpacked = get_proper_type(arg.type)
233235
if isinstance(unpacked, Instance):
234-
# TODO: this and similar asserts below may be unsafe because get_proper_type()
235-
# may be called during semantic analysis before all invalid types are removed.
236236
assert unpacked.type.fullname == "builtins.tuple"
237237
args = list(unpacked.args)
238238
return t.copy_modified(args=args)
@@ -536,7 +536,9 @@ def visit_tuple_type(self, t: TupleType) -> Type:
536536
if len(items) == 1:
537537
# Normalize Tuple[*Tuple[X, ...]] -> Tuple[X, ...]
538538
item = items[0]
539-
if isinstance(item, UnpackType):
539+
if isinstance(item, UnpackType) and not (
540+
isinstance(item.type, TypeAliasType) and item.type.is_recursive
541+
):
540542
unpacked = get_proper_type(item.type)
541543
if isinstance(unpacked, Instance):
542544
# expand_type() may be called during semantic analysis, before invalid unpacks are fixed.

mypy/semanal.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4263,6 +4263,8 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool:
42634263
# An alias gets updated.
42644264
updated = False
42654265
if isinstance(existing.node, TypeAlias):
4266+
# Invalidate recursive status cache in case it was previously set.
4267+
existing.node._is_recursive = None
42664268
if existing.node.target != res:
42674269
# Copy expansion to the existing alias, this matches how we update base classes
42684270
# for a TypeInfo _in place_ if there are nested placeholders.
@@ -4271,8 +4273,6 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool:
42714273
existing.node.alias_tvars = alias_tvars
42724274
existing.node.no_args = no_args
42734275
updated = True
4274-
# Invalidate recursive status cache in case it was previously set.
4275-
existing.node._is_recursive = None
42764276
else:
42774277
# Otherwise just replace existing placeholder with type alias *in place*.
42784278
existing._node = alias_node
@@ -5830,6 +5830,8 @@ def visit_type_alias_stmt(self, s: TypeAliasStmt) -> None:
58305830
):
58315831
updated = False
58325832
if isinstance(existing.node, TypeAlias):
5833+
# Invalidate recursive status cache in case it was previously set.
5834+
existing.node._is_recursive = None
58335835
if (
58345836
existing.node.target != res
58355837
or existing.node.alias_tvars != alias_node.alias_tvars
@@ -5840,8 +5842,6 @@ def visit_type_alias_stmt(self, s: TypeAliasStmt) -> None:
58405842
existing.node.default_depends = default_depends
58415843
existing.node.alias_tvars = alias_tvars
58425844
updated = True
5843-
# Invalidate recursive status cache in case it was previously set.
5844-
existing.node._is_recursive = None
58455845
else:
58465846
# Otherwise just replace existing placeholder with type alias *in place*.
58475847
existing._node = alias_node

mypy/semanal_typeargs.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,10 @@ def visit_type_alias_type(self, t: TypeAliasType) -> None:
108108
self.seen_aliases.discard(t)
109109

110110
def visit_tuple_type(self, t: TupleType) -> None:
111-
t.items = flatten_nested_tuples(t.items)
111+
# Unfortunately, universal normalization of tuples is not possible in presence of
112+
# recursive aliases, see testNoCrashOnNonNormalRecursiveTuple for an example.
113+
# TODO: update the places where we handle tuples to always expect non-normal ones.
114+
t.items = flatten_nested_tuples(t.items, handle_recursive=False)
112115
for i, it in enumerate(t.items):
113116
if self.check_non_paramspec(it, "tuple", t):
114117
t.items[i] = AnyType(TypeOfAny.from_error)

mypy/server/astmerge.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,8 @@ def visit_var(self, node: Var) -> None:
340340
super().visit_var(node)
341341

342342
def visit_type_alias(self, node: TypeAlias) -> None:
343+
# Updating alias target can invalidate its recursive status.
344+
node._is_recursive = None
343345
self.fixup_type(node.target)
344346
for v in node.alias_tvars:
345347
self.fixup_type(v)

mypy/typeanal.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
BoolTypeQuery,
7676
CallableArgument,
7777
CallableType,
78+
CollectAliasesVisitor,
7879
DeletedType,
7980
EllipsisType,
8081
ErasedType,
@@ -275,7 +276,9 @@ def __init__(
275276
self.prohibit_special_class_field_types = prohibit_special_class_field_types
276277
# Allow variables typed as Type[Any] and type (useful for base classes).
277278
self.allow_type_any = allow_type_any
278-
self.allow_type_var_tuple = False
279+
# Level of nesting at which a TypeVarTuple is allowed. Note we specify exact level
280+
# to prohibit things like Unpack[list[Ts]], which are not supported.
281+
self.allow_type_var_tuple = -1
279282
self.allow_unpack = allow_unpack
280283
# Set when we are analyzing a default of a type variable.
281284
self.analyzing_tvar_def = analyzing_tvar_def
@@ -453,7 +456,7 @@ def visit_unbound_type_nonoptional(self, t: UnboundType, defining_literal: bool)
453456
self.fail(msg, t, code=codes.VALID_TYPE)
454457
return AnyType(TypeOfAny.from_error)
455458
assert isinstance(tvar_def, TypeVarTupleType)
456-
if not self.allow_type_var_tuple:
459+
if self.allow_type_var_tuple != self.nesting_level:
457460
self.fail(
458461
f'TypeVarTuple "{t.name}" is only valid with an unpack',
459462
t,
@@ -808,9 +811,9 @@ def try_analyze_special_unbound_type(self, t: UnboundType, fullname: str) -> Typ
808811
if not self.allow_unpack:
809812
self.fail(message_registry.INVALID_UNPACK_POSITION, t, code=codes.VALID_TYPE)
810813
return AnyType(TypeOfAny.from_error)
811-
self.allow_type_var_tuple = True
814+
self.allow_type_var_tuple = self.nesting_level + 1
812815
result = UnpackType(self.anal_type(t.args[0]), line=t.line, column=t.column)
813-
self.allow_type_var_tuple = False
816+
self.allow_type_var_tuple = -1
814817
return result
815818
elif fullname in SELF_TYPE_NAMES:
816819
if t.args:
@@ -1161,9 +1164,9 @@ def visit_unpack_type(self, t: UnpackType) -> Type:
11611164
if not self.allow_unpack:
11621165
self.fail(message_registry.INVALID_UNPACK_POSITION, t.type, code=codes.VALID_TYPE)
11631166
return AnyType(TypeOfAny.from_error)
1164-
self.allow_type_var_tuple = True
1167+
self.allow_type_var_tuple = self.nesting_level + 1
11651168
result = UnpackType(self.anal_type(t.type), from_star_syntax=t.from_star_syntax)
1166-
self.allow_type_var_tuple = False
1169+
self.allow_type_var_tuple = -1
11671170
return result
11681171

11691172
def visit_parameters(self, t: Parameters) -> Type:
@@ -2523,6 +2526,15 @@ def detect_diverging_alias(node: TypeAlias, target: Type) -> bool:
25232526
They may be handy in rare cases, e.g. to express a union of non-mixed nested lists:
25242527
Nested = Union[T, Nested[List[T]]] ~> Union[T, List[T], List[List[T]], ...]
25252528
"""
2529+
is_recursive = node._is_recursive
2530+
if is_recursive is None:
2531+
is_recursive = node in node.target.accept(CollectAliasesVisitor())
2532+
if not is_recursive:
2533+
# Fast path: this is not a recursive alias at all.
2534+
return False
2535+
# Note we only cache positive case, caching negative case is risky, as this type alias
2536+
# (or more importantly any other alias it uses) may be not ready yet.
2537+
node._is_recursive = True
25262538
visitor = DivergingAliasDetector({node})
25272539
_ = target.accept(visitor)
25282540
return visitor.diverging

mypy/types.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4301,12 +4301,12 @@ def find_unpack_in_list(items: Sequence[Type]) -> int | None:
43014301
# Funky code here avoids mypyc narrowing the type of unpack_index.
43024302
old_index = unpack_index
43034303
assert old_index is None
4304-
# Don't return so that we can also sanity check there is only one.
4304+
# Don't return so that we can also sanity-check there is only one.
43054305
unpack_index = i
43064306
return unpack_index
43074307

43084308

4309-
def flatten_nested_tuples(types: Iterable[Type]) -> list[Type]:
4309+
def flatten_nested_tuples(types: Iterable[Type], handle_recursive: bool = True) -> list[Type]:
43104310
"""Recursively flatten TupleTypes nested with Unpack.
43114311
43124312
For example this will transform
@@ -4320,7 +4320,12 @@ def flatten_nested_tuples(types: Iterable[Type]) -> list[Type]:
43204320
res.append(typ)
43214321
continue
43224322
p_type = get_proper_type(typ.type)
4323-
if not isinstance(p_type, TupleType):
4323+
if (
4324+
not isinstance(p_type, TupleType)
4325+
or not handle_recursive
4326+
and isinstance(typ.type, TypeAliasType)
4327+
and typ.type.is_recursive
4328+
):
43244329
res.append(typ)
43254330
continue
43264331
if isinstance(typ.type, TypeAliasType):

test-data/unit/check-typevar-tuple.test

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -882,7 +882,39 @@ z: C
882882
reveal_type(x) # N: Revealed type is "Any"
883883
reveal_type(y) # N: Revealed type is "Any"
884884
reveal_type(z) # N: Revealed type is "tuple[builtins.int, Unpack[builtins.tuple[Any, ...]]]"
885+
[builtins fixtures/tuple.pyi]
886+
887+
[case testBanPathologicalRecursiveTuplesGeneric]
888+
from typing import TypeVarTuple, Unpack
889+
890+
Ts = TypeVarTuple("Ts")
891+
A = tuple[Unpack[B[Unpack[Ts]]]] # E: Invalid recursive alias: a tuple item of itself \
892+
# E: Name "B" is used before definition
893+
B = tuple[Unpack[A[Unpack[Ts]]]]
894+
[builtins fixtures/tuple.pyi]
895+
896+
[case testNoCrashOnInvalidRecursiveUnpackOfUnion]
897+
from typing import Unpack
898+
899+
A = tuple[int, str] | list[tuple[Unpack[A]]] # E: "tuple[int, str] | list[tuple[Unpack[A]]]" cannot be unpacked (must be tuple or TypeVarTuple)
900+
[builtins fixtures/tuple.pyi]
901+
902+
[case testNoCrashOnNonNormalRecursiveTuple]
903+
from typing import Unpack
904+
905+
A = tuple[int, list[tuple[str, Unpack[A]]]]
906+
a: A
907+
x, y = a
908+
y[0] = 1 # E: Incompatible types in assignment (expression has type "int", target has type "tuple[str, Unpack[A]]")
909+
[builtins fixtures/list.pyi]
885910

911+
[case testBanTypeVarTupleNotImmediatelyInsideUnpack]
912+
from typing import TypeVarTuple, Unpack
913+
914+
Ts = TypeVarTuple("Ts")
915+
A = tuple[Unpack[tuple[Ts]]] # E: TypeVarTuple "Ts" is only valid with an unpack
916+
x: A[int, str]
917+
reveal_type(x) # N: Revealed type is "tuple[Any]"
886918
[builtins fixtures/tuple.pyi]
887919

888920
[case testInferenceAgainstGenericVariadicWithBadType]

0 commit comments

Comments
 (0)