11
11
import jax .numpy as jnp
12
12
from jax .lax import scan
13
13
from jax .flatten_util import ravel_pytree
14
+ from jax .util import safe_zip
14
15
15
16
from varipeps import varipeps_config , varipeps_global_state
16
17
from varipeps .config import Optimizing_Methods
@@ -213,7 +214,7 @@ def _autosave_wrapper(
213
214
additional_input ,
214
215
):
215
216
auxiliary_data = {
216
- "best_run" : best_run if best_run is not None else 0 ,
217
+ "best_run" : jnp . array ( best_run if best_run is not None else 0 ) ,
217
218
"current_energy" : working_value ,
218
219
}
219
220
@@ -223,16 +224,43 @@ def _autosave_wrapper(
223
224
auxiliary_data [f"step_chi_{ k :d} " ] = step_chi [k ]
224
225
auxiliary_data [f"step_conv_{ k :d} " ] = step_conv [k ]
225
226
227
+ spiral_vectors = None
226
228
if spiral_indices is not None :
227
- for spiral_i in spiral_indices :
228
- auxiliary_data [f"spiral_vector_{ spiral_i :d} " ] = working_tensors [spiral_i ]
229
+ spiral_vectors = [working_tensors [spiral_i ] for spiral_i in spiral_indices ]
230
+
231
+ if any (i .size == 1 for i in spiral_vectors ):
232
+ spiral_vectors_x = additional_input .get ("spiral_vectors_x" )
233
+ spiral_vectors_y = additional_input .get ("spiral_vectors_y" )
234
+ if spiral_vectors_x is not None :
235
+ if isinstance (spiral_vectors_x , jnp .ndarray ):
236
+ spiral_vectors_x = (spiral_vectors_x ,)
237
+ spiral_vectors = tuple (
238
+ jnp .array ((sx , sy ))
239
+ for sx , sy in safe_zip (spiral_vectors_x , spiral_vectors )
240
+ )
241
+ elif spiral_vectors_y is not None :
242
+ if isinstance (spiral_vectors_y , jnp .ndarray ):
243
+ spiral_vectors_y = (spiral_vectors_y ,)
244
+ spiral_vectors = tuple (
245
+ jnp .array ((sx , sy ))
246
+ for sx , sy in safe_zip (spiral_vectors , spiral_vectors_y )
247
+ )
229
248
elif additional_input .get ("spiral_vectors" ) is not None :
230
- add_input_spiral = additional_input .get ("spiral_vectors" )
231
- if isinstance (add_input_spiral , jnp .ndarray ):
232
- add_input_spiral = (add_input_spiral ,)
233
- for spiral_i , elem in enumerate (add_input_spiral ):
234
- spiral_i += 1
235
- auxiliary_data [f"spiral_vector_{ spiral_i :d} " ] = elem
249
+ spiral_vectors = additional_input .get ("spiral_vectors" )
250
+ if isinstance (spiral_vectors , jnp .ndarray ):
251
+ spiral_vectors = (spiral_vectors ,)
252
+
253
+ if spiral_vectors is not None :
254
+ spiral_vectors = [
255
+ e if e .size == 2 else jnp .array ((e , e )).reshape (2 ) for e in spiral_vectors
256
+ ]
257
+
258
+ if len (spiral_vectors ) == 1 :
259
+ auxiliary_data ["spiral_vector" ] = spiral_vectors [0 ]
260
+ else :
261
+ for spiral_i , vec in enumerate (spiral_vectors ):
262
+ spiral_i += 1
263
+ auxiliary_data [f"spiral_vector_{ spiral_i :d} " ] = vec
236
264
237
265
autosave_func (
238
266
autosave_filename ,
0 commit comments