-
Notifications
You must be signed in to change notification settings - Fork 1k
/
gen_docs.py
96 lines (83 loc) · 3.14 KB
/
gen_docs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import sys
import os
sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))
)
from megatron.neox_arguments import neox_args, deepspeed_args
from inspect import getmembers, getsource
from dataclasses import field, is_dataclass
from itertools import tee, zip_longest
import pathlib
def pairwise(iterable):
"s -> (s0,s1), (s1,s2), (s2, s3), ..."
a, b = tee(iterable)
next(b, None)
return zip_longest(a, b)
def get_docs(module):
ARGS_CLASSES = getmembers(module, is_dataclass)
results = {}
for name, dcls in ARGS_CLASSES:
assert is_dataclass(dcls)
src = getsource(dcls)
d = dcls()
loc = 0
results[name] = {"doc": d.__doc__.strip(), "attributes": {}}
for cur, _next in pairwise(d.__dataclass_fields__.items()):
field_name, field_def = cur
field_type = field_def.type
if hasattr(field_type, "__name__"):
if field_type.__name__ == "Literal" or field_type.__name__ == "Union":
field_type = field_type
else:
field_type = str(field_type.__name__)
else:
field_type = str(field_type)
field_default = field_def.default
# try to find the field definition
loc = src.find(f" {field_name}:", loc + len(field_name) + 1)
if _next is not None:
next_field_name, _ = _next
# try to find the next field definition
next_loc = src.find(f"{next_field_name}:", loc + len(field_name))
else:
next_loc = len(src)
# try to get the docstring
_src = src[loc:next_loc].strip()
if '"""' in _src:
doc = _src.split('"""')[1].strip()
elif "'''" in _src:
doc = _src.split("'''")[1].strip()
else:
doc = ""
results[name]["attributes"][field_name] = {
"name": field_name,
"type": field_type,
"default": field_default,
"doc": doc,
}
return results
def to_md(docs, intro_str=""):
"""
Writes the docs dictionary to markdown format
"""
lines = []
lines.append(intro_str)
for name, doc in docs.items():
lines.append(f"## {name}")
lines.append(f"{doc['doc']}")
lines.append("")
for field_name, field_def in doc["attributes"].items():
# attribute name and type
lines.append(f"- **{field_name}**: {field_def['type']}")
# default value
lines.append(f" Default = {str(field_def['default'])}")
lines.append(f" {field_def['doc']}")
lines.append("")
return "\n\n".join(lines)
if __name__ == "__main__":
docs = get_docs(neox_args)
docs.update(get_docs(deepspeed_args))
intro_str = """Arguments for gpt-neox. All of the following can be specified in your .yml config file(s):\n"""
md = to_md(docs, intro_str=intro_str)
with open(f"{pathlib.Path(__file__).parent.resolve()}/neox_arguments.md", "w") as f:
f.write(md)