Skip to content

Commit 27f7b65

Browse files
XuehaiPanpytorchmergebot
authored andcommitted
[BE] Ensure generated stub files by gen_pyi are properly formatted (pytorch#150730)
Pull Request resolved: pytorch#150730 Approved by: https://github.com/aorenste
1 parent 7ebea09 commit 27f7b65

File tree

5 files changed

+84
-31
lines changed

5 files changed

+84
-31
lines changed

tools/pyi/gen_pyi.py

+38-12
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727
import argparse
2828
import collections
2929
import importlib
30+
import inspect
3031
import sys
32+
import textwrap
3133
from typing import TYPE_CHECKING
3234
from unittest.mock import Mock, patch
3335
from warnings import warn
@@ -132,7 +134,7 @@ def should_bind_method(python_func: PythonSignatureNativeFunctionPair) -> bool:
132134
"_bool | _int | slice | EllipsisType | Tensor | None" # not SupportsIndex!
133135
)
134136
_index_types = f"SupportsIndex | {_leaf_types} | _NestedSequence[{_leaf_types}]"
135-
_index_type_def = f"_Index: TypeAlias = {_index_types}"
137+
_index_type_def = f"_Index: TypeAlias = {_index_types} # fmt: skip"
136138
INDICES = "indices: _Index | tuple[_Index, ...]"
137139

