Skip to content

Commit 7514cbe

Browse files
authored
Merge pull request #319 from igerber/fix/sparse-to-dense-lstsq-fallback
Signal silent sparse -> dense lstsq fallback in ImputationDiD and TwoStageDiD
2 parents da9d3d3 + 28507ec commit 7514cbe

6 files changed

Lines changed: 119 additions & 6 deletions

File tree

diff_diff/imputation.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1515,8 +1515,18 @@ def _build_A_sparse(df_sub, unit_vals, time_vals):
15151515
A0tA0_sparse = A_0.T @ A_0 # stays sparse
15161516
try:
15171517
z = spsolve(A0tA0_sparse.tocsc(), A1_w)
1518-
except Exception:
1519-
# Fallback to dense lstsq if sparse solver fails (e.g., singular matrix)
1518+
except Exception as exc:
1519+
# Fallback to dense lstsq if sparse solver fails (e.g., singular matrix).
1520+
# Silent-failure audit axis C: emit a UserWarning on fallback instead
1521+
# of swallowing the error.
1522+
warnings.warn(
1523+
"ImputationDiD variance: sparse solve of (A_0' [W] A_0) z = A_1' w "
1524+
f"failed ({type(exc).__name__}); falling back to dense lstsq. This "
1525+
"may indicate a rank-deficient or near-singular normal-equations "
1526+
"matrix and variance estimates may be less reliable.",
1527+
UserWarning,
1528+
stacklevel=2,
1529+
)
15201530
A0tA0_dense = A0tA0_sparse.toarray()
15211531
z, _, _, _ = np.linalg.lstsq(A0tA0_dense, A1_w, rcond=None)
15221532

