Skip to content

Commit eec1fe8

Browse files
igerberclaude
andcommitted
Address AI review: tighten tolerances, extract shared DGP
P2: Replace loose max(10%*estimate, 200) tolerance with 0.1*SE (10% of one published standard error). Our actual diffs are all < 0.03 SE, so this catches real drift while absorbing minor sample differences. ATT(9,10)=90 now accepts [26, 154] instead of [-110, 290]. P3: Extract Compustat DGP into tests/edid_dgp.py as the single source of truth. Both test_efficient_did.py and test_efficient_did_validation.py import from it. Truth values (ES_avg, overall_att) are derived programmatically from the shared DGP parameters. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent a4b3881 commit eec1fe8

4 files changed

Lines changed: 116 additions & 160 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ python-packages = ["diff_diff"]
8989

9090
[tool.pytest.ini_options]
9191
testpaths = ["tests"]
92+
pythonpath = ["tests"]
9293
python_files = "test_*.py"
9394
# Exclude slow tests by default; use `pytest -m ''` to run all tests
9495
addopts = "-v --tb=short -m 'not slow'"

tests/edid_dgp.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
"""
2+
Shared Compustat-style DGP for EfficientDiD tests.
3+
4+
Used by both test_efficient_did.py and test_efficient_did_validation.py
5+
to avoid duplication. Based on Section 5.2 of Chen, Sant'Anna & Xie (2025).
6+
"""
7+
8+
import numpy as np
9+
import pandas as pd
10+
11+
# DGP parameters — treatment effect coefficients
12+
ATT_COEFS = {5: 0.154, 8: 0.093}
13+
N_PERIODS = 11
14+
15+
16+
def make_compustat_dgp(n_units=400, n_periods=N_PERIODS, rho=0.0, seed=42):
17+
"""Simplified Compustat-style DGP from Section 5.2.
18+
19+
Groups: G=5 (~1/3), G=8 (~1/3), G=inf (~1/3).
20+
ATT(5,t) = 0.154*(t-4), ATT(8,t) = 0.093*(t-7).
21+
"""
22+
rng = np.random.default_rng(seed)
23+
n_t = n_periods
24+
25+
n_g5 = n_units // 3
26+
n_g8 = n_units // 3
27+
ft = np.full(n_units, np.inf)
28+
ft[:n_g5] = 5
29+
ft[n_g5 : n_g5 + n_g8] = 8
30+
31+
units = np.repeat(np.arange(n_units), n_t)
32+
times = np.tile(np.arange(1, n_t + 1), n_units)
33+
ft_col = np.repeat(ft, n_t)
34+
35+
alpha_t = rng.normal(0, 0.1, n_t)
36+
eta_i = rng.normal(0, 0.5, n_units)
37+
unit_fe = np.repeat(eta_i, n_t)
38+
time_fe = np.tile(alpha_t, n_units)
39+
40+
eps = np.zeros((n_units, n_t))
41+
eps[:, 0] = rng.normal(0, 0.3, n_units)
42+
for t in range(1, n_t):
43+
eps[:, t] = rho * eps[:, t - 1] + rng.normal(0, 0.3, n_units)
44+
eps_flat = eps.flatten()
45+
46+
tau = np.zeros(len(units))
47+
for i in range(n_units):
48+
g = ft[i]
49+
if np.isinf(g):
50+
continue
51+
for t_idx in range(n_t):
52+
t = t_idx + 1
53+
if g == 5 and t >= 5:
54+
tau[i * n_t + t_idx] = ATT_COEFS[5] * (t - 4)
55+
elif g == 8 and t >= 8:
56+
tau[i * n_t + t_idx] = ATT_COEFS[8] * (t - 7)
57+
58+
y = unit_fe + time_fe + tau + eps_flat
59+
60+
return pd.DataFrame(
61+
{"unit": units, "time": times, "first_treat": ft_col, "y": y}
62+
)
63+
64+
65+
def true_es_avg():
66+
"""Derive ES_avg from DGP treatment effect parameters."""
67+
max_e = {g: N_PERIODS - g for g in ATT_COEFS}
68+
all_e = range(0, max(max_e.values()) + 1)
69+
es_values = []
70+
for e in all_e:
71+
contributing = [
72+
coef * (e + 1)
73+
for g, coef in ATT_COEFS.items()
74+
if e <= max_e[g]
75+
]
76+
if contributing:
77+
es_values.append(np.mean(contributing))
78+
return np.mean(es_values)
79+
80+
81+
def true_overall_att():
82+
"""Compute true overall_att using cohort-size weighting (library convention)."""
83+
effects = []
84+
for g, coef in ATT_COEFS.items():
85+
for t in range(g, N_PERIODS + 1):
86+
effects.append(coef * (t - g + 1))
87+
return np.mean(effects)

