Skip to content

Commit 287c698

Browse files
Added logging to structure factor ctmrg
1 parent 50e8003 commit 287c698

File tree

1 file changed

+27
-13
lines changed

1 file changed

+27
-13
lines changed

varipeps/ctmrg/structure_factor_routine.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66
import jax.util
77
from jax.lax import cond, while_loop
88
import jax.debug as jdebug
9+
import logging
10+
import time
11+
12+
logger = logging.getLogger("varipeps.ctmrg")
913

1014
from varipeps import varipeps_config, varipeps_global_state
1115
from varipeps.peps import PEPS_Tensor, PEPS_Unit_Cell
@@ -126,8 +130,8 @@ def _ctmrg_body_func_structure_factor(carry):
126130
measure = jnp.linalg.norm(corner_svd - last_corner_svd)
127131
converged = measure < eps
128132

129-
if config.ctmrg_print_steps:
130-
debug_print("CTMRG: {}: {}", count, measure)
133+
if logger.isEnabledFor(logging.DEBUG):
134+
jax.debug.callback(lambda cnt, msr: logger.debug(f"CTMRG: Step {cnt}: {msr}"), count, measure, ordered=True)
131135
if config.ctmrg_verbose_output:
132136
for ti, ctm_enum_i, diff in verbose_data:
133137
debug_print(
@@ -245,6 +249,7 @@ def calc_ctmrg_env_structure_factor(
245249
norm_smallest_S = jnp.nan
246250
already_tried_chi = {working_unitcell[0, 0][0][0].chi}
247251

252+
t0 = time.perf_counter()
248253
while True:
249254
tmp_count = 0
250255
corner_singular_vals = None
@@ -305,6 +310,17 @@ def calc_ctmrg_env_structure_factor(
305310
)
306311
)
307312

313+
if not converged and logger.isEnabledFor(logging.WARNING):
314+
logger.warning(
315+
"CTMRG (SF): ❌ did not converge, took %.2f seconds. (Steps: %d, Smallest SVD Norm: %.3e)",
316+
time.perf_counter() - t0, end_count, norm_smallest_S
317+
)
318+
elif logger.isEnabledFor(logging.INFO):
319+
logger.info(
320+
"CTMRG (SF): ✅ converged, took %.2f seconds. (Steps: %d, Smallest SVD Norm: %.3e)",
321+
time.perf_counter() - t0, end_count, norm_smallest_S
322+
)
323+
308324
current_truncation_eps = (
309325
varipeps_config.ctmrg_truncation_eps
310326
if varipeps_global_state.ctmrg_effective_truncation_eps is None
@@ -327,15 +343,14 @@ def calc_ctmrg_env_structure_factor(
327343
working_unitcell = working_unitcell.change_chi(new_chi)
328344
initial_unitcell = initial_unitcell.change_chi(new_chi)
329345

330-
if varipeps_config.ctmrg_print_steps:
331-
debug_print(
332-
"CTMRG: Increasing chi to {} since smallest SVD Norm was {}.",
346+
if logger.isEnabledFor(logging.INFO):
347+
logger.info(
348+
"CTMRG (SF): Increasing chi to %d since smallest SVD Norm was %.3e.",
333349
new_chi,
334350
norm_smallest_S,
335351
)
336352

337353
already_tried_chi.add(new_chi)
338-
339354
continue
340355
elif (
341356
varipeps_config.ctmrg_heuristic_decrease_chi
@@ -352,15 +367,14 @@ def calc_ctmrg_env_structure_factor(
352367
if not new_chi in already_tried_chi:
353368
working_unitcell = working_unitcell.change_chi(new_chi)
354369

355-
if varipeps_config.ctmrg_print_steps:
356-
debug_print(
357-
"CTMRG: Decreasing chi to {} since smallest SVD Norm was {}.",
370+
if logger.isEnabledFor(logging.INFO):
371+
logger.info(
372+
"CTMRG (SF): Decreasing chi to %d since smallest SVD Norm was %.3e.",
358373
new_chi,
359374
norm_smallest_S,
360375
)
361376

362377
already_tried_chi.add(new_chi)
363-
364378
continue
365379

366380
if (
@@ -376,9 +390,9 @@ def calc_ctmrg_env_structure_factor(
376390
new_truncation_eps
377391
<= varipeps_config.ctmrg_increase_truncation_eps_max_value
378392
):
379-
if varipeps_config.ctmrg_print_steps:
380-
debug_print(
381-
"CTMRG: Increasing SVD truncation eps to {}.",
393+
if logger.isEnabledFor(logging.INFO):
394+
logger.info(
395+
"CTMRG (SF): Increasing SVD truncation eps to %g.",
382396
new_truncation_eps,
383397
)
384398
varipeps_global_state.ctmrg_effective_truncation_eps = (

0 commit comments

Comments
 (0)