Skip to content

Commit 306ed99

Browse files
authored
Merge pull request #317 from igerber/fix/trop-convergence-warnings
Signal non-convergence in TROP alternating-minimization solvers
2 parents 76fb12c + 41b6a5f commit 306ed99

5 files changed

Lines changed: 352 additions & 10 deletions

File tree

diff_diff/trop.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
_PrecomputedStructures,
3838
TROPResults,
3939
)
40-
from diff_diff.utils import safe_inference
40+
from diff_diff.utils import safe_inference, warn_if_not_converged
4141

4242

4343
class TROP(TROPLocalMixin, TROPGlobalMixin):
@@ -748,6 +748,8 @@ def fit(
748748

749749
# Use pre-computed treated observations
750750
treated_observations = self._precomputed["treated_observations"]
751+
nonconverg_tracker: list = []
752+
n_fits_attempted = 0
751753

752754
for t, i in treated_observations:
753755
unit_id = idx_to_unit[i]
@@ -765,8 +767,10 @@ def fit(
765767
)
766768

767769
# Fit model with these weights
770+
n_fits_attempted += 1
768771
alpha_hat, beta_hat, L_hat = self._estimate_model(
769-
Y, control_mask, weight_matrix, lambda_nn, n_units, n_periods
772+
Y, control_mask, weight_matrix, lambda_nn, n_units, n_periods,
773+
_nonconvergence_tracker=nonconverg_tracker,
770774
)
771775

772776
# Compute treatment effect: tau_{it} = Y_{it} - alpha_i - beta_t - L_{it}
@@ -782,6 +786,16 @@ def fit(
782786
beta_estimates.append(beta_hat)
783787
L_estimates.append(L_hat)
784788

789+
if nonconverg_tracker:
790+
warn_if_not_converged(
791+
False,
792+
f"TROP local per-treated-observation fit: "
793+
f"{len(nonconverg_tracker)} of {n_fits_attempted} "
794+
f"fits did not converge",
795+
self.max_iter,
796+
self.tol,
797+
)
798+
785799
# Count valid treated observations
786800
n_valid_treated = len(tau_values)
787801
if n_valid_treated == 0:

diff_diff/trop_global.py

Lines changed: 70 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
)
2727
from diff_diff.trop_local import _soft_threshold_svd, _validate_and_pivot_treatment
2828
from diff_diff.trop_results import TROPResults
29-
from diff_diff.utils import safe_inference
29+
from diff_diff.utils import safe_inference, warn_if_not_converged
3030

3131

3232
class TROPGlobalMixin:
@@ -156,6 +156,7 @@ def _solve_global_model(
156156
Y: np.ndarray,
157157
delta: np.ndarray,
158158
lambda_nn: float,
159+
_nonconvergence_tracker: Optional[List[int]] = None,
159160
) -> Tuple[float, np.ndarray, np.ndarray, np.ndarray]:
160161
"""
161162
Dispatch to no-lowrank or with-lowrank solver based on lambda_nn.
@@ -168,7 +169,8 @@ def _solve_global_model(
168169
L = np.zeros((n_periods, n_units))
169170
else:
170171
mu, alpha, beta, L = self._solve_global_with_lowrank(
171-
Y, delta, lambda_nn, self.max_iter, self.tol
172+
Y, delta, lambda_nn, self.max_iter, self.tol,
173+
_nonconvergence_tracker=_nonconvergence_tracker,
172174
)
173175
return mu, alpha, beta, L
174176

@@ -273,14 +275,18 @@ def _loocv_score_global(
273275

274276
tau_sq_sum = 0.0
275277
n_valid = 0
278+
nonconverg_tracker: List[int] = []
276279

277280
for t_ex, i_ex in control_obs:
278281
# Create modified delta with excluded observation zeroed out
279282
delta_ex = delta.copy()
280283
delta_ex[t_ex, i_ex] = 0.0
281284

282285
try:
283-
mu, alpha, beta, L = self._solve_global_model(Y, delta_ex, lambda_nn)
286+
mu, alpha, beta, L = self._solve_global_model(
287+
Y, delta_ex, lambda_nn,
288+
_nonconvergence_tracker=nonconverg_tracker,
289+
)
284290

285291
# Pseudo treatment effect: tau = Y - mu - alpha - beta - L
286292
if np.isfinite(Y[t_ex, i_ex]):
@@ -292,6 +298,16 @@ def _loocv_score_global(
292298
# Any failure means this lambda combination is invalid per Equation 5
293299
return np.inf
294300

301+
if nonconverg_tracker:
302+
warn_if_not_converged(
303+
False,
304+
f"TROP global LOOCV: {len(nonconverg_tracker)} of {len(control_obs)} "
305+
f"per-observation fits did not converge "
306+
f"(\u03bb=({lambda_time}, {lambda_unit}, {lambda_nn}))",
307+
self.max_iter,
308+
self.tol,
309+
)
310+
295311
if n_valid == 0:
296312
return np.inf
297313

@@ -395,6 +411,7 @@ def _solve_global_with_lowrank(
395411
lambda_nn: float,
396412
max_iter: int = 100,
397413
tol: float = 1e-6,
414+
_nonconvergence_tracker: Optional[List[int]] = None,
398415
) -> Tuple[float, np.ndarray, np.ndarray, np.ndarray]:
399416
"""
400417
Solve TWFE + low-rank on control data via alternating minimization.
@@ -445,6 +462,9 @@ def _solve_global_with_lowrank(
445462
# Initialize L = 0
446463
L = np.zeros((n_periods, n_units))
447464

465+
_FISTA_MAX_ITER = 20
466+
inner_nonconverged_count = 0
467+
outer_converged = False
448468
for iteration in range(max_iter):
449469
L_old = L.copy()
450470

@@ -463,7 +483,8 @@ def _solve_global_with_lowrank(
463483
L_inner_prev = L_inner # share reference initially (no copy needed)
464484
t_fista = 1.0
465485

466-
for _ in range(20):
486+
inner_converged = False
487+
for _ in range(_FISTA_MAX_ITER):
467488
# FISTA momentum
468489
t_fista_new = (1.0 + np.sqrt(1.0 + 4.0 * t_fista**2)) / 2.0
469490
momentum = (t_fista - 1.0) / t_fista_new
@@ -479,14 +500,29 @@ def _solve_global_with_lowrank(
479500

480501
# Convergence check (L_inner_prev holds the pre-SVD value)
481502
if np.max(np.abs(L_inner - L_inner_prev)) < tol:
503+
inner_converged = True
482504
break
505+
if not inner_converged:
506+
inner_nonconverged_count += 1
483507

484508
L = L_inner
485509

486510
# Outer convergence check
487511
if np.max(np.abs(L - L_old)) < tol:
512+
outer_converged = True
488513
break
489514

515+
if not outer_converged:
516+
if _nonconvergence_tracker is not None:
517+
_nonconvergence_tracker.append(inner_nonconverged_count)
518+
else:
519+
detail = (
520+
f"TROP global alternating minimization "
521+
f"(inner FISTA non-converged in {inner_nonconverged_count}/{max_iter} "
522+
f"outer iterations, FISTA max_iter={_FISTA_MAX_ITER})"
523+
)
524+
warn_if_not_converged(False, detail, max_iter, tol)
525+
490526
# Final re-solve with converged L (match Rust behavior)
491527
Y_adj = Y_safe - L
492528
mu, alpha, beta = self._solve_global_no_lowrank(Y_adj, delta_masked)
@@ -984,6 +1020,7 @@ def _bootstrap_variance_global(
9841020
n_control_units = len(control_units)
9851021

9861022
bootstrap_estimates_list: List[float] = []
1023+
nonconverg_tracker: List[int] = []
9871024

9881025
for _ in range(self.n_bootstrap):
9891026
# Stratified sampling
@@ -1018,6 +1055,7 @@ def _bootstrap_variance_global(
10181055
optimal_lambda,
10191056
treated_periods,
10201057
survey_design=survey_design,
1058+
_nonconvergence_tracker=nonconverg_tracker,
10211059
)
10221060
if np.isfinite(tau):
10231061
bootstrap_estimates_list.append(tau)
@@ -1026,6 +1064,15 @@ def _bootstrap_variance_global(
10261064

10271065
bootstrap_estimates = np.array(bootstrap_estimates_list)
10281066

1067+
if nonconverg_tracker:
1068+
warn_if_not_converged(
1069+
False,
1070+
f"TROP global bootstrap: {len(nonconverg_tracker)} of "
1071+
f"{self.n_bootstrap} replicate fits did not converge",
1072+
self.max_iter,
1073+
self.tol,
1074+
)
1075+
10291076
if len(bootstrap_estimates) < 10:
10301077
warnings.warn(
10311078
f"Only {len(bootstrap_estimates)} bootstrap iterations succeeded.", UserWarning
@@ -1169,6 +1216,7 @@ def _bootstrap_rao_wu_global(
11691216
)
11701217

11711218
bootstrap_estimates_list: List[float] = []
1219+
nonconverg_tracker: List[int] = []
11721220

11731221
for _ in range(self.n_bootstrap):
11741222
try:
@@ -1187,7 +1235,10 @@ def _bootstrap_rao_wu_global(
11871235
delta = self._compute_global_weights(
11881236
Y, D, lambda_time, lambda_unit, treated_periods, n_units, n_periods
11891237
)
1190-
mu, alpha, beta, L = self._solve_global_model(Y, delta, lambda_nn)
1238+
mu, alpha, beta, L = self._solve_global_model(
1239+
Y, delta, lambda_nn,
1240+
_nonconvergence_tracker=nonconverg_tracker,
1241+
)
11911242

11921243
# Extract weighted ATT using Rao-Wu rescaled weights
11931244
att, _, _ = self._extract_posthoc_tau(
@@ -1201,6 +1252,15 @@ def _bootstrap_rao_wu_global(
12011252

12021253
bootstrap_estimates = np.array(bootstrap_estimates_list)
12031254

1255+
if nonconverg_tracker:
1256+
warn_if_not_converged(
1257+
False,
1258+
f"TROP global Rao-Wu bootstrap: {len(nonconverg_tracker)} of "
1259+
f"{self.n_bootstrap} replicate fits did not converge",
1260+
self.max_iter,
1261+
self.tol,
1262+
)
1263+
12041264
if len(bootstrap_estimates) < 10:
12051265
warnings.warn(
12061266
f"Only {len(bootstrap_estimates)} bootstrap iterations succeeded.",
@@ -1222,6 +1282,7 @@ def _fit_global_with_fixed_lambda(
12221282
fixed_lambda: Tuple[float, float, float],
12231283
treated_periods: int,
12241284
survey_design=None,
1285+
_nonconvergence_tracker: Optional[List[int]] = None,
12251286
) -> float:
12261287
"""
12271288
Fit global model with fixed tuning parameters.
@@ -1263,7 +1324,10 @@ def _fit_global_with_fixed_lambda(
12631324
)
12641325

12651326
# Fit model on control data and extract post-hoc tau
1266-
mu, alpha, beta, L = self._solve_global_model(Y, delta, lambda_nn)
1327+
mu, alpha, beta, L = self._solve_global_model(
1328+
Y, delta, lambda_nn,
1329+
_nonconvergence_tracker=_nonconvergence_tracker,
1330+
)
12671331
att, _, _ = self._extract_posthoc_tau(
12681332
Y, D, mu, alpha, beta, L, unit_weights=local_weight_arr
12691333
)

0 commit comments

Comments
 (0)