Skip to content

Commit 049a6ef

Browse files
igerberclaude
andcommitted
Address AI review round 4: move np.inf normalization before treatment_groups
The np.inf → 0 normalization was placed after treatment_groups was computed, so np.inf passed the `g > 0` filter and leaked into treatment cohorts. Reorder in both sun_abraham.py and staggered.py so normalization precedes treatment_groups derivation. Add results.groups assertion and all-never-treated ValueError test. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 93e9cdb commit 049a6ef

3 files changed

Lines changed: 31 additions & 10 deletions

File tree

diff_diff/staggered.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -640,15 +640,15 @@ def fit(
640640
# This avoids hardcoding column names in internal methods
641641
df['first_treat'] = df[first_treat]
642642

643-
# Identify groups and time periods
644-
time_periods = sorted(df[time].unique())
645-
treatment_groups = sorted([g for g in df[first_treat].unique() if g > 0])
646-
647-
# Never-treated indicator (first_treat = 0 or inf)
643+
# Never-treated indicator (must precede treatment_groups to exclude np.inf)
648644
df['_never_treated'] = (df[first_treat] == 0) | (df[first_treat] == np.inf)
649645
# Normalize np.inf → 0 so all downstream `> 0` checks exclude never-treated
650646
df.loc[df[first_treat] == np.inf, first_treat] = 0
651647

648+
# Identify groups and time periods
649+
time_periods = sorted(df[time].unique())
650+
treatment_groups = sorted([g for g in df[first_treat].unique() if g > 0])
651+
652652
# Get unique units
653653
unit_info = df.groupby(unit).agg({
654654
first_treat: 'first',

diff_diff/sun_abraham.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -502,15 +502,15 @@ def fit(
502502
df[time] = pd.to_numeric(df[time])
503503
df[first_treat] = pd.to_numeric(df[first_treat])
504504

505-
# Identify groups and time periods
506-
time_periods = sorted(df[time].unique())
507-
treatment_groups = sorted([g for g in df[first_treat].unique() if g > 0])
508-
509-
# Never-treated indicator
505+
# Never-treated indicator (must precede treatment_groups to exclude np.inf)
510506
df["_never_treated"] = (df[first_treat] == 0) | (df[first_treat] == np.inf)
511507
# Normalize np.inf → 0 so all downstream `> 0` checks exclude never-treated
512508
df.loc[df[first_treat] == np.inf, first_treat] = 0
513509

510+
# Identify groups and time periods
511+
time_periods = sorted(df[time].unique())
512+
treatment_groups = sorted([g for g in df[first_treat].unique() if g > 0])
513+
514514
# Get unique units
515515
unit_info = (
516516
df.groupby(unit)

tests/test_sun_abraham.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1395,10 +1395,31 @@ def test_never_treated_inf_encoding(self):
13951395
for e in results_inf.event_study_effects.keys():
13961396
assert np.isfinite(e), f"Non-finite event time {e} in event study"
13971397

1398+
# np.inf must not appear in results.groups
1399+
assert np.inf not in results_inf.groups, (
1400+
f"np.inf found in results.groups: {results_inf.groups}"
1401+
)
1402+
13981403
# Results should be identical to first_treat=0 encoding
13991404
assert np.isclose(results_inf.overall_att, results_zero.overall_att), (
14001405
f"ATT differs: inf={results_inf.overall_att}, zero={results_zero.overall_att}"
14011406
)
14021407
assert np.isclose(results_inf.overall_se, results_zero.overall_se), (
14031408
f"SE differs: inf={results_inf.overall_se}, zero={results_zero.overall_se}"
14041409
)
1410+
1411+
def test_all_never_treated_inf_raises(self):
1412+
"""Test that all-never-treated data with np.inf encoding raises ValueError."""
1413+
data = generate_staggered_data(n_units=100, n_periods=10, n_cohorts=3, seed=42)
1414+
# Set ALL units to never-treated via np.inf
1415+
data["first_treat"] = np.inf
1416+
1417+
sa = SunAbraham(n_bootstrap=0)
1418+
with pytest.raises(ValueError, match="No treated units found"):
1419+
sa.fit(
1420+
data,
1421+
outcome="outcome",
1422+
unit="unit",
1423+
time="time",
1424+
first_treat="first_treat",
1425+
)

0 commit comments

Comments
 (0)