Skip to content

Commit 0e09a61

Browse files
better support for literal? types
1 parent 72cff30 commit 0e09a61

File tree

12 files changed

+401
-38
lines changed

12 files changed

+401
-38
lines changed

mypy/join.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,11 +141,18 @@ def join_instances(self, t: Instance, s: Instance) -> ProperType:
141141
new_type = join_types(ta, sa, self)
142142
assert new_type is not None
143143
args.append(new_type)
144-
result: ProperType = Instance(t.type, args)
144+
lkv = t.last_known_value if t.last_known_value == s.last_known_value else None
145+
result: ProperType = Instance(t.type, args, last_known_value=lkv)
145146
elif t.type.bases and is_proper_subtype(
146147
t, s, subtype_context=SubtypeContext(ignore_type_params=True)
147148
):
148149
result = self.join_instances_via_supertype(t, s)
150+
elif s.type.bases and is_proper_subtype(
151+
s, t, subtype_context=SubtypeContext(ignore_type_params=True)
152+
):
153+
result = self.join_instances_via_supertype(s, t)
154+
elif is_subtype(t, s, subtype_context=SubtypeContext(ignore_type_params=True)):
155+
result = self.join_instances_via_supertype(t, s)
149156
else:
150157
# Now t is not a subtype of s, and t != s. Now s could be a subtype
151158
# of t; alternatively, we need to find a common supertype. This works
@@ -626,13 +633,17 @@ def visit_typeddict_type(self, t: TypedDictType) -> ProperType:
626633
def visit_literal_type(self, t: LiteralType) -> ProperType:
627634
if isinstance(self.s, LiteralType):
628635
if t == self.s:
636+
# E.g. Literal["x"], Literal["x"] -> Literal["x"]
629637
return t
630638
if self.s.fallback.type.is_enum and t.fallback.type.is_enum:
631639
return mypy.typeops.make_simplified_union([self.s, t])
640+
# E.g. Literal["x"], Literal["y"] -> str
632641
return join_types(self.s.fallback, t.fallback)
633642
elif isinstance(self.s, Instance) and self.s.last_known_value == t:
634-
return t
643+
# E.g. Literal["x"], Literal["x"]? -> Literal["x"]?
644+
return self.s
635645
else:
646+
# E.g. Literal["x"], Literal["y"]? -> str
636647
return join_types(self.s, t.fallback)
637648

638649
def visit_partial_type(self, t: PartialType) -> ProperType:

mypy/meet.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,11 @@ def meet_types(s: Type, t: Type) -> ProperType:
8181
t = get_proper_type(t)
8282

8383
if isinstance(s, Instance) and isinstance(t, Instance) and s.type == t.type:
84+
# special casing for dealing with last known values
85+
lkv = meet_last_known_values(t.last_known_value, s.last_known_value)
86+
t = t.copy_modified(last_known_value=lkv)
87+
s = s.copy_modified(last_known_value=lkv)
88+
8489
# Code in checker.py should merge any extra_items where possible, so we
8590
# should have only compatible extra_items here. We check this before
8691
# the below subtype check, so that extra_attrs will not get erased.
@@ -113,6 +118,30 @@ def meet_types(s: Type, t: Type) -> ProperType:
113118
return t.accept(TypeMeetVisitor(s))
114119

115120

