Skip to content

Commit 021bd66

Browse files
igerberclaude
andcommitted
Address PR #97 review round 5: not-yet-treated control group overlap
Fix critical bug where control_group="not_yet_treated" incorrectly included treated cohort g in controls for pre-treatment periods (t < g). When computing ATT(g,t), cohort g should never be in the control group, regardless of whether t < g or t >= g. The fix adds explicit exclusion of cohort g from the not-yet-treated control mask. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 5734964 commit 021bd66

3 files changed

Lines changed: 102 additions & 2 deletions

File tree

diff_diff/staggered.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -488,8 +488,9 @@ def _compute_att_gt_fast(
488488
if self.control_group == "never_treated":
489489
control_mask = never_treated_mask
490490
else: # not_yet_treated
491-
# Not yet treated at time t: never-treated OR first_treat > t
492-
control_mask = never_treated_mask | (unit_cohorts > t)
491+
# Not yet treated at time t: never-treated OR (first_treat > t AND not cohort g)
492+
# Must exclude cohort g since they are the treated group for this ATT(g,t)
493+
control_mask = never_treated_mask | ((unit_cohorts > t) & (unit_cohorts != g))
493494

494495
# Extract outcomes for base and post periods
495496
y_base = outcome_matrix[:, base_col]

docs/methodology/REGISTRY.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,11 @@ Aggregations:
227227
- "universal": All comparisons use g-anticipation-1 as base
228228
- Both produce identical post-treatment ATT(g,t); differ only pre-treatment
229229
- Matches R `did::att_gt()` base_period parameter
230+
- Control group with `control_group="not_yet_treated"`:
231+
- Always excludes cohort g from controls when computing ATT(g,t)
232+
- This applies to both pre-treatment (t < g) and post-treatment (t >= g) periods
233+
- For pre-treatment periods: even though cohort g hasn't been treated yet at time t, they are the treated group for this ATT(g,t) and cannot serve as their own controls
234+
- Control mask: `never_treated OR (first_treat > t AND first_treat != g)`
230235

231236
**Reference implementation(s):**
232237
- R: `did::att_gt()` (Callaway & Sant'Anna's official package)

tests/test_staggered.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2439,6 +2439,100 @@ def test_bootstrap_runs_for_pretreatment_effects(self):
24392439
"Overall ATT p-value should be NaN when no post-treatment"
24402440
)
24412441

2442+
def test_not_yet_treated_excludes_cohort_from_controls(self):
2443+
"""Not-yet-treated control excludes treated cohort g for pre-treatment periods.
2444+
2445+
When computing ATT(g,t) for t < g with control_group="not_yet_treated",
2446+
cohort g should NOT be included in the control group even though
2447+
they haven't been treated yet at time t.
2448+
2449+
Bug scenario (before fix):
2450+
- Computing ATT(g=5, t=3) with control_group="not_yet_treated"
2451+
- Control mask was: never_treated OR first_treat > t
2452+
- Units with first_treat=5 satisfy first_treat > 3, so they were
2453+
incorrectly included as controls for themselves!
2454+
2455+
After fix:
2456+
- Control mask is: never_treated OR (first_treat > t AND first_treat != g)
2457+
- Cohort g is always excluded from controls.
2458+
"""
2459+
# Create data with 3 distinct cohorts: g=4, g=7, and never-treated (g=0)
2460+
# This setup ensures for ATT(g=7, t=3):
2461+
# - Treated: units with first_treat=7
2462+
# - Valid controls: never-treated + cohort g=4 (since 4 > 3 and 4 != 7)
2463+
# - Invalid (excluded): cohort g=7 (even though 7 > 3)
2464+
n_units = 90 # 30 per group
2465+
n_periods = 10
2466+
np.random.seed(42)
2467+
2468+
data = []
2469+
for unit in range(n_units):
2470+
# Assign to cohorts: 0-29 -> g=4, 30-59 -> g=7, 60-89 -> never-treated
2471+
if unit < 30:
2472+
first_treat = 4
2473+
elif unit < 60:
2474+
first_treat = 7
2475+
else:
2476+
first_treat = 0 # Never-treated
2477+
2478+
for t in range(1, n_periods + 1):
2479+
# Add treatment effect after treatment
2480+
effect = 0.0
2481+
if first_treat > 0 and t >= first_treat:
2482+
effect = 2.0
2483+
2484+
outcome = np.random.randn() + effect
2485+
data.append({
2486+
'unit': unit,
2487+
'time': t,
2488+
'outcome': outcome,
2489+
'first_treat': first_treat
2490+
})
2491+
2492+
df = pd.DataFrame(data)
2493+
2494+
# Fit with not_yet_treated control group
2495+
cs = CallawaySantAnna(
2496+
control_group="not_yet_treated",
2497+
base_period="varying" # To get pre-treatment effects
2498+
)
2499+
results = cs.fit(
2500+
df,
2501+
outcome='outcome',
2502+
unit='unit',
2503+
time='time',
2504+
first_treat='first_treat'
2505+
)
2506+
2507+
# Check the group-time effects for pre-treatment ATT(g=7, t) where t < 7
2508+
# These should have been computed using valid controls only
2509+
for (g, t), eff in results.group_time_effects.items():
2510+
if g == 7 and t < g: # Pre-treatment for cohort 7
2511+
n_control = eff['n_control']
2512+
# Control should include:
2513+
# - 30 never-treated units
2514+
# - 30 units from cohort g=4 (if t < 4, they're not yet treated either)
2515+
# Control should NOT include:
2516+
# - The 30 units from cohort g=7 (they're the treated group!)
2517+
2518+
# For t < 4: controls = never-treated (30) + cohort 4 (30) = 60
2519+
# For 4 <= t < 7: controls = never-treated (30) only (cohort 4 is treated)
2520+
if t < 4:
2521+
expected_max = 60 # never-treated + cohort 4
2522+
else:
2523+
expected_max = 30 # never-treated only
2524+
2525+
# Key assertion: n_control should NOT be 90 (which would include cohort 7)
2526+
assert n_control <= expected_max, (
2527+
f"ATT(g=7, t={t}): n_control={n_control} should be <= {expected_max}. "
2528+
f"Cohort 7 (30 units) should NOT be included as controls for itself."
2529+
)
2530+
2531+
# Also verify we have a reasonable number of controls
2532+
assert n_control >= 30, (
2533+
f"ATT(g=7, t={t}): n_control={n_control} should be >= 30 (never-treated)."
2534+
)
2535+
24422536

24432537
class TestCallawaySantAnnaAnticipation:
24442538
"""Tests for anticipation parameter handling in aggregation."""

0 commit comments

Comments
 (0)