Skip to content

Commit 21c25f1

Browse files
igerberclaude
andcommitted
Address AI review: validate new params, rename stratum_effects
- Add type/length/finiteness validation for covariate_effects and te_covariate_interaction - Rename stratum_effects to base_stratum_effects in dgp_truth to clarify these are base TEs before dynamic/covariate modifiers - Add validation rejection tests for malformed inputs Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 3f06597 commit 21c25f1

2 files changed

Lines changed: 45 additions & 4 deletions

File tree

diff_diff/prep_dgp.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1288,7 +1288,8 @@ def generate_survey_did_data(
12881288
If True, attaches a diagnostic dict to ``df.attrs["dgp_truth"]``
12891289
with keys: ``population_att`` (weight-weighted average of treated
12901290
true effects), ``deff_kish`` (1 + CV(w)^2), ``stratum_effects``
1291-
(dict mapping stratum index to TE), ``icc_realized`` (ANOVA-based
1291+
(base stratum TEs before dynamic/covariate modifiers),
1292+
``icc_realized`` (ANOVA-based
12921293
ICC computed on period-1 data).
12931294
covariate_effects : tuple of (float, float), optional
12941295
Coefficients ``(beta1, beta2)`` for covariates x1 and x2 in the
@@ -1412,9 +1413,23 @@ def generate_survey_did_data(
14121413
f"got {sum(strata_sizes)}"
14131414
)
14141415

1415-
# --- Resolve covariate coefficients ---
1416+
# --- Validate and resolve covariate coefficients ---
1417+
if covariate_effects is not None:
1418+
covariate_effects = tuple(covariate_effects)
1419+
if len(covariate_effects) != 2:
1420+
raise ValueError(
1421+
f"covariate_effects must have length 2, got {len(covariate_effects)}"
1422+
)
1423+
if not all(np.isfinite(c) for c in covariate_effects):
1424+
raise ValueError(
1425+
f"covariate_effects must be finite, got {covariate_effects}"
1426+
)
14161427
_beta1, _beta2 = covariate_effects if covariate_effects is not None else (0.5, 0.3)
14171428

1429+
if not np.isfinite(te_covariate_interaction):
1430+
raise ValueError(
1431+
f"te_covariate_interaction must be finite, got {te_covariate_interaction}"
1432+
)
14181433
if te_covariate_interaction != 0.0 and not add_covariates:
14191434
raise ValueError(
14201435
"te_covariate_interaction requires add_covariates=True"
@@ -1696,7 +1711,7 @@ def generate_survey_did_data(
16961711
df.attrs["dgp_truth"] = {
16971712
"population_att": population_att,
16981713
"deff_kish": float(deff_kish),
1699-
"stratum_effects": stratum_effects,
1714+
"base_stratum_effects": stratum_effects,
17001715
"icc_realized": icc_realized,
17011716
}
17021717

tests/test_prep.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1767,7 +1767,7 @@ def test_return_true_population_att(self):
17671767
truth = df.attrs["dgp_truth"]
17681768
assert "population_att" in truth
17691769
assert "deff_kish" in truth
1770-
assert "stratum_effects" in truth
1770+
assert "base_stratum_effects" in truth
17711771
assert "icc_realized" in truth
17721772
assert truth["deff_kish"] >= 1.0
17731773
assert truth["icc_realized"] >= 0.0
@@ -1888,3 +1888,29 @@ def test_te_covariate_interaction_requires_covariates(self):
18881888
generate_survey_did_data(
18891889
te_covariate_interaction=0.5, add_covariates=False, seed=42
18901890
)
1891+
1892+
def test_covariate_effects_validation(self):
1893+
"""covariate_effects must be length 2 and finite."""
1894+
from diff_diff.prep_dgp import generate_survey_did_data
1895+
1896+
with pytest.raises(ValueError, match="covariate_effects must have length 2"):
1897+
generate_survey_did_data(
1898+
add_covariates=True, covariate_effects=(1.0,), seed=42
1899+
)
1900+
with pytest.raises(ValueError, match="covariate_effects must be finite"):
1901+
generate_survey_did_data(
1902+
add_covariates=True, covariate_effects=(np.nan, 0.3), seed=42
1903+
)
1904+
with pytest.raises(ValueError, match="covariate_effects must be finite"):
1905+
generate_survey_did_data(
1906+
add_covariates=True, covariate_effects=(0.5, np.inf), seed=42
1907+
)
1908+
1909+
def test_te_covariate_interaction_validation(self):
1910+
"""te_covariate_interaction must be finite."""
1911+
from diff_diff.prep_dgp import generate_survey_did_data
1912+
1913+
with pytest.raises(ValueError, match="te_covariate_interaction must be finite"):
1914+
generate_survey_did_data(
1915+
add_covariates=True, te_covariate_interaction=np.nan, seed=42
1916+
)

0 commit comments

Comments
 (0)