3
3
4
4
from jax .tree_util import register_pytree_node_class
5
5
6
- from typing import TypeVar , Tuple , Any , Type
6
+ from typing import TypeVar , Tuple , Any , Type , NoReturn
7
7
8
8
T_VariPEPS_Config = TypeVar ("T_VariPEPS_Config" , bound = "VariPEPS_Config" )
9
9
@@ -177,7 +177,6 @@ class VariPEPS_Config:
177
177
Constant used in Hager-Zhang line search method.
178
178
line_search_hager_zhang_rho (:obj:`float`):
179
179
Constant used in Hager-Zhang line search method.
180
-
181
180
basinhopping_niter (:obj:`int`):
182
181
Value for parameter `niter` of :obj:`scipy.optimize.basinhopping`.
183
182
See this function for details.
@@ -264,6 +263,25 @@ class VariPEPS_Config:
264
263
# Spiral PEPS
265
264
spiral_wavevector_type : Wavevector_Type = Wavevector_Type .TWO_PI_POSITIVE_ONLY
266
265
266
+ def update (self , name : str , value : Any ) -> NoReturn :
267
+ self .__setattr__ (name , value )
268
+
269
+ def __setattr__ (self , name : str , value : Any ) -> NoReturn :
270
+ try :
271
+ field = self .__dataclass_fields__ [name ]
272
+ except KeyError as e :
273
+ raise KeyError (f"Unknown config option '{ name } '." ) from e
274
+
275
+ if not type (value ) is field .type :
276
+ if field .type is float and type (value ) is int :
277
+ pass
278
+ else :
279
+ raise TypeError (
280
+ f"Type mismatch for option '{ name } ', got '{ type (value )} ', expected '{ field .type } '."
281
+ )
282
+
283
+ super ().__setattr__ (name , value )
284
+
267
285
def tree_flatten (self ) -> Tuple [Tuple [Any , ...], Tuple [Any , ...]]:
268
286
aux_data = (
269
287
{name : getattr (self , name ) for name in self .__dataclass_fields__ .keys ()},
@@ -283,3 +301,35 @@ def tree_unflatten(
283
301
284
302
285
303
config = VariPEPS_Config ()
304
+
305
+
306
+ class ConfigModuleWrapper :
307
+ __slots__ = {
308
+ "Optimizing_Methods" ,
309
+ "Line_Search_Methods" ,
310
+ "Projector_Method" ,
311
+ "Wavevector_Type" ,
312
+ "VariPEPS_Config" ,
313
+ "config" ,
314
+ }
315
+
316
+ def __init__ (self ):
317
+ for e in self .__slots__ :
318
+ setattr (self , e , globals ()[e ])
319
+
320
+ def __getattr__ (self , name : str ) -> Any :
321
+ if name .startswith ("__" ) or name in self .__slots__ :
322
+ return super ().__getattr__ (name )
323
+ else :
324
+ return getattr (self .config , name )
325
+
326
+ def __setattr__ (self , name : str , value : Any ) -> NoReturn :
327
+ if not name .startswith ("__" ) and name not in self .__slots__ :
328
+ setattr (self .config , name , value )
329
+ elif not hasattr (self , name ):
330
+ super ().__setattr__ (name , value )
331
+ else :
332
+ raise AttributeError (f"Attribute '{ name } ' is write-protected." )
333
+
334
+
335
+ wrapper = ConfigModuleWrapper ()
0 commit comments