1515import pytest
1616
1717from diff_diff import CallawaySantAnna , EfficientDiD
18+ from edid_dgp import make_compustat_dgp , true_es_avg , true_overall_att
1819
1920# =============================================================================
2021# Data Loaders & Helpers
@@ -66,70 +67,24 @@ def _get_effect(effects_dict, g, t):
6667 raise KeyError (f"ATT({ g } ,{ t } ) not found in results" )
6768
6869
69- def _assert_close (actual , expected , label , rtol = 0.10 , atol = 200 ):
70- """Assert actual is close to expected with combined tolerance."""
71- tol = max (rtol * abs (expected ), atol )
70+ def _assert_close (actual , expected , label , se = None , se_frac = 0.1 ):
71+ """Assert actual is close to expected, tolerance based on published SE.
72+
73+ Default tolerance is 0.1 * SE (10% of one standard error). Our actual
74+ diffs are all < 0.03 SE, so this catches real drift while absorbing the
75+ 4-individual sample difference (656 vs paper's 652).
76+ """
77+ if se is not None :
78+ tol = se_frac * se
79+ else :
80+ tol = max (0.05 * abs (expected ), 50 )
7281 diff = abs (actual - expected )
7382 assert diff < tol , (
7483 f"{ label } : expected { expected } , got { actual :.1f} "
7584 f"(diff={ diff :.1f} , tol={ tol :.1f} )"
7685 )
7786
7887
79- # =============================================================================
80- # Compustat DGP (copied from test_efficient_did.py)
81- # =============================================================================
82-
83-
84- def _make_compustat_dgp (n_units = 400 , n_periods = 11 , rho = 0.0 , seed = 42 ):
85- """Simplified Compustat-style DGP from Section 5.2.
86-
87- Groups: G=5 (~1/3), G=8 (~1/3), G=inf (~1/3).
88- ATT(5,t) = 0.154*(t-4), ATT(8,t) = 0.093*(t-7).
89- """
90- rng = np .random .default_rng (seed )
91- n_t = n_periods
92-
93- n_g5 = n_units // 3
94- n_g8 = n_units // 3
95- ft = np .full (n_units , np .inf )
96- ft [:n_g5 ] = 5
97- ft [n_g5 : n_g5 + n_g8 ] = 8
98-
99- units = np .repeat (np .arange (n_units ), n_t )
100- times = np .tile (np .arange (1 , n_t + 1 ), n_units )
101- ft_col = np .repeat (ft , n_t )
102-
103- alpha_t = rng .normal (0 , 0.1 , n_t )
104- eta_i = rng .normal (0 , 0.5 , n_units )
105- unit_fe = np .repeat (eta_i , n_t )
106- time_fe = np .tile (alpha_t , n_units )
107-
108- eps = np .zeros ((n_units , n_t ))
109- eps [:, 0 ] = rng .normal (0 , 0.3 , n_units )
110- for t in range (1 , n_t ):
111- eps [:, t ] = rho * eps [:, t - 1 ] + rng .normal (0 , 0.3 , n_units )
112- eps_flat = eps .flatten ()
113-
114- tau = np .zeros (len (units ))
115- for i in range (n_units ):
116- g = ft [i ]
117- if np .isinf (g ):
118- continue
119- for t_idx in range (n_t ):
120- t = t_idx + 1
121- if g == 5 and t >= 5 :
122- tau [i * n_t + t_idx ] = 0.154 * (t - 4 )
123- elif g == 8 and t >= 8 :
124- tau [i * n_t + t_idx ] = 0.093 * (t - 7 )
125-
126- y = unit_fe + time_fe + tau + eps_flat
127-
128- return pd .DataFrame (
129- {"unit" : units , "time" : times , "first_treat" : ft_col , "y" : y }
130- )
131-
132-
13388def _compute_es_avg (result ):
13489 """Compute ES_avg (Eq 2.3): uniform average over post-treatment horizons."""
13590 if result .event_study_effects is None :
@@ -142,41 +97,7 @@ def _compute_es_avg(result):
14297 return np .mean (list (es .values ()))
14398
14499
145- # Ground truth derived from DGP parameters (not hard-coded)
146- _ATT_COEFS = {5 : 0.154 , 8 : 0.093 } # ATT(g,t) = coef * (t - g + 1) for t >= g
147- _N_PERIODS = 11
148-
149-
150- def _true_es_avg_from_dgp ():
151- """Derive ES_avg from DGP treatment effect parameters."""
152- max_e = {g : _N_PERIODS - g for g in _ATT_COEFS }
153- all_e = range (0 , max (max_e .values ()) + 1 )
154- es_values = []
155- for e in all_e :
156- contributing = [
157- coef * (e + 1 )
158- for g , coef in _ATT_COEFS .items ()
159- if e <= max_e [g ]
160- ]
161- if contributing :
162- es_values .append (np .mean (contributing ))
163- return np .mean (es_values )
164-
165-
166- _TRUE_ES_AVG_COMPUSTAT = _true_es_avg_from_dgp ()
167-
168-
169- def _true_overall_att_compustat ():
170- """Compute true overall_att using cohort-size weighting (our convention)."""
171- # Groups have equal size (1/3 each), so pi_5 = pi_8
172- # Post-treatment (g,t) cells:
173- # G=5: t=5..11 -> 7 cells with effects 0.154*(1..7)
174- # G=8: t=8..11 -> 4 cells with effects 0.093*(1..4)
175- effects_g5 = [0.154 * k for k in range (1 , 8 )] # 7 cells
176- effects_g8 = [0.093 * k for k in range (1 , 5 )] # 4 cells
177- # Cohort-size-weighted: both groups have same pi, so weight by count
178- all_effects = effects_g5 + effects_g8
179- return np .mean (all_effects )
100+ _TRUE_ES_AVG_COMPUSTAT = true_es_avg ()
180101
181102
182103def _run_mc_simulation (n_sims , rho , seed = 1000 , also_cs = False ):
@@ -188,7 +109,7 @@ def _run_mc_simulation(n_sims, rho, seed=1000, also_cs=False):
188109 cs_estimates_list = []
189110
190111 for i in range (n_sims ):
191- data = _make_compustat_dgp (rho = rho , seed = seed + i )
112+ data = make_compustat_dgp (rho = rho , seed = seed + i )
192113
193114 edid = EfficientDiD (pt_assumption = "all" )
194115 res = edid .fit (
@@ -266,24 +187,23 @@ def test_sample_selection_yields_expected_counts(self, hrs_data):
266187 )
267188
268189 def test_group_time_effects_match_table6 (self , edid_hrs_result ):
269- for (g , t ), (expected_effect , _ ) in TABLE6_EDID .items ():
190+ for (g , t ), (expected_effect , se ) in TABLE6_EDID .items ():
270191 info = _get_effect (edid_hrs_result .group_time_effects , g , t )
271- _assert_close (info ["effect" ], expected_effect , f"ATT({ g } ,{ t } )" )
192+ _assert_close (info ["effect" ], expected_effect , f"ATT({ g } ,{ t } )" , se = se )
272193
273194 def test_event_study_effects_match_table6 (self , edid_hrs_result ):
274- for e , (expected_effect , _ ) in TABLE6_ES .items ():
275- # Find event study effect matching relative time e
195+ for e , (expected_effect , se ) in TABLE6_ES .items ():
276196 found = False
277197 for rel_time , info in edid_hrs_result .event_study_effects .items ():
278198 if int (rel_time ) == e :
279- _assert_close (info ["effect" ], expected_effect , f"ES({ e } )" )
199+ _assert_close (info ["effect" ], expected_effect , f"ES({ e } )" , se = se )
280200 found = True
281201 break
282202 assert found , f"ES({ e } ) not found in event study effects"
283203
284204 def test_es_avg_matches_table6 (self , edid_hrs_result ):
285205 es_avg = _compute_es_avg (edid_hrs_result )
286- _assert_close (es_avg , TABLE6_ES_AVG [0 ], "ES_avg" )
206+ _assert_close (es_avg , TABLE6_ES_AVG [0 ], "ES_avg" , se = TABLE6_ES_AVG [ 1 ] )
287207
288208 def test_se_diagnostic_comparison (self , edid_hrs_result ):
289209 """Log and sanity-check analytical vs cluster-robust SEs."""
@@ -307,11 +227,14 @@ def test_cs_cross_validation(self, hrs_data):
307227 hrs_data , outcome = "outcome" , unit = "unit" , time = "time" ,
308228 first_treat = "first_treat" ,
309229 )
230+ # CS-SA paper SEs from Table 6
231+ cs_ses = {(8 ,8 ): 1035 , (8 ,9 ): 909 , (8 ,10 ): 1008 ,
232+ (9 ,9 ): 702 , (9 ,10 ): 651 , (10 ,10 ): 995 }
310233 for (g , t ), expected_effect in TABLE6_CS_SA .items ():
311234 info = _get_effect (cs_result .group_time_effects , g , t )
312235 _assert_close (
313236 info ["effect" ], expected_effect ,
314- f"CS ATT({ g } ,{ t } )" , rtol = 0.15 , atol = 300 ,
237+ f"CS ATT({ g } ,{ t } )" , se = cs_ses [( g , t )] ,
315238 )
316239
317240 def test_pretreatment_effects_near_zero (self , edid_hrs_result ):
@@ -411,7 +334,7 @@ def test_coverage_approximately_correct(self, ci_params):
411334 n_sims = ci_params .bootstrap (200 , min_n = 49 )
412335 mc = _run_mc_simulation (n_sims , rho = 0 , seed = 5000 )
413336
414- true_overall = _true_overall_att_compustat ()
337+ true_overall = true_overall_att ()
415338 covered = sum (
416339 ci [0 ] <= true_overall <= ci [1 ]
417340 for ci in mc ["edid_overall_ci" ]
0 commit comments