diff_diff/two_stage.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1652,8 +1652,18 @@ def _compute_gmm_variance(
16521652
gamma_hat = np.column_stack(
16531653
[solve_XtX(Xt1_WX2[:, j]) for j in range(Xt1_WX2.shape[1])]
16541654
)
1655-
except RuntimeError:
1656-
# Singular matrix — fall back to dense least-squares
1655+
except RuntimeError as exc:
1656+
# Singular matrix — fall back to dense least-squares. Silent-failure
1657+
# audit axis C: emit a UserWarning on fallback instead of swallowing.
1658+
warnings.warn(
1659+
"TwoStageDiD GMM sandwich: sparse factorization of "
1660+
f"(X'_{{10}} W X_{{10}}) failed ({type(exc).__name__}); falling "
1661+
"back to dense lstsq. This may indicate a rank-deficient or "
1662+
"near-singular Stage 1 design matrix and SE estimates may be "
1663+
"less reliable.",
1664+
UserWarning,
1665+
stacklevel=2,
1666+
)
16571667
gamma_hat = np.linalg.lstsq(XtWX_10.toarray(), Xt1_WX2, rcond=None)[0]
16581668
if gamma_hat.ndim == 1:
16591669
gamma_hat = gamma_hat.reshape(-1, 1)

diff_diff/two_stage_bootstrap.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,17 @@ def _compute_cluster_S_scores(
139139
gamma_hat = np.column_stack(
140140
[solve_XtX(Xt1_X2[:, j]) for j in range(Xt1_X2.shape[1])]
141141
)
142-
except RuntimeError:
142+
except RuntimeError as exc:
143+
# Silent-failure audit axis C: emit a UserWarning on fallback instead
144+
# of swallowing the error.
145+
warnings.warn(
146+
"TwoStageDiD bootstrap: sparse factorization of X_10' X_10 "
147+
f"failed ({type(exc).__name__}); falling back to dense lstsq. "
148+
"This may indicate a rank-deficient or near-singular Stage 1 "
149+
"design matrix and bootstrap SE estimates may be less reliable.",
150+
UserWarning,
151+
stacklevel=2,
152+
)
143153
gamma_hat = np.linalg.lstsq(XtX_10.toarray(), Xt1_X2, rcond=None)[0]
144154
if gamma_hat.ndim == 1:
145155
gamma_hat = gamma_hat.reshape(-1, 1)

docs/methodology/REGISTRY.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1079,7 +1079,7 @@ where `W_it(h) = 1[K_it = h]` are lead indicators, estimated on `Omega_0` only.
10791079
- **Non-constant `first_treat` within a unit:** Emits `UserWarning` identifying the count and example unit. The estimator proceeds using the first observed value per unit (via `.first()` aggregation), but results may be unreliable.
10801080
- **treatment_effects DataFrame weights:** `weight` column uses `1/n_valid` for finite tau_hat and 0 for NaN tau_hat, consistent with the ATT estimand (unweighted), or normalized survey weights `sw_i/sum(sw)` when `survey_design` is active.
10811081
- **Rank-deficient covariates in variance:** Covariates with NaN coefficients (dropped for rank deficiency in Step 1) are excluded from the variance design matrices `A_0`/`A_1`. Only covariates with finite coefficients participate in the `v_it` projection.
1082-
- **Sparse variance solver:** `_compute_v_untreated_with_covariates` uses `scipy.sparse.linalg.spsolve` to solve `(A_0'A_0) z = A_1'w` without densifying the normal equations matrix. Falls back to dense `lstsq` if the sparse solver fails.
1082+
- **Sparse variance solver:** `_compute_v_untreated_with_covariates` uses `scipy.sparse.linalg.spsolve` to solve `(A_0'A_0) z = A_1'w` without densifying the normal equations matrix. Falls back to dense `lstsq` if the sparse solver fails and emits a `UserWarning` on the fallback (silent-failure audit axis C) so callers know variance estimates came from the degraded path.
10831083
- **Note:** Survey weights enter ImputationDiD via weighted iterative FE (Step 1), survey-weighted ATT aggregation (Step 3), and design-based variance via `compute_survey_if_variance()`. PSU clustering, stratification, and FPC are fully supported in the Theorem 3 variance path. When `resolved_survey` is present, the observation-level influence function (`v_it * epsilon_tilde_it`) is passed to `compute_survey_if_variance()` which applies the stratified PSU-level sandwich with FPC correction. Strata also enters survey df (n_PSU - n_strata) for t-distribution inference. Bootstrap + survey supported (Phase 6) via PSU-level multiplier weights.
10841084
- **Bootstrap inference:** Uses multiplier bootstrap on the Theorem 3 influence function: `psi_i = sum_t v_it * epsilon_tilde_it`. Cluster-level psi sums are pre-computed for each aggregation target (overall, per-horizon, per-group), then perturbed with multiplier weights (Rademacher by default; configurable via `bootstrap_weights` parameter to use Mammen or Webb weights, matching CallawaySantAnna). This is a library extension (not in the paper) consistent with CallawaySantAnna/SunAbraham bootstrap patterns.
10851085
- **Auxiliary residuals (Equation 8):** Uses v_it-weighted tau_tilde_g formula: `tau_tilde_g = sum(v_it * tau_hat_it) / sum(v_it)` within each partition group. Zero-weight groups (common in event-study SE computation) fall back to unweighted mean.
@@ -1162,6 +1162,7 @@ Our implementation uses multiplier bootstrap on the GMM influence function: clus
11621162
- **Zero-observation cohorts in group effects:** If all treated observations for a cohort have NaN `y_tilde` (excluded from estimation), that cohort's group effect is NaN with n_obs=0.
11631163
- **Note:** Survey weights in TwoStageDiD GMM sandwich via weighted cross-products: bread uses (X'_2 W X_2)^{-1}, gamma_hat uses (X'_{10} W X_{10})^{-1}(X'_1 W X_2), per-cluster scores multiply by survey weights. PSU clustering, stratification, and FPC are fully supported in the meat matrix via `_compute_stratified_meat_from_psu_scores()`. When strata or FPC are present, the meat computation replaces `S' S` with the stratified formula `sum_h (1 - f_h) * (n_h/(n_h-1)) * centered_h' centered_h`. Strata also enters survey df (n_PSU - n_strata) for t-distribution inference. Bootstrap + survey supported (Phase 6) via PSU-level multiplier weights.
11641164
- **Note:** Both the iterative FE solver (`_iterative_fe`, Stage 1) and the iterative alternating-projection demeaning helper (`_iterative_demean`, used in covariate residualization) emit `UserWarning` when `max_iter` exhausts without reaching `tol`, via `diff_diff.utils.warn_if_not_converged`. Silent return of the current iterate was classified as a silent failure under the Phase 2 audit and replaced with an explicit signal to match the logistic/Poisson IRLS pattern in `linalg.py`.
1165+
- **Note:** The GMM sandwich and bootstrap paths both use `scipy.sparse.linalg.factorized` for the Stage 1 normal-equations solve `(X'_{10} W X_{10}) gamma = X'_1 W X_2` and fall back to dense `lstsq` when the sparse factorization raises `RuntimeError` on a near-singular matrix. Both fallback sites emit a `UserWarning` (silent-failure audit axis C) so callers know SE estimates came from the degraded path rather than the fast sparse path.
11651166

11661167
**Reference implementation(s):**
11671168
- R: `did2s::did2s()` (Kyle Butts & John Gardner)

tests/test_imputation.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -885,6 +885,31 @@ def test_sparse_solver_dense_fallback(self):
885885
assert np.isfinite(results.overall_se)
886886
assert results.overall_se > 0
887887

888+
def test_sparse_solver_dense_fallback_emits_warning(self):
889+
"""Silent-failure audit axis C: the sparse -> dense lstsq fallback must
890+
emit a UserWarning so callers are informed that variance estimates come
891+
from the degraded path."""
892+
import unittest.mock
893+
894+
data = generate_test_data(n_units=80, n_periods=8, seed=42)
895+
rng = np.random.default_rng(42)
896+
data["x1"] = rng.standard_normal(len(data))
897+
898+
est = ImputationDiD()
899+
900+
with unittest.mock.patch(
901+
"diff_diff.imputation.spsolve", side_effect=RuntimeError("test failure")
902+
):
903+
with pytest.warns(UserWarning, match="sparse solve.*falling back to dense lstsq"):
904+
est.fit(
905+
data,
906+
outcome="outcome",
907+
unit="unit",
908+
time="time",
909+
first_treat="first_treat",
910+
covariates=["x1"],
911+
)
912+
888913

889914
# =============================================================================
890915
# TestImputationBootstrap

tests/test_two_stage.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,63 @@ def test_event_study_se_positive(self):
490490
assert eff["se"] > 0, f"SE at h={h} should be positive"
491491
assert np.isfinite(eff["se"])
492492

493+
def test_sparse_factorized_dense_fallback_emits_warning(self):
494+
"""Silent-failure audit axis C: when sparse factorization of Stage 1's
495+
normal-equations matrix fails and the GMM sandwich falls back to dense
496+
lstsq, a UserWarning must surface so callers know SE came from the
497+
degraded path rather than the fast sparse path.
498+
499+
Also verifies the dense fallback still yields finite, usable SEs so
500+
that a future regression in the fallback control flow cannot keep the
501+
warning while breaking the degraded path."""
502+
import unittest.mock
503+
504+
data = generate_test_data()
505+
506+
with unittest.mock.patch(
507+
"diff_diff.two_stage.sparse_factorized",
508+
side_effect=RuntimeError("test failure"),
509+
):
510+
with pytest.warns(UserWarning, match="sparse factorization.*falling back to dense lstsq"):
511+
results = TwoStageDiD().fit(
512+
data,
513+
outcome="outcome",
514+
unit="unit",
515+
time="time",
516+
first_treat="first_treat",
517+
)
518+
519+
# Dense fallback must still produce a usable SE.
520+
assert np.isfinite(results.overall_se)
521+
assert results.overall_se > 0
522+
523+
def test_sparse_factorized_bootstrap_dense_fallback_emits_warning(self):
524+
"""Silent-failure audit axis C: the TwoStage bootstrap path has the
525+
same sparse->dense fallback and must also emit a UserWarning.
526+
527+
Also verifies the bootstrap dense fallback still yields finite,
528+
usable SEs."""
529+
import unittest.mock
530+
531+
data = generate_test_data()
532+
533+
with unittest.mock.patch(
534+
"diff_diff.two_stage_bootstrap.sparse_factorized",
535+
side_effect=RuntimeError("test failure"),
536+
):
537+
with pytest.warns(UserWarning, match="sparse factorization.*falling back to dense lstsq"):
538+
results = TwoStageDiD(n_bootstrap=4, seed=42).fit(
539+
data,
540+
outcome="outcome",
541+
unit="unit",
542+
time="time",
543+
first_treat="first_treat",
544+
)
545+
546+
# Bootstrap dense fallback must still produce a usable SE.
547+
assert np.isfinite(results.overall_se)
548+
assert results.overall_se > 0
549+
493550

494551
# =============================================================================
495552
# TestTwoStageDiDEdgeCases

0 commit comments

Comments
 (0)