Skip to content

Commit 573de52

Browse files
igerberclaude
andcommitted
Surface silent np.linalg.solve fallbacks across axis-A minor solver paths
Addresses findings #17, #18, #19 from the Phase 2 silent-failures audit (axis A, all Minor). Each site previously ran np.linalg.solve against a matrix that could be rank-deficient or near-singular with no user-facing signal. - StaggeredTripleDifference: `_compute_did_panel` now appends a condition-number sample to an instance tracker on LinAlgError; `fit()` emits ONE aggregate UserWarning listing affected (g, g_c, t) cells and the max condition number instead of silently falling back to np.linalg.lstsq per pair. Tracker resets on repeat fit. - EfficientDiD covariate sieve (estimate_propensity_ratio_sieve, estimate_inverse_propensity_sieve): precondition-check the normal-equations matrix via np.linalg.cond before solve and reject K values above 1/sqrt(eps); partial-K skips now surface via UserWarning listing the skipped K values, instead of being swallowed by `continue`. - compute_survey_vcov: check cond(X'WX) before the sandwich solve; emit UserWarning above the 1/sqrt(eps) threshold so ill-conditioned bread matrices don't silently produce unstable variance estimates. Sibling sites picked up via repo-wide lstsq-fallback pattern grep (per the pattern-check feedback memory): - two_stage.py:1768 (TSL variance bread) - two_stage_bootstrap.py:197 (multiplier bootstrap bread) Both now warn before the silent lstsq fallback. Adds 8 targeted tests across test_staggered_triple_diff.py, test_efficient_did.py, and test_survey.py, covering collinear/ill-conditioned triggers and happy-path negatives. REGISTRY.md notes added for each affected estimator section. No behavioral change on well-conditioned inputs. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 56730af commit 573de52

9 files changed

Lines changed: 384 additions & 0 deletions

