Skip to content

Commit cf77f18

Browse files
committed
Save spiral vectors in autosave in a more consistent way
1 parent 873b157 commit cf77f18

File tree

1 file changed

+37
-9
lines changed

1 file changed

+37
-9
lines changed

varipeps/optimization/optimizer.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import jax.numpy as jnp
1212
from jax.lax import scan
1313
from jax.flatten_util import ravel_pytree
14+
from jax.util import safe_zip
1415

1516
from varipeps import varipeps_config, varipeps_global_state
1617
from varipeps.config import Optimizing_Methods
@@ -213,7 +214,7 @@ def _autosave_wrapper(
213214
additional_input,
214215
):
215216
auxiliary_data = {
216-
"best_run": best_run if best_run is not None else 0,
217+
"best_run": jnp.array(best_run if best_run is not None else 0),
217218
"current_energy": working_value,
218219
}
219220

@@ -223,16 +224,43 @@ def _autosave_wrapper(
223224
auxiliary_data[f"step_chi_{k:d}"] = step_chi[k]
224225
auxiliary_data[f"step_conv_{k:d}"] = step_conv[k]
225226

227+
spiral_vectors = None
226228
if spiral_indices is not None:
227-
for spiral_i in spiral_indices:
228-
auxiliary_data[f"spiral_vector_{spiral_i:d}"] = working_tensors[spiral_i]
229+
spiral_vectors = [working_tensors[spiral_i] for spiral_i in spiral_indices]
230+
231+
if any(i.size == 1 for i in spiral_vectors):
232+
spiral_vectors_x = additional_input.get("spiral_vectors_x")
233+
spiral_vectors_y = additional_input.get("spiral_vectors_y")
234+
if spiral_vectors_x is not None:
235+
if isinstance(spiral_vectors_x, jnp.ndarray):
236+
spiral_vectors_x = (spiral_vectors_x,)
237+
spiral_vectors = tuple(
238+
jnp.array((sx, sy))
239+
for sx, sy in safe_zip(spiral_vectors_x, spiral_vectors)
240+
)
241+
elif spiral_vectors_y is not None:
242+
if isinstance(spiral_vectors_y, jnp.ndarray):
243+
spiral_vectors_y = (spiral_vectors_y,)
244+
spiral_vectors = tuple(
245+
jnp.array((sx, sy))
246+
for sx, sy in safe_zip(spiral_vectors, spiral_vectors_y)
247+
)
229248
elif additional_input.get("spiral_vectors") is not None:
230-
add_input_spiral = additional_input.get("spiral_vectors")
231-
if isinstance(add_input_spiral, jnp.ndarray):
232-
add_input_spiral = (add_input_spiral,)
233-
for spiral_i, elem in enumerate(add_input_spiral):
234-
spiral_i += 1
235-
auxiliary_data[f"spiral_vector_{spiral_i:d}"] = elem
249+
spiral_vectors = additional_input.get("spiral_vectors")
250+
if isinstance(spiral_vectors, jnp.ndarray):
251+
spiral_vectors = (spiral_vectors,)
252+
253+
if spiral_vectors is not None:
254+
spiral_vectors = [
255+
e if e.size == 2 else jnp.array((e, e)).reshape(2) for e in spiral_vectors
256+
]
257+
258+
if len(spiral_vectors) == 1:
259+
auxiliary_data["spiral_vector"] = spiral_vectors[0]
260+
else:
261+
for spiral_i, vec in enumerate(spiral_vectors):
262+
spiral_i += 1
263+
auxiliary_data[f"spiral_vector_{spiral_i:d}"] = vec
236264

237265
autosave_func(
238266
autosave_filename,

0 commit comments

Comments
 (0)