Skip to content

Commit 6e321ad

Browse files
igerberclaude
andcommitted
Address PR #350 CI review round 2: P1 first_treat cross-validation + P1 ordered time
**P1 (first_treat_col vs dose mismatch):** The last-cohort filter trusted `first_treat_col` without validating it against the observed dose path. A swapped or mistyped cohort label could silently retain the wrong cohort as F_last. Fix: `_validate_had_panel_event_study` now cross-validates each unit's declared first_treat against their actual first-positive-dose period: - declared == 0: unit must have D=0 at every period - declared == F_g > 0: unit's first period with D>0 must equal F_g Any mismatch raises `ValueError` with an example unit, declared value, and actual first-positive period. **P1 (unordered time labels):** Event-study chronology was inferred via raw `sorted()` on period labels. For object/string dtypes that falls back to lexicographic sort, which silently misorders panels like "pre1"/"pre2"/"post1"/"post2" or month-name labels. Fix: Event-study path now requires a numeric, datetime, or ordered- categorical time column. Object/string dtypes raise a front-door `ValueError` directing users to convert. Ordered categoricals are sorted by their declared category order (not the underlying string), via a dtype-aware `_sort_key` reused by both the validator and the multi-period aggregator. **P3 (docstring):** Class docstring no longer says the event-study extension is "queued for Phase 2b"; now documents both aggregation modes with pointers to the respective result classes. **Tests added:** - `test_first_treat_col_mismatch_with_dose_raises` pins the cross- validation contract. - `test_unordered_string_time_col_rejected` pins front-door rejection of object dtypes. - `test_ordered_categorical_time_col_accepted` confirms ordered categoricals sort by category order and fit successfully. Minor: added `observed=False` to the categorical-groupby in the balance check to silence the pandas FutureWarning while preserving behavior. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 1bfec37 commit 6e321ad

2 files changed

Lines changed: 229 additions & 28 deletions

File tree

diff_diff/had.py

Lines changed: 120 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -903,6 +903,36 @@ def _validate_had_panel_event_study(
903903
f"single-period WAS)."
904904
)
905905

906+
# Ordered-time-type check. Paper Appendix B.2 event-time horizons
907+
# require chronological ordering of periods (anchor at F-1, horizons
908+
# e = t - F relative to F). Phase 2a two-period panels can use the
909+
# dose invariant alone to distinguish pre from post without needing
910+
# chronological order, so string labels ("pre", "post") work there.
911+
# For multi-period event-study, multiple pre-periods all have D=0
912+
# and multiple post-periods may both have D>0, so dose alone cannot
913+
# recover chronology: we must trust the time column's natural order.
914+
# Raw lexicographic sort on object/string labels silently misorders
915+
# panels like "pre1"/"pre2"/"post1"/"post2" or month-name labels.
916+
# Require an explicitly-ordered time representation.
917+
time_dtype = data[time_col].dtype
918+
if not (
919+
pd.api.types.is_numeric_dtype(time_dtype)
920+
or pd.api.types.is_datetime64_any_dtype(time_dtype)
921+
or (isinstance(time_dtype, pd.CategoricalDtype) and bool(time_dtype.ordered))
922+
):
923+
raise ValueError(
924+
f"HAD aggregate='event_study' requires an ordered time "
925+
f"column. time_col={time_col!r} has dtype={time_dtype!r}, "
926+
f"which has no defined chronological order; raw sort would "
927+
f"fall back to lexicographic ordering and silently misindex "
928+
f"event-time horizons (e.g., 'pre1'/'pre2'/'post1'/'post2' "
929+
f"sorts lexicographically but not chronologically). "
930+
f"Convert time_col to numeric (e.g., integer year), "
931+
f"datetime, or ordered categorical "
932+
f"(``pd.Categorical(..., ordered=True, categories=[...])``) "
933+
f"before calling fit() with aggregate='event_study'."
934+
)
935+
906936
# NaN checks on key columns (before any filter).
907937
for col in [outcome_col, dose_col, unit_col]:
908938
if bool(data[col].isna().any()):
@@ -936,6 +966,45 @@ def _validate_had_panel_event_study(
936966
f"within unit for {n_bad} unit(s). Each unit must have "
937967
f"a single first_treat value across all observed periods."
938968
)
969+
# Cross-validate first_treat_col against observed first-positive-
970+
# dose period for every unit. A mislabeled cohort column would
971+
# otherwise silently select the wrong cohort as F_last and return
972+
# event-study estimates for the wrong units. Contract:
973+
# - declared first_treat == 0: unit must have D == 0 at all t
974+
# (never-treated)
975+
# - declared first_treat == F_g > 0: unit's first period with
976+
# D > 0 must equal F_g
977+
df_for_check = data.sort_values([unit_col, time_col])
978+
pos_rows = df_for_check.loc[df_for_check[dose_col] > 0]
979+
actual_first_pos = pos_rows.groupby(unit_col)[time_col].first()
980+
declared_ft = df_for_check.groupby(unit_col)[first_treat_col].first()
981+
n_mismatch = 0
982+
example_mismatch: Optional[Tuple[Any, Any, Any]] = None
983+
for u, declared in declared_ft.items():
984+
actual = actual_first_pos.get(u, None)
985+
if declared == 0:
986+
if actual is not None:
987+
n_mismatch += 1
988+
if example_mismatch is None:
989+
example_mismatch = (u, declared, actual)
990+
else:
991+
if actual is None or actual != declared:
992+
n_mismatch += 1
993+
if example_mismatch is None:
994+
example_mismatch = (u, declared, actual)
995+
if n_mismatch > 0:
996+
u, declared, actual = example_mismatch # type: ignore[misc]
997+
raise ValueError(
998+
f"first_treat_col={first_treat_col!r} disagrees with the "
999+
f"observed dose path for {n_mismatch} unit(s). Example: "
1000+
f"unit={u!r} declares first_treat={declared!r} but the "
1001+
f"unit's first period with D>0 is {actual!r} "
1002+
f"(None means never-treated). A mislabeled cohort column "
1003+
f"would silently select the wrong cohort as F_last in the "
1004+
f"last-cohort auto-filter. Fix the first_treat_col values "
1005+
f"to equal each unit's first positive-dose period (or 0 "
1006+
f"for never-treated) before calling fit()."
1007+
)
9391008
# Identify cohorts (nonzero first_treat values).
9401009
# Use pd.unique to preserve dtype; sort with a stable key.
9411010
ft_unique = list(pd.unique(ft_raw))
@@ -1015,8 +1084,9 @@ def _validate_had_panel_event_study(
10151084
)
10161085

