From cdd19f9fb55abeb01ae32b9be81ef6376e279890 Mon Sep 17 00:00:00 2001 From: Chad Dombrova Date: Sat, 26 Apr 2025 18:11:17 -0700 Subject: [PATCH] stubgen: Use inferred types for class attributes --- mypy/stubgen.py | 22 +++-- mypy/stubgenc.py | 10 +-- mypy/stubutil.py | 104 ++++++++++++++++++---- mypy/types.py | 19 ++-- test-data/unit/stubgen.test | 171 ++++++++++++++++++++++++++++++++++++ 5 files changed, 290 insertions(+), 36 deletions(-) diff --git a/mypy/stubgen.py b/mypy/stubgen.py index 881686adc5ed..4bbc57c52e7e 100755 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -476,8 +476,9 @@ def __init__( analyzed: bool = False, export_less: bool = False, include_docstrings: bool = False, + known_modules: list[str] | None = None, ) -> None: - super().__init__(_all_, include_private, export_less, include_docstrings) + super().__init__(_all_, include_private, export_less, include_docstrings, known_modules) self._decorators: list[str] = [] # Stack of defined variables (per scope). self._vars: list[list[str]] = [[]] @@ -1233,7 +1234,10 @@ def get_init( return None self._vars[-1].append(lvalue) if annotation is not None: - typename = self.print_annotation(annotation) + if isinstance(annotation, UnboundType): + typename = self.print_annotation(annotation) + else: + typename = self.print_annotation(annotation, self.known_modules) if ( isinstance(annotation, UnboundType) and not annotation.args @@ -1460,7 +1464,14 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None: and isinstance(lvalue.expr, NameExpr) and lvalue.expr.name == "self" ): - self.results.append((lvalue.name, o.rvalue, o.unanalyzed_type)) + # lvalue.node might be populated with an inferred type + if isinstance(lvalue.node, Var) and ( + lvalue.node.is_ready or not isinstance(get_proper_type(lvalue.node.type), AnyType) + ): + annotation = lvalue.node.type + else: + annotation = o.unanalyzed_type + self.results.append((lvalue.name, o.rvalue, annotation)) def find_self_initializers(fdef: FuncBase) -> list[tuple[str, Expression, Type | None]]: @@ -1652,7 +1663,7 @@ def mypy_options(stubgen_options: Options) -> MypyOptions: options.follow_imports = "skip" options.incremental = False options.ignore_errors = True - options.semantic_analysis_only = True + options.semantic_analysis_only = False options.python_version = stubgen_options.pyversion options.show_traceback = True options.transform_source = remove_misplaced_type_comments @@ -1729,7 +1740,7 @@ def generate_stub_for_py_module( ) -> None: """Use analysed (or just parsed) AST to generate type stub for single file. - If directory for target doesn't exist it will created. Existing stub + If directory for target doesn't exist it will be created. Existing stub will be overwritten. """ if inspect: @@ -1748,6 +1759,7 @@ def generate_stub_for_py_module( else: gen = ASTStubGenerator( mod.runtime_all, + known_modules=all_modules, include_private=include_private, analyzed=not parse_only, export_less=export_less, diff --git a/mypy/stubgenc.py b/mypy/stubgenc.py index b03a88cf6f43..552b5e75ead3 100755 --- a/mypy/stubgenc.py +++ b/mypy/stubgenc.py @@ -171,8 +171,8 @@ def generate_stub_for_c_module( gen = InspectionStubGenerator( module_name, - known_modules, - doc_dir, + known_modules=known_modules, + doc_dir=doc_dir, include_private=include_private, export_less=export_less, include_docstrings=include_docstrings, @@ -240,9 +240,8 @@ def __init__( else: self.module = module self.is_c_module = is_c_module(self.module) - self.known_modules = known_modules self.resort_members = self.is_c_module - super().__init__(_all_, include_private, export_less, include_docstrings) + super().__init__(_all_, include_private, export_less, include_docstrings, known_modules) self.module_name = module_name if self.is_c_module: # Add additional implicit imports. @@ -393,10 +392,9 @@ def strip_or_import(self, type_name: str) -> str: Arguments: typ: name of the type """ - local_modules = ["builtins", self.module_name] parsed_type = parse_type_comment(type_name, 0, 0, None)[1] assert parsed_type is not None, type_name - return self.print_annotation(parsed_type, self.known_modules, local_modules) + return self.print_annotation(parsed_type, self.known_modules) def get_obj_module(self, obj: object) -> str | None: """Return module name of the object.""" diff --git a/mypy/stubutil.py b/mypy/stubutil.py index fecd9b29d57d..f39cb50f5daa 100644 --- a/mypy/stubutil.py +++ b/mypy/stubutil.py @@ -21,11 +21,18 @@ from mypy.stubdoc import ArgSig, FunctionSig from mypy.types import ( AnyType, + CallableType, + DeletedType, + ErasedType, + Instance, NoneType, Type, + TypedDictType, TypeList, TypeStrVisitor, + TypeVarType, UnboundType, + UninhabitedType, UnionType, UnpackType, ) @@ -251,6 +258,23 @@ def __init__( self.known_modules = known_modules self.local_modules = local_modules or ["builtins"] + def track_imports(self, s: str) -> str | None: + if self.known_modules is not None and "." in s: + # see if this object is from any of the modules that we're currently processing. + # reverse sort so that subpackages come before parents: e.g. "foo.bar" before "foo". + for module_name in self.local_modules + sorted(self.known_modules, reverse=True): + if s.startswith(module_name + "."): + if module_name in self.local_modules: + s = s[len(module_name) + 1 :] + arg_module = module_name + break + else: + arg_module = s[: s.rindex(".")] + if arg_module not in self.local_modules: + self.stubgen.import_tracker.add_import(arg_module, require=True) + return s + return None + def visit_any(self, t: AnyType) -> str: s = super().visit_any(t) self.stubgen.import_tracker.require_name(s) @@ -267,19 +291,9 @@ def visit_unbound_type(self, t: UnboundType) -> str: return self.stubgen.add_name("_typeshed.Incomplete") if fullname in TYPING_BUILTIN_REPLACEMENTS: s = self.stubgen.add_name(TYPING_BUILTIN_REPLACEMENTS[fullname], require=True) - if self.known_modules is not None and "." in s: - # see if this object is from any of the modules that we're currently processing. - # reverse sort so that subpackages come before parents: e.g. "foo.bar" before "foo". - for module_name in self.local_modules + sorted(self.known_modules, reverse=True): - if s.startswith(module_name + "."): - if module_name in self.local_modules: - s = s[len(module_name) + 1 :] - arg_module = module_name - break - else: - arg_module = s[: s.rindex(".")] - if arg_module not in self.local_modules: - self.stubgen.import_tracker.add_import(arg_module, require=True) + + if new_s := self.track_imports(s): + s = new_s elif s == "NoneType": # when called without analysis all types are unbound, so this won't hit # visit_none_type(). @@ -292,6 +306,9 @@ def visit_unbound_type(self, t: UnboundType) -> str: s += "[()]" return s + def typeddict_item_str(self, t: TypedDictType, name: str, typ: str) -> str: + return f"{name!r}: {typ}" + def visit_none_type(self, t: NoneType) -> str: return "None" @@ -322,6 +339,55 @@ def args_str(self, args: Iterable[Type]) -> str: res.append(arg_str) return ", ".join(res) + def visit_type_var(self, t: TypeVarType) -> str: + return t.name + + def visit_uninhabited_type(self, t: UninhabitedType) -> str: + return self.stubgen.add_name("typing.Any") + + def visit_erased_type(self, t: ErasedType) -> str: + return self.stubgen.add_name("typing.Any") + + def visit_deleted_type(self, t: DeletedType) -> str: + return self.stubgen.add_name("typing.Any") + + def visit_instance(self, t: Instance) -> str: + if t.last_known_value and not t.args: + # Instances with a literal fallback should never be generic. If they are, + # something went wrong so we fall back to showing the full Instance repr. + s = f"{t.last_known_value.accept(self)}" + else: + s = t.type.fullname or t.type.name or self.stubgen.add_name("_typeshed.Incomplete") + + s = self.track_imports(s) or s + + if t.args: + if t.type.fullname == "builtins.tuple": + assert len(t.args) == 1 + s += f"[{self.list_str(t.args)}, ...]" + else: + s += f"[{self.list_str(t.args)}]" + elif t.type.has_type_var_tuple_type and len(t.type.type_vars) == 1: + s += "[()]" + + return s + + def visit_callable_type(self, t: CallableType) -> str: + from mypy.suggestions import is_tricky_callable + + if is_tricky_callable(t): + arg_str = "..." + else: + # Note: for default arguments, we just assume that they + # are required. This isn't right, but neither is the + # other thing, and I suspect this will produce more better + # results than falling back to `...` + args = [typ.accept(self) for typ in t.arg_types] + arg_str = f"[{', '.join(args)}]" + + callable = self.stubgen.add_name("typing.Callable") + return f"{callable}[{arg_str}, {t.ret_type.accept(self)}]" + class ClassInfo: def __init__( @@ -454,11 +520,11 @@ class ImportTracker: def __init__(self) -> None: # module_for['foo'] has the module name where 'foo' was imported from, or None if - # 'foo' is a module imported directly; + # 'foo' is a module imported directly; # direct_imports['foo'] is the module path used when the name 'foo' was added to the - # namespace. + # namespace. # reverse_alias['foo'] is the name that 'foo' had originally when imported with an - # alias; examples + # alias; examples # 'from pkg import mod' ==> module_for['mod'] == 'pkg' # 'from pkg import mod as m' ==> module_for['m'] == 'pkg' # ==> reverse_alias['m'] == 'mod' @@ -618,7 +684,9 @@ def __init__( include_private: bool = False, export_less: bool = False, include_docstrings: bool = False, + known_modules: list[str] | None = None, ) -> None: + self.known_modules = known_modules or [] # Best known value of __all__. self._all_ = _all_ self._include_private = include_private @@ -839,7 +907,9 @@ def print_annotation( known_modules: list[str] | None = None, local_modules: list[str] | None = None, ) -> str: - printer = AnnotationPrinter(self, known_modules, local_modules) + printer = AnnotationPrinter( + self, known_modules, local_modules or ["builtins", self.module_name] + ) return t.accept(printer) def is_not_in_all(self, name: str) -> bool: diff --git a/mypy/types.py b/mypy/types.py index 41a958ae93cc..cfe56f3c9c05 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -3462,18 +3462,21 @@ def visit_tuple_type(self, t: TupleType, /) -> str: return f"{tuple_name}[{s}, fallback={t.partial_fallback.accept(self)}]" return f"{tuple_name}[{s}]" + def typeddict_item_str(self, t: TypedDictType, name: str, typ: str) -> str: + modifier = "" + if name not in t.required_keys: + modifier += "?" + if name in t.readonly_keys: + modifier += "=" + return f"{name!r}{modifier}: {typ}" + def visit_typeddict_type(self, t: TypedDictType, /) -> str: - def item_str(name: str, typ: str) -> str: - modifier = "" - if name not in t.required_keys: - modifier += "?" - if name in t.readonly_keys: - modifier += "=" - return f"{name!r}{modifier}: {typ}" s = ( "{" - + ", ".join(item_str(name, typ.accept(self)) for name, typ in t.items.items()) + + ", ".join( + self.typeddict_item_str(t, name, typ.accept(self)) for name, typ in t.items.items() + ) + "}" ) prefix = "" diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test index bf17c34b99a7..2f8d4069d80f 100644 --- a/test-data/unit/stubgen.test +++ b/test-data/unit/stubgen.test @@ -268,6 +268,20 @@ class C: [out] x: int +class C: + x: int + def __init__(self) -> None: ... + +[case testSelfAndClassBodyAssignment_semanal] +x = 1 +class C: + x = 1 + def __init__(self): + self.x = 1 + self.x = 1 +[out] +x: int + class C: x: int def __init__(self) -> None: ... @@ -1394,6 +1408,22 @@ class A: def __init__(self, a: Incomplete | None = None) -> None: ... def method(self, a: Incomplete | None = None) -> None: ... +[case testInferOptionalOnlyFunc_semanal] +class A: + x = None + def __init__(self, a=None): + self.x = [] + def method(self, a=None): + self.x = [] +[out] +from _typeshed import Incomplete + +class A: + x: Incomplete + def __init__(self, a: Incomplete | None = None) -> None: ... + def method(self, a: Incomplete | None = None) -> None: ... + + [case testAnnotationImportsFrom] import foo from collections import defaultdict @@ -2333,6 +2363,147 @@ class C(Base, metaclass=abc.ABCMeta): @abstractmethod def other(self): ... +[case testInferredAttribute_semanal] +from typing import Any + +class Foo: + def __init__(self, name: str, timeout=0, any: Any = None, incomplete=None) -> None: + self.name = name + self.timeout = timeout + self.inferred_any = any + self.explicit_any: Any = any + self.incomplete = incomplete + self.cache: dict[str, int] = {} + +[out] +from _typeshed import Incomplete +from typing import Any + +class Foo: + name: str + timeout: Incomplete + inferred_any: Incomplete + explicit_any: Any + incomplete: Incomplete + cache: dict[str, int] + def __init__(self, name: str, timeout: int = 0, any: Any = None, incomplete: Incomplete | None = None) -> None: ... + +[case testInferredGenericAttribute_semanal] +from typing import Generic, TypeVar + +T = TypeVar("T") + +class Foo(Generic[T]): + def __init__(self, things: list[T]) -> None: + self.things = things + +[out] +from typing import Generic, TypeVar + +T = TypeVar('T') + +class Foo(Generic[T]): + things: list[T] + def __init__(self, things: list[T]) -> None: ... + +[case testInferredTypedDict_semanal] +from typing import TypedDict + +class Context(TypedDict, total=False): + name: str + aliases: dict[str, str] + tools: set[str] + +class A: + def __init__(self): + self.context: Context | None = None + +[out] +from typing import TypedDict + +class Context(TypedDict, total=False): + name: str + aliases: dict[str, str] + tools: set[str] + +class A: + context: TypedDict('main.Context', {'name': str, 'aliases': dict[str, str], 'tools': set[str]}) | None + def __init__(self) -> None: ... + +[case testInferredComplexAlias_semanal] +# modules: main a + +import collections +from a import valid, B + +def func() -> int: + return 2 + +aliased_func = func +int_value = 1 + +class A: + cls_var = valid + + def __init__(self, arg1: str, arg2: list[B], arg3: list[str], arg4: collections.defaultdict[str, collections.deque[int]]) -> None: + self.arg1 = arg1 + self.arg2 = arg2 + self.arg3 = arg3 + self.arg4 = arg4 + self.func = func + + def meth(self) -> None: + func_value = int_value + + alias_meth = meth + alias_func = func + alias_alias_func = aliased_func + int_value = int_value + +class C: + def __init__(self, arg: A): + self.arg = arg + +[file a.py] +valid : list[int] = [1, 2, 3] + +class B: + pass + + +[out] +# main.pyi +import a +import collections +from a import B, valid +from typing import Callable + +def func() -> int: ... +aliased_func = func +int_value: int + +class A: + cls_var = valid + arg1: str + arg2: list[a.B] + arg3: list[str] + arg4: collections.defaultdict[str, collections.deque[int]] + func: Callable[[], int] + def __init__(self, arg1: str, arg2: list[B], arg3: list[str], arg4: collections.defaultdict[str, collections.deque[int]]) -> None: ... + def meth(self) -> None: ... + alias_meth = meth + alias_func = func + alias_alias_func = aliased_func + int_value = int_value + +class C: + arg: A + def __init__(self, arg: A) -> None: ... +# a.pyi +valid: list[int] + +class B: ... + [case testInvalidNumberOfArgsInAnnotation] def f(x): # type: () -> int