Skip to content

Commit cb11a1f

Browse files
igerberclaude
andcommitted
Address PR #350 CI review round 3: P0 chronological cohort sort + P3 docstrings
**P0 (cohort sort key):** `_validate_had_panel_event_study` sorted first_treat_col values with raw Python `(x is None, x)` while `time_col` was already required to be ordered (numeric/datetime/ordered categorical). On ordered-categorical staggered panels where chronological order differs from lexicographic order, `F_last = cohorts[-1]` silently picked the lexicographically latest cohort, not the chronologically latest. That keeps the wrong cohort and returns event-study estimates for the wrong estimand. Fix: Promoted the dtype-aware `_sort_key` (ordered-categorical uses declared category index; numeric/datetime use natural order) to the top of the validator, just after the time-dtype check. Cohort sorting, pre/post period sorting, contiguity check, and the staggered-without-first_treat detection all now share this single `_sort_key`. Removed the duplicate `_sort_key` definition that was sitting further down in the same function. **P3 (stale docstrings):** - `fit()` no longer opens with "two-period panel"; now describes both aggregation modes with links to the respective result classes. - `HeterogeneousAdoptionDiDEventStudyResults.n_units` docstring no longer says "only last-cohort units"; now accurately reports last-cohort PLUS never-treated retained. **Test added:** `test_staggered_ordered_categorical_chooses_chronological_last` uses categories `["q1", "q2", "q3", "q10"]` where lex max of the two cohorts (`"q2", "q10"`) is `"q2"` but chronological last is `"q10"`; asserts the fix picks `"q10"` as `F_last` and retains only the q10-cohort units. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 6e321ad commit cb11a1f

2 files changed

Lines changed: 108 additions & 25 deletions

File tree

diff_diff/had.py

Lines changed: 44 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,9 @@ class HeterogeneousAdoptionDiDEventStudyResults:
446446
fitted data.
447447
n_units : int
448448
Number of unique units contributing to the fit. After staggered
449-
auto-filter: only last-cohort units.
449+
auto-filter: last-cohort units PLUS never-treated (``first_treat = 0``)
450+
units retained as the untreated-group comparison per paper
451+
Appendix B.2. Only earlier-treated cohorts are dropped.
450452
inference_method : str
451453
``"analytical_nonparametric"`` (continuous designs) or
452454
``"analytical_2sls"`` (mass-point). Shared across horizons.
@@ -933,6 +935,25 @@ def _validate_had_panel_event_study(
933935
f"before calling fit() with aggregate='event_study'."
934936
)
935937

