Skip to content

Commit 4a48487

Browse files
authored
Merge pull request #334 from igerber/fix/axis-a-minor-solver-paths
Surface silent np.linalg.solve fallbacks across axis-A minor solver paths
2 parents 56730af + 099507d commit 4a48487

13 files changed

Lines changed: 751 additions & 16 deletions

TODO.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ Deferred items from PR reviews that were not addressed before merge.
8282
| HC2 / HC2 + Bell-McCaffrey on absorbed-FE fits currently raises `NotImplementedError` in three places: `TwoWayFixedEffects` unconditionally; `DifferenceInDifferences(absorb=..., vcov_type in {"hc2","hc2_bm"})`; `MultiPeriodDiD(absorb=..., vcov_type in {"hc2","hc2_bm"})`. Within-transformation preserves coefficients and residuals under FWL but not the hat matrix, so the reduced-design `h_ii` is not the diagonal of the full FE projection and CR2's block adjustment `A_g = (I - H_gg)^{-1/2}` is likewise wrong on absorbed cluster blocks. Lifting the guard needs HC2/CR2-BM computed from the full absorbed projection (unit/time FE dummies reconstructed internally, or a FE-aware hat-matrix formulation) and a parity harness against a full-dummy OLS run or R `fixest`/`clubSandwich`. HC1/CR1 are unaffected by this because they have no leverage term. | `twfe.py::fit`, `estimators.py::DifferenceInDifferences.fit`, `estimators.py::MultiPeriodDiD.fit` | Phase 1a | Medium |
8383
| Weighted CR2 Bell-McCaffrey cluster-robust (`vcov_type="hc2_bm"` + `cluster_ids` + `weights`) currently raises `NotImplementedError`. Weighted hat matrix and residual rebalancing need threading per clubSandwich WLS handling. | `linalg.py::_compute_cr2_bm` | Phase 1a | Medium |
8484
| Regenerate `benchmarks/data/clubsandwich_cr2_golden.json` from R (`Rscript benchmarks/R/generate_clubsandwich_golden.R`). Current JSON has `source: python_self_reference` as a stability anchor until an authoritative R run. | `benchmarks/R/generate_clubsandwich_golden.R` | Phase 1a | Medium |
85+
| `honest_did.py:1907` `np.linalg.solve(A_sys, b_sys) / except LinAlgError: continue` is a silent basis-rejection in the vertex-enumeration loop that is algorithmically intentional (try the next basis). Consider surfacing a count of rejected bases as a diagnostic when ARP enumeration exhausts, so users see when the vertex search was heavily constrained. Not a silent failure in the sense of the Phase 2 audit (the algorithm is supposed to skip), but the diagnostic would help debug borderline cases. | `honest_did.py` | #334 | Low |
8586

8687
#### Performance
8788

