21
21
from mypy .stubdoc import ArgSig , FunctionSig
22
22
from mypy .types import (
23
23
AnyType ,
24
+ CallableType ,
25
+ DeletedType ,
26
+ ErasedType ,
27
+ Instance ,
24
28
NoneType ,
25
29
Type ,
26
30
TypeList ,
27
31
TypeStrVisitor ,
32
+ TypeVarType ,
28
33
UnboundType ,
34
+ UninhabitedType ,
29
35
UnionType ,
30
36
UnpackType ,
31
37
)
@@ -251,6 +257,23 @@ def __init__(
251
257
self .known_modules = known_modules
252
258
self .local_modules = local_modules or ["builtins" ]
253
259
260
+ def track_imports (self , s : str ) -> str | None :
261
+ if self .known_modules is not None and "." in s :
262
+ # see if this object is from any of the modules that we're currently processing.
263
+ # reverse sort so that subpackages come before parents: e.g. "foo.bar" before "foo".
264
+ for module_name in self .local_modules + sorted (self .known_modules , reverse = True ):
265
+ if s .startswith (module_name + "." ):
266
+ if module_name in self .local_modules :
267
+ s = s [len (module_name ) + 1 :]
268
+ arg_module = module_name
269
+ break
270
+ else :
271
+ arg_module = s [: s .rindex ("." )]
272
+ if arg_module not in self .local_modules :
273
+ self .stubgen .import_tracker .add_import (arg_module , require = True )
274
+ return s
275
+ return None
276
+
254
277
def visit_any (self , t : AnyType ) -> str :
255
278
s = super ().visit_any (t )
256
279
self .stubgen .import_tracker .require_name (s )
@@ -267,19 +290,9 @@ def visit_unbound_type(self, t: UnboundType) -> str:
267
290
return self .stubgen .add_name ("_typeshed.Incomplete" )
268
291
if fullname in TYPING_BUILTIN_REPLACEMENTS :
269
292
s = self .stubgen .add_name (TYPING_BUILTIN_REPLACEMENTS [fullname ], require = True )
270
- if self .known_modules is not None and "." in s :
271
- # see if this object is from any of the modules that we're currently processing.
272
- # reverse sort so that subpackages come before parents: e.g. "foo.bar" before "foo".
273
- for module_name in self .local_modules + sorted (self .known_modules , reverse = True ):
274
- if s .startswith (module_name + "." ):
275
- if module_name in self .local_modules :
276
- s = s [len (module_name ) + 1 :]
277
- arg_module = module_name
278
- break
279
- else :
280
- arg_module = s [: s .rindex ("." )]
281
- if arg_module not in self .local_modules :
282
- self .stubgen .import_tracker .add_import (arg_module , require = True )
293
+
294
+ if new_s := self .track_imports (s ):
295
+ s = new_s
283
296
elif s == "NoneType" :
284
297
# when called without analysis all types are unbound, so this won't hit
285
298
# visit_none_type().
@@ -322,6 +335,55 @@ def args_str(self, args: Iterable[Type]) -> str:
322
335
res .append (arg_str )
323
336
return ", " .join (res )
324
337
338
+ def visit_type_var (self , t : TypeVarType ) -> str :
339
+ return t .name
340
+
341
+ def visit_uninhabited_type (self , t : UninhabitedType ) -> str :
342
+ return self .stubgen .add_name ("typing.Any" )
343
+
344
+ def visit_erased_type (self , t : ErasedType ) -> str :
345
+ return self .stubgen .add_name ("typing.Any" )
346
+
347
+ def visit_deleted_type (self , t : DeletedType ) -> str :
348
+ return self .stubgen .add_name ("typing.Any" )
349
+
350
+ def visit_instance (self , t : Instance ) -> str :
351
+ if t .last_known_value and not t .args :
352
+ # Instances with a literal fallback should never be generic. If they are,
353
+ # something went wrong so we fall back to showing the full Instance repr.
354
+ s = f"{ t .last_known_value .accept (self )} "
355
+ else :
356
+ s = t .type .fullname or t .type .name or self .stubgen .add_name ("_typeshed.Incomplete" )
357
+
358
+ s = self .track_imports (s ) or s
359
+
360
+ if t .args :
361
+ if t .type .fullname == "builtins.tuple" :
362
+ assert len (t .args ) == 1
363
+ s += f"[{ self .list_str (t .args )} , ...]"
364
+ else :
365
+ s += f"[{ self .list_str (t .args )} ]"
366
+ elif t .type .has_type_var_tuple_type and len (t .type .type_vars ) == 1 :
367
+ s += "[()]"
368
+
369
+ return s
370
+
371
+ def visit_callable_type (self , t : CallableType ) -> str :
372
+ from mypy .suggestions import is_tricky_callable
373
+
374
+ if is_tricky_callable (t ):
375
+ arg_str = "..."
376
+ else :
377
+ # Note: for default arguments, we just assume that they
378
+ # are required. This isn't right, but neither is the
379
+ # other thing, and I suspect this will produce more better
380
+ # results than falling back to `...`
381
+ args = [typ .accept (self ) for typ in t .arg_types ]
382
+ arg_str = f"[{ ', ' .join (args )} ]"
383
+
384
+ callable = self .stubgen .add_name ("typing.Callable" )
385
+ return f"{ callable } [{ arg_str } , { t .ret_type .accept (self )} ]"
386
+
325
387
326
388
class ClassInfo :
327
389
def __init__ (
@@ -454,11 +516,11 @@ class ImportTracker:
454
516
455
517
def __init__ (self ) -> None :
456
518
# module_for['foo'] has the module name where 'foo' was imported from, or None if
457
- # 'foo' is a module imported directly;
519
+ # 'foo' is a module imported directly;
458
520
# direct_imports['foo'] is the module path used when the name 'foo' was added to the
459
- # namespace.
521
+ # namespace.
460
522
# reverse_alias['foo'] is the name that 'foo' had originally when imported with an
461
- # alias; examples
523
+ # alias; examples
462
524
# 'from pkg import mod' ==> module_for['mod'] == 'pkg'
463
525
# 'from pkg import mod as m' ==> module_for['m'] == 'pkg'
464
526
# ==> reverse_alias['m'] == 'mod'
@@ -618,7 +680,9 @@ def __init__(
618
680
include_private : bool = False ,
619
681
export_less : bool = False ,
620
682
include_docstrings : bool = False ,
683
+ known_modules : list [str ] | None = None ,
621
684
) -> None :
685
+ self .known_modules = known_modules or []
622
686
# Best known value of __all__.
623
687
self ._all_ = _all_
624
688
self ._include_private = include_private
@@ -839,7 +903,9 @@ def print_annotation(
839
903
known_modules : list [str ] | None = None ,
840
904
local_modules : list [str ] | None = None ,
841
905
) -> str :
842
- printer = AnnotationPrinter (self , known_modules , local_modules )
906
+ printer = AnnotationPrinter (
907
+ self , known_modules , local_modules or ["builtins" , self .module_name ]
908
+ )
843
909
return t .accept (printer )
844
910
845
911
def is_not_in_all (self , name : str ) -> bool :
0 commit comments