Skip to content

Commit e95ea81

Browse files
authored
Narrow match captures based on previous cases (#21405)
Fixes #16736 Closes #18155 Co-authored-by Codex
1 parent b7a8bab commit e95ea81

2 files changed

Lines changed: 49 additions & 17 deletions

File tree

mypy/checker.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5960,7 +5960,7 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
59605960
else_map[unwrapped_subject] = else_map[named_subject]
59615961
pattern_map = self.propagate_up_typemap_info(pattern_map)
59625962
else_map = self.propagate_up_typemap_info(else_map)
5963-
self.remove_capture_conflicts(pattern_type.captures, inferred_types)
5963+
self.check_and_remove_capture_conflicts(pattern_type.captures, inferred_types)
59645964
self.push_type_map(pattern_map, from_assignment=False)
59655965
if pattern_map:
59665966
for expr, typ in pattern_map.items():
@@ -6066,15 +6066,8 @@ def infer_variable_types_from_type_maps(
60666066
already_exists = True
60676067
if isinstance(expr.node, Var) and expr.node.is_final:
60686068
self.msg.cant_assign_to_final(expr.name, False, expr)
6069-
if self.check_subtype(
6070-
typ,
6071-
previous_type,
6072-
expr,
6073-
msg=message_registry.INCOMPATIBLE_TYPES_IN_CAPTURE,
6074-
subtype_label="pattern captures type",
6075-
supertype_label="variable has type",
6076-
):
6077-
inferred_types[var] = previous_type
6069+
# We'll check compatibility in check_and_remove_capture_conflicts
6070+
inferred_types[var] = previous_type
60786071

60796072
if not already_exists:
60806073
new_type = UnionType.make_union(types)
@@ -6086,15 +6079,24 @@ def infer_variable_types_from_type_maps(
60866079
self.infer_variable_type(var, first_occurrence, new_type, first_occurrence)
60876080
return inferred_types
60886081

6089-
def remove_capture_conflicts(
6082+
def check_and_remove_capture_conflicts(
60906083
self, type_map: TypeMap, inferred_types: dict[SymbolNode, Type]
60916084
) -> None:
6092-
if not is_unreachable_map(type_map):
6093-
for expr, typ in list(type_map.items()):
6094-
if isinstance(expr, NameExpr):
6095-
node = expr.node
6096-
if node not in inferred_types or not is_subtype(typ, inferred_types[node]):
6097-
del type_map[expr]
6085+
if is_unreachable_map(type_map):
6086+
return
6087+
for expr, typ in list(type_map.items()):
6088+
if not isinstance(expr, NameExpr):
6089+
continue
6090+
node = expr.node
6091+
if node not in inferred_types or not self.check_subtype(
6092+
typ,
6093+
inferred_types[node],
6094+
expr,
6095+
msg=message_registry.INCOMPATIBLE_TYPES_IN_CAPTURE,
6096+
subtype_label="pattern captures type",
6097+
supertype_label="variable has type",
6098+
):
6099+
del type_map[expr]
60986100

60996101
def visit_type_alias_stmt(self, o: TypeAliasStmt) -> None:
61006102
if o.alias_node:

test-data/unit/check-python310.test

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1705,6 +1705,36 @@ reveal_type(a) # N: Revealed type is "builtins.bool"
17051705
a = 3
17061706
reveal_type(a) # N: Revealed type is "builtins.int"
17071707

1708+
[case testMatchCapturePatternAfterPreviousCase]
1709+
# flags: --strict-equality --warn-unreachable
1710+
1711+
def f1(x: int | None, y: int):
1712+
match x:
1713+
case None:
1714+
pass
1715+
case y:
1716+
reveal_type(y) # N: Revealed type is "builtins.int"
1717+
1718+
def f2(x: int | None, y: int, cond: bool):
1719+
match x:
1720+
case None if cond:
1721+
pass
1722+
case y: # E: Incompatible types in capture pattern (pattern captures type "int | None", variable has type "int")
1723+
reveal_type(y) # N: Revealed type is "builtins.int"
1724+
1725+
def f3(x: int | None, y: int):
1726+
match x:
1727+
case None if True:
1728+
pass
1729+
case y:
1730+
reveal_type(y) # N: Revealed type is "builtins.int"
1731+
1732+
match x:
1733+
case None if False:
1734+
pass # E: Statement is unreachable
1735+
case y: # E: Incompatible types in capture pattern (pattern captures type "int | None", variable has type "int")
1736+
reveal_type(y) # N: Revealed type is "builtins.int"
1737+
17081738
[case testMatchCapturePatternPreexistingIncompatible]
17091739
# flags: --strict-equality --warn-unreachable
17101740
a: str

0 commit comments

Comments
 (0)