Skip to content

Commit 3303ff8

Browse files
committed
Wrap config module in a wrapper object and type check settings
1 parent 9cc51aa commit 3303ff8

File tree

2 files changed

+54
-3
lines changed

2 files changed

+54
-3
lines changed

varipeps/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# First import config so it is usable in other parts of the module
2-
from . import config
2+
from .config import wrapper as config
33
from .config import config as varipeps_config
44
from .global_state import global_state as varipeps_global_state
55

@@ -26,3 +26,4 @@
2626

2727
del datetime
2828
del tqdm_logging
29+
del jax_config

varipeps/config.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from jax.tree_util import register_pytree_node_class
55

6-
from typing import TypeVar, Tuple, Any, Type
6+
from typing import TypeVar, Tuple, Any, Type, NoReturn
77

88
T_VariPEPS_Config = TypeVar("T_VariPEPS_Config", bound="VariPEPS_Config")
99

@@ -177,7 +177,6 @@ class VariPEPS_Config:
177177
Constant used in Hager-Zhang line search method.
178178
line_search_hager_zhang_rho (:obj:`float`):
179179
Constant used in Hager-Zhang line search method.
180-
181180
basinhopping_niter (:obj:`int`):
182181
Value for parameter `niter` of :obj:`scipy.optimize.basinhopping`.
183182
See this function for details.
@@ -264,6 +263,25 @@ class VariPEPS_Config:
264263
# Spiral PEPS
265264
spiral_wavevector_type: Wavevector_Type = Wavevector_Type.TWO_PI_POSITIVE_ONLY
266265

266+
def update(self, name: str, value: Any) -> NoReturn:
267+
self.__setattr__(name, value)
268+
269+
def __setattr__(self, name: str, value: Any) -> NoReturn:
270+
try:
271+
field = self.__dataclass_fields__[name]
272+
except KeyError as e:
273+
raise KeyError(f"Unknown config option '{name}'.") from e
274+
275+
if not type(value) is field.type:
276+
if field.type is float and type(value) is int:
277+
pass
278+
else:
279+
raise TypeError(
280+
f"Type mismatch for option '{name}', got '{type(value)}', expected '{field.type}'."
281+
)
282+
283+
super().__setattr__(name, value)
284+
267285
def tree_flatten(self) -> Tuple[Tuple[Any, ...], Tuple[Any, ...]]:
268286
aux_data = (
269287
{name: getattr(self, name) for name in self.__dataclass_fields__.keys()},
@@ -283,3 +301,35 @@ def tree_unflatten(
283301

284302

285303
config = VariPEPS_Config()
304+
305+
306+
class ConfigModuleWrapper:
307+
__slots__ = {
308+
"Optimizing_Methods",
309+
"Line_Search_Methods",
310+
"Projector_Method",
311+
"Wavevector_Type",
312+
"VariPEPS_Config",
313+
"config",
314+
}
315+
316+
def __init__(self):
317+
for e in self.__slots__:
318+
setattr(self, e, globals()[e])
319+
320+
def __getattr__(self, name: str) -> Any:
321+
if name.startswith("__") or name in self.__slots__:
322+
return super().__getattr__(name)
323+
else:
324+
return getattr(self.config, name)
325+
326+
def __setattr__(self, name: str, value: Any) -> NoReturn:
327+
if not name.startswith("__") and name not in self.__slots__:
328+
setattr(self.config, name, value)
329+
elif not hasattr(self, name):
330+
super().__setattr__(name, value)
331+
else:
332+
raise AttributeError(f"Attribute '{name}' is write-protected.")
333+
334+
335+
wrapper = ConfigModuleWrapper()

0 commit comments

Comments
 (0)