Skip to content

Commit 4fbb82d

Browse files
igerberclaude
andcommitted
Address P2/P3: validate trim_weights caps, add survey pretrend coverage
P2: Validate upper/lower are finite and non-negative in trim_weights. Reject NaN/negative/inf caps before applying np.minimum/np.maximum. P3: Add survey+covariates pretrend test and survey+bootstrap pretrend test exercising the distinct code paths. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent f065afe commit 4fbb82d

2 files changed

Lines changed: 50 additions & 0 deletions

File tree

diff_diff/prep.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1280,6 +1280,13 @@ def trim_weights(
12801280
raise ValueError(f"quantile must be in (0, 1), got {quantile}")
12811281
upper = float(np.nanquantile(w, quantile))
12821282

1283+
# Validate cap values are finite and non-negative
1284+
if upper is not None:
1285+
if not np.isfinite(upper) or upper < 0:
1286+
raise ValueError(f"upper must be finite and >= 0, got {upper}")
1287+
if lower is not None:
1288+
if not np.isfinite(lower) or lower < 0:
1289+
raise ValueError(f"lower must be finite and >= 0, got {lower}")
12831290
if upper is not None and lower is not None and lower > upper:
12841291
raise ValueError(
12851292
f"lower ({lower}) must be <= upper ({upper}). "

tests/test_survey_phase8.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1283,6 +1283,49 @@ def test_pretrend_test_survey_no_raise(self, staggered_data):
12831283
assert np.isfinite(pt["f_stat"])
12841284
assert np.isfinite(pt["p_value"])
12851285

1286+
def test_pretrends_survey_with_covariates(self, staggered_data):
1287+
"""Survey pretrends with covariates uses the survey+covariate path."""
1288+
from diff_diff import ImputationDiD
1289+
1290+
data = staggered_data
1291+
data["x1"] = np.random.default_rng(42).normal(0, 1, len(data))
1292+
sd = SurveyDesign(weights="weight", strata="stratum", psu="psu")
1293+
est = ImputationDiD(pretrends=True, horizon_max=3)
1294+
result = est.fit(
1295+
data,
1296+
outcome="outcome",
1297+
unit="unit",
1298+
time="time",
1299+
first_treat="first_treat",
1300+
covariates=["x1"],
1301+
survey_design=sd,
1302+
aggregate="event_study",
1303+
)
1304+
pre = [h for h in result.event_study_effects if h < -1]
1305+
for h in pre:
1306+
assert np.isfinite(result.event_study_effects[h]["se"])
1307+
1308+
def test_pretrends_survey_with_bootstrap(self, staggered_data):
1309+
"""Survey pretrends + bootstrap: bootstrap doesn't overwrite pre-period SEs."""
1310+
from diff_diff import ImputationDiD
1311+
1312+
data = staggered_data
1313+
sd = SurveyDesign(weights="weight", strata="stratum", psu="psu")
1314+
est = ImputationDiD(pretrends=True, horizon_max=3, n_bootstrap=20, seed=42)
1315+
result = est.fit(
1316+
data,
1317+
outcome="outcome",
1318+
unit="unit",
1319+
time="time",
1320+
first_treat="first_treat",
1321+
survey_design=sd,
1322+
aggregate="event_study",
1323+
)
1324+
pre = [h for h in result.event_study_effects if h < -1]
1325+
for h in pre:
1326+
# Pre-period SEs come from lead regression, not bootstrap
1327+
assert np.isfinite(result.event_study_effects[h]["se"])
1328+
12861329
def test_pretrends_survey_always_treated_psu(self):
12871330
"""Survey pretrends with a PSU/stratum that has no untreated obs."""
12881331
from diff_diff import ImputationDiD

0 commit comments

Comments
 (0)