Skip to content

Commit cdd19f9

Browse files
committed
stubgen: Use inferred types for class attributes
1 parent c7ea011 commit cdd19f9

File tree

5 files changed

+290
-36
lines changed

5 files changed

+290
-36
lines changed

mypy/stubgen.py

+17-5
Original file line numberDiff line numberDiff line change
@@ -476,8 +476,9 @@ def __init__(
476476
analyzed: bool = False,
477477
export_less: bool = False,
478478
include_docstrings: bool = False,
479+
known_modules: list[str] | None = None,
479480
) -> None:
480-
super().__init__(_all_, include_private, export_less, include_docstrings)
481+
super().__init__(_all_, include_private, export_less, include_docstrings, known_modules)
481482
self._decorators: list[str] = []
482483
# Stack of defined variables (per scope).
483484
self._vars: list[list[str]] = [[]]
@@ -1233,7 +1234,10 @@ def get_init(
12331234
return None
12341235
self._vars[-1].append(lvalue)
12351236
if annotation is not None:
1236-
typename = self.print_annotation(annotation)
1237+
if isinstance(annotation, UnboundType):
1238+
typename = self.print_annotation(annotation)
1239+
else:
1240+
typename = self.print_annotation(annotation, self.known_modules)
12371241
if (
12381242
isinstance(annotation, UnboundType)
12391243
and not annotation.args
@@ -1460,7 +1464,14 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None:
14601464
and isinstance(lvalue.expr, NameExpr)
14611465
and lvalue.expr.name == "self"
14621466
):
1463-
self.results.append((lvalue.name, o.rvalue, o.unanalyzed_type))
1467+
# lvalue.node might be populated with an inferred type
1468+
if isinstance(lvalue.node, Var) and (
1469+
lvalue.node.is_ready or not isinstance(get_proper_type(lvalue.node.type), AnyType)
1470+
):
1471+
annotation = lvalue.node.type
1472+
else:
1473+
annotation = o.unanalyzed_type
1474+
self.results.append((lvalue.name, o.rvalue, annotation))
14641475

14651476

14661477
def find_self_initializers(fdef: FuncBase) -> list[tuple[str, Expression, Type | None]]:
@@ -1652,7 +1663,7 @@ def mypy_options(stubgen_options: Options) -> MypyOptions:
16521663
options.follow_imports = "skip"
16531664
options.incremental = False
16541665
options.ignore_errors = True
1655-
options.semantic_analysis_only = True
1666+
options.semantic_analysis_only = False
16561667
options.python_version = stubgen_options.pyversion
16571668
options.show_traceback = True
16581669
options.transform_source = remove_misplaced_type_comments
@@ -1729,7 +1740,7 @@ def generate_stub_for_py_module(
17291740
) -> None:
17301741
"""Use analysed (or just parsed) AST to generate type stub for single file.
17311742
1732-
If directory for target doesn't exist it will created. Existing stub
1743+
If directory for target doesn't exist it will be created. Existing stub
17331744
will be overwritten.
17341745
"""
17351746
if inspect:
@@ -1748,6 +1759,7 @@ def generate_stub_for_py_module(
17481759
else:
17491760
gen = ASTStubGenerator(
17501761
mod.runtime_all,
1762+
known_modules=all_modules,
17511763
include_private=include_private,
17521764
analyzed=not parse_only,
17531765
export_less=export_less,

mypy/stubgenc.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,8 @@ def generate_stub_for_c_module(
171171

172172
gen = InspectionStubGenerator(
173173
module_name,
174-
known_modules,
175-
doc_dir,
174+
known_modules=known_modules,
175+
doc_dir=doc_dir,
176176
include_private=include_private,
177177
export_less=export_less,
178178
include_docstrings=include_docstrings,
@@ -240,9 +240,8 @@ def __init__(
240240
else:
241241
self.module = module
242242
self.is_c_module = is_c_module(self.module)
243-
self.known_modules = known_modules
244243
self.resort_members = self.is_c_module
245-
super().__init__(_all_, include_private, export_less, include_docstrings)
244+
super().__init__(_all_, include_private, export_less, include_docstrings, known_modules)
246245
self.module_name = module_name
247246
if self.is_c_module:
248247
# Add additional implicit imports.
@@ -393,10 +392,9 @@ def strip_or_import(self, type_name: str) -> str:
393392
Arguments:
394393
typ: name of the type
395394
"""
396-
local_modules = ["builtins", self.module_name]
397395
parsed_type = parse_type_comment(type_name, 0, 0, None)[1]
398396
assert parsed_type is not None, type_name
399-
return self.print_annotation(parsed_type, self.known_modules, local_modules)
397+
return self.print_annotation(parsed_type, self.known_modules)
400398

401399
def get_obj_module(self, obj: object) -> str | None:
402400
"""Return module name of the object."""

mypy/stubutil.py

+87-17
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,18 @@
2121
from mypy.stubdoc import ArgSig, FunctionSig
2222
from mypy.types import (
2323
AnyType,
24+
CallableType,
25+
DeletedType,
26+
ErasedType,
27+
Instance,
2428
NoneType,
2529
Type,
30+
TypedDictType,
2631
TypeList,
2732
TypeStrVisitor,
33+
TypeVarType,
2834
UnboundType,
35+
UninhabitedType,
2936
UnionType,
3037
UnpackType,
3138
)
@@ -251,6 +258,23 @@ def __init__(
251258
self.known_modules = known_modules
252259
self.local_modules = local_modules or ["builtins"]
253260

261+
def track_imports(self, s: str) -> str | None:
262+
if self.known_modules is not None and "." in s:
263+
# see if this object is from any of the modules that we're currently processing.
264+
# reverse sort so that subpackages come before parents: e.g. "foo.bar" before "foo".
265+
for module_name in self.local_modules + sorted(self.known_modules, reverse=True):
266+
if s.startswith(module_name + "."):
267+
if module_name in self.local_modules:
268+
s = s[len(module_name) + 1 :]
269+
arg_module = module_name
270+
break
271+
else:
272+
arg_module = s[: s.rindex(".")]
273+
if arg_module not in self.local_modules:
274+
self.stubgen.import_tracker.add_import(arg_module, require=True)
275+
return s
276+
return None
277+
254278
def visit_any(self, t: AnyType) -> str:
255279
s = super().visit_any(t)
256280
self.stubgen.import_tracker.require_name(s)
@@ -267,19 +291,9 @@ def visit_unbound_type(self, t: UnboundType) -> str:
267291
return self.stubgen.add_name("_typeshed.Incomplete")
268292
if fullname in TYPING_BUILTIN_REPLACEMENTS:
269293
s = self.stubgen.add_name(TYPING_BUILTIN_REPLACEMENTS[fullname], require=True)
270-
if self.known_modules is not None and "." in s:
271-
# see if this object is from any of the modules that we're currently processing.
272-
# reverse sort so that subpackages come before parents: e.g. "foo.bar" before "foo".
273-
for module_name in self.local_modules + sorted(self.known_modules, reverse=True):
274-
if s.startswith(module_name + "."):
275-
if module_name in self.local_modules:
276-
s = s[len(module_name) + 1 :]
277-
arg_module = module_name
278-
break
279-
else:
280-
arg_module = s[: s.rindex(".")]
281-
if arg_module not in self.local_modules:
282-
self.stubgen.import_tracker.add_import(arg_module, require=True)
294+
295+
if new_s := self.track_imports(s):
296+
s = new_s
283297
elif s == "NoneType":
284298
# when called without analysis all types are unbound, so this won't hit
285299
# visit_none_type().
@@ -292,6 +306,9 @@ def visit_unbound_type(self, t: UnboundType) -> str:
292306
s += "[()]"
293307
return s
294308

309+
def typeddict_item_str(self, t: TypedDictType, name: str, typ: str) -> str:
310+
return f"{name!r}: {typ}"
311+
295312
def visit_none_type(self, t: NoneType) -> str:
296313
return "None"
297314

@@ -322,6 +339,55 @@ def args_str(self, args: Iterable[Type]) -> str:
322339
res.append(arg_str)
323340
return ", ".join(res)
324341

342+
def visit_type_var(self, t: TypeVarType) -> str:
343+
return t.name
344+
345+
def visit_uninhabited_type(self, t: UninhabitedType) -> str:
346+
return self.stubgen.add_name("typing.Any")
347+
348+
def visit_erased_type(self, t: ErasedType) -> str:
349+
return self.stubgen.add_name("typing.Any")
350+
351+
def visit_deleted_type(self, t: DeletedType) -> str:
352+
return self.stubgen.add_name("typing.Any")
353+
354+
def visit_instance(self, t: Instance) -> str:
355+
if t.last_known_value and not t.args:
356+
# Instances with a literal fallback should never be generic. If they are,
357+
# something went wrong so we fall back to showing the full Instance repr.
358+
s = f"{t.last_known_value.accept(self)}"
359+
else:
360+
s = t.type.fullname or t.type.name or self.stubgen.add_name("_typeshed.Incomplete")
361+
362+
s = self.track_imports(s) or s
363+
364+
if t.args:
365+
if t.type.fullname == "builtins.tuple":
366+
assert len(t.args) == 1
367+
s += f"[{self.list_str(t.args)}, ...]"
368+
else:
369+
s += f"[{self.list_str(t.args)}]"
370+
elif t.type.has_type_var_tuple_type and len(t.type.type_vars) == 1:
371+
s += "[()]"
372+
373+
return s
374+
375+
def visit_callable_type(self, t: CallableType) -> str:
376+
from mypy.suggestions import is_tricky_callable
377+
378+
if is_tricky_callable(t):
379+
arg_str = "..."
380+
else:
381+
# Note: for default arguments, we just assume that they
382+
# are required. This isn't right, but neither is the
383+
# other thing, and I suspect this will produce more better
384+
# results than falling back to `...`
385+
args = [typ.accept(self) for typ in t.arg_types]
386+
arg_str = f"[{', '.join(args)}]"
387+
388+
callable = self.stubgen.add_name("typing.Callable")
389+
return f"{callable}[{arg_str}, {t.ret_type.accept(self)}]"
390+
325391

326392
class ClassInfo:
327393
def __init__(
@@ -454,11 +520,11 @@ class ImportTracker:
454520

455521
def __init__(self) -> None:
456522
# module_for['foo'] has the module name where 'foo' was imported from, or None if
457-
# 'foo' is a module imported directly;
523+
# 'foo' is a module imported directly;
458524
# direct_imports['foo'] is the module path used when the name 'foo' was added to the
459-
# namespace.
525+
# namespace.
460526
# reverse_alias['foo'] is the name that 'foo' had originally when imported with an
461-
# alias; examples
527+
# alias; examples
462528
# 'from pkg import mod' ==> module_for['mod'] == 'pkg'
463529
# 'from pkg import mod as m' ==> module_for['m'] == 'pkg'
464530
# ==> reverse_alias['m'] == 'mod'
@@ -618,7 +684,9 @@ def __init__(
618684
include_private: bool = False,
619685
export_less: bool = False,
620686
include_docstrings: bool = False,
687+
known_modules: list[str] | None = None,
621688
) -> None:
689+
self.known_modules = known_modules or []
622690
# Best known value of __all__.
623691
self._all_ = _all_
624692
self._include_private = include_private
@@ -839,7 +907,9 @@ def print_annotation(
839907
known_modules: list[str] | None = None,
840908
local_modules: list[str] | None = None,
841909
) -> str:
842-
printer = AnnotationPrinter(self, known_modules, local_modules)
910+
printer = AnnotationPrinter(
911+
self, known_modules, local_modules or ["builtins", self.module_name]
912+
)
843913
return t.accept(printer)
844914

845915
def is_not_in_all(self, name: str) -> bool:

mypy/types.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -3462,18 +3462,21 @@ def visit_tuple_type(self, t: TupleType, /) -> str:
34623462
return f"{tuple_name}[{s}, fallback={t.partial_fallback.accept(self)}]"
34633463
return f"{tuple_name}[{s}]"
34643464

3465+
def typeddict_item_str(self, t: TypedDictType, name: str, typ: str) -> str:
3466+
modifier = ""
3467+
if name not in t.required_keys:
3468+
modifier += "?"
3469+
if name in t.readonly_keys:
3470+
modifier += "="
3471+
return f"{name!r}{modifier}: {typ}"
3472+
34653473
def visit_typeddict_type(self, t: TypedDictType, /) -> str:
3466-
def item_str(name: str, typ: str) -> str:
3467-
modifier = ""
3468-
if name not in t.required_keys:
3469-
modifier += "?"
3470-
if name in t.readonly_keys:
3471-
modifier += "="
3472-
return f"{name!r}{modifier}: {typ}"
34733474

34743475
s = (
34753476
"{"
3476-
+ ", ".join(item_str(name, typ.accept(self)) for name, typ in t.items.items())
3477+
+ ", ".join(
3478+
self.typeddict_item_str(t, name, typ.accept(self)) for name, typ in t.items.items()
3479+
)
34773480
+ "}"
34783481
)
34793482
prefix = ""

0 commit comments

Comments
 (0)