121+
def meet_last_known_values(
122+
left: LiteralType | None, right: LiteralType | None
123+
) -> LiteralType | None:
124+
"""Return the meet of two last_known_values."""
125+
if left is None:
126+
return right
127+
if right is None:
128+
return left
129+
130+
lkv_meet = meet_types(left, right)
131+
132+
if isinstance(lkv_meet, UninhabitedType):
133+
return None
134+
if isinstance(lkv_meet, LiteralType):
135+
return lkv_meet
136+
137+
msg = (
138+
f"Unexpected result: "
139+
f"meet of last_known_values {left=!s} and {right=!s} "
140+
f"resulted in {lkv_meet!s}"
141+
)
142+
raise ValueError(msg)
143+
144+
116145
def narrow_declared_type(declared: Type, narrowed: Type) -> Type:
117146
"""Return the declared type narrowed down to another type."""
118147
# TODO: check infinite recursion for aliases here.
@@ -1114,8 +1143,14 @@ def visit_typeddict_type(self, t: TypedDictType) -> ProperType:
11141143
def visit_literal_type(self, t: LiteralType) -> ProperType:
11151144
if isinstance(self.s, LiteralType) and self.s == t:
11161145
return t
1117-
elif isinstance(self.s, Instance) and is_subtype(t.fallback, self.s):
1118-
return t
1146+
elif isinstance(self.s, Instance):
1147+
# if is_subtype(t.fallback, self.s):
1148+
# return t
1149+
if self.s.last_known_value is not None:
1150+
# meet(Literal["max"]?, Literal["max"]) -> Literal["max"]
1151+
# meet(Literal["sum"]?, Literal["max"]) -> Never
1152+
return meet_types(self.s.last_known_value, t)
1153+
return self.default(self.s)
11191154
else:
11201155
return self.default(self.s)
11211156

mypy/solve.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,8 @@ def solve_one(lowers: Iterable[Type], uppers: Iterable[Type]) -> Type | None:
319319
elif top is None:
320320
candidate = bottom
321321
elif is_subtype(bottom, top):
322-
candidate = bottom
322+
# Need to meet in case like Literal["x"]? <: T <: Literal["x"]
323+
candidate = meet_types(bottom, top)
323324
else:
324325
candidate = None
325326
return candidate

mypy/subtypes.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,13 @@ def visit_instance(self, left: Instance) -> bool:
548548
assert isinstance(erased, Instance)
549549
t = erased
550550
nominal = True
551+
if self.proper_subtype and right.last_known_value is not None:
552+
if left.last_known_value is None:
553+
# E.g. str is not a proper subtype of Literal["x"]?
554+
nominal = False
555+
else:
556+
# E.g. Literal[A]? <: Literal[B]? requires A <: B
557+
nominal &= self._is_subtype(left.last_known_value, right.last_known_value)
551558
if right.type.has_type_var_tuple_type:
552559
# For variadic instances we simply find the correct type argument mappings,
553560
# all the heavy lifting is done by the tuple subtyping.
@@ -628,8 +635,14 @@ def visit_instance(self, left: Instance) -> bool:
628635
return True
629636
if isinstance(item, Instance):
630637
return is_named_instance(item, "builtins.object")
631-
if isinstance(right, LiteralType) and left.last_known_value is not None:
632-
return self._is_subtype(left.last_known_value, right)
638+
if isinstance(right, LiteralType):
639+
if self.proper_subtype:
640+
# Instance types like Literal["sum"]? is *assignable* to Literal["sum"],
641+
# but is not a proper subtype of it. (Literal["sum"]? is a gradual type,
642+
# that is a proper subtype of str, and assignable to Literal["sum"].
643+
return False
644+
if left.last_known_value is not None:
645+
return self._is_subtype(left.last_known_value, right)
633646
if isinstance(right, FunctionLike):
634647
# Special case: Instance can be a subtype of Callable / Overloaded.
635648
call = find_member("__call__", left, left, is_operator=True)
@@ -964,6 +977,12 @@ def visit_typeddict_type(self, left: TypedDictType) -> bool:
964977
def visit_literal_type(self, left: LiteralType) -> bool:
965978
if isinstance(self.right, LiteralType):
966979
return left == self.right
980+
elif (
981+
isinstance(self.right, Instance)
982+
and self.right.last_known_value is not None
983+
and self.proper_subtype
984+
):
985+
return self._is_subtype(left, self.right.last_known_value)
967986
else:
968987
return self._is_subtype(left.fallback, self.right)
969988

@@ -2138,6 +2157,11 @@ def covers_at_runtime(item: Type, supertype: Type) -> bool:
21382157
item = get_proper_type(item)
21392158
supertype = get_proper_type(supertype)
21402159

