@@ -1180,7 +1180,7 @@ def test_basic_shape_and_columns(self):
11801180 """Test output shape and expected columns."""
11811181 from diff_diff .prep import generate_survey_did_data
11821182
1183- data = generate_survey_did_data (n_units = 100 , n_periods = 4 , seed = 42 )
1183+ data = generate_survey_did_data (n_units = 100 , n_periods = 4 , cohort_periods = [ 2 , 3 ], seed = 42 )
11841184 assert len (data ) == 400 # 100 units x 4 periods
11851185 expected = {"unit" , "period" , "outcome" , "first_treat" , "treated" ,
11861186 "true_effect" , "stratum" , "psu" , "fpc" , "weight" }
@@ -1297,8 +1297,8 @@ def test_top_level_import(self):
12971297 """Test that generate_survey_did_data is importable from diff_diff."""
12981298 from diff_diff import generate_survey_did_data
12991299
1300- data = generate_survey_did_data (n_units = 10 , n_periods = 2 , seed = 42 )
1301- assert len (data ) == 20
1300+ data = generate_survey_did_data (n_units = 10 , n_periods = 4 , cohort_periods = [ 2 ] , seed = 42 )
1301+ assert len (data ) == 40
13021302
13031303 def test_jk1_minimum_psu_guard (self ):
13041304 """Test that JK1 replicates require at least 2 PSUs."""
@@ -1316,10 +1316,10 @@ def test_repeated_cross_section(self):
13161316 from diff_diff .prep import generate_survey_did_data
13171317
13181318 data = generate_survey_did_data (
1319- n_units = 20 , n_periods = 3 , panel = False , seed = 42 ,
1319+ n_units = 20 , n_periods = 4 , cohort_periods = [ 2 ] , panel = False , seed = 42 ,
13201320 )
1321- assert len (data ) == 60
1322- assert data ["unit" ].nunique () == 60 # unique across all periods
1321+ assert len (data ) == 80
1322+ assert data ["unit" ].nunique () == 80 # unique across all periods
13231323 # No unit appears in more than one period
13241324 assert data .groupby ("unit" )["period" ].nunique ().max () == 1
13251325
@@ -1338,3 +1338,23 @@ def test_empty_cohort_periods(self):
13381338
13391339 with pytest .raises (ValueError , match = "cohort_periods must be" ):
13401340 generate_survey_did_data (cohort_periods = [], seed = 42 )
1341+
1342+ def test_cohort_period_out_of_range (self ):
1343+ """Test that out-of-range cohort periods raise ValueError."""
1344+ import pytest
1345+ from diff_diff .prep import generate_survey_did_data
1346+
1347+ # Period 0 is invalid (must be >= 1)
1348+ with pytest .raises (ValueError , match = "must be between" ):
1349+ generate_survey_did_data (cohort_periods = [0 ], seed = 42 )
1350+ # Period == n_periods is invalid (must be < n_periods)
1351+ with pytest .raises (ValueError , match = "must be between" ):
1352+ generate_survey_did_data (n_periods = 8 , cohort_periods = [8 ], seed = 42 )
1353+
1354+ def test_cohort_period_non_integer (self ):
1355+ """Test that non-integer cohort periods raise ValueError."""
1356+ import pytest
1357+ from diff_diff .prep import generate_survey_did_data
1358+
1359+ with pytest .raises (ValueError , match = "must contain integers" ):
1360+ generate_survey_did_data (cohort_periods = [2.5 ], seed = 42 )
0 commit comments