Skip to content

Commit 7ed7fde

Browse files
igerberclaude
andcommitted
Fix CI review P0s: delta dose, placebo sign, sup-t calibration, l=1 consistency
- Fix cost-benefit delta to use cumulative dose (sum_{k=0}^{l-1} |D_{g,F_g+k} - D_{g,1}|) instead of one-period dose; binary weights now proportional to l * N_l - Flip dynamic placebo sign to ref-minus-preperiod (Y_{ref} - Y_{backward}), matching the Phase 1 convention - Include l=1 in sup-t bootstrap calibration so bands are truly simultaneous over all horizons 1..L_max - Use per-group DID_{g,l} path for event_study_effects[1] when L_max >= 2, making all horizons use a consistent estimand - Label overall_att as "delta" in summary/to_dataframe when L_max > 1 - Add A11 control-availability warnings for multi-horizon empty control pools Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 1518bbc commit 7ed7fde

4 files changed

Lines changed: 98 additions & 41 deletions

diff_diff/chaisemartin_dhaultfoeuille.py

Lines changed: 70 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,6 +1030,21 @@ def fit(
10301030
T_g=T_g_arr,
10311031
L_max=L_max,
10321032
)
1033+
# Surface A11 warnings from multi-horizon computation
1034+
mh_a11 = multi_horizon_dids.pop("_a11_warnings", None)
1035+
if mh_a11:
1036+
warnings.warn(
1037+
f"Multi-horizon control-availability violations in "
1038+
f"{len(mh_a11)} (group, horizon) pair(s): affected "
1039+
f"DID_{{g,l}} values are zeroed but their switcher "
1040+
f"counts are retained in N_l (matching the A11 "
1041+
f"zero-retention convention). Examples: "
1042+
+ ", ".join(mh_a11[:3])
1043+
+ (f" (and {len(mh_a11) - 3} more)" if len(mh_a11) > 3 else ""),
1044+
UserWarning,
1045+
stacklevel=2,
1046+
)
1047+
10331048
multi_horizon_if = _compute_per_group_if_multi_horizon(
10341049
D_mat=D_mat,
10351050
Y_mat=Y_mat,
@@ -1051,7 +1066,10 @@ def fit(
10511066

10521067
multi_horizon_se = {}
10531068
multi_horizon_inference = {}
1054-
for l_h in range(2, L_max + 1):
1069+
# Compute inference for ALL horizons 1..L_max (including l=1)
1070+
# so the event_study_effects dict uses a consistent estimand
1071+
# (per-group DID_{g,l}) across all horizons.
1072+
for l_h in range(1, L_max + 1):
10551073
U_l = multi_horizon_if[l_h]
10561074
# Cohort IDs for this horizon: (D_{g,1}, F_g, S_g) triples
10571075
# are the same as Phase 1 (cohort identity depends on first
@@ -1315,7 +1333,12 @@ def fit(
13151333
[g not in singleton_baseline_set_b for g in all_groups], dtype=bool
13161334
)
13171335
mh_boot_inputs = {}
1318-
for l_h in range(2, L_max + 1):
1336+
# Include ALL horizons 1..L_max so the sup-t critical
1337+
# value is calibrated over the same set that receives
1338+
# cband_conf_int. For l=1, use the per-group IF (not
1339+
# the Phase 1 per-period IF) so the bootstrap matches
1340+
# the event_study_effects[1] estimand.
1341+
for l_h in range(1, L_max + 1):
13191342
h_data = multi_horizon_dids.get(l_h)
13201343
if h_data is None or h_data["N_l"] == 0:
13211344
continue
@@ -1400,22 +1423,24 @@ def fit(
14001423
# ------------------------------------------------------------------
14011424
# Step 20: Build the results dataclass
14021425
# ------------------------------------------------------------------
1403-
# event_study_effects: l=1 always mirrors the Phase 1 DID_M output.
1404-
# When L_max >= 2, horizons 2..L_max are populated from the Phase 2
1405-
# multi-horizon computation.
1406-
event_study_effects: Dict[int, Dict[str, Any]] = {
1407-
1: {
1408-
"effect": overall_att,
1409-
"se": overall_se,
1410-
"t_stat": overall_t,
1411-
"p_value": overall_p,
1412-
"conf_int": overall_ci,
1413-
"n_obs": N_S,
1426+
# event_study_effects: when L_max is None, l=1 mirrors Phase 1
1427+
# DID_M (per-period path). When L_max >= 2, ALL horizons including
1428+
# l=1 use the per-group DID_{g,l} path for a consistent estimand.
1429+
if multi_horizon_inference is not None and 1 in multi_horizon_inference:
1430+
# Phase 2 mode: use per-group path for all horizons
1431+
event_study_effects: Dict[int, Dict[str, Any]] = dict(multi_horizon_inference)
1432+
else:
1433+
# Phase 1 mode (L_max=None): l=1 from per-period path
1434+
event_study_effects = {
1435+
1: {
1436+
"effect": overall_att,
1437+
"se": overall_se,
1438+
"t_stat": overall_t,
1439+
"p_value": overall_p,
1440+
"conf_int": overall_ci,
1441+
"n_obs": N_S,
1442+
}
14141443
}
1415-
}
1416-
if multi_horizon_inference is not None:
1417-
for l_h, inf_dict in multi_horizon_inference.items():
1418-
event_study_effects[l_h] = inf_dict
14191444

14201445
# Phase 2: propagate bootstrap results to event_study_effects
14211446
if bootstrap_results is not None and bootstrap_results.event_study_ses:
@@ -1514,7 +1539,7 @@ def fit(
15141539
denom = n_data["denominator"]
15151540
eff = n_data["effect"]
15161541
# SE via delta method: SE(DID^n_l) = SE(DID_l) / delta^D_l
1517-
se_did_l = multi_horizon_se.get(l_h, float("nan")) if l_h >= 2 else overall_se
1542+
se_did_l = multi_horizon_se.get(l_h, float("nan"))
15181543
se_norm = se_did_l / denom if np.isfinite(denom) and denom > 0 else float("nan")
15191544
t_n, p_n, ci_n = safe_inference(eff, se_norm, alpha=self.alpha, df=None)
15201545
normalized_effects_out[l_h] = {
@@ -2119,6 +2144,7 @@ def _compute_multi_horizon_dids(
21192144
baseline_f[int(d)] = first_switch_idx[mask]
21202145

21212146
results: Dict[int, Dict[str, Any]] = {}
2147+
a11_multi_warnings: List[str] = []
21222148
N_1 = 0 # will be set at l=1 for switcher_fraction
21232149

21242150
for l in range(1, L_max + 1): # noqa: E741
@@ -2187,6 +2213,10 @@ def _compute_multi_horizon_dids(
21872213
# matching the A11 zero-retention convention: the group's
21882214
# switcher count is still in N_l.
21892215
did_g_l[g] = 0.0
2216+
a11_multi_warnings.append(
2217+
f"horizon {l}, group_idx {g}: "
2218+
f"no baseline-matched controls at outcome period"
2219+
)
21902220
continue
21912221

21922222
ctrl_changes = Y_mat[ctrl_pool, out_idx] - Y_mat[ctrl_pool, ref_idx]
@@ -2206,6 +2236,10 @@ def _compute_multi_horizon_dids(
22062236
"switcher_fraction": N_l / N_1 if N_1 > 0 else float("nan"),
22072237
}
22082238

2239+
# Attach A11 warnings to the results for the caller to surface
2240+
if a11_multi_warnings:
2241+
results["_a11_warnings"] = a11_multi_warnings # type: ignore[assignment]
2242+
22092243
return results
22102244

22112245

@@ -2393,8 +2427,9 @@ def _compute_multi_horizon_placebos(
23932427
forward_idx = ref_idx + l
23942428
d_base = int(baselines[g])
23952429

2396-
# Switcher's backward outcome change
2397-
switcher_change = Y_mat[g, backward_idx] - Y_mat[g, ref_idx]
2430+
# Switcher's backward outcome change: reference minus pre-period
2431+
# (matching Phase 1 convention: Y_{ref} - Y_{earlier})
2432+
switcher_change = Y_mat[g, ref_idx] - Y_mat[g, backward_idx]
23982433

23992434
# Control pool: same baseline, not switched by forward_idx
24002435
ctrl_indices = baseline_groups[d_base]
@@ -2410,7 +2445,7 @@ def _compute_multi_horizon_placebos(
24102445
pl_g_l[g] = 0.0
24112446
continue
24122447

2413-
ctrl_changes = Y_mat[ctrl_pool, backward_idx] - Y_mat[ctrl_pool, ref_idx]
2448+
ctrl_changes = Y_mat[ctrl_pool, ref_idx] - Y_mat[ctrl_pool, backward_idx]
24142449
ctrl_avg = float(ctrl_changes.mean())
24152450
pl_g_l[g] = switcher_change - ctrl_avg
24162451

@@ -2522,9 +2557,14 @@ def _compute_cost_benefit_delta(
25222557
dose_l = 0.0
25232558
for g in np.where(eligible)[0]:
25242559
f_g = first_switch_idx[g]
2525-
col = f_g - 1 + l
2526-
if col < D_mat.shape[1]:
2527-
dose_l += abs(float(D_mat[g, col] - baselines[g]))
2560+
# Cumulative dose: delta^D_{g,l} = sum_{k=0}^{l-1} |D_{g,F_g+k} - D_{g,1}|
2561+
# For binary treatment this equals l (each period contributes 1).
2562+
cum_dose = 0.0
2563+
for k in range(l):
2564+
col_k = f_g + k
2565+
if col_k < D_mat.shape[1]:
2566+
cum_dose += abs(float(D_mat[g, col_k] - baselines[g]))
2567+
dose_l += cum_dose
25282568
per_horizon_dose[l] = dose_l
25292569
total_dose += dose_l
25302570

@@ -2572,9 +2612,12 @@ def _compute_cost_benefit_delta(
25722612
if switch_direction[g] != direction:
25732613
continue
25742614
f_g = first_switch_idx[g]
2575-
col = f_g - 1 + l
2576-
if col < D_mat.shape[1]:
2577-
dose_l += abs(float(D_mat[g, col] - baselines[g]))
2615+
cum_dose = 0.0
2616+
for k in range(l):
2617+
col_k = f_g + k
2618+
if col_k < D_mat.shape[1]:
2619+
cum_dose += abs(float(D_mat[g, col_k] - baselines[g]))
2620+
dose_l += cum_dose
25782621
dir_horizon_dose[l] = dose_l
25792622
dir_dose += dose_l
25802623

diff_diff/chaisemartin_dhaultfoeuille_results.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -505,16 +505,22 @@ def summary(self, alpha: Optional[float] = None) -> str:
505505
]
506506
)
507507

508-
# --- Overall DID_M ---
508+
# --- Overall ---
509+
overall_label = (
510+
"Cost-Benefit Delta"
511+
if self.L_max is not None and self.L_max >= 2
512+
else "DID_M (Contemporaneous-Switch ATT)"
513+
)
514+
overall_row_label = "delta" if self.L_max is not None and self.L_max >= 2 else "DID_M"
509515
lines.extend(
510516
[
511517
thin,
512-
"DID_M (Contemporaneous-Switch ATT)".center(width),
518+
overall_label.center(width),
513519
thin,
514520
header_row,
515521
thin,
516522
_format_inference_row(
517-
"DID_M",
523+
overall_row_label,
518524
self.overall_att,
519525
self.overall_se,
520526
self.overall_t_stat,
@@ -772,7 +778,9 @@ def to_dataframe(self, level: str = "overall") -> pd.DataFrame:
772778
return pd.DataFrame(
773779
[
774780
{
775-
"estimand": "DID_M",
781+
"estimand": (
782+
"delta" if self.L_max is not None and self.L_max >= 2 else "DID_M"
783+
),
776784
"effect": self.overall_att,
777785
"se": self.overall_se,
778786
"t_stat": self.overall_t_stat,

tests/test_chaisemartin_dhaultfoeuille.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1836,12 +1836,12 @@ def test_L_max_populates_event_study_effects(self, data):
18361836
assert "n_obs" in entry
18371837
assert entry["n_obs"] > 0
18381838

1839-
def test_did_l_equals_did_m_at_l1(self, data):
1840-
"""event_study_effects[1] must equal DID_M from Phase 1."""
1839+
def test_did_l1_uses_per_group_path_when_L_max(self, data):
1840+
"""When L_max >= 2, event_study_effects[1] uses the per-group
1841+
DID_{g,1} path (consistent with horizons 2..L_max), which may
1842+
differ from the Phase 1 per-period DID_M. The per-period DID_M
1843+
is still available via the L_max=None path."""
18411844
est = ChaisemartinDHaultfoeuille(placebo=False, twfe_diagnostic=False)
1842-
r_none = est.fit(
1843-
data, outcome="outcome", group="group", time="period", treatment="treatment"
1844-
)
18451845
r_multi = est.fit(
18461846
data,
18471847
outcome="outcome",
@@ -1850,7 +1850,9 @@ def test_did_l_equals_did_m_at_l1(self, data):
18501850
treatment="treatment",
18511851
L_max=3,
18521852
)
1853-
assert r_multi.event_study_effects[1]["effect"] == pytest.approx(r_none.overall_att)
1853+
# event_study_effects[1] is populated and finite
1854+
assert np.isfinite(r_multi.event_study_effects[1]["effect"])
1855+
assert np.isfinite(r_multi.event_study_effects[1]["se"])
18541856

18551857
def test_N_l_decreases_with_horizon(self, data):
18561858
"""n_obs generally decreases for far horizons."""

tests/test_chaisemartin_dhaultfoeuille_parity.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -269,12 +269,16 @@ def test_parity_leavers_only_multi_horizon(self, golden_values):
269269

270270
def test_parity_mixed_single_switch_multi_horizon(self, golden_values):
271271
self._check_multi_horizon(
272-
golden_values, "mixed_single_switch_multi_horizon",
273-
L_max=5, rtol=self.MIXED_POINT_RTOL,
272+
golden_values,
273+
"mixed_single_switch_multi_horizon",
274+
L_max=5,
275+
rtol=self.MIXED_POINT_RTOL,
274276
)
275277

276278
def test_parity_joiners_only_long_multi_horizon(self, golden_values):
277279
self._check_multi_horizon(
278-
golden_values, "joiners_only_long_multi_horizon",
279-
L_max=5, rtol=self.POINT_RTOL,
280+
golden_values,
281+
"joiners_only_long_multi_horizon",
282+
L_max=5,
283+
rtol=self.POINT_RTOL,
280284
)

0 commit comments

Comments
 (0)