Skip to content

Commit 8339dcc

Browse files
igerberclaude
andcommitted
Signal silent sparse -> dense lstsq fallback in ImputationDiD and TwoStageDiD
Addresses axis-C findings #8, #9, and #10 from the silent-failures audit: three sites where a sparse factorization failure silently fell back to dense lstsq without any user-facing signal. - diff_diff/imputation.py:1516 (variance path: scipy.sparse.linalg.spsolve on (A_0' W A_0) z = A_1' w). Bare `except Exception` was swallowing the root cause before dense lstsq. Now emits a UserWarning identifying the exception type and explaining the fallback implication. - diff_diff/two_stage.py:1647 (GMM sandwich: sparse_factorized on X'_{10} W X_{10} for Stage 1 normal equations). `except RuntimeError` was silent; now emits a UserWarning. - diff_diff/two_stage_bootstrap.py:134 (bootstrap path: same pattern as above). `except RuntimeError` was silent; now emits a UserWarning. All three are single-call sites (per fit, or per aggregation level, or per bootstrap replicate at most a handful of times) so no aggregation wrapper pattern is needed — one warning per fallback event is appropriate. REGISTRY.md updated under ImputationDiD and TwoStageDiD. New tests (3): monkey-patch the sparse entry point to raise a RuntimeError, run .fit(), assert the UserWarning fires with the expected message prefix. Works against both the variance and bootstrap surfaces. Axis-C baseline: 3 major silent-fallback sites (imputation, two_stage, two_stage_bootstrap) -> 0 remaining in these files. PowerAnalysis simulation counter (finding #11) and ContinuousDiD B-spline (#12) still open as separate follow-ups. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 306ed99 commit 8339dcc

6 files changed

Lines changed: 104 additions & 6 deletions

File tree

diff_diff/imputation.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1515,8 +1515,18 @@ def _build_A_sparse(df_sub, unit_vals, time_vals):
15151515
A0tA0_sparse = A_0.T @ A_0 # stays sparse
15161516
try:
15171517
z = spsolve(A0tA0_sparse.tocsc(), A1_w)
1518-
except Exception:
1519-
# Fallback to dense lstsq if sparse solver fails (e.g., singular matrix)
1518+
except Exception as exc:
1519+
# Fallback to dense lstsq if sparse solver fails (e.g., singular matrix).
1520+
# Silent-failure audit axis C: emit a UserWarning on fallback instead
1521+
# of swallowing the error.
1522+
warnings.warn(
1523+
"ImputationDiD variance: sparse solve of (A_0' [W] A_0) z = A_1' w "
1524+
f"failed ({type(exc).__name__}); falling back to dense lstsq. This "
1525+
"may indicate a rank-deficient or near-singular normal-equations "
1526+
"matrix and variance estimates may be less reliable.",
1527+
UserWarning,
1528+
stacklevel=2,
1529+
)
15201530
A0tA0_dense = A0tA0_sparse.toarray()
15211531
z, _, _, _ = np.linalg.lstsq(A0tA0_dense, A1_w, rcond=None)
15221532

