Skip to content

Commit 93e9cdb

Browse files
igerberclaude
andcommitted
Address AI review round 3: normalize np.inf never-treated encoding
first_treat=np.inf (documented as valid for never-treated units) passed all `> 0` checks in sun_abraham.py and staggered.py, causing np.inf to be treated as a cohort with -inf relative times. Fixed by normalizing np.inf to 0 immediately after computing _never_treated in both estimators. Added regression tests verifying equivalence between encodings. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 89e8e72 commit 93e9cdb

5 files changed

Lines changed: 70 additions & 0 deletions

File tree

METHODOLOGY_REVIEW.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,12 @@ variables appear to the left of the `|` separator.
357357
panels these are identical; for unbalanced panels the new formula correctly reflects
358358
actual sample composition at each event-time. Added unbalanced panel test.
359359

360+
7. **Normalize `np.inf` never-treated encoding** (`sun_abraham.py`, `fit()`):
361+
`first_treat=np.inf` (documented as valid for never-treated) was included in
362+
`treatment_groups` and `_rel_time` via `> 0` checks, producing `-inf` event times.
363+
Fixed by normalizing `np.inf` to `0` immediately after computing `_never_treated`.
364+
Same fix applied to `staggered.py` (`CallawaySantAnna`).
365+
360366
**Outstanding Concerns:**
361367
- **Inference distribution**: Cohort-level p-values use t-distribution (via
362368
`LinearRegression.get_inference()`), while aggregated event study and overall ATT

diff_diff/staggered.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,7 @@ def _precompute_structures(
415415
cohort_masks[g] = (unit_cohorts == g)
416416

417417
# Never-treated mask
418+
# np.inf was normalized to 0 in fit(), so the np.inf check is defensive only
418419
never_treated_mask = (unit_cohorts == 0) | (unit_cohorts == np.inf)
419420

420421
# Pre-compute covariate matrices by time period if needed
@@ -645,6 +646,8 @@ def fit(
645646

646647
# Never-treated indicator (first_treat = 0 or inf)
647648
df['_never_treated'] = (df[first_treat] == 0) | (df[first_treat] == np.inf)
649+
# Normalize np.inf → 0 so all downstream `> 0` checks exclude never-treated
650+
df.loc[df[first_treat] == np.inf, first_treat] = 0
648651

649652
# Get unique units
650653
unit_info = df.groupby(unit).agg({

diff_diff/sun_abraham.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,8 @@ def fit(
508508

509509
# Never-treated indicator
510510
df["_never_treated"] = (df[first_treat] == 0) | (df[first_treat] == np.inf)
511+
# Normalize np.inf → 0 so all downstream `> 0` checks exclude never-treated
512+
df.loc[df[first_treat] == np.inf, first_treat] = 0
511513

512514
# Get unique units
513515
unit_info = (
@@ -1057,6 +1059,7 @@ def _run_bootstrap(
10571059
df_b[time] - df_b[first_treat],
10581060
np.nan
10591061
)
1062+
# np.inf was normalized to 0 in fit(), so the np.inf check is defensive only
10601063
df_b["_never_treated"] = (
10611064
(df_b[first_treat] == 0) | (df_b[first_treat] == np.inf)
10621065
)

tests/test_staggered.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,28 @@ def test_zero_treatment_effect(self):
102102
# Effect should be close to zero
103103
assert abs(results.overall_att) < 3 * results.overall_se
104104

105+
def test_never_treated_inf_encoding(self):
106+
"""Test that first_treat=np.inf is handled as never-treated, not as a cohort."""
107+
data = generate_staggered_data(n_units=200, n_periods=10, n_cohorts=3, seed=42)
108+
109+
cs = CallawaySantAnna(n_bootstrap=0)
110+
results_zero = cs.fit(
111+
data.copy(), outcome="outcome", unit="unit", time="time", first_treat="first_treat"
112+
)
113+
114+
# Re-encode never-treated from 0 to np.inf
115+
data_inf = data.copy()
116+
data_inf.loc[data_inf["first_treat"] == 0, "first_treat"] = np.inf
117+
118+
results_inf = cs.fit(
119+
data_inf, outcome="outcome", unit="unit", time="time", first_treat="first_treat"
120+
)
121+
122+
# Results should be identical
123+
assert np.isclose(results_inf.overall_att, results_zero.overall_att), (
124+
f"ATT differs: inf={results_inf.overall_att}, zero={results_zero.overall_att}"
125+
)
126+
105127
def test_event_study_aggregation(self):
106128
"""Test event study aggregation."""
107129
data = generate_staggered_data()

tests/test_sun_abraham.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1366,3 +1366,39 @@ def test_iw_weights_unbalanced_panel(self):
13661366
assert abs(w - expected_w) < 1e-10, (
13671367
f"Weight for cohort {g} at e={affected_e}: got {w}, expected {expected_w}"
13681368
)
1369+
1370+
def test_never_treated_inf_encoding(self):
1371+
"""Test that first_treat=np.inf is handled as never-treated, not as a cohort."""
1372+
data = generate_staggered_data(n_units=200, n_periods=10, n_cohorts=3, seed=42)
1373+
1374+
# Run with first_treat=0 as baseline
1375+
sa = SunAbraham(n_bootstrap=0)
1376+
results_zero = sa.fit(
1377+
data.copy(), outcome="outcome", unit="unit", time="time", first_treat="first_treat"
1378+
)
1379+
1380+
# Re-encode never-treated from 0 to np.inf
1381+
data_inf = data.copy()
1382+
data_inf.loc[data_inf["first_treat"] == 0, "first_treat"] = np.inf
1383+
1384+
results_inf = sa.fit(
1385+
data_inf, outcome="outcome", unit="unit", time="time", first_treat="first_treat"
1386+
)
1387+
1388+
# np.inf must not appear as a cohort in weights
1389+
for e, weights in results_inf.cohort_weights.items():
1390+
assert np.inf not in weights, (
1391+
f"np.inf found as cohort key in weights at e={e}"
1392+
)
1393+
1394+
# No ±inf in event study periods
1395+
for e in results_inf.event_study_effects.keys():
1396+
assert np.isfinite(e), f"Non-finite event time {e} in event study"
1397+
1398+
# Results should be identical to first_treat=0 encoding
1399+
assert np.isclose(results_inf.overall_att, results_zero.overall_att), (
1400+
f"ATT differs: inf={results_inf.overall_att}, zero={results_zero.overall_att}"
1401+
)
1402+
assert np.isclose(results_inf.overall_se, results_zero.overall_se), (
1403+
f"SE differs: inf={results_inf.overall_se}, zero={results_zero.overall_se}"
1404+
)

0 commit comments

Comments
 (0)