@@ -2205,3 +2205,89 @@ def test_default_base_period_is_varying(self):
22052205 cs = CallawaySantAnna ()
22062206 assert cs .base_period == "varying"
22072207 assert cs .get_params ()["base_period" ] == "varying"
2208+
2209+ def test_varying_mode_no_fallback_to_nonconsecutive (self ):
2210+ """Varying mode skips pre-treatment effects where t-1 doesn't exist."""
2211+ # Create data where first period (e.g., period 1) has no t-1 predecessor
2212+ data = generate_staggered_data (
2213+ n_units = 100 ,
2214+ n_periods = 6 , # periods 1-6
2215+ n_cohorts = 2 ,
2216+ treatment_effect = 2.0 ,
2217+ seed = 42
2218+ )
2219+
2220+ # Identify the earliest time period in data
2221+ min_period = data ['time' ].min ()
2222+
2223+ cs = CallawaySantAnna (base_period = "varying" )
2224+ results = cs .fit (
2225+ data ,
2226+ outcome = 'outcome' ,
2227+ unit = 'unit' ,
2228+ time = 'time' ,
2229+ first_treat = 'first_treat'
2230+ )
2231+
2232+ # In varying mode, ATT(g, min_period) should NOT be computed for
2233+ # any cohort g because t-1 (period 0) doesn't exist
2234+ for (g , t ) in results .group_time_effects .keys ():
2235+ if t == min_period :
2236+ # This should not happen - the (g, min_period) pair should be skipped
2237+ pytest .fail (
2238+ f"ATT({ g } , { t } ) should not exist because t-1 doesn't exist. "
2239+ "Fallback to non-consecutive base period was incorrectly applied."
2240+ )
2241+
2242+ def test_no_post_treatment_effects_returns_nan_with_warning (self ):
2243+ """Warn and return NaN when no post-treatment effects exist."""
2244+ import warnings
2245+
2246+ # Create data where the treatment cohort treats AFTER the last observed period
2247+ # so there are no post-treatment periods (t >= g never holds)
2248+ n_units = 50
2249+ n_periods = 5
2250+ np .random .seed (42 )
2251+
2252+ data = []
2253+ for unit in range (n_units ):
2254+ for t in range (1 , n_periods + 1 ):
2255+ # Treated units get treated at period 6 (beyond data range)
2256+ # Data only goes to period 5, so no post-treatment periods exist
2257+ first_treat = n_periods + 1 if unit < n_units // 2 else 0
2258+ outcome = np .random .randn ()
2259+ data .append ({
2260+ 'unit' : unit ,
2261+ 'time' : t ,
2262+ 'outcome' : outcome ,
2263+ 'first_treat' : first_treat
2264+ })
2265+
2266+ df = pd .DataFrame (data )
2267+
2268+ cs = CallawaySantAnna (base_period = "varying" )
2269+
2270+ with warnings .catch_warnings (record = True ) as w :
2271+ warnings .simplefilter ("always" )
2272+ results = cs .fit (
2273+ df ,
2274+ outcome = 'outcome' ,
2275+ unit = 'unit' ,
2276+ time = 'time' ,
2277+ first_treat = 'first_treat'
2278+ )
2279+
2280+ # Should have emitted a warning about no post-treatment effects
2281+ warning_messages = [str (warning .message ) for warning in w ]
2282+ has_warning = any (
2283+ "No post-treatment effects" in msg for msg in warning_messages
2284+ )
2285+ assert has_warning , (
2286+ f"Expected warning about no post-treatment effects, got: { warning_messages } "
2287+ )
2288+
2289+ # Overall ATT should be NaN
2290+ assert np .isnan (results .overall_att ), (
2291+ f"Expected NaN for overall_att when no post-treatment effects exist, "
2292+ f"got { results .overall_att } "
2293+ )
0 commit comments