21
21
from mypy .stubdoc import ArgSig , FunctionSig
22
22
from mypy .types import (
23
23
AnyType ,
24
+ CallableType ,
25
+ DeletedType ,
26
+ ErasedType ,
27
+ Instance ,
24
28
NoneType ,
25
29
Type ,
30
+ TypedDictType ,
26
31
TypeList ,
27
32
TypeStrVisitor ,
33
+ TypeVarType ,
28
34
UnboundType ,
35
+ UninhabitedType ,
29
36
UnionType ,
30
37
UnpackType ,
31
38
)
@@ -251,6 +258,23 @@ def __init__(
251
258
self .known_modules = known_modules
252
259
self .local_modules = local_modules or ["builtins" ]
253
260
261
+ def track_imports (self , s : str ) -> str | None :
262
+ if self .known_modules is not None and "." in s :
263
+ # see if this object is from any of the modules that we're currently processing.
264
+ # reverse sort so that subpackages come before parents: e.g. "foo.bar" before "foo".
265
+ for module_name in self .local_modules + sorted (self .known_modules , reverse = True ):
266
+ if s .startswith (module_name + "." ):
267
+ if module_name in self .local_modules :
268
+ s = s [len (module_name ) + 1 :]
269
+ arg_module = module_name
270
+ break
271
+ else :
272
+ arg_module = s [: s .rindex ("." )]
273
+ if arg_module not in self .local_modules :
274
+ self .stubgen .import_tracker .add_import (arg_module , require = True )
275
+ return s
276
+ return None
277
+
254
278
def visit_any (self , t : AnyType ) -> str :
255
279
s = super ().visit_any (t )
256
280
self .stubgen .import_tracker .require_name (s )
@@ -267,19 +291,9 @@ def visit_unbound_type(self, t: UnboundType) -> str:
267
291
return self .stubgen .add_name ("_typeshed.Incomplete" )
268
292
if fullname in TYPING_BUILTIN_REPLACEMENTS :
269
293
s = self .stubgen .add_name (TYPING_BUILTIN_REPLACEMENTS [fullname ], require = True )
270
- if self .known_modules is not None and "." in s :
271
- # see if this object is from any of the modules that we're currently processing.
272
- # reverse sort so that subpackages come before parents: e.g. "foo.bar" before "foo".
273
- for module_name in self .local_modules + sorted (self .known_modules , reverse = True ):
274
- if s .startswith (module_name + "." ):
275
- if module_name in self .local_modules :
276
- s = s [len (module_name ) + 1 :]
277
- arg_module = module_name
278
- break
279
- else :
280
- arg_module = s [: s .rindex ("." )]
281
- if arg_module not in self .local_modules :
282
- self .stubgen .import_tracker .add_import (arg_module , require = True )
294
+
295
+ if new_s := self .track_imports (s ):
296
+ s = new_s
283
297
elif s == "NoneType" :
284
298
# when called without analysis all types are unbound, so this won't hit
285
299
# visit_none_type().
@@ -292,6 +306,9 @@ def visit_unbound_type(self, t: UnboundType) -> str:
292
306
s += "[()]"
293
307
return s
294
308
309
+ def typeddict_item_str (self , t : TypedDictType , name : str , typ : str ) -> str :
310
+ return f"{ name !r} : { typ } "
311
+
295
312
def visit_none_type (self , t : NoneType ) -> str :
296
313
return "None"
297
314
@@ -322,6 +339,55 @@ def args_str(self, args: Iterable[Type]) -> str:
322
339
res .append (arg_str )
323
340
return ", " .join (res )
324
341
342
+ def visit_type_var (self , t : TypeVarType ) -> str :
343
+ return t .name
344
+
345
+ def visit_uninhabited_type (self , t : UninhabitedType ) -> str :
346
+ return self .stubgen .add_name ("typing.Any" )
347
+
348
+ def visit_erased_type (self , t : ErasedType ) -> str :
349
+ return self .stubgen .add_name ("typing.Any" )
350
+
351
+ def visit_deleted_type (self , t : DeletedType ) -> str :
352
+ return self .stubgen .add_name ("typing.Any" )
353
+
354
+ def visit_instance (self , t : Instance ) -> str :
355
+ if t .last_known_value and not t .args :
356
+ # Instances with a literal fallback should never be generic. If they are,
357
+ # something went wrong so we fall back to showing the full Instance repr.
358
+ s = f"{ t .last_known_value .accept (self )} "
359
+ else :
360
+ s = t .type .fullname or t .type .name or self .stubgen .add_name ("_typeshed.Incomplete" )
361
+
362
+ s = self .track_imports (s ) or s
363
+
364
+ if t .args :
365
+ if t .type .fullname == "builtins.tuple" :
366
+ assert len (t .args ) == 1
367
+ s += f"[{ self .list_str (t .args )} , ...]"
368
+ else :
369
+ s += f"[{ self .list_str (t .args )} ]"
370
+ elif t .type .has_type_var_tuple_type and len (t .type .type_vars ) == 1 :
371
+ s += "[()]"
372
+
373
+ return s
374
+
375
+ def visit_callable_type (self , t : CallableType ) -> str :
376
+ from mypy .suggestions import is_tricky_callable
377
+
378
+ if is_tricky_callable (t ):
379
+ arg_str = "..."
380
+ else :
381
+ # Note: for default arguments, we just assume that they
382
+ # are required. This isn't right, but neither is the
383
+ # other thing, and I suspect this will produce more better
384
+ # results than falling back to `...`
385
+ args = [typ .accept (self ) for typ in t .arg_types ]
386
+ arg_str = f"[{ ', ' .join (args )} ]"
387
+
388
+ callable = self .stubgen .add_name ("typing.Callable" )
389
+ return f"{ callable } [{ arg_str } , { t .ret_type .accept (self )} ]"
390
+
325
391
326
392
class ClassInfo :
327
393
def __init__ (
@@ -454,11 +520,11 @@ class ImportTracker:
454
520
455
521
def __init__ (self ) -> None :
456
522
# module_for['foo'] has the module name where 'foo' was imported from, or None if
457
- # 'foo' is a module imported directly;
523
+ # 'foo' is a module imported directly;
458
524
# direct_imports['foo'] is the module path used when the name 'foo' was added to the
459
- # namespace.
525
+ # namespace.
460
526
# reverse_alias['foo'] is the name that 'foo' had originally when imported with an
461
- # alias; examples
527
+ # alias; examples
462
528
# 'from pkg import mod' ==> module_for['mod'] == 'pkg'
463
529
# 'from pkg import mod as m' ==> module_for['m'] == 'pkg'
464
530
# ==> reverse_alias['m'] == 'mod'
@@ -618,7 +684,9 @@ def __init__(
618
684
include_private : bool = False ,
619
685
export_less : bool = False ,
620
686
include_docstrings : bool = False ,
687
+ known_modules : list [str ] | None = None ,
621
688
) -> None :
689
+ self .known_modules = known_modules or []
622
690
# Best known value of __all__.
623
691
self ._all_ = _all_
624
692
self ._include_private = include_private
@@ -839,7 +907,9 @@ def print_annotation(
839
907
known_modules : list [str ] | None = None ,
840
908
local_modules : list [str ] | None = None ,
841
909
) -> str :
842
- printer = AnnotationPrinter (self , known_modules , local_modules )
910
+ printer = AnnotationPrinter (
911
+ self , known_modules , local_modules or ["builtins" , self .module_name ]
912
+ )
843
913
return t .accept (printer )
844
914
845
915
def is_not_in_all (self , name : str ) -> bool :
0 commit comments