Skip to content

Commit 8bcbaa7

Browse files
igerberclaude
andcommitted
Fix OLS survey edge cases and harden solve_poisson (round 3)
- P1: Compute survey TSL vcov on kept columns only when solve_ols drops rank-deficient columns; expand back with NaN. Prevents singular bread matrix on all-eventually-treated ETWFE designs. - P1: Guard against zero-weight unit/time groups before within_transform; raise targeted ValueError instead of letting NaN propagate. - P2: Add weight validation (shape, NaN, Inf, non-negative, positive sum) to solve_poisson(weights=...) matching solve_logit pattern. - P2: Add regression tests for rank-deficient survey OLS and zero-weight unit rejection. - P3: Add pweight-only note to REGISTRY.md and survey-roadmap.md. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 31c74fa commit 8bcbaa7

5 files changed

Lines changed: 72 additions & 5 deletions

File tree

diff_diff/linalg.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2401,6 +2401,20 @@ def solve_poisson(
24012401
"""
24022402
n, k_orig = X.shape
24032403

2404+
# Validate weights (mirrors solve_logit validation)
2405+
if weights is not None:
2406+
weights = np.asarray(weights, dtype=np.float64)
2407+
if weights.shape != (n,):
2408+
raise ValueError(f"weights must have shape ({n},), got {weights.shape}")
2409+
if np.any(np.isnan(weights)):
2410+
raise ValueError("weights contain NaN values")
2411+
if np.any(~np.isfinite(weights)):
2412+
raise ValueError("weights contain Inf values")
2413+
if np.any(weights < 0):
2414+
raise ValueError("weights must be non-negative")
2415+
if np.sum(weights) <= 0:
2416+
raise ValueError("weights sum to zero — no observations have positive weight")
2417+
24042418
# Validate rank_deficient_action (same as solve_logit/solve_ols)
24052419
valid_actions = ("warn", "error", "silent")
24062420
if rank_deficient_action not in valid_actions:

diff_diff/wooldridge.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,23 @@ def _fit_ols(
647647
# weighted FWL projection — all columns (treatment interactions +
648648
# covariates) are demeaned together.
649649
wt_weights = survey_weights if survey_weights is not None else np.ones(len(tmp))
650+
651+
# Guard: zero-weight unit/time groups cause 0/0 in within_transform
652+
if survey_weights is not None and np.any(survey_weights == 0):
653+
for grp_col, grp_label in [(unit, "unit"), (time, "time period")]:
654+
grp_sums = sample.groupby(grp_col).apply(
655+
lambda g: survey_weights[g.index].sum(),
656+
include_groups=False,
657+
)
658+
zero_grps = grp_sums[grp_sums == 0].index.tolist()
659+
if zero_grps:
660+
raise ValueError(
661+
f"Survey weights sum to zero for {grp_label}(s) "
662+
f"{zero_grps[:3]}. Cannot compute weighted "
663+
f"within-transformation. Remove zero-weight "
664+
f"{grp_label}s or use non-zero weights."
665+
)
666+
650667
transformed = within_transform(
651668
tmp, all_vars, unit=unit, time=time, suffix="_demeaned",
652669
weights=wt_weights,
@@ -671,7 +688,15 @@ def _fit_ols(
671688
# Survey TSL vcov replaces cluster-robust vcov
672689
if resolved is not None:
673690
from diff_diff.survey import compute_survey_vcov
674-
vcov = compute_survey_vcov(X, resids, resolved)
691+
nan_mask_ols = np.isnan(coefs)
692+
if np.any(nan_mask_ols):
693+
kept = ~nan_mask_ols
694+
vcov_kept = compute_survey_vcov(X[:, kept], resids, resolved)
695+
vcov = np.full((len(coefs), len(coefs)), np.nan)
696+
kept_idx = np.where(kept)[0]
697+
vcov[np.ix_(kept_idx, kept_idx)] = vcov_kept
698+
else:
699+
vcov = compute_survey_vcov(X, resids, resolved)
675700

676701
# 7. Extract β_{g,t} and build gt_effects dict
677702
gt_effects: Dict[Tuple, Dict] = {}

docs/methodology/REGISTRY.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1179,6 +1179,7 @@ where `g(·)` is the link inverse (logistic or exp), `η_i` is the individual li
11791179
**Survey design notes:**
11801180
- **OLS path:** Survey-weighted within-transformation + WLS via `solve_ols(weights=...)` + TSL vcov via `compute_survey_vcov()`.
11811181
- **Logit/Poisson paths:** Survey-weighted IRLS via `solve_logit(weights=...)`/`solve_poisson(weights=...)` + X_tilde linearization trick for TSL vcov: `X_tilde = X * sqrt(V)`, `r_tilde = (y - mu) / sqrt(V)`, then `compute_survey_vcov(X_tilde, r_tilde, resolved)` gives correct QMLE sandwich. ASF means and gradients use survey-weighted averaging.
1182+
- **Note:** Only `pweight` (probability weights) are supported; `fweight`/`aweight` raise `ValueError` because the composed survey/QMLE weighting changes their semantics.
11821183
- **Note:** Replicate-weight variance is not yet supported (`NotImplementedError`). Use TSL (strata/PSU/FPC) instead.
11831184
- **Note:** Bootstrap inference (`n_bootstrap > 0`) cannot be combined with `survey_design` — no survey-aware bootstrap variant is implemented.
11841185

docs/survey-roadmap.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -209,10 +209,11 @@ co-sign.
209209
### 10f. WooldridgeDiD Survey Support — SHIPPED
210210

211211
WooldridgeDiD (ETWFE) now supports `survey_design` for all three methods
212-
(OLS, logit, Poisson). OLS uses survey-weighted within-transformation +
213-
WLS + TSL vcov. Logit/Poisson use survey-weighted IRLS + X_tilde
214-
linearization for TSL vcov. Replicate-weight designs raise
215-
`NotImplementedError`; bootstrap + survey is rejected.
212+
(OLS, logit, Poisson) with `pweight` only (`fweight`/`aweight` rejected).
213+
OLS uses survey-weighted within-transformation + WLS + TSL vcov.
214+
Logit/Poisson use survey-weighted IRLS + X_tilde linearization for TSL
215+
vcov. Replicate-weight designs raise `NotImplementedError`; bootstrap +
216+
survey is rejected.
216217

217218
### 10g. Practitioner Guidance (LOW priority)
218219

tests/test_wooldridge.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1521,3 +1521,29 @@ def test_poisson_zero_weight_cell(self, survey_panel):
15211521
)
15221522
assert np.isfinite(r.overall_att)
15231523
assert np.isfinite(r.overall_se)
1524+
1525+
def test_ols_survey_rank_deficient(self, survey_panel):
1526+
"""Survey OLS handles rank-deficient all-eventually-treated designs."""
1527+
from diff_diff.survey import SurveyDesign
1528+
# Remove never-treated (cohort=0) to create rank-deficient design
1529+
df = survey_panel[survey_panel["cohort"] > 0].copy()
1530+
sd = SurveyDesign(weights="weight", strata="stratum", psu="unit")
1531+
r = WooldridgeDiD(control_group="not_yet_treated").fit(
1532+
df, outcome="y", unit="unit", time="time",
1533+
cohort="cohort", survey_design=sd,
1534+
)
1535+
assert np.isfinite(r.overall_att)
1536+
assert np.isfinite(r.overall_se)
1537+
1538+
def test_ols_survey_zero_weight_unit_rejected(self, survey_panel):
1539+
"""Zero-weight unit raises ValueError before within_transform."""
1540+
from diff_diff.survey import SurveyDesign
1541+
df = survey_panel.copy()
1542+
# Zero out all weights for unit 0
1543+
df.loc[df["unit"] == 0, "weight"] = 0.0
1544+
sd = SurveyDesign(weights="weight", strata="stratum", psu="unit")
1545+
with pytest.raises(ValueError, match="Survey weights sum to zero for unit"):
1546+
WooldridgeDiD().fit(
1547+
df, outcome="y", unit="unit", time="time",
1548+
cohort="cohort", survey_design=sd,
1549+
)

0 commit comments

Comments
 (0)