Skip to content

Commit 133df19

Browse files
igerberclaude
andcommitted
Add solve_logit weight validation, update docstrings, add negative tests
P2: solve_logit() now validates weights (shape, NaN, Inf, positive) before IRLS, giving clear errors instead of opaque numerical failures. P3 docs: update CallawaySantAnna fit() docstring to weights-only contract; add survey_design param docs to ImputationDiD/TwoStageDiD fit() and wrapper docstrings; update REGISTRY treatment_effects["weight"] note for survey mode. P3 tests: add negative tests for solve_logit bad weights (NaN, negative, wrong shape); add aweight/fweight rejection and FPC rejection tests for ImputationDiD and TwoStageDiD. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 869b85a commit 133df19

6 files changed

Lines changed: 124 additions & 3 deletions

File tree

diff_diff/imputation.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,12 @@ def fit(
203203
balance_e : int, optional
204204
When computing event study, restrict to cohorts observed at all
205205
relative times in [-balance_e, max_h].
206+
survey_design : SurveyDesign, optional
207+
Survey design specification for design-based inference. Supports
208+
pweight only (aweight/fweight raise ValueError). FPC raises
209+
NotImplementedError. PSU is used as cluster variable for Theorem 3
210+
variance. Strata enters survey df for t-distribution inference.
211+
Requires analytical inference (n_bootstrap=0).
206212
207213
Returns
208214
-------
@@ -1951,6 +1957,12 @@ def imputation_did(
19511957
Aggregation mode: None, "simple", "event_study", "group", "all".
19521958
balance_e : int, optional
19531959
Balance event study to cohorts observed at all relative times.
1960+
survey_design : SurveyDesign, optional
1961+
Survey design specification for design-based inference. Supports
1962+
pweight only (aweight/fweight raise ValueError). FPC raises
1963+
NotImplementedError. PSU is used as cluster variable for Theorem 3
1964+
variance. Strata enters survey df for t-distribution inference.
1965+
Requires analytical inference (n_bootstrap=0).
19541966
**kwargs
19551967
Additional keyword arguments passed to ImputationDiD constructor.
19561968

diff_diff/linalg.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1157,6 +1157,18 @@ def solve_logit(
11571157
X_with_intercept = np.column_stack([np.ones(n), X])
11581158
k = p + 1 # number of parameters including intercept
11591159

1160+
# Validate weights
1161+
if weights is not None:
1162+
weights = np.asarray(weights, dtype=np.float64)
1163+
if weights.shape != (n,):
1164+
raise ValueError(f"weights must have shape ({n},), got {weights.shape}")
1165+
if np.any(np.isnan(weights)):
1166+
raise ValueError("weights contain NaN values")
1167+
if np.any(~np.isfinite(weights)):
1168+
raise ValueError("weights contain Inf values")
1169+
if np.any(weights <= 0):
1170+
raise ValueError("weights must be strictly positive")
1171+
11601172
# Validate rank_deficient_action
11611173
valid_actions = {"warn", "error", "silent"}
11621174
if rank_deficient_action not in valid_actions:

diff_diff/staggered.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1155,8 +1155,10 @@ def fit(
11551155
For event study, balance the panel at relative time e.
11561156
Ensures all groups contribute to each relative period.
11571157
survey_design : SurveyDesign, optional
1158-
Survey design specification for design-based inference.
1159-
Supports weights, strata, PSU, and FPC.
1158+
Survey design specification. Only weights-only designs are supported
1159+
(strata/PSU/FPC raise NotImplementedError). Supports pweight only.
1160+
Covariates + IPW/DR + survey also raises NotImplementedError.
1161+
Use analytical inference (n_bootstrap=0) with survey_design.
11601162
11611163
Returns
11621164
-------

diff_diff/two_stage.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,12 @@ def fit(
199199
balance_e : int, optional
200200
When computing event study, restrict to cohorts observed at all
201201
relative times in [-balance_e, max_h].
202+
survey_design : SurveyDesign, optional
203+
Survey design specification for design-based inference. Supports
204+
pweight only (aweight/fweight raise ValueError). FPC raises
205+
NotImplementedError. PSU is used as cluster variable for Theorem 3
206+
variance. Strata enters survey df for t-distribution inference.
207+
Requires analytical inference (n_bootstrap=0).
202208
203209
Returns
204210
-------
@@ -1663,6 +1669,12 @@ def two_stage_did(
16631669
Aggregation mode: None, "simple", "event_study", "group", "all".
16641670
balance_e : int, optional
16651671
Balance event study to cohorts observed at all relative times.
1672+
survey_design : SurveyDesign, optional
1673+
Survey design specification for design-based inference. Supports
1674+
pweight only (aweight/fweight raise ValueError). FPC raises
1675+
NotImplementedError. PSU is used as cluster variable for Theorem 3
1676+
variance. Strata enters survey df for t-distribution inference.
1677+
Requires analytical inference (n_bootstrap=0).
16661678
**kwargs
16671679
Additional keyword arguments passed to TwoStageDiD constructor.
16681680

docs/methodology/REGISTRY.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -840,7 +840,7 @@ Y_it = alpha_i + beta_t [+ X'_it * delta] + W'_it * gamma + epsilon_it
840840
- **`balance_e` cohort filtering:** When `balance_e` is set, cohort balance is checked against the *full panel* (pre + post treatment) via `_build_cohort_rel_times()`, requiring observations at every relative time in `[-balance_e, max_h]`. Both analytical aggregation and bootstrap inference use the same `_compute_balanced_cohort_mask` with pre-computed cohort horizons.
841841
- **Bootstrap clustering:** Multiplier bootstrap generates weights at `cluster_var` granularity (defaults to `unit` if `cluster` not specified). Invalid cluster column raises ValueError.
842842
- **Non-constant `first_treat` within a unit:** Emits `UserWarning` identifying the count and example unit. The estimator proceeds using the first observed value per unit (via `.first()` aggregation), but results may be unreliable.
843-
- **treatment_effects DataFrame weights:** `weight` column uses `1/n_valid` for finite tau_hat and 0 for NaN tau_hat, consistent with the ATT estimand.
843+
- **treatment_effects DataFrame weights:** `weight` column uses `1/n_valid` for finite tau_hat and 0 for NaN tau_hat, consistent with the ATT estimand (unweighted), or normalized survey weights `sw_i/sum(sw)` when `survey_design` is active.
844844
- **Rank-deficient covariates in variance:** Covariates with NaN coefficients (dropped for rank deficiency in Step 1) are excluded from the variance design matrices `A_0`/`A_1`. Only covariates with finite coefficients participate in the `v_it` projection.
845845
- **Sparse variance solver:** `_compute_v_untreated_with_covariates` uses `scipy.sparse.linalg.spsolve` to solve `(A_0'A_0) z = A_1'w` without densifying the normal equations matrix. Falls back to dense `lstsq` if the sparse solver fails.
846846
- **Note:** Survey weights enter ImputationDiD via weighted iterative FE (Step 1), survey-weighted ATT aggregation (Step 3), and survey-weighted conservative variance (Theorem 3). PSU is used as the cluster variable for Theorem 3 variance. Strata enters survey df (n_PSU - n_strata) for t-distribution inference. FPC is not supported (raises NotImplementedError). Strata does NOT enter the variance formula itself (no stratified sandwich) — this is conservative relative to stratified variance. Bootstrap + survey deferred.

tests/test_survey_phase4.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,37 @@ def test_weight_scale_invariance(self):
183183

184184
np.testing.assert_allclose(beta1, beta2, atol=1e-10)
185185

186+
def test_nan_weights_raises(self):
187+
"""NaN weights should raise ValueError."""
188+
rng = np.random.RandomState(42)
189+
n = 50
190+
X = rng.randn(n, 2)
191+
y = (X @ [0.5, -0.5] + rng.randn(n) > 0).astype(float)
192+
weights = np.ones(n)
193+
weights[3] = np.nan
194+
with pytest.raises(ValueError, match="NaN"):
195+
solve_logit(X, y, weights=weights)
196+
197+
def test_negative_weights_raises(self):
198+
"""Negative weights should raise ValueError."""
199+
rng = np.random.RandomState(42)
200+
n = 50
201+
X = rng.randn(n, 2)
202+
y = (X @ [0.5, -0.5] + rng.randn(n) > 0).astype(float)
203+
weights = np.ones(n)
204+
weights[0] = -1.0
205+
with pytest.raises(ValueError, match="strictly positive"):
206+
solve_logit(X, y, weights=weights)
207+
208+
def test_wrong_shape_weights_raises(self):
209+
"""Wrong-length weights should raise ValueError."""
210+
rng = np.random.RandomState(42)
211+
n = 50
212+
X = rng.randn(n, 2)
213+
y = (X @ [0.5, -0.5] + rng.randn(n) > 0).astype(float)
214+
with pytest.raises(ValueError, match="shape"):
215+
solve_logit(X, y, weights=np.ones(n + 5))
216+
186217

187218
# =============================================================================
188219
# TestImputationDiDSurvey
@@ -444,6 +475,32 @@ def test_aggregate_all_with_survey(self, staggered_survey_data, survey_design_we
444475
assert result.event_study_effects is not None
445476
assert result.group_effects is not None
446477

478+
def test_aweight_raises(self, staggered_survey_data):
479+
"""aweight survey design should raise ValueError."""
480+
sd = SurveyDesign(weights="weight", weight_type="aweight")
481+
with pytest.raises(ValueError, match="pweight"):
482+
ImputationDiD().fit(
483+
staggered_survey_data,
484+
"outcome",
485+
"unit",
486+
"period",
487+
"first_treat",
488+
survey_design=sd,
489+
)
490+
491+
def test_fpc_raises(self, staggered_survey_data):
492+
"""FPC survey design should raise NotImplementedError."""
493+
sd = SurveyDesign(weights="weight", fpc="fpc")
494+
with pytest.raises(NotImplementedError, match="FPC"):
495+
ImputationDiD().fit(
496+
staggered_survey_data,
497+
"outcome",
498+
"unit",
499+
"period",
500+
"first_treat",
501+
survey_design=sd,
502+
)
503+
447504

448505
# =============================================================================
449506
# TestTwoStageDiDSurvey
@@ -647,6 +704,32 @@ def test_always_treated_with_survey(self, staggered_survey_data):
647704
assert np.isfinite(result.overall_se)
648705
assert result.survey_metadata is not None
649706

707+
def test_aweight_raises(self, staggered_survey_data):
708+
"""aweight survey design should raise ValueError."""
709+
sd = SurveyDesign(weights="weight", weight_type="aweight")
710+
with pytest.raises(ValueError, match="pweight"):
711+
TwoStageDiD().fit(
712+
staggered_survey_data,
713+
"outcome",
714+
"unit",
715+
"period",
716+
"first_treat",
717+
survey_design=sd,
718+
)
719+
720+
def test_fpc_raises(self, staggered_survey_data):
721+
"""FPC survey design should raise NotImplementedError."""
722+
sd = SurveyDesign(weights="weight", fpc="fpc")
723+
with pytest.raises(NotImplementedError, match="FPC"):
724+
TwoStageDiD().fit(
725+
staggered_survey_data,
726+
"outcome",
727+
"unit",
728+
"period",
729+
"first_treat",
730+
survey_design=sd,
731+
)
732+
650733

651734
# =============================================================================
652735
# TestCallawaySantAnnaSurvey

0 commit comments

Comments
 (0)