Skip to content

Commit ddc09e4

Browse files
igerberclaude
andcommitted
Address PR #346 CI review round 3: P1 period inference + P2 summary label
**P1 (Methodology): _validate_had_panel inferred pre/post by lexicographic sort** Previously the validator sorted the two period labels alphabetically and assigned `t_pre=periods[0]`, `t_post=periods[1]`. On supported string-labelled panels like `("pre", "post")` the alphabetic order is ["post", "pre"], so the code flipped pre and post and then raised on the treated-period D>0 check for a valid design. Same bug for `("before", "after")` and any non-alphabetic-chronological label pair. Fix: identify `t_pre` as the unique period where dose == 0 for ALL units (HAD paper Section 2 no-unit-untreated convention); `t_post` is the other period. This is a DGP-consistent invariant, not a string ordering. If neither period has all-zero dose, raise with the contract message and per-period nonzero-count diagnostics. If both periods have all-zero dose, raise (no treatment variation to estimate). The existing pre-period D=0 check is now tautological and has been removed since the inference itself enforces the invariant. Behavior on valid numeric panels (e.g., 2020/2021) is unchanged. **P2 (Code Quality): summary() hardcoded 'WAS' row label** `HeterogeneousAdoptionDiDResults.summary()` printed "WAS" as the parameter label regardless of the resolved design. For Design 1 paths (continuous_near_d_lower, mass_point) the stored `target_parameter` is "WAS_d_lower" per paper Sections 3.2.2-3.2.4, so the user-facing output misrepresented the estimand. Fix: render `self.target_parameter` in the summary row. Now Design 1' prints "WAS", Design 1 prints "WAS_d_lower", matching the stored result metadata. **Tests (+7 regression):** - TestValidateHadPanel.test_semantic_pre_post_labels_not_lexicographic - TestValidateHadPanel.test_semantic_pre_post_with_first_treat_col - TestValidateHadPanel.test_semantic_pre_post_fit_end_to_end - TestValidateHadPanel.test_before_after_labels - TestValidateHadPanel.test_no_all_zero_period_raises - TestValidateHadPanel.test_both_all_zero_periods_raises - TestResultMethods.test_summary_uses_target_parameter_for_row_label Targeted regression: 133 HAD tests + 512 total across Phase 1 and adjacent surfaces, all green. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent c320dde commit ddc09e4

2 files changed

Lines changed: 134 additions & 18 deletions

File tree