tests/test_efficient_did.py

Lines changed: 4 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -112,67 +112,12 @@ def _make_staggered_panel(
112112
)
113113

114114

115-
def _make_compustat_dgp(
116-
n_units=400,
117-
n_periods=11,
118-
rho=0.0,
119-
seed=42,
120-
):
121-
"""Simplified Compustat-style DGP from Section 5.2.
122-
123-
Groups: G=5 (~1/3), G=8 (~1/3), G=inf (~1/3).
124-
ATT(5,t) = 0.154*(t-4), ATT(8,t) = 0.093*(t-7).
125-
"""
126-
rng = np.random.default_rng(seed)
127-
n_t = n_periods
128-
129-
# Assign groups
130-
n_g5 = n_units // 3
131-
n_g8 = n_units // 3
132-
ft = np.full(n_units, np.inf)
133-
ft[:n_g5] = 5
134-
ft[n_g5 : n_g5 + n_g8] = 8
135-
136-
units = np.repeat(np.arange(n_units), n_t)
137-
times = np.tile(np.arange(1, n_t + 1), n_units)
138-
ft_col = np.repeat(ft, n_t)
139-
140-
# Unit and time FE
141-
alpha_t = rng.normal(0, 0.1, n_t)
142-
eta_i = rng.normal(0, 0.5, n_units)
143-
unit_fe = np.repeat(eta_i, n_t)
144-
time_fe = np.tile(alpha_t, n_units)
145-
146-
# AR(1) errors
147-
eps = np.zeros((n_units, n_t))
148-
eps[:, 0] = rng.normal(0, 0.3, n_units)
149-
for t in range(1, n_t):
150-
eps[:, t] = rho * eps[:, t - 1] + rng.normal(0, 0.3, n_units)
151-
eps_flat = eps.flatten()
115+
from edid_dgp import make_compustat_dgp
152116

153-
# Treatment effects
154-
tau = np.zeros(len(units))
155-
for i in range(n_units):
156-
g = ft[i]
157-
if np.isinf(g):
158-
continue
159-
for t_idx in range(n_t):
160-
t = t_idx + 1
161-
if g == 5 and t >= 5:
162-
tau[i * n_t + t_idx] = 0.154 * (t - 4)
163-
elif g == 8 and t >= 8:
164-
tau[i * n_t + t_idx] = 0.093 * (t - 7)
165-
166-
y = unit_fe + time_fe + tau + eps_flat
167117

168-
return pd.DataFrame(
169-
{
170-
"unit": units,
171-
"time": times,
172-
"first_treat": ft_col,
173-
"y": y,
174-
}
175-
)
118+
def _make_compustat_dgp(n_units=400, n_periods=11, rho=0.0, seed=42):
119+
"""Delegate to shared DGP in edid_dgp.py."""
120+
return make_compustat_dgp(n_units=n_units, n_periods=n_periods, rho=rho, seed=seed)
176121

177122

178123
# =============================================================================

tests/test_efficient_did_validation.py

Lines changed: 24 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import pytest
1616

1717
from diff_diff import CallawaySantAnna, EfficientDiD
18+
from edid_dgp import make_compustat_dgp, true_es_avg, true_overall_att
1819

1920
# =============================================================================
2021
# Data Loaders & Helpers
@@ -66,70 +67,24 @@ def _get_effect(effects_dict, g, t):
6667
raise KeyError(f"ATT({g},{t}) not found in results")
6768

6869

69-
def _assert_close(actual, expected, label, rtol=0.10, atol=200):
70-
"""Assert actual is close to expected with combined tolerance."""
71-
tol = max(rtol * abs(expected), atol)
70+
def _assert_close(actual, expected, label, se=None, se_frac=0.1):
71+
"""Assert actual is close to expected, tolerance based on published SE.
72+
73+
Default tolerance is 0.1 * SE (10% of one standard error). Our actual
74+
diffs are all < 0.03 SE, so this catches real drift while absorbing the
75+
4-individual sample difference (656 vs paper's 652).
76+
"""
77+
if se is not None:
78+
tol = se_frac * se
79+
else:
80+
tol = max(0.05 * abs(expected), 50)
7281
diff = abs(actual - expected)
7382
assert diff < tol, (
7483
f"{label}: expected {expected}, got {actual:.1f} "
7584
f"(diff={diff:.1f}, tol={tol:.1f})"
7685
)
7786

