-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconfig.py
195 lines (159 loc) · 5.23 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
"""PVCNN configuration class and helpers."""
import collections
import importlib.util
import os
import six
from utils.container import AttrDict
__all__ = [
"Config",
"configs",
"update_configs_from_module",
"update_configs_from_arguments",
]
class Config(AttrDict):
"""Holds the parameters needed for training / testing the model."""
def __init__(self, func=None, args=None, keys=None, detach=False, **kwargs):
super().__init__(**kwargs)
if func is not None and not callable(func):
raise TypeError(f"func {repr(func)} is not a callable function or class")
if args is not None and not isinstance(
args, (collections.Sequence, collections.UserList)
):
raise TypeError(f"args {repr(args)} is not an iterable tuple or list")
if keys is not None and not isinstance(
keys, (collections.Sequence, collections.UserList)
):
raise TypeError(f"keys {repr(keys)} is not an iterable tuple or list")
self.__dict__["_func_"] = func
self.__dict__["_args_"] = args
self.__dict__["_detach_"] = detach
self.__dict__["_keys_"] = keys
def __call__(self, *args, **kwargs):
if self._func_ is None:
return self
# override args
if args:
args = list(args)
elif self._args_:
args = list(self._args_)
# override kwargs
for k, v in self.items():
if self._keys_ is None or k in self._keys_:
kwargs.setdefault(k, v)
# recursively call non-detached funcs
queue = collections.deque([args, kwargs])
while queue:
x = queue.popleft()
if isinstance(
x, (collections.Sequence, collections.UserList)
) and not isinstance(x, six.string_types):
items = enumerate(x)
elif isinstance(x, (collections.Mapping, collections.UserDict)):
items = x.items()
else:
items = []
for k, v in items:
if isinstance(v, tuple):
v = x[k] = list(v)
elif isinstance(v, Config):
if v._detach_:
continue
v = x[k] = v()
queue.append(v)
return self._func_(*args, **kwargs)
def __str__(self, indent=0):
text = ""
if self._func_ is not None:
text += " " * indent + "[func] = " + str(self._func_)
extra = False
if self._detach_:
text += " (detach=" + str(self._detach_)
extra = True
if self._keys_:
text += ", " if extra else " ("
text += "keys=" + str(self._keys_)
extra = True
text += ")\n" if extra else "\n"
if self._args_:
for k, v in enumerate(self._args_):
text += " " * indent + "[args:" + str(k) + "] = " + str(v) + "\n"
for k, v in self.items():
text += " " * indent + "[" + str(k) + "]"
if isinstance(v, Config):
text += "\n" + v.__str__(indent + 2)
else:
text += " = " + str(v)
text += "\n"
while text and text[-1] == "\n":
text = text[:-1]
return text
def __repr__(self):
text = ""
if self._func_ is not None:
text += repr(self._func_)
items = []
if self._func_ is not None and self._args_:
items += [repr(v) for v in self._args_]
items += [str(k) + "=" + repr(v) for k, v in self.items()]
if self._func_ is not None and self._detach_:
items += ["detach=" + repr(self._detach_)]
text += "(" + ", ".join(items) + ")"
return text
@staticmethod
def update_from_modules(*modules):
for module in modules:
module = module.replace(".py", "").replace("/", ".")
importlib.import_module(module)
@staticmethod
def update_from_arguments(*args):
update_configs_from_arguments(args)
configs = Config()
def update_configs_from_module(*mods):
"""Update the configs from the imported modules."""
imported_modules = set()
# from https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path
def exec_module_once(module):
if module in imported_modules:
return
imported_modules.add(module)
spec = importlib.util.spec_from_file_location(
os.path.basename(module), module
)
foo = importlib.util.module_from_spec(spec)
spec.loader.exec_module(foo)
for mod in mods:
mod = os.path.normpath(mod)
for index, char in enumerate(mod):
if index == 0 or char == os.sep:
submod = os.path.join(mod[:index], "__init__.py")
if os.path.exists(submod):
exec_module_once(submod)
exec_module_once(mod)
def update_configs_from_arguments(args: list):
"""Update the configs from a list of arguments."""
index = 0
while index < len(args):
arg = args[index]
if arg.startswith("--configs."):
arg = arg.replace("--configs.", "")
else:
raise Exception(f"unrecognized argument {arg}")
if "=" in arg:
index, keys, val = (
index + 1,
arg[: arg.index("=")].split("."),
arg[arg.index("=") + 1 :],
)
else:
index, keys, val = index + 2, arg.split("."), args[index + 1]
config = configs
for k in keys[:-1]:
if k not in config:
config[k] = Config()
config = config[k]
def parse(x):
if (x[0] == "'" and x[-1] == "'") or (x[0] == '"' and x[-1] == '"'):
return x[1:-1]
x = eval(x) # pylint: disable=eval-used
return x
config[keys[-1]] = parse(val)