Skip to content

Commit 84835de

Browse files
igerberclaude
andcommitted
Address PR #353 CI review round 2 (1 P1 + 1 P3)
P1 - ordered-categorical chronology: raw `t < base_period` / `t > base_period` comparisons in `joint_pretrends_test`, `joint_homogeneity_test`, and `did_had_pretest_workflow(aggregate= "event_study")` silently misorder ordered-categorical time columns whose lexical and chronological order disagree (e.g. categories ["q1", "q2", "q10"] sort lexically as "q1" < "q10" < "q2"). On such panels the raw comparison could (a) silently drop valid pre-period horizons via the raw `<` check, (b) emit a spurious "joint pre-trends skipped" verdict from the workflow's `earlier_pre` filter, or (c) raise on valid post-period inputs. Fix: new private helper `_build_period_rank` returns a {period_label: chronological_rank} map using the ordered- categorical category order when applicable, natural sort on numeric / datetime otherwise. Both wrappers compare period labels via rank (`rank[t1] < rank[t2]`) instead of raw Python `<`/`>`. The workflow's `earlier_pre` replaces the raw-< filter with `list(t_pre_list[:-1])` - `t_pre_list` is already chronologically sorted by the validator (via its `_sort_key`), so excluding the last element yields the earlier pre-periods regardless of dtype. P3 - ordered-categorical regression tests: new `TestOrderedCategoricalChronology` class (4 tests) with a fixture using categories `["q1", "q2", "q10", "post"]`. Covers (a) direct pretrends wrapper picks up both earlier placebos, (b) pretrends wrapper rejects lexically-ordered-but-chrono-invalid input (e.g. pre=["q10"], base="q2"), (c) homogeneity wrapper accepts valid post-period input, (d) workflow event-study dispatch surfaces both earlier placebos in `pretrends_joint.horizon_labels` without the false skip note. 123 tests pass (119 + 4 new); black/ruff/mypy clean. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 8da8e43 commit 84835de

2 files changed

Lines changed: 204 additions & 31 deletions

File tree

diff_diff/had_pretests.py

Lines changed: 80 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1640,6 +1640,30 @@ def _validate_multi_period_panel(
16401640
)
16411641

16421642

