Skip to content

Commit 966a00d

Browse files
igerberclaude
andcommitted
Fix CI review Round 2: mixed panel overall_att, TWFE guard, bootstrap N_S=0
P0: When L_max >= 1, always set overall_att from per-group DID_{g,1} (not conditional on NaN). Fixes mixed binary/non-binary panels where per-period N_S > 0 but excludes non-binary switches. P1: Gate TWFE diagnostic and twowayfeweights() to binary-only treatment. Emit warning on fit(), raise ValueError on standalone helper. P1: Refactor _compute_dcdh_bootstrap() to skip scalar DID_M when divisor_overall <= 0 but still process multi_horizon_inputs and placebo_horizon_inputs. Fixes non-binary bootstrap path. P2: Add regressions for mixed 0->1/0->2 panel at L_max=1, non-binary bootstrap, and TWFE diagnostic skip on non-binary. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent df6db2c commit 966a00d

3 files changed

Lines changed: 150 additions & 53 deletions

File tree

diff_diff/chaisemartin_dhaultfoeuille.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,18 @@ def fit(
608608
# the same _validate_and_aggregate_to_cells() output.
609609
# ------------------------------------------------------------------
610610
twfe_diagnostic_payload = None
611-
if self.twfe_diagnostic:
611+
# TWFE diagnostic assumes binary treatment (d_arr == 1 for
612+
# treated mask). Skip for non-binary data with a warning.
613+
is_binary_pre = set(cell["d_gt"].unique()).issubset({0.0, 1.0, 0, 1})
614+
if self.twfe_diagnostic and not is_binary_pre:
615+
warnings.warn(
616+
"TWFE diagnostic (twfe_diagnostic=True) is not supported for "
617+
"non-binary treatment. The diagnostic assumes binary {0, 1} "
618+
"treatment. Skipping TWFE diagnostic for this fit.",
619+
UserWarning,
620+
stacklevel=2,
621+
)
622+
elif self.twfe_diagnostic:
612623
try:
613624
twfe_diagnostic_payload = _compute_twfe_diagnostic(
614625
cell=cell,
@@ -1576,15 +1587,16 @@ def fit(
15761587
# l=1 use the per-group DID_{g,l} path for a consistent estimand.
15771588
if multi_horizon_inference is not None and 1 in multi_horizon_inference:
15781589
# Per-group mode: use per-group path for all horizons.
1579-
# Also populate overall_att from l=1 when per-period path
1580-
# yielded NaN (non-binary treatment or no binary switchers).
1581-
if np.isnan(overall_att):
1582-
l1_inf = multi_horizon_inference[1]
1583-
overall_att = l1_inf["effect"]
1584-
overall_se = l1_inf["se"]
1585-
overall_t = l1_inf["t_stat"]
1586-
overall_p = l1_inf["p_value"]
1587-
overall_ci = l1_inf["conf_int"]
1590+
# When L_max >= 1, the per-group DID_{g,1} is the correct
1591+
# estimand for overall_att (not the binary-only per-period
1592+
# DID_M). This handles both pure non-binary (N_S=0) and
1593+
# mixed binary/non-binary panels (N_S > 0 but incomplete).
1594+
l1_inf = multi_horizon_inference[1]
1595+
overall_att = l1_inf["effect"]
1596+
overall_se = l1_inf["se"]
1597+
overall_t = l1_inf["t_stat"]
1598+
overall_p = l1_inf["p_value"]
1599+
overall_ci = l1_inf["conf_int"]
15881600
event_study_effects: Dict[int, Dict[str, Any]] = dict(multi_horizon_inference)
15891601
else:
15901602
# Phase 1 mode (L_max=None): l=1 from per-period path
@@ -3656,6 +3668,14 @@ def twowayfeweights(
36563668
time=time,
36573669
treatment=treatment,
36583670
)
3671+
# TWFE diagnostic assumes binary treatment (d_arr == 1 for treated mask).
3672+
if not set(cell["d_gt"].unique()).issubset({0.0, 1.0, 0, 1}):
3673+
raise ValueError(
3674+
"twowayfeweights() requires binary treatment {0, 1}. "
3675+
"Non-binary treatment is supported by fit() with L_max >= 1 "
3676+
"but the TWFE diagnostic (Theorem 1 of AER 2020) assumes "
3677+
"binary treatment."
3678+
)
36593679
return _compute_twfe_diagnostic(
36603680
cell=cell,
36613681
group_col=group,

diff_diff/chaisemartin_dhaultfoeuille_bootstrap.py

Lines changed: 20 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
produce a bootstrap distribution per target.
2020
"""
2121

22-
import warnings
2322
from typing import TYPE_CHECKING, Dict, Optional, Tuple
2423

2524
import numpy as np
@@ -159,29 +158,29 @@ def _compute_dcdh_bootstrap(
159158
f"u_centered_overall length ({u_centered_overall.shape[0]}) does not "
160159
f"match n_groups_for_overall ({n_groups_for_overall})"
161160
)
162-
if divisor_overall <= 0:
163-
warnings.warn(
164-
f"_compute_dcdh_bootstrap: divisor_overall={divisor_overall} <= 0; "
165-
"returning all-NaN bootstrap results.",
166-
RuntimeWarning,
167-
stacklevel=2,
168-
)
169-
return _empty_bootstrap_results(self.n_bootstrap, self.bootstrap_weights, self.alpha)
170-
171161
rng = np.random.default_rng(self.seed)
172162

173163
# --- Overall DID_M ---
174-
overall_se, overall_ci, overall_p, overall_dist = _bootstrap_one_target(
175-
u_centered=u_centered_overall,
176-
divisor=divisor_overall,
177-
original=original_overall,
178-
n_bootstrap=self.n_bootstrap,
179-
weight_type=self.bootstrap_weights,
180-
alpha=self.alpha,
181-
rng=rng,
182-
context="dCDH overall DID_M bootstrap",
183-
return_distribution=True,
184-
)
164+
# Skip the scalar DID_M bootstrap when divisor_overall <= 0
165+
# (e.g., pure non-binary panels where N_S=0), but continue
166+
# to process multi_horizon_inputs and placebo_horizon_inputs.
167+
if divisor_overall > 0:
168+
overall_se, overall_ci, overall_p, overall_dist = _bootstrap_one_target(
169+
u_centered=u_centered_overall,
170+
divisor=divisor_overall,
171+
original=original_overall,
172+
n_bootstrap=self.n_bootstrap,
173+
weight_type=self.bootstrap_weights,
174+
alpha=self.alpha,
175+
rng=rng,
176+
context="dCDH overall DID_M bootstrap",
177+
return_distribution=True,
178+
)
179+
else:
180+
overall_se = np.nan
181+
overall_ci = (np.nan, np.nan)
182+
overall_p = np.nan
183+
overall_dist = None
185184

186185
results = DCDHBootstrapResults(
187186
n_bootstrap=self.n_bootstrap,
@@ -398,15 +397,3 @@ def _bootstrap_one_target(
398397
return se, ci, p_value, (boot_dist if return_distribution else None)
399398

400399

401-
def _empty_bootstrap_results(
402-
n_bootstrap: int, weight_type: str, alpha: float
403-
) -> DCDHBootstrapResults:
404-
"""Return an all-NaN bootstrap results object as a graceful fallback."""
405-
return DCDHBootstrapResults(
406-
n_bootstrap=n_bootstrap,
407-
weight_type=weight_type,
408-
alpha=alpha,
409-
overall_se=np.nan,
410-
overall_ci=(np.nan, np.nan),
411-
overall_p_value=np.nan,
412-
)

tests/test_chaisemartin_dhaultfoeuille.py

Lines changed: 100 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1671,18 +1671,18 @@ def test_twowayfeweights_rejects_nan_outcome(self):
16711671
treatment="treatment",
16721672
)
16731673

1674-
def test_twowayfeweights_accepts_non_binary_treatment(self):
1675-
"""Non-binary treatment is now supported."""
1674+
def test_twowayfeweights_rejects_non_binary_treatment(self):
1675+
"""TWFE diagnostic requires binary treatment."""
16761676
data = generate_reversible_did_data(n_groups=20, n_periods=4, seed=1)
16771677
data.loc[data.index[0], "treatment"] = 2 # non-binary
1678-
result = twowayfeweights(
1679-
data,
1680-
outcome="outcome",
1681-
group="group",
1682-
time="period",
1683-
treatment="treatment",
1684-
)
1685-
assert result is not None
1678+
with pytest.raises(ValueError, match="binary treatment"):
1679+
twowayfeweights(
1680+
data,
1681+
outcome="outcome",
1682+
group="group",
1683+
time="period",
1684+
treatment="treatment",
1685+
)
16861686

16871687
def test_twowayfeweights_rejects_nan_group(self):
16881688
data = generate_reversible_did_data(n_groups=20, n_periods=4, seed=1)
@@ -2333,6 +2333,96 @@ def test_monotone_multi_step_dropped(self):
23332333
# Group 0 (0->1->2, 2 change periods) should be dropped
23342334
assert r.n_groups_dropped_crossers >= 1
23352335

2336+
def test_mixed_binary_nonbinary_panel_lmax1(self):
2337+
"""Mixed panel with both 0->1 and 0->2 switches at L_max=1.
2338+
overall_att should use the per-group path (includes all switches),
2339+
not the per-period path (binary-only)."""
2340+
np.random.seed(88)
2341+
rows = []
2342+
# Binary switchers: 0->1
2343+
for g in range(10):
2344+
for t in range(6):
2345+
d = 0 if t < 3 else 1
2346+
y = 10 + t + d * 2 + np.random.randn() * 0.3
2347+
rows.append({"group": g, "period": t, "treatment": d, "outcome": y})
2348+
# Non-binary switchers: 0->2
2349+
for g in range(10, 20):
2350+
for t in range(6):
2351+
d = 0 if t < 3 else 2
2352+
y = 10 + t + d * 1.5 + np.random.randn() * 0.3
2353+
rows.append({"group": g, "period": t, "treatment": d, "outcome": y})
2354+
# Controls
2355+
for g in range(20, 40):
2356+
for t in range(6):
2357+
y = 10 + t + np.random.randn() * 0.3
2358+
rows.append({"group": g, "period": t, "treatment": 0, "outcome": y})
2359+
df = pd.DataFrame(rows)
2360+
est = ChaisemartinDHaultfoeuille(twfe_diagnostic=False)
2361+
with warnings.catch_warnings():
2362+
warnings.simplefilter("ignore")
2363+
r = est.fit(
2364+
df, outcome="outcome", group="group", time="period",
2365+
treatment="treatment", L_max=1,
2366+
)
2367+
# overall_att should be from per-group path (includes both 0->1 and 0->2)
2368+
assert np.isfinite(r.overall_att)
2369+
# event_study_effects[1] and overall_att should be the same estimand
2370+
assert r.overall_att == r.event_study_effects[1]["effect"]
2371+
2372+
def test_nonbinary_bootstrap(self, ci_params):
2373+
"""Non-binary panel with bootstrap should produce finite event study SEs."""
2374+
np.random.seed(66)
2375+
n_boot = ci_params.bootstrap(99)
2376+
rows = []
2377+
for g in range(20):
2378+
for t in range(6):
2379+
d = 0 if t < 3 else 2
2380+
y = 10 + t + d * 1.5 + np.random.randn() * 0.3
2381+
rows.append({"group": g, "period": t, "treatment": d, "outcome": y})
2382+
for g in range(20, 40):
2383+
for t in range(6):
2384+
y = 10 + t + np.random.randn() * 0.3
2385+
rows.append({"group": g, "period": t, "treatment": 0, "outcome": y})
2386+
df = pd.DataFrame(rows)
2387+
est = ChaisemartinDHaultfoeuille(
2388+
twfe_diagnostic=False, n_bootstrap=n_boot, seed=42
2389+
)
2390+
with warnings.catch_warnings():
2391+
warnings.simplefilter("ignore")
2392+
r = est.fit(
2393+
df, outcome="outcome", group="group", time="period",
2394+
treatment="treatment", L_max=1,
2395+
)
2396+
assert r.bootstrap_results is not None
2397+
assert r.bootstrap_results.event_study_ses is not None
2398+
assert 1 in r.bootstrap_results.event_study_ses
2399+
assert np.isfinite(r.bootstrap_results.event_study_ses[1])
2400+
2401+
def test_twfe_diagnostic_skipped_nonbinary(self):
2402+
"""TWFE diagnostic should be skipped (with warning) for non-binary."""
2403+
np.random.seed(77)
2404+
rows = []
2405+
for g in range(20):
2406+
for t in range(6):
2407+
d = 0 if t < 3 else 2
2408+
y = 10 + t + d + np.random.randn() * 0.3
2409+
rows.append({"group": g, "period": t, "treatment": d, "outcome": y})
2410+
for g in range(20, 40):
2411+
for t in range(6):
2412+
y = 10 + t + np.random.randn() * 0.3
2413+
rows.append({"group": g, "period": t, "treatment": 0, "outcome": y})
2414+
df = pd.DataFrame(rows)
2415+
est = ChaisemartinDHaultfoeuille(twfe_diagnostic=True)
2416+
with warnings.catch_warnings(record=True) as w:
2417+
warnings.simplefilter("always")
2418+
r = est.fit(
2419+
df, outcome="outcome", group="group", time="period",
2420+
treatment="treatment", L_max=1,
2421+
)
2422+
twfe_warnings = [x for x in w if "TWFE diagnostic" in str(x.message)]
2423+
assert len(twfe_warnings) >= 1
2424+
assert r.twfe_weights is None # diagnostic was skipped
2425+
23362426
def test_normalized_effects_general_formula(self):
23372427
"""For non-binary treatment, normalized denominator uses actual dose change."""
23382428
np.random.seed(99)

0 commit comments

Comments
 (0)