Skip to content

Commit df331ab

Browse files
igerberclaude
andcommitted
Validate cohort_periods range and type in generate_survey_did_data
Mirror the existing generate_staggered_data() guard: reject cohort periods that are non-integer, < 1, or >= n_periods. Add negative tests for out-of-range and non-integer inputs. Fix existing tests that used default cohort_periods with small n_periods (now caught by the new validation). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 2dc1888 commit df331ab

2 files changed

Lines changed: 35 additions & 6 deletions

File tree

diff_diff/prep_dgp.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,6 +1222,15 @@ def generate_survey_did_data(
12221222
cohort_periods = [3, 5]
12231223
if not cohort_periods:
12241224
raise ValueError("cohort_periods must be a non-empty list of integers")
1225+
for cp in cohort_periods:
1226+
if not isinstance(cp, int) or isinstance(cp, bool):
1227+
raise ValueError(
1228+
f"cohort_periods must contain integers, got {cp!r}"
1229+
)
1230+
if cp < 1 or cp >= n_periods:
1231+
raise ValueError(
1232+
f"Cohort period {cp} must be between 1 and {n_periods - 1}"
1233+
)
12251234

12261235
valid_wv = ("none", "moderate", "high")
12271236
if weight_variation not in valid_wv:

tests/test_prep.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1180,7 +1180,7 @@ def test_basic_shape_and_columns(self):
11801180
"""Test output shape and expected columns."""
11811181
from diff_diff.prep import generate_survey_did_data
11821182

1183-
data = generate_survey_did_data(n_units=100, n_periods=4, seed=42)
1183+
data = generate_survey_did_data(n_units=100, n_periods=4, cohort_periods=[2, 3], seed=42)
11841184
assert len(data) == 400 # 100 units x 4 periods
11851185
expected = {"unit", "period", "outcome", "first_treat", "treated",
11861186
"true_effect", "stratum", "psu", "fpc", "weight"}
@@ -1297,8 +1297,8 @@ def test_top_level_import(self):
12971297
"""Test that generate_survey_did_data is importable from diff_diff."""
12981298
from diff_diff import generate_survey_did_data
12991299

1300-
data = generate_survey_did_data(n_units=10, n_periods=2, seed=42)
1301-
assert len(data) == 20
1300+
data = generate_survey_did_data(n_units=10, n_periods=4, cohort_periods=[2], seed=42)
1301+
assert len(data) == 40
13021302

13031303
def test_jk1_minimum_psu_guard(self):
13041304
"""Test that JK1 replicates require at least 2 PSUs."""
@@ -1316,10 +1316,10 @@ def test_repeated_cross_section(self):
13161316
from diff_diff.prep import generate_survey_did_data
13171317

13181318
data = generate_survey_did_data(
1319-
n_units=20, n_periods=3, panel=False, seed=42,
1319+
n_units=20, n_periods=4, cohort_periods=[2], panel=False, seed=42,
13201320
)
1321-
assert len(data) == 60
1322-
assert data["unit"].nunique() == 60 # unique across all periods
1321+
assert len(data) == 80
1322+
assert data["unit"].nunique() == 80 # unique across all periods
13231323
# No unit appears in more than one period
13241324
assert data.groupby("unit")["period"].nunique().max() == 1
13251325

@@ -1338,3 +1338,23 @@ def test_empty_cohort_periods(self):
13381338

13391339
with pytest.raises(ValueError, match="cohort_periods must be"):
13401340
generate_survey_did_data(cohort_periods=[], seed=42)
1341+
1342+
def test_cohort_period_out_of_range(self):
1343+
"""Test that out-of-range cohort periods raise ValueError."""
1344+
import pytest
1345+
from diff_diff.prep import generate_survey_did_data
1346+
1347+
# Period 0 is invalid (must be >= 1)
1348+
with pytest.raises(ValueError, match="must be between"):
1349+
generate_survey_did_data(cohort_periods=[0], seed=42)
1350+
# Period == n_periods is invalid (must be < n_periods)
1351+
with pytest.raises(ValueError, match="must be between"):
1352+
generate_survey_did_data(n_periods=8, cohort_periods=[8], seed=42)
1353+
1354+
def test_cohort_period_non_integer(self):
1355+
"""Test that non-integer cohort periods raise ValueError."""
1356+
import pytest
1357+
from diff_diff.prep import generate_survey_did_data
1358+
1359+
with pytest.raises(ValueError, match="must contain integers"):
1360+
generate_survey_did_data(cohort_periods=[2.5], seed=42)

0 commit comments

Comments
 (0)