2160+
# Use last known value for Instance types, if available.
2161+
# This ensures that e.g. Literal["max"]? is covered by Literal["max"].
2162+
if isinstance(item, Instance) and item.last_known_value is not None:
2163+
item = item.last_known_value
2164+
21412165
# Since runtime type checks will ignore type arguments, erase the types.
21422166
if not (isinstance(supertype, FunctionLike) and supertype.is_type_obj()):
21432167
supertype = erase_type(supertype)

mypy/test/testsubtypes.py

Lines changed: 125 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from mypy.nodes import CONTRAVARIANT, COVARIANT, INVARIANT
4-
from mypy.subtypes import is_subtype
4+
from mypy.subtypes import is_proper_subtype, is_subtype, restrict_subtype_away
55
from mypy.test.helpers import Suite
66
from mypy.test.typefixture import InterfaceTypeFixture, TypeFixture
77
from mypy.types import Instance, TupleType, Type, UninhabitedType, UnpackType
@@ -277,6 +277,74 @@ def test_type_var_tuple_unpacked_variable_length_tuple(self) -> None:
277277
def test_fallback_not_subtype_of_tuple(self) -> None:
278278
self.assert_not_subtype(self.fx.a, TupleType([self.fx.b], fallback=self.fx.a))
279279

280+
def test_literal(self) -> None:
281+
str1 = self.fx.lit_str1
282+
str2 = self.fx.lit_str2
283+
str1_inst = self.fx.lit_str1_inst
284+
str2_inst = self.fx.lit_str2_inst
285+
str_type = self.fx.str_type
286+
287+
# other operand is the fallback type
288+
# "x" ≲ str -> YES
289+
# str ≲ "x" -> NO
290+
# "x"? ≲ str -> YES
291+
# str ≲ "x"? -> YES
292+
self.assert_subtype(str1, str_type)
293+
self.assert_not_subtype(str_type, str1)
294+
self.assert_subtype(str1_inst, str_type)
295+
self.assert_subtype(str_type, str1_inst)
296+
297+
# other operand is the same literal
298+
# "x" ≲ "x" -> YES
299+
# "x" ≲ "x"? -> YES
300+
# "x"? ≲ "x" -> YES
301+
# "x"? ≲ "x"? -> YES
302+
self.assert_subtype(str1, str1)
303+
self.assert_subtype(str1, str1_inst)
304+
self.assert_subtype(str1_inst, str1)
305+
self.assert_subtype(str1_inst, str1_inst)
306+
307+
# other operand is a different literal
308+
# "x" ≲ "y" -> NO
309+
# "x" ≲ "y"? -> YES
310+
# "x"? ≲ "y" -> NO
311+
# "x"? ≲ "y"? -> YES
312+
self.assert_not_subtype(str1, str2)
313+
self.assert_subtype(str1, str2_inst)
314+
self.assert_not_subtype(str1_inst, str2)
315+
self.assert_subtype(str1_inst, str2_inst)
316+
317+
# check proper subtyping
318+
# other operand is the fallback type
319+
# "x" <: str -> YES
320+
# str <: "x" -> NO
321+
# "x"? <: str -> YES
322+
# str <: "x"? -> NO
323+
self.assert_proper_subtype(str1, str_type)
324+
self.assert_not_proper_subtype(str_type, str1)
325+
self.assert_proper_subtype(str1_inst, str_type)
326+
self.assert_not_proper_subtype(str_type, str1_inst)
327+
328+
# other operand is the same literal
329+
# "x" <: "x" -> YES
330+
# "x" <: "x"? -> YES
331+
# "x"? <: "x" -> NO
332+
# "x"? <: "x"? -> YES
333+
self.assert_proper_subtype(str1, str1)
334+
self.assert_proper_subtype(str1, str1_inst)
335+
self.assert_not_proper_subtype(str1_inst, str1)
336+
self.assert_proper_subtype(str1_inst, str1_inst)
337+
338+
# other operand is a different literal
339+
# "x" <: "y" -> NO
340+
# "x" <: "y"? -> NO
341+
# "x"? <: "y" -> NO
342+
# "x"? <: "y"? -> NO
343+
self.assert_not_proper_subtype(str1, str2)
344+
self.assert_not_proper_subtype(str1, str2_inst)
345+
self.assert_not_proper_subtype(str1_inst, str2)
346+
self.assert_not_proper_subtype(str1_inst, str2_inst)
347+
280348
# IDEA: Maybe add these test cases (they are tested pretty well in type
281349
# checker tests already):
282350
# * more interface subtyping test cases
@@ -287,6 +355,12 @@ def test_fallback_not_subtype_of_tuple(self) -> None:
287355
# * any type
288356
# * generic function types
289357

