@@ -1966,6 +1966,152 @@ def test_te_covariate_interaction_validation(self):
19661966 with pytest .raises (ValueError , match = "te_covariate_interaction must be finite" ):
19671967 generate_survey_did_data (add_covariates = True , te_covariate_interaction = np .nan , seed = 42 )
19681968
1969+ # --- conditional_pt parameter tests ---
1970+
1971+ def test_conditional_pt_requires_covariates (self ):
1972+ """conditional_pt requires add_covariates=True."""
1973+ from diff_diff .prep_dgp import generate_survey_did_data
1974+
1975+ with pytest .raises (ValueError , match = "conditional_pt requires add_covariates" ):
1976+ generate_survey_did_data (conditional_pt = 0.3 , add_covariates = False , seed = 42 )
1977+
1978+ def test_conditional_pt_nonfinite_rejected (self ):
1979+ """conditional_pt must be finite."""
1980+ from diff_diff .prep_dgp import generate_survey_did_data
1981+
1982+ with pytest .raises (ValueError , match = "conditional_pt must be finite" ):
1983+ generate_survey_did_data (
1984+ add_covariates = True , conditional_pt = np .inf , seed = 42
1985+ )
1986+ with pytest .raises (ValueError , match = "conditional_pt must be finite" ):
1987+ generate_survey_did_data (
1988+ add_covariates = True , conditional_pt = np .nan , seed = 42
1989+ )
1990+
1991+ def test_conditional_pt_x1_distribution_shift (self ):
1992+ """Treated units should have higher x1 when conditional_pt is active."""
1993+ from diff_diff .prep_dgp import generate_survey_did_data
1994+
1995+ df = generate_survey_did_data (
1996+ n_units = 1000 ,
1997+ n_periods = 4 ,
1998+ add_covariates = True ,
1999+ conditional_pt = 0.3 ,
2000+ seed = 42 ,
2001+ )
2002+ p1 = df [df ["period" ] == 1 ]
2003+ x1_treated = p1 .loc [p1 ["first_treat" ] > 0 , "x1" ].values
2004+ x1_control = p1 .loc [p1 ["first_treat" ] == 0 , "x1" ].values
2005+ shift = x1_treated .mean () - x1_control .mean ()
2006+ # Expect ~1.0 SD shift; require at least 0.5
2007+ assert shift > 0.5 , f"x1 mean shift too small: { shift :.3f} "
2008+
2009+ def test_conditional_pt_unconditional_pt_fails (self ):
2010+ """With conditional_pt active, unconditional pre-trends should differ."""
2011+ from diff_diff .prep_dgp import generate_survey_did_data
2012+
2013+ df = generate_survey_did_data (
2014+ n_units = 2000 ,
2015+ n_periods = 8 ,
2016+ add_covariates = True ,
2017+ conditional_pt = 0.5 ,
2018+ never_treated_frac = 0.5 ,
2019+ seed = 42 ,
2020+ )
2021+ # Compute mean outcome change (period 2 - period 1) for each group
2022+ # before any treatment (use periods 1 and 2, treatment starts at 3+)
2023+ p1 = df [df ["period" ] == 1 ].set_index ("unit" )
2024+ p2 = df [df ["period" ] == 2 ].set_index ("unit" )
2025+ common = p1 .index .intersection (p2 .index )
2026+ dy = p2 .loc [common , "outcome" ] - p1 .loc [common , "outcome" ]
2027+ is_treated = p1 .loc [common , "first_treat" ] > 0
2028+
2029+ trend_treated = dy [is_treated ].mean ()
2030+ trend_control = dy [~ is_treated ].mean ()
2031+ gap = abs (trend_treated - trend_control )
2032+ # With conditional_pt=0.5 and 1 SD shift, expect a detectable gap
2033+ assert gap > 0.01 , f"Unconditional PT gap too small: { gap :.4f} "
2034+
2035+ def test_conditional_pt_conditional_pt_holds (self ):
2036+ """Controlling for x1, treated/control pre-trends should be equal.
2037+
2038+ Use low PSU noise so the conditional_pt signal dominates.
2039+ """
2040+ from diff_diff .prep_dgp import generate_survey_did_data
2041+
2042+ df = generate_survey_did_data (
2043+ n_units = 2000 ,
2044+ n_periods = 8 ,
2045+ add_covariates = True ,
2046+ conditional_pt = 2.0 ,
2047+ never_treated_frac = 0.5 ,
2048+ psu_re_sd = 0.1 ,
2049+ psu_period_factor = 0.1 ,
2050+ noise_sd = 0.2 ,
2051+ seed = 42 ,
2052+ )
2053+ p1 = df [df ["period" ] == 1 ].set_index ("unit" )
2054+ p2 = df [df ["period" ] == 2 ].set_index ("unit" )
2055+ common = p1 .index .intersection (p2 .index )
2056+ dy = p2 .loc [common , "outcome" ].values - p1 .loc [common , "outcome" ].values
2057+ x1_vals = p1 .loc [common , "x1" ].values
2058+ is_treated = (p1 .loc [common , "first_treat" ] > 0 ).values .astype (float )
2059+
2060+ # Unconditional regression: dy ~ treated (should show large gap)
2061+ n = len (dy )
2062+ X_uncond = np .column_stack ([np .ones (n ), is_treated ])
2063+ beta_uncond = np .linalg .lstsq (X_uncond , dy , rcond = None )[0 ]
2064+ uncond_gap = abs (beta_uncond [1 ])
2065+
2066+ # Conditional regression: dy ~ treated + x1 (gap should shrink)
2067+ X_cond = np .column_stack ([np .ones (n ), is_treated , x1_vals ])
2068+ beta_cond = np .linalg .lstsq (X_cond , dy , rcond = None )[0 ]
2069+ cond_gap = abs (beta_cond [1 ])
2070+
2071+ # With low noise and strong signal, controlling for x1 should
2072+ # substantially reduce the treated coefficient
2073+ assert uncond_gap > 0.05 , f"Unconditional gap too small: { uncond_gap :.4f} "
2074+ assert cond_gap < uncond_gap * 0.5 , (
2075+ f"Conditional gap ({ cond_gap :.4f} ) should be much smaller than "
2076+ f"unconditional ({ uncond_gap :.4f} )"
2077+ )
2078+
2079+ def test_conditional_pt_backward_compatible (self ):
2080+ """conditional_pt=0.0 should produce identical output to default."""
2081+ from diff_diff .prep_dgp import generate_survey_did_data
2082+
2083+ df_default = generate_survey_did_data (
2084+ n_units = 100 , add_covariates = True , seed = 99
2085+ )
2086+ df_explicit = generate_survey_did_data (
2087+ n_units = 100 , add_covariates = True , conditional_pt = 0.0 , seed = 99
2088+ )
2089+ pd .testing .assert_frame_equal (df_default , df_explicit )
2090+
2091+ def test_conditional_pt_panel_and_crosssection (self ):
2092+ """conditional_pt should work in both panel and cross-section modes."""
2093+ from diff_diff .prep_dgp import generate_survey_did_data
2094+
2095+ for panel_mode in [True , False ]:
2096+ df = generate_survey_did_data (
2097+ n_units = 500 ,
2098+ n_periods = 4 ,
2099+ add_covariates = True ,
2100+ conditional_pt = 0.3 ,
2101+ panel = panel_mode ,
2102+ seed = 42 ,
2103+ )
2104+ # Basic sanity: data is produced
2105+ assert len (df ) == 500 * 4
2106+ assert "x1" in df .columns
2107+ # Check x1 shift exists in period 1
2108+ p1 = df [df ["period" ] == 1 ]
2109+ x1_treated = p1 .loc [p1 ["first_treat" ] > 0 , "x1" ].mean ()
2110+ x1_control = p1 .loc [p1 ["first_treat" ] == 0 , "x1" ].mean ()
2111+ assert x1_treated > x1_control , (
2112+ f"panel={ panel_mode } : treated x1 not shifted"
2113+ )
2114+
19692115
19702116class TestAggregateSurvey :
19712117 """Tests for aggregate_survey function."""
0 commit comments