Skip to content

Commit 938dbe2

Browse files
authored
Respect explicit return type of __new__() (#21441)
Fixes #8330 Fixes #15182 Closes #16020 Closes #14471 Fixes #13824 Fixes #14502 With this PR will still give an error if the explicit return `__new__()` is not a subtype of current class, but now we will actually use it. There are _two exceptions_ (to preserve backwards compatibility): * If the return type is `Any` with still use current class as the return type. * If the explicit return type comes from a superclass and is a supertype of implicit return type. ```python class A: def __new__(cls): ... reveal_type(A()) # still __main__.A class B: def __new__(cls) -> B: return cls() class C(B): ... reveal_type(C()) # still __main__.C ``` This uses a more principled implementation than some earlier attempts: adding a new dedicated attribute to `CallableType` for this purpose. Some comments: * This PR has a bit for boilerplate, but this is expected. When adding a new attribute to a type, one needs to update most visitors. * While doing the above I noticed that `CallableType.type_guard` and `CallableType.type_is` were not handled in dependency visitors (neither coarse-grained nor fine-grained). IMO they definitely should be handled there, so I now handle them. * I try to reduce the size of `CallableType` a bit to compensate for new attribute by removing (rarely used) `min_args` attribute. I also use compact flags serialization. * I skimmed the code base and updated all places where we "casually" use `CallableType.ret_type` for various type object edge cases. * Note that I don't assert that `instance_type` is set when `is_type_obj()` returns `True`. This is mostly to avoid breaking 3rd party plugins. * I didn't add a test case for each edge case, but I added (improved) test cases from #16020 plus some more. Suggestions for more test cases are welcome. * It looks like this exposes a pre-existing bug where we leaked type variables in `type[T]`. It is relatively niche edge case, but it looks important for NumPy stubs, so I am trying to fix it here.
1 parent 668733d commit 938dbe2

26 files changed

Lines changed: 627 additions & 111 deletions

mypy/applytype.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,11 +179,17 @@ def apply_generic_arguments(
179179
assert isinstance(typ, TypeVarLikeType)
180180
remaining_tvars.append(typ)
181181

182+
instance_type = None
183+
if callable.instance_type is not None:
184+
instance_type = expand_type(callable.instance_type, id_to_type)
185+
assert isinstance(instance_type, ProperType)
186+
182187
return callable.copy_modified(
183188
ret_type=expand_type(callable.ret_type, id_to_type),
184189
variables=remaining_tvars,
185190
type_guard=type_guard,
186191
type_is=type_is,
192+
instance_type=instance_type,
187193
)
188194

189195

mypy/cache.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969
from mypy_extensions import u8
7070

7171
# High-level cache layout format
72-
CACHE_VERSION: Final = 8
72+
CACHE_VERSION: Final = 9
7373

7474
# Type used internally to represent errors:
7575
# (path, line, column, end_line, end_column, severity, message, code)
@@ -558,6 +558,20 @@ def write_json(data: WriteBuffer, value: dict[str, Any]) -> None:
558558
write_json_value(data, value[key])
559559

560560

561+
def write_flags(data: WriteBuffer, flags: list[bool]) -> None:
562+
assert len(flags) <= 26, "This many flags not supported yet"
563+
packed = 0
564+
for i, flag in enumerate(flags):
565+
if flag:
566+
packed |= 1 << i
567+
write_int(data, packed)
568+
569+
570+
def read_flags(data: ReadBuffer, num_flags: int) -> list[bool]:
571+
packed = read_int(data)
572+
return [(packed & (1 << i)) != 0 for i in range(num_flags)]
573+
574+
561575
def write_errors(data: WriteBuffer, errs: list[ErrorTuple]) -> None:
562576
write_tag(data, LIST_GEN)
563577
write_int_bare(data, len(errs))

mypy/checker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5515,7 +5515,7 @@ def check_except_handler_test(self, n: Expression, is_star: bool) -> Type:
55155515
if not item.is_type_obj():
55165516
self.fail(message_registry.INVALID_EXCEPTION_TYPE, n)
55175517
return self.default_exception_type(is_star)
5518-
exc_type = erase_typevars(item.ret_type)
5518+
exc_type = erase_typevars(item.get_instance_type())
55195519
elif isinstance(ttype, TypeType):
55205520
exc_type = ttype.item
55215521
else:

mypy/checkexpr.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,6 @@
124124
from mypy.subtypes import (
125125
covers_at_runtime,
126126
find_member,
127-
is_equivalent,
128127
is_same_type,
129128
is_subtype,
130129
non_method_protocol_members,
@@ -689,7 +688,7 @@ def method_fullname(self, object_type: Type, method_name: str) -> str | None:
689688
# For class method calls, object_type is a callable representing the class object.
690689
# We "unwrap" it to a regular type, as the class/instance method difference doesn't
691690
# affect the fully qualified name.
692-
object_type = get_proper_type(object_type.ret_type)
691+
object_type = object_type.get_instance_type()
693692
elif isinstance(object_type, TypeType):
694693
object_type = object_type.item
695694

@@ -717,9 +716,9 @@ def always_returns_none(self, node: Expression) -> bool:
717716
if isinstance(typ, Instance):
718717
info = typ.type
719718
elif isinstance(typ, CallableType) and typ.is_type_obj():
720-
ret_type = get_proper_type(typ.ret_type)
721-
if isinstance(ret_type, Instance):
722-
info = ret_type.type
719+
instance_type = typ.get_instance_type(force_fallback=True)
720+
if isinstance(instance_type, Instance):
721+
info = instance_type.type
723722
else:
724723
return False
725724
else:
@@ -1668,9 +1667,10 @@ def check_callable_call(
16681667
callee = callee.with_unpacked_kwargs().with_normalized_var_args()
16691668
if callable_name is None and callee.name:
16701669
callable_name = callee.name
1671-
ret_type = get_proper_type(callee.ret_type)
1672-
if callee.is_type_obj() and isinstance(ret_type, Instance):
1673-
callable_name = ret_type.type.fullname
1670+
if callee.is_type_obj():
1671+
instance_type = callee.get_instance_type(force_fallback=True)
1672+
if isinstance(instance_type, Instance):
1673+
callable_name = instance_type.type.fullname
16741674
if isinstance(callable_node, RefExpr) and callable_node.fullname in ENUM_BASES:
16751675
# An Enum() call that failed SemanticAnalyzerPass2.check_enum_call().
16761676
return callee.ret_type, callee
@@ -1867,7 +1867,7 @@ def check_callable_call(
18671867
if (
18681868
callee.is_type_obj()
18691869
and (len(arg_types) == 1)
1870-
and is_equivalent(callee.ret_type, self.named_type("builtins.type"))
1870+
and is_named_instance(callee.get_instance_type(), "builtins.type")
18711871
):
18721872
callee = callee.copy_modified(ret_type=TypeType.make_normalized(arg_types[0]))
18731873

mypy/checkmember.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -407,15 +407,8 @@ def validate_super_call(node: FuncBase, mx: MemberContext) -> None:
407407
def analyze_type_callable_member_access(name: str, typ: FunctionLike, mx: MemberContext) -> Type:
408408
# Class attribute.
409409
# TODO super?
410-
ret_type = typ.items[0].ret_type
411-
assert isinstance(ret_type, ProperType)
412-
if isinstance(ret_type, TupleType):
413-
ret_type = tuple_fallback(ret_type)
414-
if isinstance(ret_type, TypedDictType):
415-
ret_type = ret_type.fallback
416-
if isinstance(ret_type, LiteralType):
417-
ret_type = ret_type.fallback
418-
if isinstance(ret_type, Instance):
410+
instance_type = typ.items[0].get_instance_type(force_fallback=True)
411+
if isinstance(instance_type, Instance):
419412
if not mx.is_operator:
420413
# When Python sees an operator (eg `3 == 4`), it automatically translates that
421414
# into something like `int.__eq__(3, 4)` instead of `(3).__eq__(4)` as an
@@ -432,14 +425,18 @@ def analyze_type_callable_member_access(name: str, typ: FunctionLike, mx: Member
432425
# See https://github.com/python/mypy/pull/1787 for more info.
433426
# TODO: do not rely on same type variables being present in all constructor overloads.
434427
result = analyze_class_attribute_access(
435-
ret_type, name, mx, original_vars=typ.items[0].variables, mcs_fallback=typ.fallback
428+
instance_type,
429+
name,
430+
mx,
431+
original_vars=typ.items[0].variables,
432+
mcs_fallback=typ.fallback,
436433
)
437434
if result:
438435
return result
439436
# Look up from the 'type' type.
440437
return _analyze_member_access(name, typ.fallback, mx)
441438
else:
442-
assert False, f"Unexpected type {ret_type!r}"
439+
assert False, f"Unexpected type {instance_type!r}"
443440

444441

445442
def analyze_type_type_member_access(
@@ -721,7 +718,7 @@ def analyze_descriptor_access(descriptor_type: Type, mx: MemberContext) -> Type:
721718
dunder_get_type = expand_type_by_instance(bound_method, typ)
722719

723720
if isinstance(instance_type, FunctionLike) and instance_type.is_type_obj():
724-
owner_type = instance_type.items[0].ret_type
721+
owner_type = instance_type.items[0].get_instance_type()
725722
instance_type = NoneType()
726723
elif isinstance(instance_type, TypeType):
727724
owner_type = instance_type.item

mypy/constraints.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import mypy.typeops
1010
from mypy.argmap import ArgTypeExpander
1111
from mypy.erasetype import erase_typevars
12+
from mypy.expandtype import expand_type_by_instance
1213
from mypy.maptype import map_instance_to_supertype
1314
from mypy.nodes import (
1415
ARG_OPT,
@@ -275,7 +276,11 @@ def infer_constraints_for_callable(
275276

276277

277278
def infer_constraints(
278-
template: Type, actual: Type, direction: int, skip_neg_op: bool = False
279+
template: Type,
280+
actual: Type,
281+
direction: int,
282+
skip_neg_op: bool = False,
283+
erase_types: bool = True,
279284
) -> list[Constraint]:
280285
"""Infer type constraints.
281286
@@ -312,14 +317,14 @@ def infer_constraints(
312317
# Return early on an empty branch.
313318
return []
314319
type_state.inferring.append((template, actual))
315-
res = _infer_constraints(template, actual, direction, skip_neg_op)
320+
res = _infer_constraints(template, actual, direction, skip_neg_op, erase_types)
316321
type_state.inferring.pop()
317322
return res
318-
return _infer_constraints(template, actual, direction, skip_neg_op)
323+
return _infer_constraints(template, actual, direction, skip_neg_op, erase_types)
319324

320325

321326
def _infer_constraints(
322-
template: Type, actual: Type, direction: int, skip_neg_op: bool
327+
template: Type, actual: Type, direction: int, skip_neg_op: bool, erase_types: bool
323328
) -> list[Constraint]:
324329
orig_template = template
325330
template = get_proper_type(template)
@@ -424,7 +429,7 @@ def _infer_constraints(
424429
return []
425430

426431
# Remaining cases are handled by ConstraintBuilderVisitor.
427-
return template.accept(ConstraintBuilderVisitor(actual, direction, skip_neg_op))
432+
return template.accept(ConstraintBuilderVisitor(actual, direction, skip_neg_op, erase_types))
428433

429434

430435
def _is_type_type(tp: ProperType) -> TypeGuard[TypeType | UnionType]:
@@ -659,14 +664,20 @@ class ConstraintBuilderVisitor(TypeVisitor[list[Constraint]]):
659664
# TODO: The value may be None. Is that actually correct?
660665
actual: ProperType
661666

662-
def __init__(self, actual: ProperType, direction: int, skip_neg_op: bool) -> None:
667+
def __init__(
668+
self, actual: ProperType, direction: int, skip_neg_op: bool, erase_types: bool
669+
) -> None:
663670
# Direction must be SUBTYPE_OF or SUPERTYPE_OF.
664671
self.actual = actual
665672
self.direction = direction
666673
# Whether to skip polymorphic inference (involves inference in opposite direction)
667674
# this is used to prevent infinite recursion when both template and actual are
668675
# generic callables.
669676
self.skip_neg_op = skip_neg_op
677+
# Normally we should erase generic actual type when inferring against type[T]
678+
# to avoid leaking type variables, see testGenericClassAsArgumentToType.
679+
# The only exception is self-types in generic classes, where we set this to False.
680+
self.erase_types = erase_types
670681

671682
# Trivial leaf types
672683

@@ -759,13 +770,11 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
759770
and template.type.is_protocol
760771
and self.direction == SUPERTYPE_OF
761772
):
762-
ret_type = get_proper_type(actual.ret_type)
763-
if isinstance(ret_type, TupleType):
764-
ret_type = mypy.typeops.tuple_fallback(ret_type)
765-
if isinstance(ret_type, Instance):
773+
instance_type = actual.get_instance_type(force_fallback=True)
774+
if isinstance(instance_type, Instance):
766775
res.extend(
767776
self.infer_constraints_from_protocol_members(
768-
ret_type, template, ret_type, template, class_obj=True
777+
instance_type, template, instance_type, template, class_obj=True
769778
)
770779
)
771780
actual = actual.fallback
@@ -1213,6 +1222,20 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
12131222
elif isinstance(self.actual, Overloaded):
12141223
return self.infer_against_overloaded(self.actual, template)
12151224
elif isinstance(self.actual, TypeType):
1225+
# This matches the corresponding logic in subtypes.py.
1226+
item = self.actual.item
1227+
if isinstance(item, TupleType):
1228+
item = mypy.typeops.tuple_fallback(item)
1229+
if isinstance(item, Instance):
1230+
constructor = mypy.typeops.type_object_type(item.type)
1231+
constructor = expand_type_by_instance(constructor, item)
1232+
# Only consider return type to match historic behavior (see below).
1233+
if isinstance(constructor, CallableType):
1234+
return infer_constraints(
1235+
template.ret_type, constructor.ret_type, self.direction
1236+
)
1237+
elif isinstance(constructor, Overloaded):
1238+
return self.infer_against_overloaded(constructor, template, ret_only=True)
12161239
return infer_constraints(template.ret_type, self.actual.item, self.direction)
12171240
elif isinstance(self.actual, Instance):
12181241
# Instances with __call__ method defined are considered structural
@@ -1228,14 +1251,16 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
12281251
return []
12291252

12301253
def infer_against_overloaded(
1231-
self, overloaded: Overloaded, template: CallableType
1254+
self, overloaded: Overloaded, template: CallableType, ret_only: bool = False
12321255
) -> list[Constraint]:
12331256
# Create constraints by matching an overloaded type against a template.
12341257
# This is tricky to do in general. We cheat by only matching against
12351258
# the first overload item that is callable compatible. This
12361259
# seems to work somewhat well, but we should really use a more
12371260
# reliable technique.
12381261
item = find_matching_overload_item(overloaded, template)
1262+
if ret_only:
1263+
return infer_constraints(template.ret_type, item.ret_type, self.direction)
12391264
return infer_constraints(template, item, self.direction)
12401265

12411266
def visit_tuple_type(self, template: TupleType) -> list[Constraint]:
@@ -1398,8 +1423,18 @@ def visit_overloaded(self, template: Overloaded) -> list[Constraint]:
13981423

13991424
def visit_type_type(self, template: TypeType) -> list[Constraint]:
14001425
if isinstance(self.actual, CallableType):
1426+
if self.actual.is_type_obj():
1427+
instance_type = self.actual.get_instance_type()
1428+
if self.erase_types:
1429+
instance_type = erase_typevars(instance_type)
1430+
return infer_constraints(template.item, instance_type, self.direction)
14011431
return infer_constraints(template.item, self.actual.ret_type, self.direction)
14021432
elif isinstance(self.actual, Overloaded):
1433+
if self.actual.is_type_obj():
1434+
instance_type = self.actual.items[0].get_instance_type()
1435+
if self.erase_types:
1436+
instance_type = erase_typevars(instance_type)
1437+
return infer_constraints(template.item, instance_type, self.direction)
14031438
return infer_constraints(template.item, self.actual.items[0].ret_type, self.direction)
14041439
elif isinstance(self.actual, TypeType):
14051440
return infer_constraints(template.item, self.actual.item, self.direction)

mypy/expandtype.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -485,11 +485,16 @@ def visit_callable_type(self, t: CallableType) -> CallableType:
485485
arg_types = self.interpolate_args_for_unpack(t, var_arg.typ)
486486
else:
487487
arg_types = self.expand_types(t.arg_types)
488+
instance_type = None
489+
if t.instance_type is not None:
490+
instance_type = t.instance_type.accept(self)
491+
assert isinstance(instance_type, ProperType)
488492
expanded = t.copy_modified(
489493
arg_types=arg_types,
490494
ret_type=t.ret_type.accept(self),
491-
type_guard=(t.type_guard.accept(self) if t.type_guard is not None else None),
492-
type_is=(t.type_is.accept(self) if t.type_is is not None else None),
495+
type_guard=t.type_guard.accept(self) if t.type_guard is not None else None,
496+
type_is=t.type_is.accept(self) if t.type_is is not None else None,
497+
instance_type=instance_type,
493498
)
494499
if needs_normalization:
495500
return expanded.with_normalized_var_args()

mypy/fixup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,8 @@ def visit_callable_type(self, ct: CallableType) -> None:
279279
ct.type_guard.accept(self)
280280
if ct.type_is is not None:
281281
ct.type_is.accept(self)
282+
if ct.instance_type is not None:
283+
ct.instance_type.accept(self)
282284

283285
def visit_overloaded(self, t: Overloaded) -> None:
284286
for ct in t.items:

mypy/indirection.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,12 @@ def visit_instance(self, t: types.Instance) -> None:
134134
def visit_callable_type(self, t: types.CallableType) -> None:
135135
self._visit_type_list(t.arg_types)
136136
self._visit(t.ret_type)
137+
if t.type_guard is not None:
138+
self._visit(t.type_guard)
139+
if t.type_is is not None:
140+
self._visit(t.type_is)
141+
if t.instance_type is not None:
142+
self._visit(t.instance_type)
137143
self._visit_type_tuple(t.variables)
138144

139145
def visit_overloaded(self, t: types.Overloaded) -> None:

mypy/infer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,11 @@ def infer_type_arguments(
7070
actual: Type,
7171
is_supertype: bool = False,
7272
skip_unsatisfied: bool = False,
73+
erase_types: bool = True,
7374
) -> list[Type | None]:
7475
# Like infer_function_type_arguments, but only match a single type
7576
# against a generic type.
76-
constraints = infer_constraints(template, actual, SUPERTYPE_OF if is_supertype else SUBTYPE_OF)
77+
constraints = infer_constraints(
78+
template, actual, SUPERTYPE_OF if is_supertype else SUBTYPE_OF, erase_types=erase_types
79+
)
7780
return solve_constraints(type_vars, constraints, skip_unsatisfied=skip_unsatisfied)[0]

0 commit comments

Comments
 (0)