diff_diff/efficient_did_covariates.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,10 @@ def estimate_propensity_ratio_sieve(
227227

228228
best_ic = np.inf
229229
best_ratio = np.ones(n_units) # fallback: constant ratio 1
230+
singular_K: List[int] = [] # K values skipped due to rank deficiency (#18)
231+
# Near-singular matrices solve without raising LinAlgError but return
232+
# numerically meaningless beta. Rule-of-thumb threshold: 1/sqrt(eps).
233+
cond_threshold = 1.0 / np.sqrt(np.finfo(float).eps)
230234

231235
for K in range(1, k_max + 1):
232236
n_basis = comb(K + d, d)
@@ -249,13 +253,23 @@ def estimate_propensity_ratio_sieve(
249253
A = Psi_gp.T @ Psi_gp
250254
b = Psi_g.sum(axis=0)
251255

256+
# Precondition check (#18, axis A): reject near-singular A explicitly
257+
# so np.linalg.solve can't silently return garbage coefficients.
258+
with np.errstate(invalid="ignore", over="ignore"):
259+
A_cond = float(np.linalg.cond(A))
260+
if not np.isfinite(A_cond) or A_cond > cond_threshold:
261+
singular_K.append(K)
262+
continue
263+
252264
try:
253265
beta = np.linalg.solve(A, b)
254266
except np.linalg.LinAlgError:
267+
singular_K.append(K)
255268
continue # singular — try next K
256269

257270
# Check for NaN/Inf in solution
258271
if not np.all(np.isfinite(beta)):
272+
singular_K.append(K)
259273
continue
260274

261275
# Predicted ratio for all units
@@ -282,6 +296,18 @@ def estimate_propensity_ratio_sieve(
282296
UserWarning,
283297
stacklevel=2,
284298
)
299+
elif singular_K:
300+
# Finding #18 (axis A): partial K-failure was previously silent.
301+
# Surface it so users see that the selected basis order was
302+
# forced by rank deficiency at higher K rather than by the IC.
303+
warnings.warn(
304+
f"Propensity ratio sieve: skipped K={singular_K} due to "
305+
f"rank-deficient or non-finite normal equations. "
306+
f"Selected basis used the remaining K values; "
307+
f"this may indicate limited variation in the covariates.",
308+
UserWarning,
309+
stacklevel=2,
310+
)
285311

286312
# Overlap diagnostics: warn if ratios require significant clipping
287313
n_extreme = int(np.sum((best_ratio < 1.0 / ratio_clip) | (best_ratio > ratio_clip)))
@@ -377,6 +403,8 @@ def estimate_inverse_propensity_sieve(
377403

378404
best_ic = np.inf
379405
best_s = np.full(n_units, fallback_ratio) # fallback: unconditional
406+
singular_K: List[int] = [] # K values skipped due to rank deficiency (#18)
407+
cond_threshold = 1.0 / np.sqrt(np.finfo(float).eps)
380408

381409
for K in range(1, k_max + 1):
382410
n_basis = comb(K + d, d)
@@ -397,11 +425,20 @@ def estimate_inverse_propensity_sieve(
397425
# RHS: sum of basis over ALL units (not just one group)
398426
b = basis_all.sum(axis=0)
399427

428+
# Precondition check (#18, axis A): see ratio-sieve comment above.
429+
with np.errstate(invalid="ignore", over="ignore"):
430+
A_cond = float(np.linalg.cond(A))
431+
if not np.isfinite(A_cond) or A_cond > cond_threshold:
432+
singular_K.append(K)
433+
continue
434+
400435
try:
401436
beta = np.linalg.solve(A, b)
402437
except np.linalg.LinAlgError:
438+
singular_K.append(K)
403439
continue
404440
if not np.all(np.isfinite(beta)):
441+
singular_K.append(K)
405442
continue
406443

407444
s_hat = basis_all @ beta
@@ -423,6 +460,16 @@ def estimate_inverse_propensity_sieve(
423460
UserWarning,
424461
stacklevel=2,
425462
)
463+
elif singular_K:
464+
# Finding #18 (axis A): partial K-failure was previously silent.
465+
warnings.warn(
466+
f"Inverse propensity sieve: skipped K={singular_K} due to "
467+
f"rank-deficient or non-finite normal equations. "
468+
f"Selected basis used the remaining K values; "
469+
f"this may indicate limited variation in the covariates.",
470+
UserWarning,
471+
stacklevel=2,
472+
)
426473

427474
# Overlap diagnostics: warn if s_hat values require clipping
428475
n_clipped = int(np.sum((best_s < 1.0) | (best_s > float(n_units))))

diff_diff/staggered_triple_diff.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,11 @@ def fit(
348348
{} if (covariates and self.estimation_method in ("ipw", "dr")) else None
349349
)
350350

351+
# Tracker for rank-deficient OR-IF solves across all (g, g_c, t) cells.
352+
# _compute_did_panel appends one condition-number sample per LinAlgError
353+
# so we emit ONE aggregate warning below rather than fanning out.
354+
self._lstsq_fallback_tracker: List[float] = []
355+
351356
for g in treatment_groups:
352357
# In universal mode, skip the reference period (t == g-1-anticipation)
353358
# so it's omitted from GT estimation. The event-study mixin injects
@@ -507,6 +512,26 @@ def fit(
507512
comparison_group_counts[(g, t)] = len(gc_labels)
508513
gmm_weights_store[(g, t)] = dict(zip(gc_labels, gmm_w.tolist()))
509514

515+
# Consolidated OR influence-function rank-deficiency warning.
516+
# Finding #17 in the Phase 2 silent-failures audit: the per-pair OR
517+
# solve at _compute_did_panel() previously fell back to lstsq with no
518+
# signal, so near/fully singular X'WX in the covariate expansion went
519+
# to the user as a normal result.
520+
if self._lstsq_fallback_tracker:
521+
n_cells = len(self._lstsq_fallback_tracker)
522+
finite_conds = [c for c in self._lstsq_fallback_tracker if np.isfinite(c)]
523+
max_cond = max(finite_conds) if finite_conds else float("inf")
524+
warnings.warn(
525+
f"Rank-deficient X'WX detected in the outcome-regression "
526+
f"influence-function step for {n_cells} (g, g_c, t) pair(s); "
527+
f"fell back to np.linalg.lstsq. "
528+
f"Max condition number of affected X'WX: {max_cond:.2e}. "
529+
f"Consider dropping collinear covariates or using "
530+
f"estimation_method='ipw' to avoid the OR projection.",
531+
UserWarning,
532+
stacklevel=2,
533+
)
534+
510535
# Consolidated EPV summary warning
511536
if epv_diagnostics:
512537
low_epv = {k: v for k, v in epv_diagnostics.items() if v.get("is_low")}
@@ -1330,6 +1355,14 @@ def _compute_did_panel(
13301355
try:
13311356
asy_linear_or = (np.linalg.solve(XpX, or_ex.T)).T
13321357
except np.linalg.LinAlgError:
1358+
# Rank-deficient X'WX in the OR influence-function step. Record
1359+
# a condition-number sample so fit() can emit ONE aggregate
1360+
# warning across all (g, g_c, t) cells rather than fanning out.
1361+
tracker = getattr(self, "_lstsq_fallback_tracker", None)
1362+
if tracker is not None:
1363+
with np.errstate(invalid="ignore", over="ignore"):
1364+
cond = float(np.linalg.cond(XpX))
1365+
tracker.append(cond)
13331366
asy_linear_or = (np.linalg.lstsq(XpX, or_ex.T, rcond=None)[0]).T
13341367

13351368
inf_treat_or = -(asy_linear_or @ M1)

diff_diff/survey.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1445,6 +1445,25 @@ def compute_survey_vcov(
14451445
return np.zeros((k, k))
14461446
return np.full((k, k), np.nan)
14471447

1448+
# Precondition check: near-singular X'WX lets np.linalg.solve return
1449+
# unstable values without raising (finding #19, axis A). Threshold of
1450+
# 1/sqrt(eps) ≈ 6.7e7 is the standard rule of thumb — above it, the
1451+
# sandwich bread becomes numerically unreliable and the caller should
1452+
# be told so.
1453+
with np.errstate(invalid="ignore", over="ignore"):
1454+
XtWX_cond = float(np.linalg.cond(XtWX))
1455+
cond_threshold = 1.0 / np.sqrt(np.finfo(float).eps)
1456+
if np.isfinite(XtWX_cond) and XtWX_cond > cond_threshold:
1457+
warnings.warn(
1458+
f"X'WX is ill-conditioned (cond={XtWX_cond:.2e}) in the "
1459+
f"survey sandwich variance; variance estimates may be "
1460+
f"numerically unstable. This typically indicates near "
1461+
f"multicollinearity or zero-weight strata dominating the "
1462+
f"bread matrix.",
1463+
UserWarning,
1464+
stacklevel=2,
1465+
)
1466+
14481467
# Sandwich: (X'WX)^{-1} meat (X'WX)^{-1}
14491468
try:
14501469
temp = np.linalg.solve(XtWX, meat)

diff_diff/two_stage.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1768,6 +1768,17 @@ def _compute_gmm_variance(
17681768
try:
17691769
bread = np.linalg.solve(XtWX_2, np.eye(k))
17701770
except np.linalg.LinAlgError:
1771+
# Sibling of finding #17 (axis A) — silent lstsq fallback in the
1772+
# TSL-variance bread was previously silent. Surface it so a
1773+
# rank-deficient second-stage design doesn't quietly degrade SEs.
1774+
warnings.warn(
1775+
"Rank-deficient second-stage X'WX in TwoStageDiD TSL variance; "
1776+
"falling back to np.linalg.lstsq for the bread matrix. "
1777+
"Analytical SEs may be numerically unstable; consider dropping "
1778+
"collinear covariates.",
1779+
UserWarning,
1780+
stacklevel=2,
1781+
)
17711782
bread = np.linalg.lstsq(XtWX_2, np.eye(k), rcond=None)[0]
17721783

17731784
# 7. V = bread @ meat @ bread

diff_diff/two_stage_bootstrap.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,17 @@ def _compute_cluster_S_scores(
197197
try:
198198
bread = np.linalg.solve(XtX_2, np.eye(k))
199199
except np.linalg.LinAlgError:
200+
# Sibling of finding #17 (axis A) — silent lstsq fallback in the
201+
# TwoStage bootstrap bread matrix. Called once per (static / event-
202+
# study / group) aggregation, so warning fan-out is bounded.
203+
warnings.warn(
204+
"Rank-deficient second-stage X'WX in TwoStageDiD multiplier "
205+
"bootstrap bread; falling back to np.linalg.lstsq. Bootstrap "
206+
"SEs may be numerically unstable; consider dropping collinear "
207+
"covariates.",
208+
UserWarning,
209+
stacklevel=2,
210+
)
200211
bread = np.linalg.lstsq(XtX_2, np.eye(k), rcond=None)[0]
201212

202213
return S, bread, unique_clusters

docs/methodology/REGISTRY.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -753,6 +753,7 @@ See `docs/methodology/continuous-did.md` Section 4 for full details.
753753
- **Balanced panel**: Short balanced panel required ("large-n, fixed-T" regime). Does not handle unbalanced panels or repeated cross-sections
754754
- Warn if treatment varies within units (non-absorbing treatment)
755755
- Warn if propensity score estimates are near boundary values
756+
- **Note:** Polynomial-sieve propensity fits now reject any K whose normal-equations matrix has condition number above `1/sqrt(eps)` (≈ 6.7e7) — previously a near-singular `np.linalg.solve` could return numerically meaningless coefficients without raising. If at least one K succeeds but others were skipped via this precondition, a `UserWarning` lists the skipped K values. If every K is skipped, the existing "estimation failed for all K values" fallback warning still fires. Axis-A finding #18 in the Phase 2 silent-failures audit.
756757

757758
*Estimator equation -- single treatment date (Equations 3.2, 3.5):*
758759

@@ -1175,6 +1176,7 @@ Our implementation uses multiplier bootstrap on the GMM influence function: clus
11751176
- **Zero-observation cohorts in group effects:** If all treated observations for a cohort have NaN `y_tilde` (excluded from estimation), that cohort's group effect is NaN with n_obs=0.
11761177
- **Note:** Survey weights in TwoStageDiD GMM sandwich via weighted cross-products: bread uses (X'_2 W X_2)^{-1}, gamma_hat uses (X'_{10} W X_{10})^{-1}(X'_1 W X_2), per-cluster scores multiply by survey weights. PSU clustering, stratification, and FPC are fully supported in the meat matrix via `_compute_stratified_meat_from_psu_scores()`. When strata or FPC are present, the meat computation replaces `S' S` with the stratified formula `sum_h (1 - f_h) * (n_h/(n_h-1)) * centered_h' centered_h`. Strata also enters survey df (n_PSU - n_strata) for t-distribution inference. Bootstrap + survey supported (Phase 6) via PSU-level multiplier weights.
11771178
- **Note:** Both the iterative FE solver (`_iterative_fe`, Stage 1) and the iterative alternating-projection demeaning helper (`_iterative_demean`, used in covariate residualization) emit `UserWarning` when `max_iter` exhausts without reaching `tol`, via `diff_diff.utils.warn_if_not_converged`. Silent return of the current iterate was classified as a silent failure under the Phase 2 audit and replaced with an explicit signal to match the logistic/Poisson IRLS pattern in `linalg.py`.
1179+
- **Note:** When the Stage-2 bread `X'_2 W X_2` is singular, both the analytical TSL variance (`two_stage.py`) and the multiplier-bootstrap bread (`two_stage_bootstrap.py`) now emit a `UserWarning` before falling back to `np.linalg.lstsq`. Previously this fallback was silent. Sibling of axis-A finding #17 in the Phase 2 silent-failures audit; surfaced by the repo-wide lstsq-fallback pattern grep that accompanied the StaggeredTripleDifference fix.
11781180
- **Note:** The GMM sandwich and bootstrap paths both use `scipy.sparse.linalg.factorized` for the Stage 1 normal-equations solve `(X'_{10} W X_{10}) gamma = X'_1 W X_2` and fall back to dense `lstsq` when the sparse factorization raises `RuntimeError` on a near-singular matrix. Both fallback sites emit a `UserWarning` (silent-failure audit axis C) so callers know SE estimates came from the degraded path rather than the fast sparse path.
11791181

11801182
**Reference implementation(s):**
@@ -1695,6 +1697,7 @@ has no additional effect.
16951697
- **Note:** `pscore_fallback` default changed from unconditional to error.
16961698
Set `pscore_fallback="unconditional"` for legacy behavior.
16971699
- Warns on singular GMM covariance matrix (falls back to pseudoinverse)
1700+
- **Note:** Rank-deficient X'WX in the per-pair outcome-regression influence-function step now emits ONE aggregate `UserWarning` at `fit()` time (counting affected (g, g_c, t) cells and reporting the max condition number), instead of silently falling back to `np.linalg.lstsq`. Axis-A finding #17 in the Phase 2 silent-failures audit.
16981701

16991702
*Data structure:*
17001703

@@ -2719,6 +2722,12 @@ unequal selection probabilities).
27192722
per-observation PSUs for the TSL meat computation, consistent with the
27202723
stratified-no-PSU path. The adjustment factor is `n/(n-1)` (not HC1's
27212724
`n/(n-k)`).
2725+
- **Note:** TSL now precondition-checks `X'WX` via `np.linalg.cond` before
2726+
solving the sandwich. If the condition number exceeds `1/sqrt(eps)` (≈
2727+
6.7e7) a `UserWarning` fires stating that the bread is ill-conditioned
2728+
and variance estimates may be numerically unstable. Previously a near-
2729+
singular `X'WX` could silently produce unstable SEs. Axis-A finding #19
2730+
in the Phase 2 silent-failures audit.
27222731

27232732
### Weight Type Effects on Inference
27242733

tests/test_efficient_did.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2052,3 +2052,85 @@ def test_inverse_propensity_sieve_fallback_warns(self):
20522052
assert np.all(np.isfinite(s_hat))
20532053
# Should fall back to unconditional n/n_group = 100/2 = 50
20542054
assert np.allclose(s_hat, 50.0)
2055+
2056+
2057+
# ---------------------------------------------------------------------------
2058+
# Silent-failure audit PR #9: finding #18 — estimate_*_sieve silently
2059+
# `continue`'d past rank-deficient K values. Now we track skipped K and
2060+
# warn when we ship a result that wasn't the IC-winner across all K.
2061+
# ---------------------------------------------------------------------------
2062+
2063+
2064+
class TestSievePartialKSkipWarning:
2065+
"""Finding #18 (axis A): partial K-failure no longer silent."""
2066+
2067+
def test_ratio_sieve_partial_skip_warns(self):
2068+
"""If some K's are rank-deficient but at least one succeeds,
2069+
the function warns about the partial skip instead of swallowing it."""
2070+
from diff_diff.efficient_did_covariates import estimate_propensity_ratio_sieve
2071+
2072+
rng = np.random.default_rng(7)
2073+
n = 200
2074+
# 1D covariate with discrete support {0, 1}. At K=1 the basis is
2075+
# [1, x]; at K>=2 the basis reaches size >= n_gp for most groups
2076+
# before hitting singularity, but with this discrete support the
2077+
# polynomial powers x^2, x^3, ... equal x, yielding rank-deficient
2078+
# normal equations deterministically.
2079+
X = rng.integers(0, 2, size=(n, 1)).astype(float)
2080+
mask_g = np.zeros(n, dtype=bool)
2081+
mask_g[:100] = True
2082+
mask_gp = np.zeros(n, dtype=bool)
2083+
mask_gp[100:] = True
2084+
with pytest.warns(UserWarning) as caught:
2085+
ratio = estimate_propensity_ratio_sieve(X, mask_g, mask_gp, k_max=3)
2086+
assert np.all(np.isfinite(ratio))
2087+
partial_skip_msgs = [
2088+
str(w.message) for w in caught if "skipped K=" in str(w.message)
2089+
]
2090+
assert partial_skip_msgs, (
2091+
"Expected a partial-K-skip warning when some K's are rank deficient "
2092+
"but at least one succeeds; got none."
2093+
)
2094+
# Message should name the specific K values that were skipped.
2095+
assert any("K=" in m for m in partial_skip_msgs)
2096+
2097+
def test_inverse_propensity_sieve_partial_skip_warns(self):
2098+
"""Same contract for the inverse propensity sieve."""
2099+
from diff_diff.efficient_did_covariates import estimate_inverse_propensity_sieve
2100+
2101+
rng = np.random.default_rng(7)
2102+
n = 200
2103+
X = rng.integers(0, 2, size=(n, 1)).astype(float)
2104+
mask = np.zeros(n, dtype=bool)
2105+
mask[:100] = True
2106+
with pytest.warns(UserWarning) as caught:
2107+
s_hat = estimate_inverse_propensity_sieve(X, mask, k_max=3)
2108+
assert np.all(np.isfinite(s_hat))
2109+
partial_skip_msgs = [
2110+
str(w.message) for w in caught if "skipped K=" in str(w.message)
2111+
]
2112+
assert partial_skip_msgs
2113+
2114+
def test_ratio_sieve_no_warning_when_no_skips(self):
2115+
"""Clean, well-conditioned covariates → no partial-skip warning."""
2116+
from diff_diff.efficient_did_covariates import estimate_propensity_ratio_sieve
2117+
2118+
rng = np.random.default_rng(101)
2119+
n = 300
2120+
X = rng.normal(0, 1, (n, 2))
2121+
mask_g = np.zeros(n, dtype=bool)
2122+
mask_g[:150] = True
2123+
mask_gp = np.zeros(n, dtype=bool)
2124+
mask_gp[150:] = True
2125+
import warnings as _w
2126+
2127+
with _w.catch_warnings(record=True) as caught:
2128+
_w.simplefilter("always")
2129+
ratio = estimate_propensity_ratio_sieve(X, mask_g, mask_gp, k_max=3)
2130+
assert np.all(np.isfinite(ratio))
2131+
partial_skip_msgs = [
2132+
str(w.message) for w in caught if "skipped K=" in str(w.message)
2133+
]
2134+
assert partial_skip_msgs == [], (
2135+
f"Unexpected partial-skip warning on clean data: {partial_skip_msgs}"
2136+
)

0 commit comments

Comments
 (0)