7887

79-
# =============================================================================
80-
# Compustat DGP (copied from test_efficient_did.py)
81-
# =============================================================================
82-
83-
84-
def _make_compustat_dgp(n_units=400, n_periods=11, rho=0.0, seed=42):
85-
"""Simplified Compustat-style DGP from Section 5.2.
86-
87-
Groups: G=5 (~1/3), G=8 (~1/3), G=inf (~1/3).
88-
ATT(5,t) = 0.154*(t-4), ATT(8,t) = 0.093*(t-7).
89-
"""
90-
rng = np.random.default_rng(seed)
91-
n_t = n_periods
92-
93-
n_g5 = n_units // 3
94-
n_g8 = n_units // 3
95-
ft = np.full(n_units, np.inf)
96-
ft[:n_g5] = 5
97-
ft[n_g5 : n_g5 + n_g8] = 8
98-
99-
units = np.repeat(np.arange(n_units), n_t)
100-
times = np.tile(np.arange(1, n_t + 1), n_units)
101-
ft_col = np.repeat(ft, n_t)
102-
103-
alpha_t = rng.normal(0, 0.1, n_t)
104-
eta_i = rng.normal(0, 0.5, n_units)
105-
unit_fe = np.repeat(eta_i, n_t)
106-
time_fe = np.tile(alpha_t, n_units)
107-
108-
eps = np.zeros((n_units, n_t))
109-
eps[:, 0] = rng.normal(0, 0.3, n_units)
110-
for t in range(1, n_t):
111-
eps[:, t] = rho * eps[:, t - 1] + rng.normal(0, 0.3, n_units)
112-
eps_flat = eps.flatten()
113-
114-
tau = np.zeros(len(units))
115-
for i in range(n_units):
116-
g = ft[i]
117-
if np.isinf(g):
118-
continue
119-
for t_idx in range(n_t):
120-
t = t_idx + 1
121-
if g == 5 and t >= 5:
122-
tau[i * n_t + t_idx] = 0.154 * (t - 4)
123-
elif g == 8 and t >= 8:
124-
tau[i * n_t + t_idx] = 0.093 * (t - 7)
125-
126-
y = unit_fe + time_fe + tau + eps_flat
127-
128-
return pd.DataFrame(
129-
{"unit": units, "time": times, "first_treat": ft_col, "y": y}
130-
)
131-
132-
13388
def _compute_es_avg(result):
13489
"""Compute ES_avg (Eq 2.3): uniform average over post-treatment horizons."""
13590
if result.event_study_effects is None:
@@ -142,41 +97,7 @@ def _compute_es_avg(result):
14297
return np.mean(list(es.values()))
14398

14499

145-
# Ground truth derived from DGP parameters (not hard-coded)
146-
_ATT_COEFS = {5: 0.154, 8: 0.093} # ATT(g,t) = coef * (t - g + 1) for t >= g
147-
_N_PERIODS = 11
148-
149-
150-
def _true_es_avg_from_dgp():
151-
"""Derive ES_avg from DGP treatment effect parameters."""
152-
max_e = {g: _N_PERIODS - g for g in _ATT_COEFS}
153-
all_e = range(0, max(max_e.values()) + 1)
154-
es_values = []
155-
for e in all_e:
156-
contributing = [
157-
coef * (e + 1)
158-
for g, coef in _ATT_COEFS.items()
159-
if e <= max_e[g]
160-
]
161-
if contributing:
162-
es_values.append(np.mean(contributing))
163-
return np.mean(es_values)
164-
165-
166-
_TRUE_ES_AVG_COMPUSTAT = _true_es_avg_from_dgp()
167-
168-
169-
def _true_overall_att_compustat():
170-
"""Compute true overall_att using cohort-size weighting (our convention)."""
171-
# Groups have equal size (1/3 each), so pi_5 = pi_8
172-
# Post-treatment (g,t) cells:
173-
# G=5: t=5..11 -> 7 cells with effects 0.154*(1..7)
174-
# G=8: t=8..11 -> 4 cells with effects 0.093*(1..4)
175-
effects_g5 = [0.154 * k for k in range(1, 8)] # 7 cells
176-
effects_g8 = [0.093 * k for k in range(1, 5)] # 4 cells
177-
# Cohort-size-weighted: both groups have same pi, so weight by count
178-
all_effects = effects_g5 + effects_g8
179-
return np.mean(all_effects)
100+
_TRUE_ES_AVG_COMPUSTAT = true_es_avg()
180101

