Skip to content

Commit bc2eb77

Browse files
igerberclaude
andcommitted
Address PR review: edge case fixes for TwoStageDiD
- Always-treated warning now lists affected unit IDs (truncated at 10) - Bootstrap handles NaN y_tilde: masks NaN obs in static, event study, and group bootstrap paths; returns None when all treated obs are NaN - balance_e warns when no cohorts qualify instead of silently falling back - Add 3 edge case tests and REGISTRY.md update Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent bc0025d commit bc2eb77

3 files changed

Lines changed: 120 additions & 19 deletions

File tree

diff_diff/two_stage.py

Lines changed: 48 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -593,18 +593,19 @@ def fit(
593593
# Check for always-treated units
594594
min_time = df[time].min()
595595
always_treated_mask = (~df["_never_treated"]) & (df[first_treat] <= min_time)
596-
n_always_treated = df.loc[always_treated_mask, unit].nunique()
596+
always_treated_units = df.loc[always_treated_mask, unit].unique()
597+
n_always_treated = len(always_treated_units)
597598
if n_always_treated > 0:
599+
unit_list = ", ".join(str(u) for u in always_treated_units[:10])
600+
suffix = f" (and {n_always_treated - 10} more)" if n_always_treated > 10 else ""
598601
warnings.warn(
599602
f"{n_always_treated} unit(s) are treated in all observed periods "
600-
f"(first_treat <= {min_time}). These units have no untreated "
601-
"observations and cannot contribute to the counterfactual model. "
602-
"Excluding from estimation.",
603+
f"(first_treat <= {min_time}): [{unit_list}{suffix}]. "
604+
"These units have no untreated observations and cannot contribute "
605+
"to the counterfactual model. Excluding from estimation.",
603606
UserWarning,
604607
stacklevel=2,
605608
)
606-
# Exclude always-treated units
607-
always_treated_units = df.loc[always_treated_mask, unit].unique()
608609
df = df[~df[unit].isin(always_treated_units)].copy()
609610

610611
# Treatment indicator with anticipation
@@ -1183,11 +1184,25 @@ def _stage2_event_study(
11831184
for g, horizons in cohort_rel_times.items():
11841185
if required_range.issubset(horizons):
11851186
balanced_cohorts.add(g)
1186-
balance_mask = (
1187-
df[first_treat].isin(balanced_cohorts).values
1188-
if balanced_cohorts
1189-
else np.ones(n, dtype=bool)
1190-
)
1187+
if not balanced_cohorts:
1188+
warnings.warn(
1189+
f"No cohorts satisfy balance_e={balance_e} requirement. "
1190+
"Event study results will contain only the reference period. "
1191+
"Consider reducing balance_e.",
1192+
UserWarning,
1193+
stacklevel=2,
1194+
)
1195+
return {
1196+
ref_period: {
1197+
"effect": 0.0,
1198+
"se": 0.0,
1199+
"t_stat": np.nan,
1200+
"p_value": np.nan,
1201+
"conf_int": (0.0, 0.0),
1202+
"n_obs": 0,
1203+
}
1204+
}
1205+
balance_mask = df[first_treat].isin(balanced_cohorts).values
11911206
else:
11921207
balance_mask = np.ones(n, dtype=bool)
11931208

@@ -1724,7 +1739,7 @@ def _run_bootstrap(
17241739
original_event_study: Optional[Dict[int, Dict[str, Any]]],
17251740
original_group: Optional[Dict[Any, Dict[str, Any]]],
17261741
aggregate: Optional[str],
1727-
) -> TwoStageBootstrapResults:
1742+
) -> Optional[TwoStageBootstrapResults]:
17281743
"""Run multiplier bootstrap on GMM influence function."""
17291744
if self.n_bootstrap < 50:
17301745
warnings.warn(
@@ -1738,12 +1753,23 @@ def _run_bootstrap(
17381753

17391754
from diff_diff.staggered_bootstrap import _generate_bootstrap_weights_batch
17401755

1741-
y_tilde = df["_y_tilde"].values
1756+
y_tilde = df["_y_tilde"].values.copy() # .copy() to avoid mutating df column
17421757
n = len(df)
17431758
cluster_ids = df[cluster_var].values
17441759

1760+
# Handle NaN y_tilde (from unidentified FEs) — matches _stage2_static logic
1761+
nan_mask = ~np.isfinite(y_tilde)
1762+
if nan_mask.any():
1763+
y_tilde[nan_mask] = 0.0
1764+
17451765
# --- Static specification bootstrap ---
1746-
D = omega_1_mask.values.astype(float)
1766+
D = omega_1_mask.values.astype(float) # .astype() already creates a copy
1767+
D[nan_mask] = 0.0 # Exclude NaN y_tilde obs from bootstrap estimation
1768+
1769+
# Degenerate case: all treated obs have NaN y_tilde
1770+
if D.sum() == 0:
1771+
return None
1772+
17471773
X_2_static = D.reshape(-1, 1)
17481774
coef_static = solve_ols(X_2_static, y_tilde, return_vcov=False)[0]
17491775
eps_2_static = y_tilde - X_2_static @ coef_static
@@ -1811,11 +1837,10 @@ def _run_bootstrap(
18111837
for g, horizons in cohort_rel_times.items():
18121838
if required_range.issubset(horizons):
18131839
balanced_cohorts.add(g)
1814-
balance_mask = (
1815-
df[first_treat].isin(balanced_cohorts).values
1816-
if balanced_cohorts
1817-
else np.ones(n, dtype=bool)
1818-
)
1840+
if not balanced_cohorts:
1841+
all_horizons = [] # No qualifying cohorts -> skip event study bootstrap
1842+
else:
1843+
balance_mask = df[first_treat].isin(balanced_cohorts).values
18191844
else:
18201845
balance_mask = np.ones(n, dtype=bool)
18211846

@@ -1827,6 +1852,8 @@ def _run_bootstrap(
18271852
for i in range(n):
18281853
if not balance_mask[i]:
18291854
continue
1855+
if nan_mask[i]:
1856+
continue # NaN y_tilde -> exclude from bootstrap event study
18301857
h = rel_times[i]
18311858
if np.isfinite(h):
18321859
h_int = int(h)
@@ -1890,6 +1917,8 @@ def _run_bootstrap(
18901917
treated_mask = omega_1_mask.values
18911918
for i in range(n):
18921919
if treated_mask[i]:
1920+
if nan_mask[i]:
1921+
continue # NaN y_tilde -> exclude from group bootstrap
18931922
g = ft_vals[i]
18941923
if g in group_to_col:
18951924
X_2_grp[i, group_to_col[g]] = 1.0

docs/methodology/REGISTRY.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,7 @@ Our implementation uses multiplier bootstrap on the GMM influence function: clus
618618
- **NaN y_tilde handling:** When Stage 1 FE are unidentified for some observations, the residualized outcome `y_tilde` is NaN. These observations are zeroed out (excluded) from the Stage 2 regression and variance computation, matching the treatment of unimputable observations in ImputationDiD.
619619
- **NaN inference for undefined statistics:** t_stat uses NaN when SE is non-finite or zero; p_value and CI also NaN. Matches CallawaySantAnna/ImputationDiD NaN convention.
620620
- **Event study aggregation:** Horizon-specific effects use the same two-stage procedure with horizon indicator dummies in Stage 2. Unidentified horizons (e.g., long-run effects without never-treated units, per Proposition 5 of Borusyak et al. 2024) produce NaN.
621+
- **balance_e with no qualifying cohorts:** If no cohorts have sufficient pre/post coverage for the requested `balance_e`, a warning is emitted and event study results contain only the reference period.
621622
- **No never-treated units:** Long-run effects may be unidentified (same limitation as ImputationDiD). Warning emitted for affected horizons.
622623

623624
**Reference implementation(s):**

tests/test_two_stage.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -723,6 +723,77 @@ def test_horizon_max(self):
723723
if results.event_study_effects[h].get("n_obs", 0) > 0:
724724
assert abs(h) <= 2
725725

726+
def test_always_treated_warning_lists_unit_ids(self):
727+
"""Always-treated warning should include affected unit IDs."""
728+
data = generate_test_data()
729+
730+
# Add two always-treated units (first_treat before min_time=0)
731+
always_treated = pd.DataFrame(
732+
{
733+
"unit": np.repeat([997, 998], 10),
734+
"time": np.tile(np.arange(10), 2),
735+
"outcome": np.random.default_rng(42).standard_normal(20),
736+
"first_treat": np.repeat([-1, -2], 10),
737+
}
738+
)
739+
data_with_always = pd.concat([data, always_treated], ignore_index=True)
740+
741+
with warnings.catch_warnings(record=True) as w:
742+
warnings.simplefilter("always")
743+
TwoStageDiD().fit(
744+
data_with_always,
745+
outcome="outcome",
746+
unit="unit",
747+
time="time",
748+
first_treat="first_treat",
749+
)
750+
always_warns = [x for x in w if "treated in all observed periods" in str(x.message)]
751+
assert len(always_warns) == 1
752+
msg = str(always_warns[0].message)
753+
assert "997" in msg
754+
assert "998" in msg
755+
756+
def test_bootstrap_with_nan_y_tilde(self, ci_params):
757+
"""Bootstrap should handle NaN y_tilde from unidentified FEs."""
758+
# No never-treated units: cohorts 3, 5, 7 on periods 0-9 means
759+
# periods 7-9 have zero untreated obs -> NaN y_tilde
760+
data = generate_test_data(never_treated_frac=0.0)
761+
n_boot = ci_params.bootstrap(20)
762+
763+
results = TwoStageDiD(n_bootstrap=n_boot).fit(
764+
data,
765+
outcome="outcome",
766+
unit="unit",
767+
time="time",
768+
first_treat="first_treat",
769+
)
770+
771+
assert np.isfinite(results.overall_att)
772+
assert results.overall_se > 0
773+
774+
def test_balance_e_empty_cohorts_warns(self):
775+
"""Unreasonably large balance_e should warn when no cohorts qualify."""
776+
data = generate_test_data()
777+
778+
with warnings.catch_warnings(record=True) as w:
779+
warnings.simplefilter("always")
780+
results = TwoStageDiD().fit(
781+
data,
782+
outcome="outcome",
783+
unit="unit",
784+
time="time",
785+
first_treat="first_treat",
786+
aggregate="event_study",
787+
balance_e=100, # No cohort can satisfy this
788+
)
789+
balance_warns = [x for x in w if "No cohorts satisfy" in str(x.message)]
790+
assert len(balance_warns) > 0
791+
792+
# Event study should contain only the reference period
793+
assert len(results.event_study_effects) == 1
794+
ref_key = list(results.event_study_effects.keys())[0]
795+
assert results.event_study_effects[ref_key]["n_obs"] == 0
796+
726797

727798
# =============================================================================
728799
# TestTwoStageDiDParameters

0 commit comments

Comments
 (0)