138140
blocklist = [
@@ -252,6 +254,11 @@ def sig_for_ops(opname: str) -> list[str]:
252254
f"def {opname}(self, other: Tensor | Number | _complex) -> Tensor: ... # type: ignore[has-type]"
253255
]
254256
elif name in arithmetic_ops:
257+
if name.startswith("i"):
258+
# In-place binary-operation dunder methods, like `__iadd__`, should return `Self`
259+
return [
260+
f"def {opname}(self, other: Tensor | Number | _complex) -> Tensor: ... # noqa: PYI034"
261+
]
255262
return [f"def {opname}(self, other: Tensor | Number | _complex) -> Tensor: ..."]
256263
elif name in logic_ops:
257264
return [f"def {opname}(self, other: Tensor | _bool) -> Tensor: ..."]
@@ -327,14 +334,29 @@ def get_max_pool_dispatch(name: str, arg_list: list[str]) -> dict[str, list[str]
327334
arg_list_keyword.insert(flag_pos, "*")
328335
return {
329336
name: [
330-
defs(name, arg_list, "Tensor").format(
331-
return_indices="return_indices: Literal[False] = False",
337+
defs(
338+
name,
339+
[
340+
arg.format(return_indices="return_indices: Literal[False] = False")
341+
for arg in arg_list
342+
],
343+
"Tensor",
332344
),
333-
defs(name, arg_list_positional, "tuple[Tensor, Tensor]").format(
334-
return_indices="return_indices: Literal[True]",
345+
defs(
346+
name,
347+
[
348+
arg.format(return_indices="return_indices: Literal[True]")
349+
for arg in arg_list_positional
350+
],
351+
"tuple[Tensor, Tensor]",
335352
),
336-
defs(name, arg_list_keyword, "tuple[Tensor, Tensor]").format(
337-
return_indices="return_indices: Literal[True]",
353+
defs(
354+
name,
355+
[
356+
arg.format(return_indices="return_indices: Literal[True]")
357+
for arg in arg_list_keyword
358+
],
359+
"tuple[Tensor, Tensor]",
338360
),
339361
]
340362
}
@@ -669,12 +691,16 @@ def mock_add_docstr(func: Mock, docstr: str) -> None:
669691

670692

671693
def add_docstr_to_hint(docstr: str, hint: str) -> str:
694+
docstr = inspect.cleandoc(docstr).strip()
672695
if "..." in hint: # function or method
673696
assert hint.endswith("..."), f"Hint `{hint}` does not end with '...'"
674-
hint = hint[:-3] # remove "..."
675-
return "\n ".join([hint, 'r"""'] + docstr.split("\n") + ['"""', "..."])
676-
else: # attribute or property
677-
return f'{hint}\nr"""{docstr}"""\n'
697+
hint = hint.removesuffix("...").rstrip() # remove "..."
698+
content = hint + "\n" + textwrap.indent(f'r"""\n{docstr}\n"""', prefix=" ")
699+
# Remove trailing whitespace on each line
700+
return "\n".join(map(str.rstrip, content.splitlines())).rstrip()
701+
702+
# attribute or property
703+
return f'{hint}\nr"""{docstr}"""'
678704

679705

680706
def gen_pyi(
@@ -1557,7 +1583,7 @@ def replace_special_case(hint: str) -> str:
15571583
# Generate type signatures for legacy classes
15581584
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
15591585

1560-
legacy_storage_base_hints = ["class StorageBase(object): ..."]
1586+
legacy_storage_base_hints = ["class StorageBase: ..."]
15611587

15621588
legacy_class_hints = []
15631589
for c in (

torch/_tensor_docs.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -3163,7 +3163,10 @@ def callable(a, b) -> number
31633163
Example:
31643164
31653165
>>> self = torch.tensor([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]])
3166-
>>> mask = torch.tensor([[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]], dtype=torch.bool)
3166+
>>> mask = torch.tensor(
3167+
... [[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]],
3168+
... dtype=torch.bool,
3169+
... )
31673170
>>> source = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
31683171
>>> self.masked_scatter_(mask, source)
31693172
tensor([[0, 0, 0, 0, 1],
@@ -3645,7 +3648,7 @@ def callable(a, b) -> number
36453648
# Example 1: Padding
36463649
>>> input_tensor = torch.tensor([[1, 0], [3, 2]])
36473650
>>> static_size = 4
3648-
>>> t = torch.nonzero_static(input_tensor, size = static_size)
3651+
>>> t = torch.nonzero_static(input_tensor, size=static_size)
36493652
tensor([[ 0, 0],
36503653
[ 1, 0],
36513654
[ 1, 1],
@@ -3654,20 +3657,20 @@ def callable(a, b) -> number
36543657
# Example 2: Truncating
36553658
>>> input_tensor = torch.tensor([[1, 0], [3, 2]])
36563659
>>> static_size = 2
3657-
>>> t = torch.nonzero_static(input_tensor, size = static_size)
3660+
>>> t = torch.nonzero_static(input_tensor, size=static_size)
36583661
tensor([[ 0, 0],
36593662
[ 1, 0]], dtype=torch.int64)
36603663
36613664
# Example 3: 0 size
36623665
>>> input_tensor = torch.tensor([10])
36633666
>>> static_size = 0
3664-
>>> t = torch.nonzero_static(input_tensor, size = static_size)
3667+
>>> t = torch.nonzero_static(input_tensor, size=static_size)
36653668
tensor([], size=(0, 1), dtype=torch.int64)
36663669
36673670
# Example 4: 0 rank input
36683671
>>> input_tensor = torch.tensor(10)
36693672
>>> static_size = 2
3670-
>>> t = torch.nonzero_static(input_tensor, size = static_size)
3673+
>>> t = torch.nonzero_static(input_tensor, size=static_size)
36713674
tensor([], size=(2, 0), dtype=torch.int64)
36723675
""",
36733676
)
@@ -6561,7 +6564,10 @@ def callable(a, b) -> number
65616564
Example:
65626565
65636566
>>> self = torch.tensor([0, 0, 0, 0, 0])
6564-
>>> mask = torch.tensor([[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]], dtype=torch.bool)
6567+
>>> mask = torch.tensor(
6568+
... [[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]],
6569+
... dtype=torch.bool,
6570+
... )
65656571
>>> source = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
65666572
>>> self.masked_scatter(mask, source)
65676573
tensor([[0, 0, 0, 0, 1],

torch/_torch_docs.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -10550,7 +10550,8 @@ def merge_dicts(*dicts):
1055010550
... [[ 0.2035, 1.2959, 1.8101, -0.4644],
1055110551
... [ 1.5027, -0.3270, 0.5905, 0.6538],
1055210552
... [-1.5745, 1.3330, -0.5596, -0.6548],
10553-
... [ 0.1264, -0.5080, 1.6420, 0.1992]])
10553+
... [ 0.1264, -0.5080, 1.6420, 0.1992]]
10554+
... ) # fmt: skip
1055410555
>>> torch.std(a, dim=1, keepdim=True)
1055510556
tensor([[1.0311],
1055610557
[0.7477],
@@ -10608,7 +10609,8 @@ def merge_dicts(*dicts):
1060810609
... [[ 0.2035, 1.2959, 1.8101, -0.4644],
1060910610
... [ 1.5027, -0.3270, 0.5905, 0.6538],
1061010611
... [-1.5745, 1.3330, -0.5596, -0.6548],
10611-
... [ 0.1264, -0.5080, 1.6420, 0.1992]])
10612+
... [ 0.1264, -0.5080, 1.6420, 0.1992]]
10613+
... ) # fmt: skip
1061210614
>>> torch.std_mean(a, dim=0, keepdim=True)
1061310615
(tensor([[1.2620, 1.0028, 1.0957, 0.6038]]),
1061410616
tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]]))
@@ -11896,7 +11898,8 @@ def merge_dicts(*dicts):
1189611898
... [[ 0.2035, 1.2959, 1.8101, -0.4644],
1189711899
... [ 1.5027, -0.3270, 0.5905, 0.6538],
1189811900
... [-1.5745, 1.3330, -0.5596, -0.6548],
11899-
... [ 0.1264, -0.5080, 1.6420, 0.1992]])
11901+
... [ 0.1264, -0.5080, 1.6420, 0.1992]]
11902+
... ) # fmt: skip
1190011903
>>> torch.var(a, dim=1, keepdim=True)
1190111904
tensor([[1.0631],
1190211905
[0.5590],
@@ -11953,7 +11956,8 @@ def merge_dicts(*dicts):
1195311956
... [[ 0.2035, 1.2959, 1.8101, -0.4644],
1195411957
... [ 1.5027, -0.3270, 0.5905, 0.6538],
1195511958
... [-1.5745, 1.3330, -0.5596, -0.6548],
11956-
... [ 0.1264, -0.5080, 1.6420, 0.1992]])
11959+
... [ 0.1264, -0.5080, 1.6420, 0.1992]]
11960+
... ) # fmt: skip
1195711961
>>> torch.var_mean(a, dim=0, keepdim=True)
1195811962
(tensor([[1.5926, 1.0056, 1.2005, 0.3646]]),
1195911963
tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]]))

torchgen/api/python.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -212,13 +212,17 @@ def format_function_signature(
212212
if len(sig) <= 80 or len(arguments) == 0 or tuple(arguments) == ("self",):
213213
return sig
214214

215+
arguments = [f" {arg}," for arg in arguments]
215216
return "\n".join(
216217
(
217218
f"def {name}(",
218-
*(f" {arg}," for arg in arguments),
219+
*(
220+
arg if len(arg) <= 80 else f" # fmt: off\n{arg}\n # fmt: on"
221+
for arg in arguments
222+
),
219223
f"){return_type}: ...",
220224
)
221-
)
225+
).replace(" # fmt: off\n # fmt: on\n", "")
222226

223227

224228
@dataclass(frozen=True)
@@ -1029,7 +1033,7 @@ def returns_structseq_pyi(signature: PythonSignature) -> tuple[str, str] | None:
10291033
# does not allow us to override __init__.
10301034
seq_type = f"tuple[{', '.join(python_returns)}]"
10311035
structseq_def_lines = [
1032-
f"class {structseq_name}({seq_type}):",
1036+
f"class {structseq_name}({seq_type}): # fmt: skip",
10331037
]
10341038
for name, ret_type in zip(field_names, python_returns):
10351039
structseq_def_lines.extend(
@@ -1040,7 +1044,11 @@ def returns_structseq_pyi(signature: PythonSignature) -> tuple[str, str] | None:
10401044
)
10411045
structseq_def_lines.extend(
10421046
[
1043-
f" def __new__(cls, sequence: {seq_type}) -> Self: ...",
1047+
" def __new__(",
1048+
" cls,",
1049+
f" sequence: {seq_type},",
1050+
" ) -> Self: # fmt: skip",
1051+
" ...",
10441052
f" n_fields: Final[_int] = {len(field_names)}",
10451053
f" n_sequence_fields: Final[_int] = {len(field_names)}",
10461054
" n_unnamed_fields: Final[_int] = 0",
@@ -1051,12 +1059,16 @@ def returns_structseq_pyi(signature: PythonSignature) -> tuple[str, str] | None:
10511059
structseq_def = "\n".join(structseq_def_lines)
10521060
# Example:
10531061
# structseq_def = (
1054-
# "class max(tuple[Tensor, Tensor]):\n"
1062+
# "class max(tuple[Tensor, Tensor]): # fmt: skip\n"
10551063
# " @property\n"
10561064
# " def values(self) -> Tensor: ...\n"
10571065
# " @property\n"
10581066
# " def indices(self) -> Tensor: ...\n"
1059-
# " def __new__(cls, sequence: tuple[Tensor, Tensor]) -> Self: ...\n"
1067+
# " def __new__(\n"
1068+
# " cls,\n"
1069+
# " sequence: tuple[Tensor, Tensor],\n"
1070+
# " ) -> Self: # fmt: skip\n"
1071+
# " ...\n"
10601072
# " n_fields: Final[_int] = 2",
10611073
# " n_sequence_fields: Final[_int] = 2",
10621074
# " n_unnamed_fields: Final[_int] = 0",

torchgen/code_template.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from __future__ import annotations
22

3+
import itertools
34
import re
5+
import textwrap
46
from typing import TYPE_CHECKING
57

68

@@ -45,9 +47,12 @@ def lookup(v: str) -> object:
4547
return kwargs[v] if v in kwargs else env[v]
4648

4749
def indent_lines(indent: str, v: Sequence[object]) -> str:
48-
return "".join(
49-
[indent + l + "\n" for e in v for l in str(e).splitlines()]
50-
).rstrip()
50+
content = "\n".join(
51+
itertools.chain.from_iterable(str(e).splitlines() for e in v)
52+
)
53+
content = textwrap.indent(content, prefix=indent)
54+
# Remove trailing whitespace on each line
55+
return "\n".join(map(str.rstrip, content.splitlines())).rstrip()
5156

5257
def replace(match: re.Match[str]) -> str:
5358
indent = match.group(1)

0 commit comments

Comments
 (0)