2121from mypy .stubdoc import ArgSig , FunctionSig
2222from mypy .types import (
2323 AnyType ,
24+ CallableType ,
25+ DeletedType ,
26+ ErasedType ,
27+ Instance ,
2428 NoneType ,
2529 Type ,
30+ TypedDictType ,
2631 TypeList ,
2732 TypeStrVisitor ,
33+ TypeVarType ,
2834 UnboundType ,
35+ UninhabitedType ,
2936 UnionType ,
3037 UnpackType ,
3138)
@@ -251,6 +258,23 @@ def __init__(
251258 self .known_modules = known_modules
252259 self .local_modules = local_modules or ["builtins" ]
253260
261+ def track_imports (self , s : str ) -> str | None :
262+ if self .known_modules is not None and "." in s :
263+ # see if this object is from any of the modules that we're currently processing.
264+ # reverse sort so that subpackages come before parents: e.g. "foo.bar" before "foo".
265+ for module_name in self .local_modules + sorted (self .known_modules , reverse = True ):
266+ if s .startswith (module_name + "." ):
267+ if module_name in self .local_modules :
268+ s = s [len (module_name ) + 1 :]
269+ arg_module = module_name
270+ break
271+ else :
272+ arg_module = s [: s .rindex ("." )]
273+ if arg_module not in self .local_modules :
274+ self .stubgen .import_tracker .add_import (arg_module , require = True )
275+ return s
276+ return None
277+
254278 def visit_any (self , t : AnyType ) -> str :
255279 s = super ().visit_any (t )
256280 self .stubgen .import_tracker .require_name (s )
@@ -267,19 +291,9 @@ def visit_unbound_type(self, t: UnboundType) -> str:
267291 return self .stubgen .add_name ("_typeshed.Incomplete" )
268292 if fullname in TYPING_BUILTIN_REPLACEMENTS :
269293 s = self .stubgen .add_name (TYPING_BUILTIN_REPLACEMENTS [fullname ], require = True )
270- if self .known_modules is not None and "." in s :
271- # see if this object is from any of the modules that we're currently processing.
272- # reverse sort so that subpackages come before parents: e.g. "foo.bar" before "foo".
273- for module_name in self .local_modules + sorted (self .known_modules , reverse = True ):
274- if s .startswith (module_name + "." ):
275- if module_name in self .local_modules :
276- s = s [len (module_name ) + 1 :]
277- arg_module = module_name
278- break
279- else :
280- arg_module = s [: s .rindex ("." )]
281- if arg_module not in self .local_modules :
282- self .stubgen .import_tracker .add_import (arg_module , require = True )
294+
295+ if new_s := self .track_imports (s ):
296+ s = new_s
283297 elif s == "NoneType" :
284298 # when called without analysis all types are unbound, so this won't hit
285299 # visit_none_type().
@@ -292,6 +306,9 @@ def visit_unbound_type(self, t: UnboundType) -> str:
292306 s += "[()]"
293307 return s
294308
309+ def typeddict_item_str (self , t : TypedDictType , name : str , typ : str ) -> str :
310+ return f"{ name !r} : { typ } "
311+
295312 def visit_none_type (self , t : NoneType ) -> str :
296313 return "None"
297314
@@ -322,6 +339,55 @@ def args_str(self, args: Iterable[Type]) -> str:
322339 res .append (arg_str )
323340 return ", " .join (res )
324341
342+ def visit_type_var (self , t : TypeVarType ) -> str :
343+ return t .name
344+
345+ def visit_uninhabited_type (self , t : UninhabitedType ) -> str :
346+ return self .stubgen .add_name ("typing.Any" )
347+
348+ def visit_erased_type (self , t : ErasedType ) -> str :
349+ return self .stubgen .add_name ("typing.Any" )
350+
351+ def visit_deleted_type (self , t : DeletedType ) -> str :
352+ return self .stubgen .add_name ("typing.Any" )
353+
354+ def visit_instance (self , t : Instance ) -> str :
355+ if t .last_known_value and not t .args :
356+ # Instances with a literal fallback should never be generic. If they are,
357+ # something went wrong so we fall back to showing the full Instance repr.
358+ s = f"{ t .last_known_value .accept (self )} "
359+ else :
360+ s = t .type .fullname or t .type .name or self .stubgen .add_name ("_typeshed.Incomplete" )
361+
362+ s = self .track_imports (s ) or s
363+
364+ if t .args :
365+ if t .type .fullname == "builtins.tuple" :
366+ assert len (t .args ) == 1
367+ s += f"[{ self .list_str (t .args )} , ...]"
368+ else :
369+ s += f"[{ self .list_str (t .args )} ]"
370+ elif t .type .has_type_var_tuple_type and len (t .type .type_vars ) == 1 :
371+ s += "[()]"
372+
373+ return s
374+
375+ def visit_callable_type (self , t : CallableType ) -> str :
376+ from mypy .suggestions import is_tricky_callable
377+
378+ if is_tricky_callable (t ):
379+ arg_str = "..."
380+ else :
381+ # Note: for default arguments, we just assume that they
382+ # are required. This isn't right, but neither is the
383+ # other thing, and I suspect this will produce more better
384+ # results than falling back to `...`
385+ args = [typ .accept (self ) for typ in t .arg_types ]
386+ arg_str = f"[{ ', ' .join (args )} ]"
387+
388+ callable = self .stubgen .add_name ("typing.Callable" )
389+ return f"{ callable } [{ arg_str } , { t .ret_type .accept (self )} ]"
390+
325391
326392class ClassInfo :
327393 def __init__ (
@@ -454,11 +520,11 @@ class ImportTracker:
454520
455521 def __init__ (self ) -> None :
456522 # module_for['foo'] has the module name where 'foo' was imported from, or None if
457- # 'foo' is a module imported directly;
523+ # 'foo' is a module imported directly;
458524 # direct_imports['foo'] is the module path used when the name 'foo' was added to the
459- # namespace.
525+ # namespace.
460526 # reverse_alias['foo'] is the name that 'foo' had originally when imported with an
461- # alias; examples
527+ # alias; examples
462528 # 'from pkg import mod' ==> module_for['mod'] == 'pkg'
463529 # 'from pkg import mod as m' ==> module_for['m'] == 'pkg'
464530 # ==> reverse_alias['m'] == 'mod'
@@ -618,7 +684,9 @@ def __init__(
618684 include_private : bool = False ,
619685 export_less : bool = False ,
620686 include_docstrings : bool = False ,
687+ known_modules : list [str ] | None = None ,
621688 ) -> None :
689+ self .known_modules = known_modules or []
622690 # Best known value of __all__.
623691 self ._all_ = _all_
624692 self ._include_private = include_private
@@ -839,7 +907,9 @@ def print_annotation(
839907 known_modules : list [str ] | None = None ,
840908 local_modules : list [str ] | None = None ,
841909 ) -> str :
842- printer = AnnotationPrinter (self , known_modules , local_modules )
910+ printer = AnnotationPrinter (
911+ self , known_modules , local_modules or ["builtins" , self .module_name ]
912+ )
843913 return t .accept (printer )
844914
845915 def is_not_in_all (self , name : str ) -> bool :
0 commit comments