@@ -403,3 +403,71 @@ def test_to_dict(self, simple_data):
403403 assert "overall_att" in d
404404 assert "n_obs" in d
405405 assert "estimation_method" in d
406+
407+
408+ # ---------------------------------------------------------------------------
409+ # Regression tests for specific bug fixes
410+ # ---------------------------------------------------------------------------
411+
412+ class TestStaggeredTripleDiffRegressions :
413+ def test_base_period_outside_panel_warns (self ):
414+ """Cohort with base period before observed panel should warn, not crash."""
415+ # Cohort g=2 with anticipation=1 needs base_period = g-1-1 = 0,
416+ # but periods start at 1. Should warn and skip that cell.
417+ data = generate_staggered_ddd_data (
418+ n_units = 100 , n_periods = 4 , cohort_periods = [2 , 4 ],
419+ seed = 77 ,
420+ )
421+ est = StaggeredTripleDifference (anticipation = 1 )
422+ import warnings as _w
423+ with _w .catch_warnings (record = True ) as caught :
424+ _w .simplefilter ("always" )
425+ res = est .fit (data , "outcome" , "unit" , "period" ,
426+ "first_treat" , "eligibility" )
427+ base_period_warnings = [
428+ w for w in caught if "outside the observed panel" in str (w .message )
429+ ]
430+ assert len (base_period_warnings ) > 0 , "Expected warning about base period"
431+ assert np .isfinite (res .overall_att )
432+
433+ def test_empty_subgroup_warns (self ):
434+ """Data where one (S,Q) cell is empty should warn, not crash."""
435+ data = generate_staggered_ddd_data (
436+ n_units = 100 , cohort_periods = [4 , 6 ], seed = 88 ,
437+ )
438+ # Remove all ineligible units from cohort 6 to make (S=6,Q=0) empty
439+ mask = ~ ((data ["first_treat" ] == 6 ) & (data ["eligibility" ] == 0 ))
440+ data = data [mask ].reset_index (drop = True )
441+ est = StaggeredTripleDifference ()
442+ import warnings as _w
443+ with _w .catch_warnings (record = True ) as caught :
444+ _w .simplefilter ("always" )
445+ res = est .fit (data , "outcome" , "unit" , "period" ,
446+ "first_treat" , "eligibility" )
447+ subgroup_warnings = [
448+ w for w in caught if "Empty subgroup" in str (w .message )
449+ ]
450+ assert len (subgroup_warnings ) > 0 , "Expected warning about empty subgroup"
451+ assert np .isfinite (res .overall_att )
452+
453+ def test_collinear_covariates_cached_ps_finite (self ):
454+ """Collinear covariates with PS cache reuse should produce finite results."""
455+ data = generate_staggered_ddd_data (
456+ n_units = 200 , treatment_effect = 3.0 ,
457+ add_covariates = True , seed = 55 ,
458+ )
459+ # Add a perfectly collinear covariate (x3 = 2*x1)
460+ data ["x3" ] = 2.0 * data ["x1" ]
461+ est = StaggeredTripleDifference (
462+ estimation_method = "dr" , rank_deficient_action = "warn" ,
463+ )
464+ import warnings as _w
465+ with _w .catch_warnings (record = True ):
466+ _w .simplefilter ("always" )
467+ res = est .fit (data , "outcome" , "unit" , "period" ,
468+ "first_treat" , "eligibility" ,
469+ covariates = ["x1" , "x2" , "x3" ])
470+ # All group-time effects should be finite despite collinearity
471+ for (g , t ), eff in res .group_time_effects .items ():
472+ assert np .isfinite (eff ["effect" ]), f"Non-finite ATT at (g={ g } ,t={ t } )"
473+ assert np .isfinite (eff ["se" ]), f"Non-finite SE at (g={ g } ,t={ t } )"
0 commit comments