diff_diff/two_stage.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1652,8 +1652,18 @@ def _compute_gmm_variance(
16521652
gamma_hat = np.column_stack(
16531653
[solve_XtX(Xt1_WX2[:, j]) for j in range(Xt1_WX2.shape[1])]
16541654
)
1655-
except RuntimeError:
1656-
# Singular matrix — fall back to dense least-squares
1655+
except RuntimeError as exc:
1656+
# Singular matrix — fall back to dense least-squares. Silent-failure
1657+
# audit axis C: emit a UserWarning on fallback instead of swallowing.
1658+
warnings.warn(
1659+
"TwoStageDiD GMM sandwich: sparse factorization of "
1660+
f"(X'_{{10}} W X_{{10}}) failed ({type(exc).__name__}); falling "
1661+
"back to dense lstsq. This may indicate a rank-deficient or "
1662+
"near-singular Stage 1 design matrix and SE estimates may be "
1663+
"less reliable.",
1664+
UserWarning,
1665+
stacklevel=2,
1666+
)
16571667
gamma_hat = np.linalg.lstsq(XtWX_10.toarray(), Xt1_WX2, rcond=None)[0]
16581668
if gamma_hat.ndim == 1:
16591669
gamma_hat = gamma_hat.reshape(-1, 1)

diff_diff/two_stage_bootstrap.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,17 @@ def _compute_cluster_S_scores(
139139
gamma_hat = np.column_stack(
140140
[solve_XtX(Xt1_X2[:, j]) for j in range(Xt1_X2.shape[1])]
141141
)
142-
except RuntimeError:
142+
except RuntimeError as exc:
143+
# Silent-failure audit axis C: emit a UserWarning on fallback instead
144+
# of swallowing the error.
145+
warnings.warn(
146+
"TwoStageDiD bootstrap: sparse factorization of X_10' X_10 "
147+
f"failed ({type(exc).__name__}); falling back to dense lstsq. "
148+
"This may indicate a rank-deficient or near-singular Stage 1 "
149+
"design matrix and bootstrap SE estimates may be less reliable.",
150+
UserWarning,
151+
stacklevel=2,
152+
)
143153
gamma_hat = np.linalg.lstsq(XtX_10.toarray(), Xt1_X2, rcond=None)[0]
144154
if gamma_hat.ndim == 1:
145155
gamma_hat = gamma_hat.reshape(-1, 1)

docs/methodology/REGISTRY.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1079,7 +1079,7 @@ where `W_it(h) = 1[K_it = h]` are lead indicators, estimated on `Omega_0` only.
10791079
- **Non-constant `first_treat` within a unit:** Emits `UserWarning` identifying the count and example unit. The estimator proceeds using the first observed value per unit (via `.first()` aggregation), but results may be unreliable.
10801080
- **treatment_effects DataFrame weights:** `weight` column uses `1/n_valid` for finite tau_hat and 0 for NaN tau_hat, consistent with the ATT estimand (unweighted), or normalized survey weights `sw_i/sum(sw)` when `survey_design` is active.
10811081
- **Rank-deficient covariates in variance:** Covariates with NaN coefficients (dropped for rank deficiency in Step 1) are excluded from the variance design matrices `A_0`/`A_1`. Only covariates with finite coefficients participate in the `v_it` projection.
1082-
- **Sparse variance solver:** `_compute_v_untreated_with_covariates` uses `scipy.sparse.linalg.spsolve` to solve `(A_0'A_0) z = A_1'w` without densifying the normal equations matrix. Falls back to dense `lstsq` if the sparse solver fails.
1082+
- **Sparse variance solver:** `_compute_v_untreated_with_covariates` uses `scipy.sparse.linalg.spsolve` to solve `(A_0'A_0) z = A_1'w` without densifying the normal equations matrix. Falls back to dense `lstsq` if the sparse solver fails and emits a `UserWarning` on the fallback (silent-failure audit axis C) so callers know variance estimates came from the degraded path.
10831083
- **Note:** Survey weights enter ImputationDiD via weighted iterative FE (Step 1), survey-weighted ATT aggregation (Step 3), and design-based variance via `compute_survey_if_variance()`. PSU clustering, stratification, and FPC are fully supported in the Theorem 3 variance path. When `resolved_survey` is present, the observation-level influence function (`v_it * epsilon_tilde_it`) is passed to `compute_survey_if_variance()` which applies the stratified PSU-level sandwich with FPC correction. Strata also enters survey df (n_PSU - n_strata) for t-distribution inference. Bootstrap + survey supported (Phase 6) via PSU-level multiplier weights.
10841084
- **Bootstrap inference:** Uses multiplier bootstrap on the Theorem 3 influence function: `psi_i = sum_t v_it * epsilon_tilde_it`. Cluster-level psi sums are pre-computed for each aggregation target (overall, per-horizon, per-group), then perturbed with multiplier weights (Rademacher by default; configurable via `bootstrap_weights` parameter to use Mammen or Webb weights, matching CallawaySantAnna). This is a library extension (not in the paper) consistent with CallawaySantAnna/SunAbraham bootstrap patterns.
10851085
- **Auxiliary residuals (Equation 8):** Uses v_it-weighted tau_tilde_g formula: `tau_tilde_g = sum(v_it * tau_hat_it) / sum(v_it)` within each partition group. Zero-weight groups (common in event-study SE computation) fall back to unweighted mean.
@@ -1162,6 +1162,7 @@ Our implementation uses multiplier bootstrap on the GMM influence function: clus
11621162
- **Zero-observation cohorts in group effects:** If all treated observations for a cohort have NaN `y_tilde` (excluded from estimation), that cohort's group effect is NaN with n_obs=0.
11631163
- **Note:** Survey weights in TwoStageDiD GMM sandwich via weighted cross-products: bread uses (X'_2 W X_2)^{-1}, gamma_hat uses (X'_{10} W X_{10})^{-1}(X'_1 W X_2), per-cluster scores multiply by survey weights. PSU clustering, stratification, and FPC are fully supported in the meat matrix via `_compute_stratified_meat_from_psu_scores()`. When strata or FPC are present, the meat computation replaces `S' S` with the stratified formula `sum_h (1 - f_h) * (n_h/(n_h-1)) * centered_h' centered_h`. Strata also enters survey df (n_PSU - n_strata) for t-distribution inference. Bootstrap + survey supported (Phase 6) via PSU-level multiplier weights.
11641164
- **Note:** Both the iterative FE solver (`_iterative_fe`, Stage 1) and the iterative alternating-projection demeaning helper (`_iterative_demean`, used in covariate residualization) emit `UserWarning` when `max_iter` exhausts without reaching `tol`, via `diff_diff.utils.warn_if_not_converged`. Silent return of the current iterate was classified as a silent failure under the Phase 2 audit and replaced with an explicit signal to match the logistic/Poisson IRLS pattern in `linalg.py`.
1165+
- **Note:** The GMM sandwich and bootstrap paths both use `scipy.sparse.linalg.factorized` for the Stage 1 normal-equations solve `(X'_{10} W X_{10}) gamma = X'_1 W X_2` and fall back to dense `lstsq` when the sparse factorization raises `RuntimeError` on a near-singular matrix. Both fallback sites emit a `UserWarning` (silent-failure audit axis C) so callers know SE estimates came from the degraded path rather than the fast sparse path.
11651166

11661167
**Reference implementation(s):**
11671168
- R: `did2s::did2s()` (Kyle Butts & John Gardner)

tests/test_imputation.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -885,6 +885,31 @@ def test_sparse_solver_dense_fallback(self):
885885
assert np.isfinite(results.overall_se)
886886
assert results.overall_se > 0
887887

888+
def test_sparse_solver_dense_fallback_emits_warning(self):
889+
"""Silent-failure audit axis C: the sparse -> dense lstsq fallback must
890+
emit a UserWarning so callers are informed that variance estimates come
891+
from the degraded path."""
892+
import unittest.mock
893+
894+
data = generate_test_data(n_units=80, n_periods=8, seed=42)
895+
rng = np.random.default_rng(42)
896+
data["x1"] = rng.standard_normal(len(data))
897+
898+
est = ImputationDiD()
899+
900+
with unittest.mock.patch(
901+
"diff_diff.imputation.spsolve", side_effect=RuntimeError("test failure")
902+
):
903+
with pytest.warns(UserWarning, match="sparse solve.*falling back to dense lstsq"):
904+
est.fit(
905+
data,
906+
outcome="outcome",
907+
unit="unit",
908+
time="time",
909+
first_treat="first_treat",
910+
covariates=["x1"],
911+
)
912+
888913

889914
# =============================================================================
890915
# TestImputationBootstrap

tests/test_two_stage.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,48 @@ def test_event_study_se_positive(self):
490490
assert eff["se"] > 0, f"SE at h={h} should be positive"
491491
assert np.isfinite(eff["se"])
492492

493+
def test_sparse_factorized_dense_fallback_emits_warning(self):
494+
"""Silent-failure audit axis C: when sparse factorization of Stage 1's
495+
normal-equations matrix fails and the GMM sandwich falls back to dense
496+
lstsq, a UserWarning must surface so callers know SE came from the
497+
degraded path rather than the fast sparse path."""
498+
import unittest.mock
499+
500+
data = generate_test_data()
501+
502+
with unittest.mock.patch(
503+
"diff_diff.two_stage.sparse_factorized",
504+
side_effect=RuntimeError("test failure"),
505+
):
506+
with pytest.warns(UserWarning, match="sparse factorization.*falling back to dense lstsq"):
507+
TwoStageDiD().fit(
508+
data,
509+
outcome="outcome",
510+
unit="unit",
511+
time="time",
512+
first_treat="first_treat",
513+
)
514+
515+
def test_sparse_factorized_bootstrap_dense_fallback_emits_warning(self):
516+
"""Silent-failure audit axis C: the TwoStage bootstrap path has the
517+
same sparse->dense fallback and must also emit a UserWarning."""
518+
import unittest.mock
519+
520+
data = generate_test_data()
521+
522+
with unittest.mock.patch(
523+
"diff_diff.two_stage_bootstrap.sparse_factorized",
524+
side_effect=RuntimeError("test failure"),
525+
):
526+
with pytest.warns(UserWarning, match="sparse factorization.*falling back to dense lstsq"):
527+
TwoStageDiD(n_bootstrap=4, seed=42).fit(
528+
data,
529+
outcome="outcome",
530+
unit="unit",
531+
time="time",
532+
first_treat="first_treat",
533+
)
534+
493535

494536
# =============================================================================
495537
# TestTwoStageDiDEdgeCases

0 commit comments

Comments
 (0)