Skip to content

Commit 71bf84a

Browse files
igerberclaude
andcommitted
Add empty-treatment guard, last_cohort aggregate/bootstrap tests, fix n_obs
Address rerun review findings: - Guard against empty treatment_groups before max() in last_cohort branch - Fix n_obs to report effective sample size (n_units * n_periods) after trimming - Add last_cohort tests for aggregate=event_study/all and bootstrap - Add all-never-treated + last_cohort regression test - Update influence_functions docstring to match dict contract Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 2cda8b9 commit 71bf84a

3 files changed

Lines changed: 66 additions & 3 deletions

File tree

diff_diff/efficient_did.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,11 @@ def fit(
415415
# Control group logic
416416
if self.control_group == "last_cohort":
417417
# Always reclassify last cohort as pseudo-control when requested
418+
if not treatment_groups:
419+
raise ValueError(
420+
"No treated cohorts found. control_group='last_cohort' requires "
421+
"at least 2 treatment cohorts."
422+
)
418423
last_g = max(treatment_groups)
419424
treatment_groups = [g for g in treatment_groups if g != last_g]
420425
if not treatment_groups:
@@ -900,7 +905,7 @@ def fit(
900905
overall_conf_int=overall_ci,
901906
groups=treatment_groups,
902907
time_periods=time_periods,
903-
n_obs=len(df),
908+
n_obs=n_units * len(time_periods),
904909
n_treated_units=n_treated_units,
905910
n_control_units=n_control_units,
906911
alpha=self.alpha,

diff_diff/efficient_did_results.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,10 @@ class EfficientDiDResults:
105105
``{(g, t): ndarray}`` — diagnostic: weight vector per target.
106106
omega_condition_numbers : dict, optional
107107
``{(g, t): float}`` — diagnostic: Omega* condition numbers.
108-
influence_functions : ndarray, optional
109-
Stored EIF matrix for bootstrap / manual SE computation.
108+
influence_functions : dict, optional
109+
``{(g, t): ndarray(n_units,)}`` — per-unit EIF values for each
110+
group-time cell. Only populated when ``store_eif=True`` in
111+
:meth:`~EfficientDiD.fit` (used internally by ``hausman_pretest``).
110112
bootstrap_results : EDiDBootstrapResults, optional
111113
Bootstrap inference results.
112114
estimation_path : str

tests/test_efficient_did.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,6 +1026,62 @@ def test_control_group_invalid_raises(self):
10261026
with pytest.raises(ValueError, match="control_group"):
10271027
EfficientDiD(control_group="invalid")
10281028

1029+
def test_last_cohort_no_treated_raises(self):
1030+
"""All-never-treated data with last_cohort should raise."""
1031+
df = _make_staggered_panel(n_per_group=0, n_control=100, groups=())
1032+
with pytest.raises(ValueError, match="No treated cohorts"):
1033+
EfficientDiD(control_group="last_cohort").fit(df, "y", "unit", "time", "first_treat")
1034+
1035+
def test_last_cohort_aggregate_event_study(self):
1036+
"""last_cohort with aggregate='event_study' should produce finite results."""
1037+
df = _make_staggered_panel(
1038+
n_per_group=60,
1039+
n_control=0,
1040+
groups=(3, 5, 7),
1041+
effects={3: 2.0, 5: 1.5, 7: 1.0},
1042+
)
1043+
result = EfficientDiD(control_group="last_cohort").fit(
1044+
df, "y", "unit", "time", "first_treat", aggregate="event_study"
1045+
)
1046+
assert result.event_study_effects is not None
1047+
assert 7 not in result.groups
1048+
for e, d in result.event_study_effects.items():
1049+
assert np.isfinite(d["effect"])
1050+
1051+
def test_last_cohort_aggregate_all(self):
1052+
"""last_cohort with aggregate='all' should produce finite results."""
1053+
df = _make_staggered_panel(
1054+
n_per_group=60,
1055+
n_control=0,
1056+
groups=(3, 5, 7),
1057+
effects={3: 2.0, 5: 1.5, 7: 1.0},
1058+
)
1059+
result = EfficientDiD(control_group="last_cohort").fit(
1060+
df, "y", "unit", "time", "first_treat", aggregate="all"
1061+
)
1062+
assert result.event_study_effects is not None
1063+
assert result.group_effects is not None
1064+
assert 7 not in result.groups
1065+
for g, d in result.group_effects.items():
1066+
assert g != 7
1067+
assert np.isfinite(d["effect"])
1068+
1069+
def test_last_cohort_bootstrap(self, ci_params):
1070+
"""last_cohort with bootstrap should produce finite inference."""
1071+
n_boot = ci_params.bootstrap(99)
1072+
df = _make_staggered_panel(
1073+
n_per_group=60,
1074+
n_control=0,
1075+
groups=(3, 5, 7),
1076+
effects={3: 2.0, 5: 1.5, 7: 1.0},
1077+
)
1078+
result = EfficientDiD(control_group="last_cohort", n_bootstrap=n_boot, seed=42).fit(
1079+
df, "y", "unit", "time", "first_treat"
1080+
)
1081+
assert np.isfinite(result.overall_se)
1082+
assert result.overall_se > 0
1083+
assert 7 not in result.groups
1084+
10291085

10301086
class TestBalanceE:
10311087
"""Test balance_e event study balancing."""

0 commit comments

Comments
 (0)