diff_diff/efficient_did_covariates.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,15 @@ def estimate_propensity_ratio_sieve(
172172
173173
Selects K via AIC/BIC: ``IC(K) = 2*loss(K) + C_n*K/n``.
174174
175-
On singular basis: tries lower K. Short-circuits r_{g,g}(X) = 1.
175+
Precondition check per K: if ``cond(Psi_{g'}' W Psi_{g'}) > 1/sqrt(eps)``
176+
(≈ 6.7e7), that K is skipped. LinAlgError on the `np.linalg.solve` call
177+
or a non-finite beta skips as well. If at least one K succeeds but
178+
others were skipped, emits a ``UserWarning`` listing the skipped K
179+
values (silent-failure audit PR, axis-A finding #18). If every K is
180+
skipped, the caller falls back to a constant ratio of 1 with a
181+
separate "estimation failed for all K values" warning.
182+
183+
Short-circuits ``r_{g,g}(X) = 1`` for same-cohort comparisons (PT-All).
176184
177185
Parameters
178186
----------
@@ -227,6 +235,10 @@ def estimate_propensity_ratio_sieve(
227235

228236
best_ic = np.inf
229237
best_ratio = np.ones(n_units) # fallback: constant ratio 1
238+
singular_K: List[int] = [] # K values skipped due to rank deficiency (#18)
239+
# Near-singular matrices solve without raising LinAlgError but return
240+
# numerically meaningless beta. Rule-of-thumb threshold: 1/sqrt(eps).
241+
cond_threshold = 1.0 / np.sqrt(np.finfo(float).eps)
230242

231243
for K in range(1, k_max + 1):
232244
n_basis = comb(K + d, d)
@@ -249,13 +261,23 @@ def estimate_propensity_ratio_sieve(
249261
A = Psi_gp.T @ Psi_gp
250262
b = Psi_g.sum(axis=0)
251263

264+
# Precondition check (#18, axis A): reject near-singular A explicitly
265+
# so np.linalg.solve can't silently return garbage coefficients.
266+
with np.errstate(invalid="ignore", over="ignore"):
267+
A_cond = float(np.linalg.cond(A))
268+
if not np.isfinite(A_cond) or A_cond > cond_threshold:
269+
singular_K.append(K)
270+
continue
271+
252272
try:
253273
beta = np.linalg.solve(A, b)
254274
except np.linalg.LinAlgError:
275+
singular_K.append(K)
255276
continue # singular — try next K
256277

257278
# Check for NaN/Inf in solution
258279
if not np.all(np.isfinite(beta)):
280+
singular_K.append(K)
259281
continue
260282

261283
# Predicted ratio for all units
@@ -282,6 +304,18 @@ def estimate_propensity_ratio_sieve(
282304
UserWarning,
283305
stacklevel=2,
284306
)
307+
elif singular_K:
308+
# Finding #18 (axis A): partial K-failure was previously silent.
309+
# Surface it so users see that the selected basis order was
310+
# forced by rank deficiency at higher K rather than by the IC.
311+
warnings.warn(
312+
f"Propensity ratio sieve: skipped K={singular_K} due to "
313+
f"rank-deficient or non-finite normal equations. "
314+
f"Selected basis used the remaining K values; "
315+
f"this may indicate limited variation in the covariates.",
316+
UserWarning,
317+
stacklevel=2,
318+
)
285319

286320
# Overlap diagnostics: warn if ratios require significant clipping
287321
n_extreme = int(np.sum((best_ratio < 1.0 / ratio_clip) | (best_ratio > ratio_clip)))
@@ -329,6 +363,14 @@ def estimate_inverse_propensity_sieve(
329363
units on the RHS (not just group g), following the paper's
330364
algorithm step 4.
331365
366+
Precondition check per K: if ``cond(Psi_{g'}' W Psi_{g'}) > 1/sqrt(eps)``
367+
(≈ 6.7e7), that K is skipped. LinAlgError on the `np.linalg.solve` call
368+
or a non-finite beta skips as well. If at least one K succeeds but
369+
others were skipped, emits a ``UserWarning`` listing the skipped K
370+
values (silent-failure audit PR, axis-A finding #18). If every K is
371+
skipped, the caller falls back to unconditional ``n/n_group`` scaling
372+
with a separate "estimation failed for all K values" warning.
373+
332374
Parameters
333375
----------
334376
covariate_matrix : ndarray, shape (n_units, n_covariates)
@@ -377,6 +419,8 @@ def estimate_inverse_propensity_sieve(
377419

378420
best_ic = np.inf
379421
best_s = np.full(n_units, fallback_ratio) # fallback: unconditional
422+
singular_K: List[int] = [] # K values skipped due to rank deficiency (#18)
423+
cond_threshold = 1.0 / np.sqrt(np.finfo(float).eps)
380424

381425
for K in range(1, k_max + 1):
382426
n_basis = comb(K + d, d)
@@ -397,11 +441,20 @@ def estimate_inverse_propensity_sieve(
397441
# RHS: sum of basis over ALL units (not just one group)
398442
b = basis_all.sum(axis=0)
399443

444+
# Precondition check (#18, axis A): see ratio-sieve comment above.
445+
with np.errstate(invalid="ignore", over="ignore"):
446+
A_cond = float(np.linalg.cond(A))
447+
if not np.isfinite(A_cond) or A_cond > cond_threshold:
448+
singular_K.append(K)
449+
continue
450+
400451
try:
401452
beta = np.linalg.solve(A, b)
402453
except np.linalg.LinAlgError:
454+
singular_K.append(K)
403455
continue
404456
if not np.all(np.isfinite(beta)):
457+
singular_K.append(K)
405458
continue
406459

407460
s_hat = basis_all @ beta
@@ -423,6 +476,16 @@ def estimate_inverse_propensity_sieve(
423476
UserWarning,
424477
stacklevel=2,
425478
)
479+
elif singular_K:
480+
# Finding #18 (axis A): partial K-failure was previously silent.
481+
warnings.warn(
482+
f"Inverse propensity sieve: skipped K={singular_K} due to "
483+
f"rank-deficient or non-finite normal equations. "
484+
f"Selected basis used the remaining K values; "
485+
f"this may indicate limited variation in the covariates.",
486+
UserWarning,
487+
stacklevel=2,
488+
)
426489

427490
# Overlap diagnostics: warn if s_hat values require clipping
428491
n_clipped = int(np.sum((best_s < 1.0) | (best_s > float(n_units))))

diff_diff/staggered.py

Lines changed: 77 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,29 @@ def _linear_regression(
9292
return beta, residuals
9393

9494

95-
def _safe_inv(A: np.ndarray) -> np.ndarray:
96-
"""Invert a square matrix with lstsq fallback for near-singular cases."""
95+
def _safe_inv(
96+
A: np.ndarray,
97+
tracker: Optional[list] = None,
98+
) -> np.ndarray:
99+
"""Invert a square matrix with lstsq fallback for near-singular cases.
100+
101+
Parameters
102+
----------
103+
A : np.ndarray
104+
Square matrix to invert.
105+
tracker : list, optional
106+
When provided, one condition-number sample of ``A`` is appended on
107+
every LinAlgError fallback. ``CallawaySantAnna.fit()`` initializes
108+
a list and emits a single aggregate `UserWarning` after the fit
109+
finishes, rather than surfacing a separate warning per fallback.
110+
Sibling of finding #17 in the Phase 2 silent-failures audit.
111+
"""
97112
try:
98113
return np.linalg.solve(A, np.eye(A.shape[0]))
99114
except np.linalg.LinAlgError:
115+
if tracker is not None:
116+
with np.errstate(invalid="ignore", over="ignore"):
117+
tracker.append(float(np.linalg.cond(A)))
100118
return np.linalg.lstsq(A, np.eye(A.shape[0]), rcond=None)[0]
101119

102120

@@ -1436,6 +1454,12 @@ def fit(
14361454
# Reset stale state from prior fit (prevents leaking event-study VCV)
14371455
self._event_study_vcov = None
14381456

1457+
# Tracker for _safe_inv lstsq fallbacks across all analytical SE
1458+
# paths (PS Hessian, OR bread, event-study bread, etc.). Emit ONE
1459+
# aggregate warning at the end of fit rather than fanning out per
1460+
# cell. Sibling of PR #9 finding #17.
1461+
self._safe_inv_tracker: List[float] = []
1462+
14391463
if not self.panel:
14401464
warnings.warn(
14411465
"panel=False uses repeated cross-section DRDID estimators "
@@ -1976,6 +2000,26 @@ def fit(
19762000
eff_data["effect"] + cband_crit_value * se_val,
19772001
)
19782002

2003+
# Consolidated _safe_inv lstsq-fallback warning (sibling of PR #9
2004+
# finding #17). Rank-deficient PS Hessian / OR bread matrices in the
2005+
# analytical SE paths previously fell back to np.linalg.lstsq
2006+
# silently per cell. Now aggregated here into ONE UserWarning so
2007+
# a bad design surface doesn't quietly degrade analytical SEs.
2008+
if self._safe_inv_tracker:
2009+
n_fallbacks = len(self._safe_inv_tracker)
2010+
finite_conds = [c for c in self._safe_inv_tracker if np.isfinite(c)]
2011+
max_cond = max(finite_conds) if finite_conds else float("inf")
2012+
warnings.warn(
2013+
f"Rank-deficient matrix encountered {n_fallbacks} time(s) "
2014+
f"in analytical SE paths (propensity-score Hessian or "
2015+
f"outcome-regression bread); fell back to np.linalg.lstsq. "
2016+
f"Max condition number of affected matrix: {max_cond:.2e}. "
2017+
f"Analytical SEs may be numerically unstable; consider "
2018+
f"dropping collinear covariates or using n_bootstrap > 0.",
2019+
UserWarning,
2020+
stacklevel=2,
2021+
)
2022+
19792023
# Store results
19802024
# Retrieve event-study VCV from aggregation mixin (Phase 7d).
19812025
# Clear it when bootstrap overwrites event-study SEs to prevent
@@ -2276,7 +2320,7 @@ def _ipw_estimation(
22762320
W_ps = W_ps * sw_all
22772321
# R: Hessian.ps = crossprod(X * sqrt(W)) / n
22782322
H_psi = X_all_int.T @ (W_ps[:, None] * X_all_int) / n_all_panel
2279-
H_psi_inv = _safe_inv(H_psi)
2323+
H_psi_inv = _safe_inv(H_psi, tracker=self._safe_inv_tracker)
22802324

22812325
D_all = np.concatenate([np.ones(n_t), np.zeros(n_c)])
22822326
score_ps = (D_all - pscore_all)[:, None] * X_all_int
@@ -2562,7 +2606,7 @@ def _doubly_robust(
25622606
if sw_all is not None:
25632607
W_ps = W_ps * sw_all
25642608
H_psi = X_all_int.T @ (W_ps[:, None] * X_all_int) / n_all_panel
2565-
H_psi_inv = _safe_inv(H_psi)
2609+
H_psi_inv = _safe_inv(H_psi, tracker=self._safe_inv_tracker)
25662610

25672611
D_all = np.concatenate([np.ones(n_t), np.zeros(n_c)])
25682612
score_ps = (D_all - pscore_all)[:, None] * X_all_int
@@ -2584,7 +2628,7 @@ def _doubly_robust(
25842628
X_c_int = X_control_with_intercept
25852629
W_diag = sw_control if sw_control is not None else np.ones(n_c)
25862630
XtWX = X_c_int.T @ (W_diag[:, None] * X_c_int)
2587-
bread = _safe_inv(XtWX)
2631+
bread = _safe_inv(XtWX, tracker=self._safe_inv_tracker)
25882632

25892633
# M1: dATT/dbeta — gradient of DR ATT w.r.t. OR parameters
25902634
X_t_int = X_treated_with_intercept
@@ -2628,7 +2672,7 @@ def _doubly_robust(
26282672

26292673
W_ps = pscore_all * (1 - pscore_all)
26302674
H_psi = X_all_int.T @ (W_ps[:, None] * X_all_int) / n_all_panel
2631-
H_psi_inv = _safe_inv(H_psi)
2675+
H_psi_inv = _safe_inv(H_psi, tracker=self._safe_inv_tracker)
26322676

26332677
D_all = np.concatenate([np.ones(n_t), np.zeros(n_c)])
26342678
score_ps = (D_all - pscore_all)[:, None] * X_all_int
@@ -2645,7 +2689,7 @@ def _doubly_robust(
26452689
# --- OR IF correction ---
26462690
X_c_int = X_control_with_intercept
26472691
XtX = X_c_int.T @ X_c_int
2648-
bread = _safe_inv(XtX)
2692+
bread = _safe_inv(XtX, tracker=self._safe_inv_tracker)
26492693

26502694
X_t_int = X_treated_with_intercept
26512695
M1 = (
@@ -3204,8 +3248,14 @@ def _outcome_regression_rc(
32043248
# R's colMeans (= sum/n_all) for M1, matching the product exactly.
32053249
W_ct = sw_ct if sw_ct is not None else np.ones(n_ct)
32063250
W_cs = sw_cs if sw_cs is not None else np.ones(n_cs)
3207-
bread_t = _safe_inv(X_ct_int.T @ (W_ct[:, None] * X_ct_int))
3208-
bread_s = _safe_inv(X_cs_int.T @ (W_cs[:, None] * X_cs_int))
3251+
bread_t = _safe_inv(
3252+
X_ct_int.T @ (W_ct[:, None] * X_ct_int),
3253+
tracker=self._safe_inv_tracker,
3254+
)
3255+
bread_s = _safe_inv(
3256+
X_cs_int.T @ (W_cs[:, None] * X_cs_int),
3257+
tracker=self._safe_inv_tracker,
3258+
)
32093259

32103260
# R: M1 = colMeans(w.cont * out.x) = sum(w_D * X) / n_all
32113261
M1 = (
@@ -3407,7 +3457,7 @@ def _ipw_estimation_rc(
34073457
W_ps = W_ps * sw_all
34083458
# R: Hessian.ps = crossprod(X * sqrt(W)) / n
34093459
H_psi = X_all_int.T @ (W_ps[:, None] * X_all_int) / n_all
3410-
H_psi_inv = _safe_inv(H_psi)
3460+
H_psi_inv = _safe_inv(H_psi, tracker=self._safe_inv_tracker)
34113461

34123462
score_ps = (D_all - pscore)[:, None] * X_all_int
34133463
if sw_all is not None:
@@ -3744,7 +3794,7 @@ def _doubly_robust_rc(
37443794
W_ps = W_ps * sw_all
37453795
# R: Hessian.ps = crossprod(X * sqrt(W)) / n
37463796
H_psi = X_all_int.T @ (W_ps[:, None] * X_all_int) / n_all
3747-
H_psi_inv = _safe_inv(H_psi)
3797+
H_psi_inv = _safe_inv(H_psi, tracker=self._safe_inv_tracker)
37483798

37493799
score_ps = (D_all - pscore)[:, None] * X_all_int
37503800
if sw_all is not None:
@@ -3779,8 +3829,14 @@ def _doubly_robust_rc(
37793829
# =====================================================================
37803830
W_ct_vals = sw_ct if sw_ct is not None else np.ones(n_ct)
37813831
W_cs_vals = sw_cs if sw_cs is not None else np.ones(n_cs)
3782-
bread_ct = _safe_inv(X_ct_int.T @ (W_ct_vals[:, None] * X_ct_int))
3783-
bread_cs = _safe_inv(X_cs_int.T @ (W_cs_vals[:, None] * X_cs_int))
3832+
bread_ct = _safe_inv(
3833+
X_ct_int.T @ (W_ct_vals[:, None] * X_ct_int),
3834+
tracker=self._safe_inv_tracker,
3835+
)
3836+
bread_cs = _safe_inv(
3837+
X_cs_int.T @ (W_cs_vals[:, None] * X_cs_int),
3838+
tracker=self._safe_inv_tracker,
3839+
)
37843840

37853841
# R: asy.lin.rep.ols (per-obs OLS score * bread)
37863842
asy_lin_rep_ct = (W_ct_vals * resid_ct)[:, None] * X_ct_int @ bread_ct
@@ -3818,8 +3874,14 @@ def _doubly_robust_rc(
38183874
# =====================================================================
38193875
W_gt_vals = sw_gt if sw_gt is not None else np.ones(n_gt)
38203876
W_gs_vals = sw_gs if sw_gs is not None else np.ones(n_gs)
3821-
bread_gt = _safe_inv(X_gt_int.T @ (W_gt_vals[:, None] * X_gt_int))
3822-
bread_gs = _safe_inv(X_gs_int.T @ (W_gs_vals[:, None] * X_gs_int))
3877+
bread_gt = _safe_inv(
3878+
X_gt_int.T @ (W_gt_vals[:, None] * X_gt_int),
3879+
tracker=self._safe_inv_tracker,
3880+
)
3881+
bread_gs = _safe_inv(
3882+
X_gs_int.T @ (W_gs_vals[:, None] * X_gs_int),
3883+
tracker=self._safe_inv_tracker,
3884+
)
38233885

38243886
asy_lin_rep_gt = (W_gt_vals * resid_gt)[:, None] * X_gt_int @ bread_gt
38253887
asy_lin_rep_gs = (W_gs_vals * resid_gs)[:, None] * X_gs_int @ bread_gs

0 commit comments

Comments
 (0)