1643+
def _build_period_rank(data: pd.DataFrame, time_col: str) -> Dict[Any, int]:
1644+
"""Build a ``{period_label: chronological_rank}`` map.
1645+
1646+
For ordered categorical time columns, uses the declared category
1647+
order so that e.g. ``["q1", "q2", "q10"]`` ranks chronologically
1648+
even though it sorts lexically in the opposite order. For numeric
1649+
or datetime time columns, uses natural Python `sorted` order on
1650+
the unique period labels. Object dtypes would fall back to
1651+
lexicographic order - callers relying on chronology with object-
1652+
dtype labels should convert to an ordered categorical first
1653+
(this mirrors the contract in ``_validate_had_panel_event_study``).
1654+
1655+
The rank map lets the joint-pretest wrappers compare period labels
1656+
chronologically via ``rank[t1] < rank[t2]`` instead of raw Python
1657+
``t1 < t2``, which would silently misorder ordered-categorical
1658+
panels (paper Appendix B.2 support contract).
1659+
"""
1660+
time_dtype = data[time_col].dtype
1661+
if isinstance(time_dtype, pd.CategoricalDtype) and time_dtype.ordered:
1662+
return {c: i for i, c in enumerate(time_dtype.categories)}
1663+
periods = sorted(data[time_col].unique())
1664+
return {p: i for i, p in enumerate(periods)}
1665+
1666+
16431667
def _aggregate_for_joint_test(
16441668
data: pd.DataFrame,
16451669
outcome_col: str,
@@ -2157,25 +2181,32 @@ def joint_pretrends_test(
21572181
f"base_period={base_period!r} must not appear in " f"pre_periods {list(pre_periods)!r}."
21582182
)
21592183

2160-
# Ordering check: all pre_periods strictly < base_period (natural
2161-
# order on the column dtype). We rely on the time column being
2162-
# comparable (numeric, datetime, or ordered categorical); other
2163-
# dtypes would silently misorder. The multi-period validator (when
2164-
# called via the workflow) enforces an ordered dtype; direct callers
2165-
# get a TypeError here on incomparable types.
2166-
try:
2167-
out_of_order = [t for t in pre_periods if not (t < base_period)]
2168-
except TypeError as exc:
2169-
raise TypeError(
2170-
"pre_periods and base_period must be comparable "
2171-
"(numeric, datetime, or ordered categorical values). "
2172-
f"Got pre_periods={list(pre_periods)!r}, "
2173-
f"base_period={base_period!r}."
2174-
) from exc
2184+
# Ordering check: all pre_periods strictly < base_period in
2185+
# chronological order. Uses `_build_period_rank` to handle ordered-
2186+
# categorical time columns correctly (raw Python `<` would fail on
2187+
# categories whose lexical order disagrees with chronology, e.g.
2188+
# ["q1", "q2", "q10"]). Numeric / datetime dtypes get natural order.
2189+
period_rank = _build_period_rank(data, time_col)
2190+
if base_period not in period_rank:
2191+
raise ValueError(
2192+
f"base_period={base_period!r} not found in time_col "
2193+
f"{time_col!r}. Available: "
2194+
f"{sorted(period_rank.keys(), key=lambda t: period_rank[t])!r}."
2195+
)
2196+
missing_pre_in_data = [t for t in pre_periods if t not in period_rank]
2197+
if missing_pre_in_data:
2198+
raise ValueError(
2199+
f"pre_periods entries {missing_pre_in_data!r} not found in "
2200+
f"time_col {time_col!r}. Available: "
2201+
f"{sorted(period_rank.keys(), key=lambda t: period_rank[t])!r}."
2202+
)
2203+
base_rank = period_rank[base_period]
2204+
out_of_order = [t for t in pre_periods if period_rank[t] >= base_rank]
21752205
if out_of_order:
21762206
raise ValueError(
2177-
f"All pre_periods must be strictly < base_period. "
2178-
f"Violators: {out_of_order!r} (base_period={base_period!r})."
2207+
f"All pre_periods must be strictly < base_period in "
2208+
f"chronological order. Violators: {out_of_order!r} "
2209+
f"(base_period={base_period!r})."
21792210
)
21802211

21812212
# Event-study validation contract (paper Appendix B.2):
@@ -2341,21 +2372,31 @@ def joint_homogeneity_test(
23412372
f"post_periods {list(post_periods)!r}."
23422373
)
23432374

2344-
# Ordering: all post_periods >= base_period (and in fact strictly
2345-
# greater under the HAD contract where base is the last pre-period).
2346-
try:
2347-
out_of_order = [t for t in post_periods if not (t > base_period)]
2348-
except TypeError as exc:
2349-
raise TypeError(
2350-
"post_periods and base_period must be comparable "
2351-
"(numeric, datetime, or ordered categorical values). "
2352-
f"Got post_periods={list(post_periods)!r}, "
2353-
f"base_period={base_period!r}."
2354-
) from exc
2375+
# Ordering: all post_periods strictly > base_period in
2376+
# chronological order. Uses `_build_period_rank` for ordered-
2377+
# categorical correctness (raw Python `>` would misorder e.g.
2378+
# "q10" > "q2").
2379+
period_rank = _build_period_rank(data, time_col)
2380+
if base_period not in period_rank:
2381+
raise ValueError(
2382+
f"base_period={base_period!r} not found in time_col "
2383+
f"{time_col!r}. Available: "
2384+
f"{sorted(period_rank.keys(), key=lambda t: period_rank[t])!r}."
2385+
)
2386+
missing_post_in_data = [t for t in post_periods if t not in period_rank]
2387+
if missing_post_in_data:
2388+
raise ValueError(
2389+
f"post_periods entries {missing_post_in_data!r} not found in "
2390+
f"time_col {time_col!r}. Available: "
2391+
f"{sorted(period_rank.keys(), key=lambda t: period_rank[t])!r}."
2392+
)
2393+
base_rank = period_rank[base_period]
2394+
out_of_order = [t for t in post_periods if period_rank[t] <= base_rank]
23552395
if out_of_order:
23562396
raise ValueError(
2357-
f"All post_periods must be strictly > base_period. "
2358-
f"Violators: {out_of_order!r} (base_period={base_period!r})."
2397+
f"All post_periods must be strictly > base_period in "
2398+
f"chronological order. Violators: {out_of_order!r} "
2399+
f"(base_period={base_period!r})."
23592400
)
23602401

23612402
# Event-study validation contract (paper Appendix B.2) - twin of
@@ -2595,7 +2636,15 @@ def did_had_pretest_workflow(
25952636
# strictly before base_period). If only the base pre-period is
25962637
# available (len(t_pre_list) == 1), there are no earlier
25972638
# placebos; set pretrends_joint=None and flag in verdict.
2598-
earlier_pre = [t for t in t_pre_list if t < base_period]
2639+
# ``t_pre_list`` is returned chronologically sorted by
2640+
# ``_validate_had_panel_event_study`` (using the column's
2641+
# ordered-categorical category order or the natural numeric /
2642+
# datetime order), so taking everything but the last element
2643+
# gives the earlier pre-periods regardless of dtype. Raw
2644+
# ``t < base_period`` would misorder ordered-categorical labels
2645+
# whose lexical and chronological order disagree (e.g. "q10" <
2646+
# "q2" lexically but > chronologically).
2647+
earlier_pre = list(t_pre_list[:-1])
25992648
if len(earlier_pre) >= 1:
26002649
pretrends_joint = joint_pretrends_test(
26012650
data_filtered,

tests/test_had_pretests.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2357,6 +2357,130 @@ def test_event_study_all_conclusive_no_reject_admissible(self):
23572357
assert "TWFE admissible under Section 4" in verdict
23582358

23592359

2360+
class TestOrderedCategoricalChronology:
2361+
"""R2 P1 regressions: ordered-categorical time columns whose lexical
2362+
and chronological order disagree (e.g. ``"q10"`` < ``"q2"``
2363+
lexically but > chronologically). Raw ``t < base_period`` comparisons
2364+
misorder these panels; the wrappers and workflow must use validated-
2365+
rank comparisons to apply the test to the intended horizons."""
2366+
2367+
@staticmethod
2368+
def _categorical_panel(
2369+
G: int = 60,
2370+
categories=("q1", "q2", "q10", "post"),
2371+
first_treat="post",
2372+
seed: int = 501,
2373+
) -> pd.DataFrame:
2374+
"""Panel with ordered-categorical time whose lexical order
2375+
(``"q1" < "q10" < "q2" < "post"``) differs from chronological
2376+
order (``"q1" < "q2" < "q10" < "post"``)."""
2377+
cat_type = pd.CategoricalDtype(categories=list(categories), ordered=True)
2378+
rng = np.random.default_rng(seed)
2379+
doses = rng.uniform(0.05, 1.0, size=G)
2380+
rows = []
2381+
for g in range(G):
2382+
for t in categories:
2383+
is_post = t == first_treat
2384+
d = float(doses[g]) if is_post else 0.0
2385+
y = 0.1 * g + (0.4 * d if is_post else 0.0) + rng.normal(0.0, 0.1)
2386+
rows.append({"unit": g, "period": t, "y": y, "d": d})
2387+
df = pd.DataFrame(rows)
2388+
df["period"] = df["period"].astype(cat_type)
2389+
return df
2390+
2391+
def test_joint_pretrends_test_uses_chronological_rank(self):
2392+
"""Direct wrapper call with categories ["q1", "q2", "q10"] where
2393+
the lexical order puts "q10" BEFORE "q2" but chronologically
2394+
"q10" comes AFTER "q2". All three pre-periods must be accepted
2395+
without a false out-of-order error."""
2396+
df = self._categorical_panel()
2397+
result = joint_pretrends_test(
2398+
df,
2399+
"y",
2400+
"d",
2401+
"period",
2402+
"unit",
2403+
pre_periods=["q1", "q2"],
2404+
base_period="q10",
2405+
n_bootstrap=199,
2406+
seed=3,
2407+
)
2408+
assert result.n_horizons == 2
2409+
assert set(result.horizon_labels) == {"q1", "q2"}
2410+
# The detrended-outcome residuals are mean-centered; under null
2411+
# (no pre-trend correlated with D), p should be > 0.05 on this
2412+
# weakly-noisy DGP.
2413+
assert np.isfinite(result.p_value)
2414+
2415+
def test_joint_pretrends_raises_on_lexically_ordered_but_chrono_invalid(self):
2416+
"""With base_period="q2" and pre_periods=["q10"], chronologically
2417+
q10 > q2 so this is out-of-order - the rank-based check must
2418+
raise. Raw `<` on the lexical side would INCORRECTLY accept
2419+
it since "q10" < "q2" lexically."""
2420+
df = self._categorical_panel()
2421+
with pytest.raises(ValueError, match="chronological order"):
2422+
joint_pretrends_test(
2423+
df,
2424+
"y",
2425+
"d",
2426+
"period",
2427+
"unit",
2428+
pre_periods=["q10"],
2429+
base_period="q2",
2430+
n_bootstrap=199,
2431+
seed=0,
2432+
)
2433+
2434+
def test_joint_homogeneity_test_uses_chronological_rank(self):
2435+
"""Homogeneity wrapper twin of the pretrends test. Post-period
2436+
"post" comes after all pre-periods chronologically; base="q10"
2437+
is the last pre-period. Lexically "post" > "q10" too (coincides
2438+
here), but the rank-based check must not rely on that."""
2439+
df = self._categorical_panel()
2440+
result = joint_homogeneity_test(
2441+
df,
2442+
"y",
2443+
"d",
2444+
"period",
2445+
"unit",
2446+
post_periods=["post"],
2447+
base_period="q10",
2448+
n_bootstrap=199,
2449+
seed=7,
2450+
)
2451+
assert result.n_horizons == 1
2452+
assert result.horizon_labels == ["post"]
2453+
assert np.isfinite(result.p_value)
2454+
2455+
def test_workflow_event_study_ordered_categorical(self):
2456+
"""did_had_pretest_workflow(aggregate="event_study") must pick
2457+
up BOTH earlier pre-periods ("q1", "q2") from an ordered-
2458+
categorical panel where lexical order would silently drop one
2459+
of them. Regression against the `earlier_pre` raw-< fix."""
2460+
df = self._categorical_panel()
2461+
report = did_had_pretest_workflow(
2462+
df,
2463+
"y",
2464+
"d",
2465+
"period",
2466+
"unit",
2467+
aggregate="event_study",
2468+
n_bootstrap=199,
2469+
seed=13,
2470+
)
2471+
assert report.aggregate == "event_study"
2472+
assert report.pretrends_joint is not None
2473+
# t_pre_list = ["q1", "q2", "q10"] chronologically; base = "q10"
2474+
# (last pre-period); earlier_pre should be ["q1", "q2"] - both
2475+
# placebo horizons must appear in pretrends_joint.
2476+
assert set(report.pretrends_joint.horizon_labels) == {"q1", "q2"}
2477+
assert report.homogeneity_joint is not None
2478+
assert report.homogeneity_joint.horizon_labels == ["post"]
2479+
# Verdict does not emit the step-2-skipped flag (both earlier
2480+
# placebos were found).
2481+
assert "joint pre-trends skipped" not in report.verdict
2482+
2483+
23602484
class TestHADPretestReportSerialization:
23612485
"""Tests for HADPretestReport serialization branching by aggregate."""
23622486

0 commit comments

Comments
 (0)