diff_diff/had.py

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ def summary(self) -> str:
289289
bc = self.bias_corrected_fit
290290
lines.append(f"{'Bandwidth h used:':<30} {bc.h:>20.6g}")
291291
lines.append(f"{'Obs in window (n_used):':<30} {bc.n_used:>20}")
292+
param_label = self.target_parameter
292293
lines.extend(
293294
[
294295
"",
@@ -299,7 +300,7 @@ def summary(self) -> str:
299300
),
300301
"-" * width,
301302
(
302-
f"{'WAS':<15} {self.att:>12.4f} {self.se:>12.4f} "
303+
f"{param_label:<15} {self.att:>12.4f} {self.se:>12.4f} "
303304
f"{self.t_stat:>10.3f} {self.p_value:>10.4f}"
304305
),
305306
"-" * width,
@@ -395,30 +396,27 @@ def _validate_had_panel(
395396
if missing:
396397
raise ValueError(f"Missing column(s) in data: {missing}. Required: {required}.")
397398

398-
periods = np.sort(np.asarray(data[time_col].unique()))
399-
if len(periods) < 2:
399+
periods_list = list(data[time_col].unique())
400+
if len(periods_list) < 2:
400401
raise ValueError(
401-
f"HAD requires a two-period panel; got {len(periods)} distinct "
402+
f"HAD requires a two-period panel; got {len(periods_list)} distinct "
402403
f"period(s) in column {time_col!r}."
403404
)
404-
if len(periods) > 2:
405+
if len(periods_list) > 2:
405406
if first_treat_col is None:
406407
raise ValueError(
407408
f"HAD Phase 2a requires exactly two time periods "
408-
f"(got {len(periods)} in {time_col!r}) when "
409+
f"(got {len(periods_list)} in {time_col!r}) when "
409410
f"first_treat_col=None. Multi-period / staggered adoption "
410411
f"support is queued for Phase 2b (Appendix B.2 event-study)."
411412
)
412413
raise ValueError(
413414
f"HAD Phase 2a requires exactly two time periods "
414-
f"(got {len(periods)} in {time_col!r}). Staggered adoption "
415+
f"(got {len(periods_list)} in {time_col!r}). Staggered adoption "
415416
f"reduction (first_treat_col supplied with >2 periods) is "
416417
f"queued for Phase 2b (Appendix B.2 event-study)."
417418
)
418419

419-
t_pre = int(periods[0]) if np.issubdtype(periods.dtype, np.integer) else periods[0]
420-
t_post = int(periods[1]) if np.issubdtype(periods.dtype, np.integer) else periods[1]
421-
422420
# Balanced-panel check: every unit appears exactly once per period.
423421
counts = data.groupby([unit_col, time_col]).size()
424422
if (counts != 1).any():
@@ -446,17 +444,35 @@ def _validate_had_panel(
446444
f"calling fit()."
447445
)
448446

449-
# Pre-period no-unit-untreated check.
450-
pre_mask = data[time_col] == t_pre
451-
pre_doses = np.asarray(data.loc[pre_mask, dose_col], dtype=np.float64)
452-
nonzero_pre = pre_doses != 0
453-
if nonzero_pre.any():
454-
n_bad = int(nonzero_pre.sum())
447+
# Identify t_pre and t_post by the HAD invariant rather than by
448+
# lexicographic sort on the time labels: D_{g, t_pre} = 0 for all
449+
# units (paper Section 2 no-unit-untreated pre-period convention).
450+
# Sorting labels alphabetically reverses valid chronologies like
451+
# ("pre", "post") where ordering is semantic, not alphabetic.
452+
per_period_nonzero: Dict[Any, int] = {}
453+
for p in periods_list:
454+
p_doses = np.asarray(data.loc[data[time_col] == p, dose_col], dtype=np.float64)
455+
per_period_nonzero[p] = int((p_doses != 0).sum())
456+
all_zero_periods = [p for p, nz in per_period_nonzero.items() if nz == 0]
457+
if len(all_zero_periods) == 0:
458+
# Neither period has all-zero dose: HAD pre-period contract violated.
459+
stats_str = ", ".join(f"{p!r}: {nz} nonzero" for p, nz in per_period_nonzero.items())
455460
raise ValueError(
456461
f"HAD requires D_{{g,1}} = 0 for all units (pre-period "
457-
f"untreated). {n_bad} unit(s) have nonzero dose at "
458-
f"t_pre={t_pre}. Drop these units or verify the dose column."
462+
f"untreated). Neither period in column {time_col!r} has "
463+
f"all-zero dose ({stats_str}). Exactly one period must be "
464+
f"the pre-treatment period with D_{{g,1}} = 0 for every unit; "
465+
f"drop rows with nonzero pre-period dose or verify the dose "
466+
f"column."
467+
)
468+
if len(all_zero_periods) == 2:
469+
raise ValueError(
470+
f"HAD requires variation in D_{{g,2}} for estimation. Both "
471+
f"periods in column {time_col!r} have all-zero dose, so "
472+
f"there is no treatment assignment to estimate."
459473
)
474+
t_pre = all_zero_periods[0]
475+
t_post = [p for p in periods_list if p != t_pre][0]
460476

461477
# Post-period nonnegative-dose check on the ORIGINAL (unshifted) dose
462478
# scale. Front-door rejection per paper Assumption (dose definition

tests/test_had.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -899,6 +899,42 @@ def test_summary_returns_string(self):
899899
assert "WAS" in s
900900
assert "Confidence Interval" in s
901901

902+
def test_summary_uses_target_parameter_for_row_label(self):
903+
"""Review P2: the estimate row must render target_parameter (WAS or
904+
WAS_d_lower), not hardcoded 'WAS'.
905+
"""
906+
# Design 1' -> target_parameter = "WAS"
907+
d, dy = _dgp_continuous_at_zero(400, seed=0)
908+
panel = _make_panel(d, dy)
909+
r_d1p = HeterogeneousAdoptionDiD(design="continuous_at_zero").fit(
910+
panel, "outcome", "dose", "period", "unit"
911+
)
912+
s_d1p = r_d1p.summary()
913+
assert r_d1p.target_parameter == "WAS"
914+
assert "WAS" in s_d1p
915+
916+
# Design 1 continuous-near-d_lower -> target_parameter = "WAS_d_lower"
917+
d, dy = _dgp_continuous_near_d_lower(400, seed=0)
918+
panel = _make_panel(d, dy)
919+
with warnings.catch_warnings():
920+
warnings.simplefilter("ignore", UserWarning)
921+
r_d1 = HeterogeneousAdoptionDiD(design="continuous_near_d_lower").fit(
922+
panel, "outcome", "dose", "period", "unit"
923+
)
924+
assert r_d1.target_parameter == "WAS_d_lower"
925+
assert "WAS_d_lower" in r_d1.summary()
926+
927+
# Design 1 mass-point -> target_parameter = "WAS_d_lower"
928+
d, dy = _dgp_mass_point(400, seed=0)
929+
panel = _make_panel(d, dy)
930+
with warnings.catch_warnings():
931+
warnings.simplefilter("ignore", UserWarning)
932+
r_mp = HeterogeneousAdoptionDiD(design="mass_point").fit(
933+
panel, "outcome", "dose", "period", "unit"
934+
)
935+
assert r_mp.target_parameter == "WAS_d_lower"
936+
assert "WAS_d_lower" in r_mp.summary()
937+
902938
def test_print_summary_executes(self, capsys):
903939
r = self._result()
904940
r.print_summary()
@@ -1430,6 +1466,70 @@ def test_first_treat_col_dtype_agnostic_rejects_invalid_string(self):
14301466
with pytest.raises(ValueError, match="first_treat_col"):
14311467
_validate_had_panel(panel, "outcome", "dose", "period", "unit", "ft")
14321468

1469+
def test_semantic_pre_post_labels_not_lexicographic(self):
1470+
"""Review P1 round 3: pre/post inference must be dose-based.
1471+
1472+
("pre", "post") sorts alphabetically to ["post", "pre"], which
1473+
previously flipped the pre/post labels and raised on a valid
1474+
panel. The validator now infers pre from the all-zero-dose
1475+
period.
1476+
"""
1477+
d, dy = _dgp_continuous_at_zero(100, seed=0)
1478+
panel = _make_panel(d, dy, periods=("pre", "post"))
1479+
t_pre, t_post = _validate_had_panel(panel, "outcome", "dose", "period", "unit", None)
1480+
assert t_pre == "pre"
1481+
assert t_post == "post"
1482+
1483+
def test_semantic_pre_post_with_first_treat_col(self):
1484+
"""Combined: string periods + first_treat_col in {0, 'post'}."""
1485+
d, dy = _dgp_continuous_at_zero(100, seed=0)
1486+
panel = _make_panel(d, dy, periods=("pre", "post"))
1487+
ft_unit = np.array([0 if i % 2 == 0 else "post" for i in range(100)], dtype=object)
1488+
panel["ft"] = np.repeat(ft_unit, 2)
1489+
t_pre, t_post = _validate_had_panel(panel, "outcome", "dose", "period", "unit", "ft")
1490+
assert t_pre == "pre"
1491+
assert t_post == "post"
1492+
1493+
def test_semantic_pre_post_fit_end_to_end(self):
1494+
"""End-to-end: fit() runs on ("pre","post")-labelled panel."""
1495+
d, dy = _dgp_continuous_at_zero(500, seed=0)
1496+
panel = _make_panel(d, dy, periods=("pre", "post"))
1497+
r = HeterogeneousAdoptionDiD(design="continuous_at_zero").fit(
1498+
panel, "outcome", "dose", "period", "unit"
1499+
)
1500+
assert np.isfinite(r.att)
1501+
1502+
def test_before_after_labels(self):
1503+
"""("before","after") is also reversed alphabetically; must not fail."""
1504+
d, dy = _dgp_continuous_at_zero(100, seed=0)
1505+
panel = _make_panel(d, dy, periods=("before", "after"))
1506+
t_pre, t_post = _validate_had_panel(panel, "outcome", "dose", "period", "unit", None)
1507+
assert t_pre == "before"
1508+
assert t_post == "after"
1509+
1510+
def test_no_all_zero_period_raises(self):
1511+
"""If neither period has all-zero dose, HAD's D_{g,1}=0 contract fails."""
1512+
d, dy = _dgp_continuous_at_zero(100, seed=0)
1513+
panel = _make_panel(d, dy)
1514+
# Inject nonzero dose into the pre period so neither period is all-zero.
1515+
panel.loc[panel["period"] == 1, "dose"] = 0.5
1516+
with pytest.raises(ValueError, match=r"D_\{g,1\}|pre-treatment"):
1517+
_validate_had_panel(panel, "outcome", "dose", "period", "unit", None)
1518+
1519+
def test_both_all_zero_periods_raises(self):
1520+
"""If both periods have all-zero dose, no treatment to estimate."""
1521+
G = 100
1522+
panel = pd.DataFrame(
1523+
{
1524+
"unit": np.repeat(np.arange(G), 2),
1525+
"period": np.tile([1, 2], G),
1526+
"dose": np.zeros(2 * G),
1527+
"outcome": np.random.default_rng(0).standard_normal(2 * G),
1528+
}
1529+
)
1530+
with pytest.raises(ValueError, match="variation"):
1531+
_validate_had_panel(panel, "outcome", "dose", "period", "unit", None)
1532+
14331533

14341534
# =============================================================================
14351535
# Review P1: continuous_near_d_lower on a true mass-point sample rejects

0 commit comments

Comments
 (0)