358+
def assert_proper_subtype(self, s: Type, t: Type) -> None:
359+
assert is_proper_subtype(s, t), f"{s} not proper subtype of {t}"
360+
361+
def assert_not_proper_subtype(self, s: Type, t: Type) -> None:
362+
assert not is_proper_subtype(s, t), f"{s} not proper subtype of {t}"
363+
290364
def assert_subtype(self, s: Type, t: Type) -> None:
291365
assert is_subtype(s, t), f"{s} not subtype of {t}"
292366

@@ -304,3 +378,53 @@ def assert_equivalent(self, s: Type, t: Type) -> None:
304378
def assert_unrelated(self, s: Type, t: Type) -> None:
305379
self.assert_not_subtype(s, t)
306380
self.assert_not_subtype(t, s)
381+
382+
383+
class RestrictionSuite(Suite):
384+
# Tests for type restrictions "A - B", i.e. ``T <: A and not T <: B``.
385+
386+
def setUp(self) -> None:
387+
self.fx = TypeFixture()
388+
389+
def assert_restriction(self, s: Type, t: Type, expected: Type) -> None:
390+
actual = restrict_subtype_away(s, t)
391+
msg = f"restrict_subtype_away({s}, {t}) == {{}} ({{}} expected)"
392+
self.assertEqual(actual, expected, msg=msg.format(actual, expected))
393+
394+
def test_literal(self) -> None:
395+
str1 = self.fx.lit_str1
396+
str2 = self.fx.lit_str2
397+
str1_inst = self.fx.lit_str1_inst
398+
str2_inst = self.fx.lit_str2_inst
399+
str_type = self.fx.str_type
400+
uninhabited = self.fx.uninhabited
401+
402+
# other operand is the fallback type
403+
# "x" - str -> Never
404+
# str - "x" -> str
405+
# "x"? - str -> Never
406+
# str - "x"? -> Never
407+
self.assert_restriction(str1, str_type, uninhabited)
408+
self.assert_restriction(str_type, str1, str_type)
409+
self.assert_restriction(str1_inst, str_type, uninhabited)
410+
self.assert_restriction(str_type, str1_inst, uninhabited)
411+
412+
# other operand is the same literal
413+
# "x" - "x" -> Never
414+
# "x" - "x"? -> Never
415+
# "x"? - "x" -> Never
416+
# "x"? - "x"? -> Never
417+
self.assert_restriction(str1, str1, uninhabited)
418+
self.assert_restriction(str1, str1_inst, uninhabited)
419+
self.assert_restriction(str1_inst, str1, uninhabited)
420+
self.assert_restriction(str1_inst, str1_inst, uninhabited)
421+
422+
# other operand is a different literal
423+
# "x" - "y" -> "x"
424+
# "x" - "y"? -> Never
425+
# "x"? - "y" -> "x"?
426+
# "x"? - "y"? -> Never
427+
self.assert_restriction(str1, str2, str1)
428+
self.assert_restriction(str1, str2_inst, uninhabited)
429+
self.assert_restriction(str1_inst, str2, str1_inst)
430+
self.assert_restriction(str1_inst, str2_inst, uninhabited)

0 commit comments

Comments
 (0)