Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,11 +228,11 @@ def visit_instance(self, t: Instance) -> Type:
if t.type.fullname == "builtins.tuple":
# Normalize Tuple[*Tuple[X, ...], ...] -> Tuple[X, ...]
arg = args[0]
if isinstance(arg, UnpackType):
if isinstance(arg, UnpackType) and not (
isinstance(arg.type, TypeAliasType) and arg.type.is_recursive
):
unpacked = get_proper_type(arg.type)
if isinstance(unpacked, Instance):
# TODO: this and similar asserts below may be unsafe because get_proper_type()
# may be called during semantic analysis before all invalid types are removed.
assert unpacked.type.fullname == "builtins.tuple"
args = list(unpacked.args)
return t.copy_modified(args=args)
Expand Down Expand Up @@ -535,7 +535,9 @@ def visit_tuple_type(self, t: TupleType) -> Type:
if len(items) == 1:
# Normalize Tuple[*Tuple[X, ...]] -> Tuple[X, ...]
item = items[0]
if isinstance(item, UnpackType):
if isinstance(item, UnpackType) and not (
isinstance(item.type, TypeAliasType) and item.type.is_recursive
):
unpacked = get_proper_type(item.type)
if isinstance(unpacked, Instance):
# expand_type() may be called during semantic analysis, before invalid unpacks are fixed.
Expand Down
8 changes: 4 additions & 4 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4263,6 +4263,8 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool:
# An alias gets updated.
updated = False
if isinstance(existing.node, TypeAlias):
# Invalidate recursive status cache in case it was previously set.
existing.node._is_recursive = None
if existing.node.target != res:
# Copy expansion to the existing alias, this matches how we update base classes
# for a TypeInfo _in place_ if there are nested placeholders.
Expand All @@ -4271,8 +4273,6 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool:
existing.node.alias_tvars = alias_tvars
existing.node.no_args = no_args
updated = True
# Invalidate recursive status cache in case it was previously set.
existing.node._is_recursive = None
else:
# Otherwise just replace existing placeholder with type alias *in place*.
existing._node = alias_node
Expand Down Expand Up @@ -5830,6 +5830,8 @@ def visit_type_alias_stmt(self, s: TypeAliasStmt) -> None:
):
updated = False
if isinstance(existing.node, TypeAlias):
# Invalidate recursive status cache in case it was previously set.
existing.node._is_recursive = None
if (
existing.node.target != res
or existing.node.alias_tvars != alias_node.alias_tvars
Expand All @@ -5840,8 +5842,6 @@ def visit_type_alias_stmt(self, s: TypeAliasStmt) -> None:
existing.node.default_depends = default_depends
existing.node.alias_tvars = alias_tvars
updated = True
# Invalidate recursive status cache in case it was previously set.
existing.node._is_recursive = None
else:
# Otherwise just replace existing placeholder with type alias *in place*.
existing._node = alias_node
Expand Down
5 changes: 4 additions & 1 deletion mypy/semanal_typeargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,10 @@ def visit_type_alias_type(self, t: TypeAliasType) -> None:
self.seen_aliases.discard(t)

def visit_tuple_type(self, t: TupleType) -> None:
t.items = flatten_nested_tuples(t.items)
# Unfortunately, universal normalization of tuples is not possible in presence of
# recursive aliases, see testNoCrashOnNonNormalRecursiveTuple for an example.
# TODO: update the places where we handle tuples to always expect non-normal ones.
t.items = flatten_nested_tuples(t.items, handle_recursive=False)
for i, it in enumerate(t.items):
if self.check_non_paramspec(it, "tuple", t):
t.items[i] = AnyType(TypeOfAny.from_error)
Expand Down
2 changes: 2 additions & 0 deletions mypy/server/astmerge.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,8 @@ def visit_var(self, node: Var) -> None:
super().visit_var(node)

def visit_type_alias(self, node: TypeAlias) -> None:
# Updating alias target can invalidate its recursive status.
node._is_recursive = None
self.fixup_type(node.target)
for v in node.alias_tvars:
self.fixup_type(v)
Expand Down
24 changes: 18 additions & 6 deletions mypy/typeanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
BoolTypeQuery,
CallableArgument,
CallableType,
CollectAliasesVisitor,
DeletedType,
EllipsisType,
ErasedType,
Expand Down Expand Up @@ -275,7 +276,9 @@ def __init__(
self.prohibit_special_class_field_types = prohibit_special_class_field_types
# Allow variables typed as Type[Any] and type (useful for base classes).
self.allow_type_any = allow_type_any
self.allow_type_var_tuple = False
# Level of nesting at which a TypeVarTuple is allowed. Note we specify exact level
# to prohibit things like Unpack[list[Ts]], which are not supported.
self.allow_type_var_tuple = -1
self.allow_unpack = allow_unpack
# Set when we are analyzing a default of a type variable.
self.analyzing_tvar_def = analyzing_tvar_def
Expand Down Expand Up @@ -453,7 +456,7 @@ def visit_unbound_type_nonoptional(self, t: UnboundType, defining_literal: bool)
self.fail(msg, t, code=codes.VALID_TYPE)
return AnyType(TypeOfAny.from_error)
assert isinstance(tvar_def, TypeVarTupleType)
if not self.allow_type_var_tuple:
if self.allow_type_var_tuple != self.nesting_level:
self.fail(
f'TypeVarTuple "{t.name}" is only valid with an unpack',
t,
Expand Down Expand Up @@ -808,9 +811,9 @@ def try_analyze_special_unbound_type(self, t: UnboundType, fullname: str) -> Typ
if not self.allow_unpack:
self.fail(message_registry.INVALID_UNPACK_POSITION, t, code=codes.VALID_TYPE)
return AnyType(TypeOfAny.from_error)
self.allow_type_var_tuple = True
self.allow_type_var_tuple = self.nesting_level + 1
result = UnpackType(self.anal_type(t.args[0]), line=t.line, column=t.column)
self.allow_type_var_tuple = False
self.allow_type_var_tuple = -1
return result
elif fullname in SELF_TYPE_NAMES:
if t.args:
Expand Down Expand Up @@ -1161,9 +1164,9 @@ def visit_unpack_type(self, t: UnpackType) -> Type:
if not self.allow_unpack:
self.fail(message_registry.INVALID_UNPACK_POSITION, t.type, code=codes.VALID_TYPE)
return AnyType(TypeOfAny.from_error)
self.allow_type_var_tuple = True
self.allow_type_var_tuple = self.nesting_level + 1
result = UnpackType(self.anal_type(t.type), from_star_syntax=t.from_star_syntax)
self.allow_type_var_tuple = False
self.allow_type_var_tuple = -1
return result

def visit_parameters(self, t: Parameters) -> Type:
Expand Down Expand Up @@ -2518,6 +2521,15 @@ def detect_diverging_alias(node: TypeAlias, target: Type) -> bool:
They may be handy in rare cases, e.g. to express a union of non-mixed nested lists:
Nested = Union[T, Nested[List[T]]] ~> Union[T, List[T], List[List[T]], ...]
"""
is_recursive = node._is_recursive
if is_recursive is None:
is_recursive = node in node.target.accept(CollectAliasesVisitor())
if not is_recursive:
# Fast path: this is not a recursive alias at all.
return False
# Note we only cache positive case, caching negative case is risky, as this type alias
# (or more importantly any other alias it uses) may be not ready yet.
node._is_recursive = True
visitor = DivergingAliasDetector({node})
_ = target.accept(visitor)
return visitor.diverging
Expand Down
11 changes: 8 additions & 3 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4251,12 +4251,12 @@ def find_unpack_in_list(items: Sequence[Type]) -> int | None:
# Funky code here avoids mypyc narrowing the type of unpack_index.
old_index = unpack_index
assert old_index is None
# Don't return so that we can also sanity check there is only one.
# Don't return so that we can also sanity-check there is only one.
unpack_index = i
return unpack_index


def flatten_nested_tuples(types: Iterable[Type]) -> list[Type]:
def flatten_nested_tuples(types: Iterable[Type], handle_recursive: bool = True) -> list[Type]:
"""Recursively flatten TupleTypes nested with Unpack.

For example this will transform
Expand All @@ -4270,7 +4270,12 @@ def flatten_nested_tuples(types: Iterable[Type]) -> list[Type]:
res.append(typ)
continue
p_type = get_proper_type(typ.type)
if not isinstance(p_type, TupleType):
if (
not isinstance(p_type, TupleType)
or not handle_recursive
and isinstance(typ.type, TypeAliasType)
and typ.type.is_recursive
):
res.append(typ)
continue
if isinstance(typ.type, TypeAliasType):
Expand Down
32 changes: 32 additions & 0 deletions test-data/unit/check-typevar-tuple.test
Original file line number Diff line number Diff line change
Expand Up @@ -882,7 +882,39 @@ z: C
reveal_type(x) # N: Revealed type is "Any"
reveal_type(y) # N: Revealed type is "Any"
reveal_type(z) # N: Revealed type is "tuple[builtins.int, Unpack[builtins.tuple[Any, ...]]]"
[builtins fixtures/tuple.pyi]

[case testBanPathologicalRecursiveTuplesGeneric]
from typing import TypeVarTuple, Unpack

Ts = TypeVarTuple("Ts")
A = tuple[Unpack[B[Unpack[Ts]]]] # E: Invalid recursive alias: a tuple item of itself \
# E: Name "B" is used before definition
B = tuple[Unpack[A[Unpack[Ts]]]]
[builtins fixtures/tuple.pyi]

[case testNoCrashOnInvalidRecursiveUnpackOfUnion]
from typing import Unpack

A = tuple[int, str] | list[tuple[Unpack[A]]] # E: "tuple[int, str] | list[tuple[Unpack[A]]]" cannot be unpacked (must be tuple or TypeVarTuple)
[builtins fixtures/tuple.pyi]

[case testNoCrashOnNonNormalRecursiveTuple]
from typing import Unpack

A = tuple[int, list[tuple[str, Unpack[A]]]]
a: A
x, y = a
y[0] = 1 # E: Incompatible types in assignment (expression has type "int", target has type "tuple[str, Unpack[A]]")
[builtins fixtures/list.pyi]

[case testBanTypeVarTupleNotImmediatelyInsideUnpack]
from typing import TypeVarTuple, Unpack

Ts = TypeVarTuple("Ts")
A = tuple[Unpack[tuple[Ts]]] # E: TypeVarTuple "Ts" is only valid with an unpack
x: A[int, str]
reveal_type(x) # N: Revealed type is "tuple[Any]"
[builtins fixtures/tuple.pyi]

[case testInferenceAgainstGenericVariadicWithBadType]
Expand Down
Loading