Skip to content

Commit 44e7522

Browse files
authored
Merge pull request #314 from igerber/fix/fe-imputation-convergence-warnings
Signal non-convergence in FE imputation alternating-projection solvers
2 parents 475f4e3 + f749416 commit 44e7522

7 files changed

Lines changed: 208 additions & 6 deletions

File tree

diff_diff/imputation.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
ImputationDiDResults,
2929
)
3030
from diff_diff.linalg import solve_ols
31-
from diff_diff.utils import safe_inference
31+
from diff_diff.utils import safe_inference, warn_if_not_converged
3232

3333
# =============================================================================
3434
# Main Estimator
@@ -909,6 +909,7 @@ def _iterative_fe(
909909
wsum_t = w_series.groupby(time_vals).transform("sum").values
910910
wsum_u = w_series.groupby(unit_vals).transform("sum").values
911911

912+
converged = False
912913
with np.errstate(invalid="ignore", divide="ignore"):
913914
for iteration in range(max_iter):
914915
resid_after_alpha = y - alpha
@@ -943,7 +944,9 @@ def _iterative_fe(
943944
alpha = alpha_new
944945
beta = beta_new
945946
if max_change < tol:
947+
converged = True
946948
break
949+
warn_if_not_converged(converged, "ImputationDiD iterative FE solver", max_iter, tol)
947950

948951
unit_fe = pd.Series(alpha, index=idx).groupby(unit_vals).first().to_dict()
949952
time_fe = pd.Series(beta, index=idx).groupby(time_vals).first().to_dict()
@@ -978,6 +981,7 @@ def _iterative_demean(
978981
wsum_t = w_series.groupby(time_vals).transform("sum").values
979982
wsum_u = w_series.groupby(unit_vals).transform("sum").values
980983

984+
converged = False
981985
with np.errstate(invalid="ignore", divide="ignore"):
982986
for _ in range(max_iter):
983987
if weights is not None:
@@ -1001,8 +1005,10 @@ def _iterative_demean(
10011005
result_new = result_after_time - unit_means
10021006
if np.max(np.abs(result_new - result)) < tol:
10031007
result = result_new
1008+
converged = True
10041009
break
10051010
result = result_new
1011+
warn_if_not_converged(converged, "ImputationDiD iterative demean", max_iter, tol)
10061012
return result
10071013

10081014
@staticmethod

diff_diff/two_stage.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
TwoStageBootstrapResults, # noqa: F401
4242
TwoStageDiDResults,
4343
) # noqa: F401 (re-export)
44-
from diff_diff.utils import safe_inference
44+
from diff_diff.utils import safe_inference, warn_if_not_converged
4545

4646
# =============================================================================
4747
# Main Estimator
@@ -887,6 +887,7 @@ def _iterative_fe(
887887
wsum_t = w_series.groupby(time_vals).transform("sum").values
888888
wsum_u = w_series.groupby(unit_vals).transform("sum").values
889889

890+
converged = False
890891
with np.errstate(invalid="ignore", divide="ignore"):
891892
for iteration in range(max_iter):
892893
resid_after_alpha = y - alpha
@@ -920,7 +921,9 @@ def _iterative_fe(
920921
alpha = alpha_new
921922
beta = beta_new
922923
if max_change < tol:
924+
converged = True
923925
break
926+
warn_if_not_converged(converged, "TwoStageDiD iterative FE solver", max_iter, tol)
924927

925928
unit_fe = pd.Series(alpha, index=idx).groupby(unit_vals).first().to_dict()
926929
time_fe = pd.Series(beta, index=idx).groupby(time_vals).first().to_dict()
@@ -951,6 +954,7 @@ def _iterative_demean(
951954
wsum_t = w_series.groupby(time_vals).transform("sum").values
952955
wsum_u = w_series.groupby(unit_vals).transform("sum").values
953956

957+
converged = False
954958
with np.errstate(invalid="ignore", divide="ignore"):
955959
for _ in range(max_iter):
956960
if weights is not None:
@@ -974,8 +978,10 @@ def _iterative_demean(
974978
result_new = result_after_time - unit_means
975979
if np.max(np.abs(result_new - result)) < tol:
976980
result = result_new
981+
converged = True
977982
break
978983
result = result_new
984+
warn_if_not_converged(converged, "TwoStageDiD iterative demean", max_iter, tol)
979985
return result
980986

981987
def _fit_untreated_model(

diff_diff/utils.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,29 @@ def validate_binary(arr: np.ndarray, name: str) -> None:
6565
raise ValueError(f"{name} must be binary (0 or 1). " f"Found values: {unique_values}")
6666

6767

68+
def warn_if_not_converged(
69+
converged: bool,
70+
method_name: str,
71+
max_iter: int,
72+
tol: Optional[float] = None,
73+
stacklevel: int = 3,
74+
) -> None:
75+
"""Emit a UserWarning when an iterative solver exhausts max_iter without converging.
76+
77+
Shared helper for axis-B silent-failure fixes (iterative loops that otherwise
78+
return the current iterate without signaling non-convergence).
79+
"""
80+
if converged:
81+
return
82+
tol_suffix = f" (tol={tol})" if tol is not None else ""
83+
warnings.warn(
84+
f"{method_name} did not converge in {max_iter} iterations{tol_suffix}. "
85+
"Results may be inaccurate.",
86+
UserWarning,
87+
stacklevel=stacklevel,
88+
)
89+
90+
6891
def compute_robust_se(
6992
X: np.ndarray, residuals: np.ndarray, cluster_ids: Optional[np.ndarray] = None
7093
) -> np.ndarray:
@@ -1791,6 +1814,8 @@ def within_transform(
17911814
inplace: bool = False,
17921815
suffix: str = "_demeaned",
17931816
weights: Optional[np.ndarray] = None,
1817+
max_iter: int = 100,
1818+
tol: float = 1e-8,
17941819
) -> pd.DataFrame:
17951820
"""
17961821
Apply two-way within transformation to remove unit and time fixed effects.
@@ -1818,6 +1843,14 @@ def within_transform(
18181843
Suffix for new column names when inplace=False.
18191844
weights : np.ndarray, optional
18201845
Observation weights for weighted group means.
1846+
max_iter : int, default 100
1847+
Maximum number of alternating-projection iterations. Used only when
1848+
``weights`` is not ``None``; the unweighted path is a single pass and
1849+
ignores this argument. Emits a ``UserWarning`` per call when any
1850+
variable fails to converge within this budget.
1851+
tol : float, default 1e-8
1852+
Convergence tolerance on the max absolute change across the iterate.
1853+
Used only when ``weights`` is not ``None``.
18211854
18221855
Returns
18231856
-------
@@ -1853,29 +1886,45 @@ def _weighted_group_demean(x, groups, w, w_sum):
18531886
wx_sum = pd.Series(w * x).groupby(groups).transform("sum").values
18541887
return x - wx_sum / w_sum
18551888

1889+
non_converged_vars: List[str] = []
18561890
if inplace:
18571891
for var in variables:
18581892
x = data[var].values.astype(np.float64)
1859-
for _iter in range(100): # max iterations
1893+
converged = False
1894+
for _iter in range(max_iter):
18601895
x_old = x.copy()
18611896
x = _weighted_group_demean(x, unit_groups, w, unit_w_sum)
18621897
x = _weighted_group_demean(x, time_groups, w, time_w_sum)
1863-
if np.max(np.abs(x - x_old)) < 1e-8:
1898+
if np.max(np.abs(x - x_old)) < tol:
1899+
converged = True
18641900
break
1901+
if not converged:
1902+
non_converged_vars.append(var)
18651903
data[var] = x
18661904
else:
18671905
demeaned_data = {}
18681906
for var in variables:
18691907
x = data[var].values.astype(np.float64)
1870-
for _iter in range(100):
1908+
converged = False
1909+
for _iter in range(max_iter):
18711910
x_old = x.copy()
18721911
x = _weighted_group_demean(x, unit_groups, w, unit_w_sum)
18731912
x = _weighted_group_demean(x, time_groups, w, time_w_sum)
1874-
if np.max(np.abs(x - x_old)) < 1e-8:
1913+
if np.max(np.abs(x - x_old)) < tol:
1914+
converged = True
18751915
break
1916+
if not converged:
1917+
non_converged_vars.append(var)
18761918
demeaned_data[f"{var}{suffix}"] = x
18771919
demeaned_df = pd.DataFrame(demeaned_data, index=data.index)
18781920
data = pd.concat([data, demeaned_df], axis=1)
1921+
if non_converged_vars:
1922+
warn_if_not_converged(
1923+
False,
1924+
f"within_transform weighted demean (variables: {non_converged_vars})",
1925+
max_iter,
1926+
tol,
1927+
)
18791928
else:
18801929
# Cache groupby objects for efficiency
18811930
unit_grouper = data.groupby(unit, sort=False)

docs/methodology/REGISTRY.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1083,6 +1083,7 @@ where `W_it(h) = 1[K_it = h]` are lead indicators, estimated on `Omega_0` only.
10831083
- **Note:** Survey weights enter ImputationDiD via weighted iterative FE (Step 1), survey-weighted ATT aggregation (Step 3), and design-based variance via `compute_survey_if_variance()`. PSU clustering, stratification, and FPC are fully supported in the Theorem 3 variance path. When `resolved_survey` is present, the observation-level influence function (`v_it * epsilon_tilde_it`) is passed to `compute_survey_if_variance()` which applies the stratified PSU-level sandwich with FPC correction. Strata also enters survey df (n_PSU - n_strata) for t-distribution inference. Bootstrap + survey supported (Phase 6) via PSU-level multiplier weights.
10841084
- **Bootstrap inference:** Uses multiplier bootstrap on the Theorem 3 influence function: `psi_i = sum_t v_it * epsilon_tilde_it`. Cluster-level psi sums are pre-computed for each aggregation target (overall, per-horizon, per-group), then perturbed with multiplier weights (Rademacher by default; configurable via `bootstrap_weights` parameter to use Mammen or Webb weights, matching CallawaySantAnna). This is a library extension (not in the paper) consistent with CallawaySantAnna/SunAbraham bootstrap patterns.
10851085
- **Auxiliary residuals (Equation 8):** Uses v_it-weighted tau_tilde_g formula: `tau_tilde_g = sum(v_it * tau_hat_it) / sum(v_it)` within each partition group. Zero-weight groups (common in event-study SE computation) fall back to unweighted mean.
1086+
- **Note:** Both the iterative FE solver (`_iterative_fe`, Step 1) and the iterative alternating-projection demeaning helper (`_iterative_demean`, used in covariate residualization and the pre-trend test) emit `UserWarning` when `max_iter` exhausts without reaching `tol`, via `diff_diff.utils.warn_if_not_converged`. Silent return of the current iterate was classified as a silent failure under the Phase 2 audit and replaced with an explicit signal to match the logistic/Poisson IRLS pattern in `linalg.py`.
10861087

10871088
**Reference implementation(s):**
10881089
- Stata: `did_imputation` (Borusyak, Jaravel, Spiess; available from SSC)
@@ -1160,6 +1161,7 @@ Our implementation uses multiplier bootstrap on the GMM influence function: clus
11601161
- **Zero-observation horizons after filtering:** When `balance_e` or NaN `y_tilde` filtering results in zero observations for some non-Prop-5 event study horizons, those horizons produce NaN for all inference fields (effect, SE, t-stat, p-value, CI) with n_obs=0.
11611162
- **Zero-observation cohorts in group effects:** If all treated observations for a cohort have NaN `y_tilde` (excluded from estimation), that cohort's group effect is NaN with n_obs=0.
11621163
- **Note:** Survey weights in TwoStageDiD GMM sandwich via weighted cross-products: bread uses (X'_2 W X_2)^{-1}, gamma_hat uses (X'_{10} W X_{10})^{-1}(X'_1 W X_2), per-cluster scores multiply by survey weights. PSU clustering, stratification, and FPC are fully supported in the meat matrix via `_compute_stratified_meat_from_psu_scores()`. When strata or FPC are present, the meat computation replaces `S' S` with the stratified formula `sum_h (1 - f_h) * (n_h/(n_h-1)) * centered_h' centered_h`. Strata also enters survey df (n_PSU - n_strata) for t-distribution inference. Bootstrap + survey supported (Phase 6) via PSU-level multiplier weights.
1164+
- **Note:** Both the iterative FE solver (`_iterative_fe`, Stage 1) and the iterative alternating-projection demeaning helper (`_iterative_demean`, used in covariate residualization) emit `UserWarning` when `max_iter` exhausts without reaching `tol`, via `diff_diff.utils.warn_if_not_converged`. Silent return of the current iterate was classified as a silent failure under the Phase 2 audit and replaced with an explicit signal to match the logistic/Poisson IRLS pattern in `linalg.py`.
11631165

11641166
**Reference implementation(s):**
11651167
- R: `did2s::did2s()` (Kyle Butts & John Gardner)
@@ -1299,6 +1301,7 @@ The saturated ETWFE regression includes:
12991301

13001302
The interaction coefficient `δ_{g,t}` identifies `ATT(g, t)` under parallel trends.
13011303
- **Note:** OLS path uses iterative alternating-projection within-transformation (uniform weights) for exact FE absorption on both balanced and unbalanced panels. One-pass demeaning (`y - ȳ_i - ȳ_t + ȳ`) is only exact for balanced panels.
1304+
- **Note:** The weighted within-transformation (`utils.within_transform` with `weights`) is invoked on every WooldridgeDiD fit (survey weights when provided, `np.ones` otherwise) and emits a `UserWarning` on non-convergence per the shared convention documented under *Absorbed Fixed Effects with Survey Weights*.
13021305

13031306
*Nonlinear extensions (Wooldridge 2023):*
13041307

@@ -2520,6 +2523,15 @@ unequal selection probabilities).
25202523
are rejected (single-pass sequential demeaning is not the correct weighted
25212524
FWL projection for N > 1 dimensions; iterative alternating projections are
25222525
needed but not yet implemented).
2526+
- **Note:** The shared weighted within-transformation path
2527+
(`diff_diff.utils.within_transform`, hit whenever `weights is not None`) emits
2528+
a `UserWarning` per call when any transformed variable exits the
2529+
alternating-projection loop without reaching `tol` within `max_iter`.
2530+
Defaults: `max_iter=100`, `tol=1e-8`. This signal applies uniformly across
2531+
TwoWayFixedEffects, SunAbraham, BaconDecomposition, and WooldridgeDiD whenever
2532+
they route through this helper (survey-weighted or otherwise). Silent return
2533+
of the current iterate was classified as a silent failure under the Phase 2
2534+
audit and replaced with this explicit signal.
25232535

25242536
### Survey Degrees of Freedom
25252537

tests/test_imputation.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2087,3 +2087,57 @@ def test_balanced_cohort_mask_requires_negative_horizons(self):
20872087
df_treated, "first_treat", all_horizons, 1, cohort_rel_times
20882088
)
20892089
assert all(mask1)
2090+
2091+
def test_iterative_fe_warns_on_nonconvergence(self):
2092+
"""Silent-failure audit axis B: _iterative_fe must warn when max_iter exhausts."""
2093+
rng = np.random.default_rng(42)
2094+
n_units, n_periods = 8, 5
2095+
units = np.repeat(np.arange(n_units), n_periods)
2096+
times = np.tile(np.arange(n_periods), n_units)
2097+
y = rng.standard_normal(n_units * n_periods)
2098+
idx = pd.RangeIndex(len(y))
2099+
est = ImputationDiD()
2100+
2101+
with pytest.warns(UserWarning, match="did not converge"):
2102+
est._iterative_fe(y, units, times, idx, max_iter=1, tol=1e-15)
2103+
2104+
def test_iterative_fe_no_warning_on_convergence(self):
2105+
"""Silent-failure audit axis B: no warning on well-behaved convergent input."""
2106+
rng = np.random.default_rng(42)
2107+
n_units, n_periods = 8, 5
2108+
units = np.repeat(np.arange(n_units), n_periods)
2109+
times = np.tile(np.arange(n_periods), n_units)
2110+
y = rng.standard_normal(n_units * n_periods)
2111+
idx = pd.RangeIndex(len(y))
2112+
est = ImputationDiD()
2113+
2114+
with warnings.catch_warnings(record=True) as w:
2115+
warnings.simplefilter("always")
2116+
est._iterative_fe(y, units, times, idx)
2117+
assert not any("did not converge" in str(x.message) for x in w)
2118+
2119+
def test_iterative_demean_warns_on_nonconvergence(self):
2120+
"""Silent-failure audit axis B: _iterative_demean must warn when max_iter exhausts."""
2121+
rng = np.random.default_rng(42)
2122+
n_units, n_periods = 8, 5
2123+
units = np.repeat(np.arange(n_units), n_periods)
2124+
times = np.tile(np.arange(n_periods), n_units)
2125+
vals = rng.standard_normal(n_units * n_periods)
2126+
idx = pd.RangeIndex(len(vals))
2127+
2128+
with pytest.warns(UserWarning, match="did not converge"):
2129+
ImputationDiD._iterative_demean(vals, units, times, idx, max_iter=1, tol=1e-15)
2130+
2131+
def test_iterative_demean_no_warning_on_convergence(self):
2132+
"""Silent-failure audit axis B: no warning on well-behaved convergent input."""
2133+
rng = np.random.default_rng(42)
2134+
n_units, n_periods = 8, 5
2135+
units = np.repeat(np.arange(n_units), n_periods)
2136+
times = np.tile(np.arange(n_periods), n_units)
2137+
vals = rng.standard_normal(n_units * n_periods)
2138+
idx = pd.RangeIndex(len(vals))
2139+
2140+
with warnings.catch_warnings(record=True) as w:
2141+
warnings.simplefilter("always")
2142+
ImputationDiD._iterative_demean(vals, units, times, idx)
2143+
assert not any("did not converge" in str(x.message) for x in w)

tests/test_methodology_twfe.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,27 @@ def test_demeaned_outcome_sums_to_zero(self):
235235
np.testing.assert_allclose(unit_sums.values, 0, atol=1e-10)
236236
np.testing.assert_allclose(time_sums.values, 0, atol=1e-10)
237237

238+
def test_within_transform_weighted_warns_on_nonconvergence(self):
239+
"""Silent-failure audit axis B: within_transform weighted path must warn."""
240+
data = generate_twfe_panel(n_units=20, n_periods=4, seed=99)
241+
weights = np.ones(len(data))
242+
243+
with pytest.warns(UserWarning, match="did not converge"):
244+
within_transform(
245+
data, ["outcome"], "unit", "period",
246+
weights=weights, max_iter=1, tol=1e-15,
247+
)
248+
249+
def test_within_transform_weighted_no_warning_on_convergence(self):
250+
"""Silent-failure audit axis B: no warning on well-behaved convergent input."""
251+
data = generate_twfe_panel(n_units=20, n_periods=4, seed=99)
252+
weights = np.ones(len(data))
253+
254+
with warnings.catch_warnings(record=True) as w:
255+
warnings.simplefilter("always")
256+
within_transform(data, ["outcome"], "unit", "period", weights=weights)
257+
assert not any("did not converge" in str(x.message) for x in w)
258+
238259

239260
# =============================================================================
240261
# Phase 2: R Comparison

0 commit comments

Comments
 (0)