diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 71b8b0ba59f5..850853fc23ae 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -185,6 +185,8 @@ def is_subtype( # steps we come back to initial call is_subtype(A, B) and immediately return True. with pop_on_exit(type_state.get_assumptions(is_proper=False), left, right): return _is_subtype(left, right, subtype_context, proper_subtype=False) + left = get_proper_type(left) + right = get_proper_type(right) return _is_subtype(left, right, subtype_context, proper_subtype=False) @@ -282,6 +284,21 @@ def is_same_type( ) +# This is a helper function used to check for recursive type of distributed tuple +def is_structurally_recursive(typ: Type, seen: set[Type] | None = None) -> bool: + if seen is None: + seen = set() + typ = get_proper_type(typ) + if typ in seen: + return True + seen.add(typ) + if isinstance(typ, UnionType): + return any(is_structurally_recursive(item, seen.copy()) for item in typ.items) + if isinstance(typ, TupleType): + return any(is_structurally_recursive(item, seen.copy()) for item in typ.items) + return False + + # This is a common entry point for subtyping checks (both proper and non-proper). # Never call this private function directly, use the public versions. def _is_subtype( @@ -303,7 +320,30 @@ def _is_subtype( # TODO: should we consider all types proper subtypes of UnboundType and/or # ErasedType as we do for non-proper subtyping. return True - + if isinstance(left, TupleType) and isinstance(right, UnionType): + # check only if not recursive type because if recursive type, + # test run into maximum recursive depth reached + if not is_structurally_recursive(left) and not is_structurally_recursive(right): + fallback = left.partial_fallback + tuple_items = left.items + if hasattr(left, "fallback") and left.fallback is not None: + fallback = left.fallback + for i in range(len(tuple_items)): + uitems = tuple_items[i] + uitems_type = get_proper_type(uitems) + if isinstance(uitems_type, UnionType): + new_tuples = [ + TupleType( + tuple_items[:i] + [uitem] + tuple_items[i + 1 :], fallback=fallback + ) + for uitem in uitems_type.items + ] + result = [ + _is_subtype(t, right, subtype_context, proper_subtype=False) + for t in new_tuples + ] + inverted_list = [not item for item in result] + return not any(inverted_list) if isinstance(right, UnionType) and not isinstance(left, UnionType): # Normally, when 'left' is not itself a union, the only way # 'left' can be a subtype of the union 'right' is if it is a diff --git a/test-data/unit/check-tuples.test b/test-data/unit/check-tuples.test index 3424d053fe42..1c6f78a2ce9f 100644 --- a/test-data/unit/check-tuples.test +++ b/test-data/unit/check-tuples.test @@ -1041,6 +1041,47 @@ reveal_type(x) # N: Revealed type is "Tuple[builtins.int, fallback=__main__.Tes [out] +-- Union-in-Tuple distribution +-- --------------------------------------------------------- + + +[case testTupleUnionDistributionSuccess] +from typing import Tuple, Optional, Union + +def f1(x: Tuple[float, Optional[float]]) -> Union[Tuple[float, float], Tuple[float, None]]: + return x + +def f2(x: Tuple[Union[int, str], float]) -> Union[Tuple[int, float], Tuple[str, float]]: + return x + +def f3(x: Tuple[int, Union[str, None]]) -> Union[Tuple[int, str], Tuple[int, None]]: + return x + +def f4(x: Tuple[Union[int, float]]) -> Union[Tuple[int], Tuple[float]]: + return x + +def f5(x: Tuple[Union[int, str], Union[bool, None]]) -> Union[ + Tuple[int, bool], + Tuple[int, None], + Tuple[str, bool], + Tuple[str, None], +]: + return x +[builtins fixtures/tuple.pyi] +[out] + +[case testTupleUnionDistributionFail] +from typing import Tuple, Optional, Union + +def g1(x: Tuple[float, Optional[float]]) -> Union[Tuple[float, float], Tuple[str, float]]: + return x # E: Incompatible return value type (got "Tuple[float, Optional[float]]", expected "Union[Tuple[float, float], Tuple[str, float]]") + +def g2(x: Tuple[float, Optional[float]]) -> Union[Tuple[float, str], Tuple[float, None]]: + return x # E: Incompatible return value type (got "Tuple[float, Optional[float]]", expected "Union[Tuple[float, str], Tuple[float, None]]") +[builtins fixtures/tuple.pyi] +[out] + + -- Variable-length tuples (Tuple[t, ...] with literal '...') -- ---------------------------------------------------------