938+
# Construct the chronological sort key once, shared across every
939+
# downstream ordering: cohort ranking, pre/post period sorting, and
940+
# contiguity checks. Ordered categoricals use their declared
941+
# category index (``list(categorical)`` strips the ordering and
942+
# falls back to string comparison); numeric / datetime use natural
943+
# Python order. Reused by ``_aggregate_multi_period_first_differences``
944+
# via a parallel construction in that helper (both read the same
945+
# ``time_dtype``).
946+
if isinstance(time_dtype, pd.CategoricalDtype) and time_dtype.ordered:
947+
_cat_order = {c: i for i, c in enumerate(time_dtype.categories)}
948+
949+
def _sort_key(x: Any) -> Tuple[bool, Any]:
950+
return (x is None, _cat_order.get(x, len(_cat_order)))
951+
952+
else:
953+
954+
def _sort_key(x: Any) -> Tuple[bool, Any]:
955+
return (x is None, x)
956+
936957
# NaN checks on key columns (before any filter).
937958
for col in [outcome_col, dose_col, unit_col]:
938959
if bool(data[col].isna().any()):
@@ -1005,12 +1026,17 @@ def _validate_had_panel_event_study(
10051026
f"to equal each unit's first positive-dose period (or 0 "
10061027
f"for never-treated) before calling fit()."
10071028
)
1008-
# Identify cohorts (nonzero first_treat values).
1009-
# Use pd.unique to preserve dtype; sort with a stable key.
1029+
# Identify cohorts (nonzero first_treat values). Sort using
1030+
# ``_sort_key`` (chronological order from ``time_dtype``), NOT
1031+
# raw Python sort: first_treat values are period labels and
1032+
# must rank chronologically so ``F_last = cohorts[-1]`` is the
1033+
# chronologically latest cohort. Under ordered-categorical time
1034+
# labels (e.g. month names), raw Python sort is lexicographic
1035+
# and would silently pick the wrong ``F_last``.
10101036
ft_unique = list(pd.unique(ft_raw))
10111037
cohorts = sorted(
10121038
[v for v in ft_unique if v != 0 and not (isinstance(v, float) and np.isnan(v))],
1013-
key=lambda x: (x is None, x),
1039+
key=_sort_key,
10141040
)
10151041
if len(cohorts) == 0:
10161042
raise ValueError(
@@ -1127,22 +1153,9 @@ def _validate_had_panel_event_study(
11271153
f"zero dose; there is no treatment to estimate."
11281154
)
11291155

1130-
# Sort by natural ordering on the time column dtype. For ordered
1131-
# categorical dtypes, use the declared category order (since
1132-
# ``list(categorical)`` strips the ordered semantics and falls back
1133-
# to string comparison). For numeric / datetime, use natural Python
1134-
# order. Tuple key places None at the end.
1135-
if isinstance(time_dtype, pd.CategoricalDtype) and time_dtype.ordered:
1136-
_cat_order = {c: i for i, c in enumerate(time_dtype.categories)}
1137-
1138-
def _sort_key(x: Any) -> Tuple[bool, Any]:
1139-
return (x is None, _cat_order.get(x, len(_cat_order)))
1140-
1141-
else:
1142-
1143-
def _sort_key(x: Any) -> Tuple[bool, Any]:
1144-
return (x is None, x)
1145-
1156+
# Sort using the same ``_sort_key`` already constructed for cohorts
1157+
# (ordered-categorical uses declared category order; numeric /
1158+
# datetime use natural Python order).
11461159
t_pre_list = sorted(t_pre_list_unsorted, key=_sort_key)
11471160
t_post_list = sorted(t_post_list_unsorted, key=_sort_key)
11481161

@@ -1203,10 +1216,8 @@ def _sort_key(x: Any) -> Tuple[bool, Any]:
12031216
first_pos_per_unit = df_sorted.loc[pos_mask_global].groupby(unit_col)[time_col].first()
12041217
cohort_labels = list(first_pos_per_unit.unique())
12051218
if len(cohort_labels) > 1:
1206-
try:
1207-
distinct_cohorts = sorted(cohort_labels, key=lambda x: (x is None, x))
1208-
except TypeError:
1209-
distinct_cohorts = list(cohort_labels)
1219+
# Sort chronologically via the validated time-column order.
1220+
distinct_cohorts = sorted(cohort_labels, key=_sort_key)
12101221
raise ValueError(
12111222
f"Staggered-timing panel detected (first_treat_col is "
12121223
f"None): {len(distinct_cohorts)} distinct first-positive-"
@@ -1940,7 +1951,15 @@ def fit(
19401951
survey: Any = None,
19411952
weights: Optional[np.ndarray] = None,
19421953
) -> HeterogeneousAdoptionDiDResults:
1943-
"""Fit the HAD estimator on a two-period panel.
1954+
"""Fit the HAD estimator.
1955+
1956+
``aggregate="overall"`` (default) fits on a two-period panel and
1957+
returns a :class:`HeterogeneousAdoptionDiDResults` with the
1958+
single-period WAS estimate. ``aggregate="event_study"`` fits on
1959+
a multi-period panel (``T > 2``) and returns a
1960+
:class:`HeterogeneousAdoptionDiDEventStudyResults` with per-
1961+
event-time WAS estimates using a uniform ``F-1`` anchor (paper
1962+
Appendix B.2).
19441963
19451964
Both the overall and event-study paths are **panel-only**: the paper
19461965
(Section 2) defines HAD on panel or repeated-cross-section data,

tests/test_had.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2727,6 +2727,70 @@ def test_time_varying_post_F_dose_rejected(self):
27272727
panel, "outcome", "dose", "period", "unit", aggregate="event_study"
27282728
)
27292729

2730+
def test_staggered_ordered_categorical_chooses_chronological_last(self):
2731+
"""Staggered filter uses chronological (not lexicographic) last.
2732+
2733+
Constructs an ordered-categorical time column where lexicographic
2734+
and chronological orderings disagree. With category order
2735+
``["q1", "q2", "q3", "q10"]``, chronological last is ``"q10"``
2736+
but lexicographic last is ``"q3"``. If cohorts are ``{"q2", "q10"}``,
2737+
a raw-sort implementation would pick ``F_last = "q2"`` (lex-max
2738+
of the two strings); the fixed version must pick ``F_last = "q10"``.
2739+
2740+
Covers CI reviewer round 3 P0: cohort sorting must use
2741+
chronological order from ``time_dtype``, not raw Python sort.
2742+
"""
2743+
rng = np.random.default_rng(0)
2744+
G = 80
2745+
periods = ["q1", "q2", "q3", "q10"]
2746+
cat_dtype = pd.CategoricalDtype(categories=periods, ordered=True)
2747+
# Half of units treated at q2 (cohort 1), half at q10 (cohort 2).
2748+
rows = []
2749+
for g in range(G):
2750+
F_g = "q2" if g < G // 2 else "q10"
2751+
d_g = float(rng.uniform(0.1, 1.0))
2752+
for p in periods:
2753+
# Dose = d_g once the period >= F_g in chronological order.
2754+
chrono_g = periods.index(F_g)
2755+
chrono_p = periods.index(p)
2756+
dose = d_g if chrono_p >= chrono_g else 0.0
2757+
rows.append(
2758+
{
2759+
"unit": g,
2760+
"period": p,
2761+
"dose": dose,
2762+
"outcome": rng.standard_normal(),
2763+
"first_treat": F_g,
2764+
}
2765+
)
2766+
panel = pd.DataFrame(rows)
2767+
panel["period"] = panel["period"].astype(cat_dtype)
2768+
panel["first_treat"] = panel["first_treat"].astype(cat_dtype)
2769+
2770+
with warnings.catch_warnings():
2771+
warnings.simplefilter("ignore", UserWarning)
2772+
result = HeterogeneousAdoptionDiD(design="auto").fit(
2773+
panel,
2774+
"outcome",
2775+
"dose",
2776+
"period",
2777+
"unit",
2778+
first_treat_col="first_treat",
2779+
aggregate="event_study",
2780+
)
2781+
2782+
# Chronological last cohort = "q10", not lexicographic last ("q3"
2783+
# is not even a cohort here; lex last of the two cohorts would
2784+
# be "q2" since "q10" < "q2" lexicographically).
2785+
assert result.filter_info is not None
2786+
assert result.filter_info["F_last"] == "q10"
2787+
assert result.F == "q10"
2788+
# q2-cohort units (G/2) are dropped; q10-cohort units (G/2)
2789+
# retained.
2790+
assert result.n_units == G // 2
2791+
# Dropped cohorts should list "q2".
2792+
assert "q2" in result.filter_info["dropped_cohorts"]
2793+
27302794
def test_first_treat_col_mismatch_with_dose_raises(self):
27312795
"""first_treat_col disagreeing with observed dose path must raise.
27322796

0 commit comments

Comments
 (0)