10171086
# Balanced panel on the (possibly-filtered) data: every unit appears
1018-
# exactly once per period.
1019-
counts = data_filtered.groupby([unit_col, time_col]).size()
1087+
# exactly once per period. ``observed=False`` preserves current
1088+
# behavior on categorical time columns (pandas' default is changing).
1089+
counts = data_filtered.groupby([unit_col, time_col], observed=False).size()
10201090
if (counts != 1).any():
10211091
n_bad = int((counts != 1).sum())
10221092
raise ValueError(
@@ -1057,36 +1127,35 @@ def _validate_had_panel_event_study(
10571127
f"zero dose; there is no treatment to estimate."
10581128
)
10591129

1060-
# Sort by natural ordering on the time column dtype. Tuple key
1061-
# ``(x is None, x)`` places None at the end and sorts the rest by
1062-
# natural order (works for int/float/str/datetime when the dtype is
1063-
# homogeneous; mixed dtypes would raise at comparison time, which is
1064-
# the desired failure mode).
1065-
t_pre_list = sorted(t_pre_list_unsorted, key=lambda x: (x is None, x))
1066-
t_post_list = sorted(t_post_list_unsorted, key=lambda x: (x is None, x))
1130+
# Sort by natural ordering on the time column dtype. For ordered
1131+
# categorical dtypes, use the declared category order (since
1132+
# ``list(categorical)`` strips the ordered semantics and falls back
1133+
# to string comparison). For numeric / datetime, use natural Python
1134+
# order. Tuple key places None at the end.
1135+
if isinstance(time_dtype, pd.CategoricalDtype) and time_dtype.ordered:
1136+
_cat_order = {c: i for i, c in enumerate(time_dtype.categories)}
1137+
1138+
def _sort_key(x: Any) -> Tuple[bool, Any]:
1139+
return (x is None, _cat_order.get(x, len(_cat_order)))
1140+
1141+
else:
1142+
1143+
def _sort_key(x: Any) -> Tuple[bool, Any]:
1144+
return (x is None, x)
1145+
1146+
t_pre_list = sorted(t_pre_list_unsorted, key=_sort_key)
1147+
t_post_list = sorted(t_post_list_unsorted, key=_sort_key)
10671148

10681149
# Contiguity check: all pre < all post in the natural ordering.
10691150
# The HAD dose invariant requires a single transition from all-zero
10701151
# to any-nonzero; interleaved pre/post periods indicate a malformed
10711152
# panel (e.g., dose going back to zero after treatment, or mixing
1072-
# never-treated units with out-of-order labels).
1153+
# never-treated units with out-of-order labels). Uses ``_sort_key``
1154+
# so ordered categoricals respect their declared category order.
10731155
if t_pre_list and t_post_list:
10741156
max_pre = t_pre_list[-1]
10751157
min_post = t_post_list[0]
1076-
# Check all pre-periods are less than all post-periods via the
1077-
# natural order. If types are comparable, direct comparison works;
1078-
# otherwise fall back to the sorted-key view.
1079-
try:
1080-
contiguous = max_pre < min_post
1081-
except TypeError:
1082-
# Mixed incomparable dtypes (e.g., None vs int after removing
1083-
# None above). Fall back to sorted-position check.
1084-
contiguous = True
1085-
for pre_p in t_pre_list:
1086-
for post_p in t_post_list:
1087-
if not (pre_p < post_p):
1088-
contiguous = False
1089-
break
1158+
contiguous = _sort_key(max_pre) < _sort_key(min_post)
10901159
if not contiguous:
10911160
raise ValueError(
10921161
f"HAD dose invariant violated: pre-periods (all D=0) "
@@ -1318,7 +1387,23 @@ def _aggregate_multi_period_first_differences(
13181387
equal to the LAST pre-period).
13191388
"""
13201389
df = data.sort_values([unit_col, time_col]).reset_index(drop=True)
1321-
all_periods = sorted(t_pre_list + t_post_list, key=lambda x: (x is None, x))
1390+
# Period sort respects ordered categorical dtypes (matches
1391+
# ``_validate_had_panel_event_study``). The validator already
1392+
# enforces a numeric / datetime / ordered-categorical dtype on the
1393+
# event-study path, so ``_sort_key`` lookups are well-defined here.
1394+
time_dtype = data[time_col].dtype
1395+
if isinstance(time_dtype, pd.CategoricalDtype) and time_dtype.ordered:
1396+
_cat_order = {c: i for i, c in enumerate(time_dtype.categories)}
1397+
1398+
def _sort_key(x: Any) -> Tuple[bool, Any]:
1399+
return (x is None, _cat_order.get(x, len(_cat_order)))
1400+
1401+
else:
1402+
1403+
def _sort_key(x: Any) -> Tuple[bool, Any]:
1404+
return (x is None, x)
1405+
1406+
all_periods = sorted(t_pre_list + t_post_list, key=_sort_key)
13221407
# Event-time mapping: natural rank of each period relative to F.
13231408
F_idx = all_periods.index(F)
13241409
period_to_event_time: Dict[Any, int] = {p: (i - F_idx) for i, p in enumerate(all_periods)}
@@ -1604,9 +1689,16 @@ class HeterogeneousAdoptionDiD:
16041689
Weighted-Average-Slope (WAS) estimator with three design-dispatch
16051690
paths: Design 1' (continuous-at-zero), Design 1 continuous-near-
16061691
d_lower, and Design 1 mass-point (2SLS sample-average per paper
1607-
Section 3.2.4). Phase 2a ships the single-period path only; the
1608-
multi-period event-study extension (Appendix B.2) is queued for
1609-
Phase 2b.
1692+
Section 3.2.4). Two aggregation modes:
1693+
1694+
- ``aggregate="overall"`` (Phase 2a, default) returns a single-period
1695+
:class:`HeterogeneousAdoptionDiDResults` on a two-period panel.
1696+
- ``aggregate="event_study"`` (Phase 2b, paper Appendix B.2) returns
1697+
a :class:`HeterogeneousAdoptionDiDEventStudyResults` with per-
1698+
event-time WAS estimates on a multi-period panel, using a uniform
1699+
``F-1`` anchor and pointwise CIs per horizon. Staggered-timing
1700+
panels auto-filter to the last-treatment cohort plus never-treated
1701+
units (paper Appendix B.2 prescription).
16101702
16111703
Parameters
16121704
----------

tests/test_had.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2727,6 +2727,115 @@ def test_time_varying_post_F_dose_rejected(self):
27272727
panel, "outcome", "dose", "period", "unit", aggregate="event_study"
27282728
)
27292729

2730+
def test_first_treat_col_mismatch_with_dose_raises(self):
2731+
"""first_treat_col disagreeing with observed dose path must raise.
2732+
2733+
A mislabeled cohort column would otherwise silently select the
2734+
wrong cohort as F_last in the last-cohort auto-filter and
2735+
produce event-study estimates for the wrong units. Covers CI
2736+
reviewer round 2 P1.
2737+
"""
2738+
rng = np.random.default_rng(0)
2739+
G = 40
2740+
rows = []
2741+
for g in range(G):
2742+
# Actual first-positive-dose period: t=3 for half, t=5 for half.
2743+
F_actual = 3 if g < G // 2 else 5
2744+
# But deliberately mislabel: swap the first_treat labels so
2745+
# G/2 units declare 5 when actual is 3, and vice versa.
2746+
F_declared = 5 if g < G // 2 else 3
2747+
d_g = float(rng.uniform(0.1, 1.0))
2748+
for t in range(1, 7):
2749+
dose = d_g if t >= F_actual else 0.0
2750+
rows.append(
2751+
{
2752+
"unit": g,
2753+
"period": t,
2754+
"dose": dose,
2755+
"outcome": rng.standard_normal(),
2756+
"first_treat": F_declared,
2757+
}
2758+
)
2759+
panel = pd.DataFrame(rows)
2760+
with pytest.raises(ValueError, match="disagrees with the observed dose"):
2761+
HeterogeneousAdoptionDiD(design="auto").fit(
2762+
panel,
2763+
"outcome",
2764+
"dose",
2765+
"period",
2766+
"unit",
2767+
first_treat_col="first_treat",
2768+
aggregate="event_study",
2769+
)
2770+
2771+
def test_unordered_string_time_col_rejected(self):
2772+
"""Object/string time columns raise on event-study path.
2773+
2774+
Raw sort on arbitrary string labels is lexicographic, not
2775+
chronological (e.g., 'pre1'/'pre2'/'post1'/'post2' would map
2776+
to wrong event-time horizons). Covers CI reviewer round 2 P1.
2777+
"""
2778+
rng = np.random.default_rng(0)
2779+
G = 50
2780+
rows = []
2781+
d_post = rng.uniform(0.0, 1.0, G)
2782+
d_post[0] = 0.0
2783+
for g in range(G):
2784+
for label, dose in [
2785+
("pre1", 0.0),
2786+
("pre2", 0.0),
2787+
("post1", d_post[g]),
2788+
("post2", d_post[g]),
2789+
]:
2790+
rows.append(
2791+
{
2792+
"unit": g,
2793+
"period": label, # object dtype
2794+
"dose": dose,
2795+
"outcome": rng.standard_normal(),
2796+
}
2797+
)
2798+
panel = pd.DataFrame(rows)
2799+
with pytest.raises(ValueError, match="ordered time column|dtype"):
2800+
HeterogeneousAdoptionDiD(design="auto").fit(
2801+
panel, "outcome", "dose", "period", "unit", aggregate="event_study"
2802+
)
2803+
2804+
def test_ordered_categorical_time_col_accepted(self):
2805+
"""Ordered categorical time dtype passes the ordered-time check."""
2806+
rng = np.random.default_rng(0)
2807+
G = 50
2808+
labels = ["pre1", "pre2", "post1", "post2"]
2809+
cat_dtype = pd.CategoricalDtype(categories=labels, ordered=True)
2810+
rows = []
2811+
d_post = rng.uniform(0.1, 1.0, G)
2812+
d_post[0] = 0.0
2813+
for g in range(G):
2814+
for label, dose in [
2815+
("pre1", 0.0),
2816+
("pre2", 0.0),
2817+
("post1", d_post[g]),
2818+
("post2", d_post[g]),
2819+
]:
2820+
rows.append(
2821+
{
2822+
"unit": g,
2823+
"period": label,
2824+
"dose": dose,
2825+
"outcome": rng.standard_normal(),
2826+
}
2827+
)
2828+
panel = pd.DataFrame(rows)
2829+
panel["period"] = panel["period"].astype(cat_dtype)
2830+
# Should fit without raising the ordered-time error.
2831+
with warnings.catch_warnings():
2832+
warnings.simplefilter("ignore", UserWarning)
2833+
result = HeterogeneousAdoptionDiD(design="auto").fit(
2834+
panel, "outcome", "dose", "period", "unit", aggregate="event_study"
2835+
)
2836+
# post1 is F; e=-2 (pre1) and e=0 (post1), e=1 (post2) expected.
2837+
assert result.F == "post1"
2838+
27302839
def test_staggered_without_first_treat_col_rejected(self):
27312840
"""Multi-cohort panel without first_treat_col raises (not silent).
27322841

0 commit comments

Comments
 (0)