2626)
2727from diff_diff .trop_local import _soft_threshold_svd , _validate_and_pivot_treatment
2828from diff_diff .trop_results import TROPResults
29- from diff_diff .utils import safe_inference
29+ from diff_diff .utils import safe_inference , warn_if_not_converged
3030
3131
3232class TROPGlobalMixin :
@@ -156,6 +156,7 @@ def _solve_global_model(
156156 Y : np .ndarray ,
157157 delta : np .ndarray ,
158158 lambda_nn : float ,
159+ _nonconvergence_tracker : Optional [List [int ]] = None ,
159160 ) -> Tuple [float , np .ndarray , np .ndarray , np .ndarray ]:
160161 """
161162 Dispatch to no-lowrank or with-lowrank solver based on lambda_nn.
@@ -168,7 +169,8 @@ def _solve_global_model(
168169 L = np .zeros ((n_periods , n_units ))
169170 else :
170171 mu , alpha , beta , L = self ._solve_global_with_lowrank (
171- Y , delta , lambda_nn , self .max_iter , self .tol
172+ Y , delta , lambda_nn , self .max_iter , self .tol ,
173+ _nonconvergence_tracker = _nonconvergence_tracker ,
172174 )
173175 return mu , alpha , beta , L
174176
@@ -273,14 +275,18 @@ def _loocv_score_global(
273275
274276 tau_sq_sum = 0.0
275277 n_valid = 0
278+ nonconverg_tracker : List [int ] = []
276279
277280 for t_ex , i_ex in control_obs :
278281 # Create modified delta with excluded observation zeroed out
279282 delta_ex = delta .copy ()
280283 delta_ex [t_ex , i_ex ] = 0.0
281284
282285 try :
283- mu , alpha , beta , L = self ._solve_global_model (Y , delta_ex , lambda_nn )
286+ mu , alpha , beta , L = self ._solve_global_model (
287+ Y , delta_ex , lambda_nn ,
288+ _nonconvergence_tracker = nonconverg_tracker ,
289+ )
284290
285291 # Pseudo treatment effect: tau = Y - mu - alpha - beta - L
286292 if np .isfinite (Y [t_ex , i_ex ]):
@@ -292,6 +298,16 @@ def _loocv_score_global(
292298 # Any failure means this lambda combination is invalid per Equation 5
293299 return np .inf
294300
301+ if nonconverg_tracker :
302+ warn_if_not_converged (
303+ False ,
304+ f"TROP global LOOCV: { len (nonconverg_tracker )} of { len (control_obs )} "
305+ f"per-observation fits did not converge "
306+ f"(\u03bb =({ lambda_time } , { lambda_unit } , { lambda_nn } ))" ,
307+ self .max_iter ,
308+ self .tol ,
309+ )
310+
295311 if n_valid == 0 :
296312 return np .inf
297313
@@ -395,6 +411,7 @@ def _solve_global_with_lowrank(
395411 lambda_nn : float ,
396412 max_iter : int = 100 ,
397413 tol : float = 1e-6 ,
414+ _nonconvergence_tracker : Optional [List [int ]] = None ,
398415 ) -> Tuple [float , np .ndarray , np .ndarray , np .ndarray ]:
399416 """
400417 Solve TWFE + low-rank on control data via alternating minimization.
@@ -445,6 +462,9 @@ def _solve_global_with_lowrank(
445462 # Initialize L = 0
446463 L = np .zeros ((n_periods , n_units ))
447464
465+ _FISTA_MAX_ITER = 20
466+ inner_nonconverged_count = 0
467+ outer_converged = False
448468 for iteration in range (max_iter ):
449469 L_old = L .copy ()
450470
@@ -463,7 +483,8 @@ def _solve_global_with_lowrank(
463483 L_inner_prev = L_inner # share reference initially (no copy needed)
464484 t_fista = 1.0
465485
466- for _ in range (20 ):
486+ inner_converged = False
487+ for _ in range (_FISTA_MAX_ITER ):
467488 # FISTA momentum
468489 t_fista_new = (1.0 + np .sqrt (1.0 + 4.0 * t_fista ** 2 )) / 2.0
469490 momentum = (t_fista - 1.0 ) / t_fista_new
@@ -479,14 +500,29 @@ def _solve_global_with_lowrank(
479500
480501 # Convergence check (L_inner_prev holds the pre-SVD value)
481502 if np .max (np .abs (L_inner - L_inner_prev )) < tol :
503+ inner_converged = True
482504 break
505+ if not inner_converged :
506+ inner_nonconverged_count += 1
483507
484508 L = L_inner
485509
486510 # Outer convergence check
487511 if np .max (np .abs (L - L_old )) < tol :
512+ outer_converged = True
488513 break
489514
515+ if not outer_converged :
516+ if _nonconvergence_tracker is not None :
517+ _nonconvergence_tracker .append (inner_nonconverged_count )
518+ else :
519+ detail = (
520+ f"TROP global alternating minimization "
521+ f"(inner FISTA non-converged in { inner_nonconverged_count } /{ max_iter } "
522+ f"outer iterations, FISTA max_iter={ _FISTA_MAX_ITER } )"
523+ )
524+ warn_if_not_converged (False , detail , max_iter , tol )
525+
490526 # Final re-solve with converged L (match Rust behavior)
491527 Y_adj = Y_safe - L
492528 mu , alpha , beta = self ._solve_global_no_lowrank (Y_adj , delta_masked )
@@ -984,6 +1020,7 @@ def _bootstrap_variance_global(
9841020 n_control_units = len (control_units )
9851021
9861022 bootstrap_estimates_list : List [float ] = []
1023+ nonconverg_tracker : List [int ] = []
9871024
9881025 for _ in range (self .n_bootstrap ):
9891026 # Stratified sampling
@@ -1018,6 +1055,7 @@ def _bootstrap_variance_global(
10181055 optimal_lambda ,
10191056 treated_periods ,
10201057 survey_design = survey_design ,
1058+ _nonconvergence_tracker = nonconverg_tracker ,
10211059 )
10221060 if np .isfinite (tau ):
10231061 bootstrap_estimates_list .append (tau )
@@ -1026,6 +1064,15 @@ def _bootstrap_variance_global(
10261064
10271065 bootstrap_estimates = np .array (bootstrap_estimates_list )
10281066
1067+ if nonconverg_tracker :
1068+ warn_if_not_converged (
1069+ False ,
1070+ f"TROP global bootstrap: { len (nonconverg_tracker )} of "
1071+ f"{ self .n_bootstrap } replicate fits did not converge" ,
1072+ self .max_iter ,
1073+ self .tol ,
1074+ )
1075+
10291076 if len (bootstrap_estimates ) < 10 :
10301077 warnings .warn (
10311078 f"Only { len (bootstrap_estimates )} bootstrap iterations succeeded." , UserWarning
@@ -1169,6 +1216,7 @@ def _bootstrap_rao_wu_global(
11691216 )
11701217
11711218 bootstrap_estimates_list : List [float ] = []
1219+ nonconverg_tracker : List [int ] = []
11721220
11731221 for _ in range (self .n_bootstrap ):
11741222 try :
@@ -1187,7 +1235,10 @@ def _bootstrap_rao_wu_global(
11871235 delta = self ._compute_global_weights (
11881236 Y , D , lambda_time , lambda_unit , treated_periods , n_units , n_periods
11891237 )
1190- mu , alpha , beta , L = self ._solve_global_model (Y , delta , lambda_nn )
1238+ mu , alpha , beta , L = self ._solve_global_model (
1239+ Y , delta , lambda_nn ,
1240+ _nonconvergence_tracker = nonconverg_tracker ,
1241+ )
11911242
11921243 # Extract weighted ATT using Rao-Wu rescaled weights
11931244 att , _ , _ = self ._extract_posthoc_tau (
@@ -1201,6 +1252,15 @@ def _bootstrap_rao_wu_global(
12011252
12021253 bootstrap_estimates = np .array (bootstrap_estimates_list )
12031254
1255+ if nonconverg_tracker :
1256+ warn_if_not_converged (
1257+ False ,
1258+ f"TROP global Rao-Wu bootstrap: { len (nonconverg_tracker )} of "
1259+ f"{ self .n_bootstrap } replicate fits did not converge" ,
1260+ self .max_iter ,
1261+ self .tol ,
1262+ )
1263+
12041264 if len (bootstrap_estimates ) < 10 :
12051265 warnings .warn (
12061266 f"Only { len (bootstrap_estimates )} bootstrap iterations succeeded." ,
@@ -1222,6 +1282,7 @@ def _fit_global_with_fixed_lambda(
12221282 fixed_lambda : Tuple [float , float , float ],
12231283 treated_periods : int ,
12241284 survey_design = None ,
1285+ _nonconvergence_tracker : Optional [List [int ]] = None ,
12251286 ) -> float :
12261287 """
12271288 Fit global model with fixed tuning parameters.
@@ -1263,7 +1324,10 @@ def _fit_global_with_fixed_lambda(
12631324 )
12641325
12651326 # Fit model on control data and extract post-hoc tau
1266- mu , alpha , beta , L = self ._solve_global_model (Y , delta , lambda_nn )
1327+ mu , alpha , beta , L = self ._solve_global_model (
1328+ Y , delta , lambda_nn ,
1329+ _nonconvergence_tracker = _nonconvergence_tracker ,
1330+ )
12671331 att , _ , _ = self ._extract_posthoc_tau (
12681332 Y , D , mu , alpha , beta , L , unit_weights = local_weight_arr
12691333 )
0 commit comments