Skip to content

Commit 954b7c0

Browse files
igerberclaude
andcommitted
Signal non-convergence in TROP alternating-minimization solvers
Addresses axis B findings #6 and #7 from the silent-failures audit: trop_global.py:448 outer alternating-min loop, trop_global.py:466 hard-coded range(20) inner FISTA loop, and trop_local.py:680 alternating-minimization loop all exited silently on max_iter exhaustion, returning the current iterate as if converged. - trop_global._solve_global_with_lowrank: thread a converged flag through the outer loop; count non-convergence events from the inner FISTA and surface the count in the outer warning for diagnostic context. One warn_if_not_converged call per solver invocation. - trop_local._estimate_model: thread a converged flag through the outer alternating-min loop; call warn_if_not_converged on exhaustion. - REGISTRY updated under TROP. New TestTROPConvergenceWarnings class (4 tests) exercises both global and local paths with forced non-convergence (max_iter=1, tol=1e-15) and a convergent negative control. Notable: the default TROP local config (max_iter=100, tol=1e-6) does not converge within max_iter on typical synthetic panels, so this PR surfaces a previously silent non-convergence that affected routine user fits. No numerical change in the returned iterate; the warning is additive. Axis-B regression-lint baseline: 5 -> 2 silent range(max_iter) loops remaining (minor loops in honest_did/power not yet addressed). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 32a4d09 commit 954b7c0

4 files changed

Lines changed: 131 additions & 2 deletions

File tree

diff_diff/trop_global.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
)
2727
from diff_diff.trop_local import _soft_threshold_svd, _validate_and_pivot_treatment
2828
from diff_diff.trop_results import TROPResults
29-
from diff_diff.utils import safe_inference
29+
from diff_diff.utils import safe_inference, warn_if_not_converged
3030

3131

