6
6
import textwrap
7
7
import types
8
8
from ast import FunctionDef , Module , stmt
9
+ from dataclasses import dataclass
9
10
from functools import lru_cache
10
11
from typing import Any , AnyStr , Callable , ForwardRef , NewType , TypeVar , get_type_hints
11
12
13
+ from docutils .frontend import OptionParser
14
+ from docutils .nodes import Node
15
+ from docutils .parsers .rst import Parser as RstParser
16
+ from docutils .utils import new_document
12
17
from sphinx .application import Sphinx
13
18
from sphinx .config import Config
14
19
from sphinx .environment import BuildEnvironment
@@ -641,6 +646,89 @@ def _inject_signature(
641
646
lines .insert (insert_index , type_annotation )
642
647
643
648
649
+ @dataclass
650
+ class InsertIndexInfo :
651
+ insert_index : int
652
+ found_param : bool = False
653
+ found_return : bool = False
654
+ found_directive : bool = False
655
+
656
+
657
+ # Sphinx allows so many synonyms...
658
+ # See sphinx.domains.python.PyObject
659
+ PARAM_SYNONYMS = ("param " , "parameter " , "arg " , "argument " , "keyword " , "kwarg " , "kwparam " )
660
+
661
+
662
+ def line_before_node (node : Node ) -> int :
663
+ line = node .line
664
+ assert line
665
+ return line - 2
666
+
667
+
668
+ def tag_name (node : Node ) -> str :
669
+ return node .tagname # type:ignore[attr-defined,no-any-return] # noqa: SC200
670
+
671
+
672
+ def get_insert_index (app : Sphinx , lines : list [str ]) -> InsertIndexInfo | None :
673
+ # 1. If there is an existing :rtype: anywhere, don't insert anything.
674
+ if any (line .startswith (":rtype:" ) for line in lines ):
675
+ return None
676
+
677
+ # 2. If there is a :returns: anywhere, either modify that line or insert
678
+ # just before it.
679
+ for at , line in enumerate (lines ):
680
+ if line .startswith ((":return:" , ":returns:" )):
681
+ return InsertIndexInfo (insert_index = at , found_return = True )
682
+
683
+ # 3. Insert after the parameters.
684
+ # To find the parameters, parse as a docutils tree.
685
+ settings = OptionParser (components = (RstParser ,)).get_default_values ()
686
+ settings .env = app .env
687
+ doc = new_document ("" , settings = settings )
688
+ RstParser ().parse ("\n " .join (lines ), doc )
689
+
690
+ # Find a top level child which is a field_list that contains a field whose
691
+ # name starts with one of the PARAM_SYNONYMS. This is the parameter list. We
692
+ # hope there is at most of these.
693
+ for idx , child in enumerate (doc .children ):
694
+ if tag_name (child ) != "field_list" :
695
+ continue
696
+
697
+ if any (c .children [0 ].astext ().startswith (PARAM_SYNONYMS ) for c in child .children ):
698
+ idx = idx
699
+ break
700
+ else :
701
+ idx = - 1
702
+
703
+ if idx == - 1 :
704
+ # No parameters
705
+ pass
706
+ elif idx + 1 < len (doc .children ):
707
+ # Unfortunately docutils only tells us the line numbers that nodes start on,
708
+ # not the range (boo!). So insert before the line before the next sibling.
709
+ at = line_before_node (doc .children [idx + 1 ])
710
+ return InsertIndexInfo (insert_index = at , found_param = True )
711
+ else :
712
+ # No next sibling, insert at end
713
+ return InsertIndexInfo (insert_index = len (lines ), found_param = True )
714
+
715
+ # 4. Insert before examples
716
+ # TODO: Maybe adjust which tags to insert ahead of
717
+ for idx , child in enumerate (doc .children ):
718
+ if tag_name (child ) not in ["literal_block" , "paragraph" , "field_list" ]:
719
+ idx = idx
720
+ break
721
+ else :
722
+ idx = - 1
723
+
724
+ if idx != - 1 :
725
+ at = line_before_node (doc .children [idx ])
726
+ return InsertIndexInfo (insert_index = at , found_directive = True )
727
+
728
+ # 5. Otherwise, insert at end
729
+ return InsertIndexInfo (insert_index = len (lines ))
730
+
731
+
644
732
def _inject_rtype (
645
733
type_hints : dict [str , Any ],
646
734
original_obj : Any ,
@@ -653,37 +741,32 @@ def _inject_rtype(
653
741
return
654
742
if what == "method" and name .endswith (".__init__" ): # avoid adding a return type for data class __init__
655
743
return
744
+ if not app .config .typehints_document_rtype :
745
+ return
746
+
747
+ r = get_insert_index (app , lines )
748
+ if r is None :
749
+ return
750
+
751
+ insert_index = r .insert_index
752
+
753
+ if not app .config .typehints_use_rtype and r .found_return and " -- " in lines [insert_index ]:
754
+ return
755
+
656
756
formatted_annotation = format_annotation (type_hints ["return" ], app .config )
657
- insert_index : int | None = len (lines )
658
- extra_newline = False
659
- for at , line in enumerate (lines ):
660
- if line .startswith (":rtype:" ):
661
- insert_index = None
662
- break
663
- if line .startswith (":return:" ) or line .startswith (":returns:" ):
664
- if " -- " in line and not app .config .typehints_use_rtype :
665
- insert_index = None
666
- break
667
- insert_index = at
668
- elif line .startswith (".." ):
669
- # Make sure that rtype comes before any usage or examples section, with a blank line between.
670
- insert_index = at
671
- extra_newline = True
672
- break
673
757
674
- if insert_index is not None and app .config .typehints_document_rtype :
675
- if insert_index == len (lines ): # ensure that :rtype: doesn't get joined with a paragraph of text
676
- lines .append ("" )
677
- insert_index += 1
678
- if app .config .typehints_use_rtype or insert_index == len (lines ):
679
- line = f":rtype: { formatted_annotation } "
680
- if extra_newline :
681
- lines [insert_index :insert_index ] = [line , "\n " ]
682
- else :
683
- lines .insert (insert_index , line )
684
- else :
685
- line = lines [insert_index ]
686
- lines [insert_index ] = f":return: { formatted_annotation } --{ line [line .find (' ' ):]} "
758
+ if insert_index == len (lines ) and not r .found_param :
759
+ # ensure that :rtype: doesn't get joined with a paragraph of text
760
+ lines .append ("" )
761
+ insert_index += 1
762
+ if app .config .typehints_use_rtype or not r .found_return :
763
+ line = f":rtype: { formatted_annotation } "
764
+ lines .insert (insert_index , line )
765
+ if r .found_directive :
766
+ lines .insert (insert_index + 1 , "" )
767
+ else :
768
+ line = lines [insert_index ]
769
+ lines [insert_index ] = f":return: { formatted_annotation } --{ line [line .find (' ' ):]} "
687
770
688
771
689
772
def validate_config (app : Sphinx , env : BuildEnvironment , docnames : list [str ]) -> None : # noqa: U100
0 commit comments