181102

182103
def _run_mc_simulation(n_sims, rho, seed=1000, also_cs=False):
@@ -188,7 +109,7 @@ def _run_mc_simulation(n_sims, rho, seed=1000, also_cs=False):
188109
cs_estimates_list = []
189110

190111
for i in range(n_sims):
191-
data = _make_compustat_dgp(rho=rho, seed=seed + i)
112+
data = make_compustat_dgp(rho=rho, seed=seed + i)
192113

193114
edid = EfficientDiD(pt_assumption="all")
194115
res = edid.fit(
@@ -266,24 +187,23 @@ def test_sample_selection_yields_expected_counts(self, hrs_data):
266187
)
267188

268189
def test_group_time_effects_match_table6(self, edid_hrs_result):
269-
for (g, t), (expected_effect, _) in TABLE6_EDID.items():
190+
for (g, t), (expected_effect, se) in TABLE6_EDID.items():
270191
info = _get_effect(edid_hrs_result.group_time_effects, g, t)
271-
_assert_close(info["effect"], expected_effect, f"ATT({g},{t})")
192+
_assert_close(info["effect"], expected_effect, f"ATT({g},{t})", se=se)
272193

273194
def test_event_study_effects_match_table6(self, edid_hrs_result):
274-
for e, (expected_effect, _) in TABLE6_ES.items():
275-
# Find event study effect matching relative time e
195+
for e, (expected_effect, se) in TABLE6_ES.items():
276196
found = False
277197
for rel_time, info in edid_hrs_result.event_study_effects.items():
278198
if int(rel_time) == e:
279-
_assert_close(info["effect"], expected_effect, f"ES({e})")
199+
_assert_close(info["effect"], expected_effect, f"ES({e})", se=se)
280200
found = True
281201
break
282202
assert found, f"ES({e}) not found in event study effects"
283203

284204
def test_es_avg_matches_table6(self, edid_hrs_result):
285205
es_avg = _compute_es_avg(edid_hrs_result)
286-
_assert_close(es_avg, TABLE6_ES_AVG[0], "ES_avg")
206+
_assert_close(es_avg, TABLE6_ES_AVG[0], "ES_avg", se=TABLE6_ES_AVG[1])
287207

288208
def test_se_diagnostic_comparison(self, edid_hrs_result):
289209
"""Log and sanity-check analytical vs cluster-robust SEs."""
@@ -307,11 +227,14 @@ def test_cs_cross_validation(self, hrs_data):
307227
hrs_data, outcome="outcome", unit="unit", time="time",
308228
first_treat="first_treat",
309229
)
230+
# CS-SA paper SEs from Table 6
231+
cs_ses = {(8,8): 1035, (8,9): 909, (8,10): 1008,
232+
(9,9): 702, (9,10): 651, (10,10): 995}
310233
for (g, t), expected_effect in TABLE6_CS_SA.items():
311234
info = _get_effect(cs_result.group_time_effects, g, t)
312235
_assert_close(
313236
info["effect"], expected_effect,
314-
f"CS ATT({g},{t})", rtol=0.15, atol=300,
237+
f"CS ATT({g},{t})", se=cs_ses[(g, t)],
315238
)
316239

317240
def test_pretreatment_effects_near_zero(self, edid_hrs_result):
@@ -411,7 +334,7 @@ def test_coverage_approximately_correct(self, ci_params):
411334
n_sims = ci_params.bootstrap(200, min_n=49)
412335
mc = _run_mc_simulation(n_sims, rho=0, seed=5000)
413336

414-
true_overall = _true_overall_att_compustat()
337+
true_overall = true_overall_att()
415338
covered = sum(
416339
ci[0] <= true_overall <= ci[1]
417340
for ci in mc["edid_overall_ci"]

0 commit comments

Comments
 (0)