3232
class TROPGlobalMixin:
@@ -445,6 +445,9 @@ def _solve_global_with_lowrank(
445445
# Initialize L = 0
446446
L = np.zeros((n_periods, n_units))
447447

448+
_FISTA_MAX_ITER = 20
449+
inner_nonconverged_count = 0
450+
outer_converged = False
448451
for iteration in range(max_iter):
449452
L_old = L.copy()
450453

@@ -463,7 +466,8 @@ def _solve_global_with_lowrank(
463466
L_inner_prev = L_inner # share reference initially (no copy needed)
464467
t_fista = 1.0
465468

466-
for _ in range(20):
469+
inner_converged = False
470+
for _ in range(_FISTA_MAX_ITER):
467471
# FISTA momentum
468472
t_fista_new = (1.0 + np.sqrt(1.0 + 4.0 * t_fista**2)) / 2.0
469473
momentum = (t_fista - 1.0) / t_fista_new
@@ -479,14 +483,26 @@ def _solve_global_with_lowrank(
479483

480484
# Convergence check (L_inner_prev holds the pre-SVD value)
481485
if np.max(np.abs(L_inner - L_inner_prev)) < tol:
486+
inner_converged = True
482487
break
488+
if not inner_converged:
489+
inner_nonconverged_count += 1
483490

484491
L = L_inner
485492

486493
# Outer convergence check
487494
if np.max(np.abs(L - L_old)) < tol:
495+
outer_converged = True
488496
break
489497

498+
if not outer_converged:
499+
detail = (
500+
f"TROP global alternating minimization "
501+
f"(inner FISTA non-converged in {inner_nonconverged_count}/{max_iter} "
502+
f"outer iterations, FISTA max_iter={_FISTA_MAX_ITER})"
503+
)
504+
warn_if_not_converged(False, detail, max_iter, tol)
505+
490506
# Final re-solve with converged L (match Rust behavior)
491507
Y_adj = Y_safe - L
492508
mu, alpha, beta = self._solve_global_no_lowrank(Y_adj, delta_masked)

diff_diff/trop_local.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
_rust_unit_distance_matrix,
2626
)
2727
from diff_diff.trop_results import _PrecomputedStructures
28+
from diff_diff.utils import warn_if_not_converged
2829

2930

3031
def _validate_and_pivot_treatment(data, time, unit, treatment, all_periods, all_units):
@@ -677,6 +678,7 @@ def _estimate_model(
677678

678679
# Alternating minimization following Algorithm 1 (page 9)
679680
# Minimize: sum W_{ti}(Y_{ti} - alpha_i - beta_t - L_{ti})^2 + lambda_nn||L||_*
681+
converged = False
680682
for _ in range(self.max_iter):
681683
alpha_old = alpha.copy()
682684
beta_old = beta.copy()
@@ -717,7 +719,11 @@ def _estimate_model(
717719
L_diff = np.max(np.abs(L - L_old))
718720

719721
if max(alpha_diff, beta_diff, L_diff) < self.tol:
722+
converged = True
720723
break
724+
warn_if_not_converged(
725+
converged, "TROP local alternating minimization", self.max_iter, self.tol
726+
)
721727

722728
return alpha, beta, L
723729

docs/methodology/REGISTRY.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1972,6 +1972,7 @@ Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]²
19721972
- **Bootstrap minimum**: `n_bootstrap` must be >= 2 (enforced via `ValueError`). TROP uses bootstrap for all variance estimation — there is no analytical SE formula.
19731973
- **LOOCV failure metadata**: When LOOCV fits fail in the Rust backend, the first failed observation coordinates (t, i) are returned to Python for informative warning messages
19741974
- **Inference CI distribution**: After `safe_inference()` migration, CI uses t-distribution (df = max(1, n_treated_obs - 1)), consistent with p_value. Previously CI used normal-distribution while p_value used t-distribution (inconsistent). This is a minor behavioral change; CIs may be slightly wider for small n_treated_obs.
1975+
- **Note:** Both the `local` alternating-minimization solver (`_estimate_model`) and the `global` alternating-minimization solver (`_solve_global_with_lowrank`, including its hard-coded inner FISTA loop of 20 iterations) emit `UserWarning` via `diff_diff.utils.warn_if_not_converged` when the outer loop exhausts `max_iter` without reaching `tol`. The global-method warning surfaces the inner-FISTA non-convergence count as diagnostic context. Silent return of the current iterate was classified as a silent failure under the Phase 2 audit and replaced with an explicit signal to match the convention used across other iterative solvers in the library.
19751976

19761977
**Reference implementation(s):**
19771978
- Authors' replication code (forthcoming)

tests/test_trop.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3936,3 +3936,109 @@ def test_observed_treatment_nan_raises_local(self):
39363936
)
39373937
with pytest.raises(ValueError, match="missing treatment values"):
39383938
trop_est.fit(df, "outcome", "treated", "unit", "time")
3939+
3940+
3941+
class TestTROPConvergenceWarnings:
3942+
"""Silent-failure audit axis B: TROP alternating minimization must warn on non-convergence."""
3943+
3944+
@staticmethod
3945+
def _panel_matrices(simple_panel_data):
3946+
"""Pivot simple_panel_data into (Y, D, n_units, n_periods, treated_periods)."""
3947+
all_units = sorted(simple_panel_data["unit"].unique())
3948+
all_periods = sorted(simple_panel_data["period"].unique())
3949+
n_units = len(all_units)
3950+
n_periods = len(all_periods)
3951+
Y = (
3952+
simple_panel_data.pivot(index="period", columns="unit", values="outcome")
3953+
.reindex(index=all_periods, columns=all_units)
3954+
.values
3955+
)
3956+
D = (
3957+
simple_panel_data.pivot(index="period", columns="unit", values="treated")
3958+
.reindex(index=all_periods, columns=all_units)
3959+
.fillna(0)
3960+
.astype(int)
3961+
.values
3962+
)
3963+
treated_periods = int(np.sum(np.any(D == 1, axis=1)))
3964+
return Y, D, n_units, n_periods, treated_periods
3965+
3966+
def test_global_alternating_min_warns_on_nonconvergence(self, simple_panel_data):
3967+
"""_solve_global_with_lowrank must warn when outer alternating-min loop exhausts max_iter."""
3968+
Y, D, n_units, n_periods, treated_periods = self._panel_matrices(simple_panel_data)
3969+
3970+
trop_est = TROP(
3971+
method="global",
3972+
lambda_time_grid=[1.0],
3973+
lambda_unit_grid=[1.0],
3974+
lambda_nn_grid=[0.1],
3975+
seed=42,
3976+
)
3977+
delta = trop_est._compute_global_weights(
3978+
Y, D, 1.0, 1.0, treated_periods, n_units, n_periods
3979+
)
3980+
3981+
with pytest.warns(UserWarning, match="did not converge"):
3982+
trop_est._solve_global_with_lowrank(Y, delta, lambda_nn=0.1, max_iter=1, tol=1e-15)
3983+
3984+
def test_global_alternating_min_no_warning_on_convergence(self, simple_panel_data):
3985+
"""_solve_global_with_lowrank must not warn on a well-behaved fit with generous max_iter."""
3986+
Y, D, n_units, n_periods, treated_periods = self._panel_matrices(simple_panel_data)
3987+
3988+
trop_est = TROP(
3989+
method="global",
3990+
lambda_time_grid=[1.0],
3991+
lambda_unit_grid=[1.0],
3992+
lambda_nn_grid=[0.1],
3993+
seed=42,
3994+
)
3995+
delta = trop_est._compute_global_weights(
3996+
Y, D, 1.0, 1.0, treated_periods, n_units, n_periods
3997+
)
3998+
3999+
with warnings.catch_warnings(record=True) as w:
4000+
warnings.simplefilter("always")
4001+
trop_est._solve_global_with_lowrank(Y, delta, lambda_nn=0.1, max_iter=500, tol=1e-6)
4002+
assert not any("did not converge" in str(x.message) for x in w)
4003+
4004+
def test_local_alternating_min_warns_on_nonconvergence(self, simple_panel_data):
4005+
"""TROP local _estimate_model must warn when alternating-min exhausts max_iter."""
4006+
Y, D, n_units, n_periods, _ = self._panel_matrices(simple_panel_data)
4007+
control_mask = (np.sum(D, axis=0) == 0) # units never treated
4008+
4009+
trop_est = TROP(
4010+
method="local",
4011+
lambda_time_grid=[1.0],
4012+
lambda_unit_grid=[1.0],
4013+
lambda_nn_grid=[0.1],
4014+
max_iter=1,
4015+
tol=1e-15,
4016+
seed=42,
4017+
)
4018+
W = np.where(D == 0, 1.0, 0.0)
4019+
4020+
with pytest.warns(UserWarning, match="did not converge"):
4021+
trop_est._estimate_model(Y, control_mask, W, lambda_nn=0.1,
4022+
n_units=n_units, n_periods=n_periods)
4023+
4024+
def test_local_alternating_min_no_warning_on_convergence(self, simple_panel_data):
4025+
"""TROP local _estimate_model must not warn on a well-behaved fit."""
4026+
Y, D, n_units, n_periods, _ = self._panel_matrices(simple_panel_data)
4027+
control_mask = (np.sum(D, axis=0) == 0)
4028+
4029+
trop_est = TROP(
4030+
method="local",
4031+
lambda_time_grid=[1.0],
4032+
lambda_unit_grid=[1.0],
4033+
lambda_nn_grid=[0.1],
4034+
max_iter=500,
4035+
tol=1e-6,
4036+
seed=42,
4037+
)
4038+
W = np.where(D == 0, 1.0, 0.0)
4039+
4040+
with warnings.catch_warnings(record=True) as w:
4041+
warnings.simplefilter("always")
4042+
trop_est._estimate_model(Y, control_mask, W, lambda_nn=0.1,
4043+
n_units=n_units, n_periods=n_periods)
4044+
assert not any("did not converge" in str(x.message) for x in w)

0 commit comments

Comments
 (0)