Skip to content

Commit cfb3200

Browse files
igerberclaude
andcommitted
Address CI R8 codex review (1 P1 + 1 P3) on PreTrendsPower PR-B
R8 CI codex caught a P1 my local R7 reviewer missed — exactly the `feedback_local_codex_vs_ci_codex_divergence.md` pattern. **P1 — MPD non-numeric labels silently fell back to count-based, undocumented as a deviation in REGISTRY** R3's MPD branch returned `relative_times=None` for non-numeric `reference_period` values (string period IDs, etc.), silently using the legacy count-based normalized direction — but the REGISTRY note described the γ-unit deviation as "resolved" without qualifying that exception. Two-part fix: 1. **Better coercion** for datetime-like labels: new module-level helper `_coerce_relative_times_from_reference` (`pretrends.py:92`) handles three regimes: - Numeric (`int` / `float` / `np.int64`) — direct `float()` - `pandas.Period` / `Timestamp` / `np.datetime64` — subtraction- based offset arithmetic (`.n` for Period, `.days` for Timedelta, fall through to `/ np.timedelta64(1, 'D')`) - Genuinely non-numeric (string period IDs, unranked categoricals) — emits an explicit `UserWarning` documenting that the reported MDV is NOT in Roth's γ units under this fallback, and recommends re-fitting with numeric labels. 2. **Documentation alignment**: REGISTRY `## PreTrendsPower` convention note and METHODOLOGY_REVIEW.md `## PreTrendsPower` Verified Components checklist both enumerate the supported label types (numeric + pandas.Period + Timestamp + datetime64) and explicitly call out the non-numeric warn-and-fallback behavior as a documented edge case (not a "resolved" deviation). **P3 — `docs/api/pretrends.rst` still referenced removed `custom_delta` parameter name** The custom-violation entry in the violation-types section used the parameter name `custom_delta`, but the actual API exposes `violation_weights` (both on `PreTrendsPower` and on the helper functions per PR-B Step 6). Fix: rename in docs and add a one-line note that both the class and the helpers accept the kwarg. **Tests** (`tests/test_methodology_pretrends.py::TestPretrendsLinearGrid`): - `test_mpd_non_numeric_reference_falls_back_to_legacy_weights` renamed to `..._warns_and_falls_back...` and now asserts the explicit `UserWarning` is emitted (mentioning "γ units"). - NEW `test_mpd_pandas_period_reference_yields_numeric_relative_times`: constructs a `MultiPeriodDiDResults` with `pd.Period('2019Q1..Q3')` pre-periods and `reference_period=pd.Period('2019Q4')`, asserts the derived `relative_times == [-3, -2, -1]` (quarters) and linear weights = `[3, 2, 1]` in γ units. Locks the Period-arithmetic path the codex specifically flagged. The P3 R-parity-script placeholder is deferred to PR-C per the existing TODO row (codex labeled it informational / non-blocker). Tests: 403 pass across pretrends + DR + BR. 4 skipped (R-parity stubs + 1 fixture skip). No regressions. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent da2a7bd commit cfb3200

5 files changed

Lines changed: 180 additions & 29 deletions

File tree

METHODOLOGY_REVIEW.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1063,7 +1063,7 @@ and covariate-adjusted specifications.)
10631063
- [x] Non-bootstrap CS adapter consumes full `event_study_vcov` sub-block (not diag)
10641064
- [x] Non-bootstrap SA adapter consumes full `event_study_vcov` sub-block (W-matrix construction `event_study_vcov = W @ vcov_cohort @ W.T` added to `SunAbrahamResults`)
10651065
- [x] Bootstrap CS/SA and replicate-weight survey paths fall through to `diag(ses^2)` (analytical VCV cleared to prevent mixing with bootstrap/replicate SE overrides)
1066-
- [x] `_get_violation_weights('linear')` honors actual pre-period relative-time labels via `fit()` threading → reported MDV is in Roth's γ units on irregular and anticipation-shifted grids
1066+
- [x] `_get_violation_weights('linear')` honors actual pre-period relative-time labels via `fit()` threading → reported MDV is in Roth's γ units on irregular and anticipation-shifted grids. For `MultiPeriodDiDResults`, supported label types are numeric (`int` / `float` / `np.int64`) and `pandas.Period` / `pandas.Timestamp` / `np.datetime64`; **genuinely non-numeric labels** (string period IDs, unranked categoricals) emit an explicit `UserWarning` and fall through to the legacy count-based normalized direction (MDV is NOT in γ units in that case — re-fit with numeric labels)
10671067
- [x] `PreTrendsPowerResults` persists fitted `violation_weights` + `pretest_form` + `nis_box_probability`; `power_at(M)` works for all four violation types on fresh fits
10681068
- [x] Helper API (`compute_pretrends_power`, `compute_mdv`) accepts `violation_weights` and `pretest_form`; closes the PR-A R18 helper/class API gap
10691069
- [x] Summary, `to_dict`, `to_dataframe` dispatch on `pretest_form` (NIS prints box probability; Wald prints noncentrality)

diff_diff/pretrends.py

Lines changed: 102 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
diff_diff.honest_did - Sensitivity analysis for parallel trends violations
2626
"""
2727

28+
import warnings
2829
from dataclasses import dataclass, field
2930
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
3031

@@ -88,6 +89,89 @@ def _compute_nis_acceptance_prob(
8889
return float(np.clip(accept_prob, 0.0, 1.0))
8990

9091

92+
def _coerce_relative_times_from_reference(
93+
estimated_pre_periods: List[Any],
94+
reference_period: Any,
95+
) -> Optional[np.ndarray]:
96+
"""
97+
Convert ``estimated_pre_periods`` to Roth-style relative-time offsets
98+
from a numeric / Period / datetime ``reference_period``.
99+
100+
Returns ``np.ndarray`` of float relative times when conversion succeeds,
101+
or ``None`` when the labels are genuinely non-numeric / unordered
102+
(string period IDs, categoricals, etc.). In the ``None`` case, the
103+
caller's downstream linear-violation weight construction falls back to
104+
the legacy count-based normalized direction — the reported MDV is then
105+
NOT in Roth's γ units. We emit a ``UserWarning`` so the user knows
106+
the γ-unit contract did not hold and can re-fit with numeric labels.
107+
108+
Supported regimes:
109+
110+
- Numeric (``int`` / ``float`` / ``np.int64``): direct ``float()``
111+
coercion gives the correct relative offset.
112+
- ``pandas.Period`` / ``pandas.Timestamp`` / ``np.datetime64``: period
113+
arithmetic returns an offset / ``Timedelta`` that we coerce to a
114+
float via ``.n`` (for Period frequencies) or ``.days`` (for
115+
Timedelta-like). The result is in units of the reference's
116+
frequency for Period, days for Timestamp / datetime64 — the linear
117+
γ-units scale is per-unit-of-frequency.
118+
- Anything else (string period IDs, categoricals with no ordering,
119+
mixed types): returns ``None`` with a warning.
120+
"""
121+
# Path 1: direct float coercion (numeric scalars).
122+
try:
123+
ref_float = float(reference_period)
124+
return np.asarray(
125+
[float(p) - ref_float for p in estimated_pre_periods],
126+
dtype=float,
127+
)
128+
except (TypeError, ValueError):
129+
pass
130+
131+
# Path 2: pandas.Period / pandas.Timestamp / datetime64 — try
132+
# subtraction-based offset arithmetic.
133+
try:
134+
diffs = [p - reference_period for p in estimated_pre_periods]
135+
floats: List[float] = []
136+
for d in diffs:
137+
# pandas.tseries.offsets.* or pandas.Period offset — has `.n`.
138+
n_attr = getattr(d, "n", None)
139+
if n_attr is not None:
140+
floats.append(float(n_attr))
141+
continue
142+
# pandas.Timedelta / numpy.timedelta64 — convert to days.
143+
days_attr = getattr(d, "days", None)
144+
if days_attr is not None:
145+
floats.append(float(days_attr))
146+
continue
147+
# Bare numpy.timedelta64 fallback.
148+
try:
149+
floats.append(float(d / np.timedelta64(1, "D")))
150+
continue
151+
except (TypeError, ValueError):
152+
raise TypeError(
153+
f"cannot coerce difference {d!r} of type {type(d).__name__} "
154+
"to float days/periods"
155+
)
156+
return np.asarray(floats, dtype=float)
157+
except (TypeError, ValueError):
158+
pass
159+
160+
# Path 3: genuinely non-numeric labels — warn and fall back to legacy.
161+
warnings.warn(
162+
f"PreTrendsPower: reference_period {reference_period!r} (type "
163+
f"{type(reference_period).__name__}) is not numeric or datetime-like, "
164+
"so per-period relative times cannot be derived. Linear-violation "
165+
"weights will use the legacy count-based [n_pre-1, ..., 0]/||·||_2 "
166+
"direction; the reported MDV is NOT in Roth (2022) γ units. Re-fit "
167+
"with numeric period labels (int year, pandas.Period, datetime) to "
168+
"obtain γ-unit MDV.",
169+
UserWarning,
170+
stacklevel=3,
171+
)
172+
return None
173+
174+
91175
def _extract_event_study_vcov_subblock(
92176
results: Any,
93177
pre_periods: List[int],
@@ -914,27 +998,27 @@ def _extract_pre_period_params(
914998
# For MultiPeriodDiDResults, period identifiers are generic
915999
# (often calendar years, sometimes pre-shifted relative times).
9161000
# Roth's δ_t = γ·t convention needs RELATIVE offsets from the
917-
# treatment / reference period. Derive them from
918-
# `results.reference_period` when numeric:
919-
# relative_times = estimated_pre_periods - reference_period
920-
# If `reference_period` is None or non-numeric (string, categorical),
921-
# return None so `_get_violation_weights('linear')` falls back to
922-
# the legacy count-based [n_pre-1, ..., 0] / ||·||_2 direction
923-
# (the pre-PR-B shipped behavior; preserves backwards-compat for
924-
# MPD callers that don't expose a numeric reference period).
1001+
# treatment / reference period. Three label-type regimes:
1002+
#
1003+
# 1. Numeric (int / float / np.int64) — direct float() coercion
1004+
# gives the correct relative offset.
1005+
# 2. pandas.Period — period arithmetic works on the Period
1006+
# object directly (``p - ref`` returns ordinal-difference);
1007+
# we cast via the `n` attribute on the resulting offset for
1008+
# sub-period frequencies. Datetime-like labels (Timestamp,
1009+
# np.datetime64) are caught the same way and converted to
1010+
# days via numpy timedelta semantics.
1011+
# 3. Genuinely non-numeric / unordered labels (string period
1012+
# IDs, categoricals without a ranking) — emit an explicit
1013+
# UserWarning and fall back to the legacy count-based
1014+
# [n_pre-1, ..., 0] / ||·||_2 normalized direction. The
1015+
# reported MDV under this fallback is NOT in Roth's γ
1016+
# units; users on non-numeric labels who need γ-unit MDV
1017+
# should re-fit with numeric period labels.
9251018
ref = getattr(results, "reference_period", None)
9261019
relative_times: Optional[np.ndarray] = None
9271020
if ref is not None:
928-
try:
929-
ref_float = float(ref)
930-
relative_times = np.asarray(
931-
[float(p) - ref_float for p in estimated_pre_periods],
932-
dtype=float,
933-
)
934-
except (TypeError, ValueError):
935-
# Non-numeric labels (string period IDs, etc.) — fall
936-
# back to legacy normalized linear direction.
937-
relative_times = None
1021+
relative_times = _coerce_relative_times_from_reference(estimated_pre_periods, ref)
9381022
return effects, ses, vcov, n_pre, relative_times, covariance_source
9391023

9401024
# Try CallawaySantAnnaResults

docs/api/pretrends.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,9 @@ The module supports several types of pre-trends violations:
133133
``delta[-1] = M``, all other pre-periods are zero.
134134

135135
**custom**
136-
User-specified violation pattern via the ``custom_delta`` parameter.
136+
User-specified violation pattern via the ``violation_weights`` parameter.
137+
Accepted by both ``PreTrendsPower`` (constructor kwarg) and the convenience
138+
helpers ``compute_pretrends_power`` / ``compute_mdv`` (forwarded kwarg).
137139

138140
Complete Example
139141
----------------

docs/methodology/REGISTRY.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2817,7 +2817,7 @@ Violation types:
28172817

28182818
- **Note (paper-supported alternative — Wald pretest form):** the library retains the Wald noncentral-χ² form as `pretest_form='wald'`. NIS is the paper's primary analysis convention (used for all 12 surveyed papers' empirical exercises in Section I), but the Wald form is also a paper-supported alternative: Roth's Propositions 1, 3, and 4 apply to any (measurable) acceptance region for the conditional moments (Props 1+3) and to any convex acceptance region for the variance-reduction guarantee (Prop 4). The Wald ellipsoid is convex, so all four propositions apply. Wald is faster (no MVN CDF call) and matches the pre-PR-B shipped numerical baseline. Use Wald for backwards-compat / speed; use NIS for canonical paper alignment and R `pretrends` parity.
28192819

2820-
- **Note (convention — `linear` violation pattern, γ-unit MDV):** `_get_violation_weights('linear')` consumes actual pre-period relative-time labels threaded through `fit()` (PR-B 2026-05-17 resolution of the PR-A linear-pattern deviation). When `relative_times` is provided (e.g., `[-3, -2, -1]` for a regular grid or `[-5, -3, -1]` for an irregular grid), weights = `|t|` directly with NO L2 normalization, so `δ_pre = M · |t|` reflects Roth's `δ_t = γ · t` convention and the reported MDV equals γ. Callers that bypass `fit()` and supply only `n_pre` retain the previous count-based, L2-normalized `[n_pre-1, ..., 0]` direction (preserves shipped Wald numerical baselines for unit tests).
2820+
- **Note (convention — `linear` violation pattern, γ-unit MDV):** `_get_violation_weights('linear')` consumes actual pre-period relative-time labels threaded through `fit()` (PR-B 2026-05-17 resolution of the PR-A linear-pattern deviation). When `relative_times` is provided (e.g., `[-3, -2, -1]` for a regular grid or `[-5, -3, -1]` for an irregular grid), weights = `|t|` directly with NO L2 normalization, so `δ_pre = M · |t|` reflects Roth's `δ_t = γ · t` convention and the reported MDV equals γ. Callers that bypass `fit()` and supply only `n_pre` retain the previous count-based, L2-normalized `[n_pre-1, ..., 0]` direction (preserves shipped Wald numerical baselines for unit tests). **MPD period-label coverage:** for `MultiPeriodDiDResults`, the relative-time derivation in `_extract_pre_period_params` supports numeric labels (`int` / `float` / `np.int64`) and `pandas.Period` / `pandas.Timestamp` / `np.datetime64` (via Period or Timedelta arithmetic with units of frequency / days respectively). For genuinely non-numeric or unordered labels (string period IDs, unranked categoricals), the helper emits an explicit `UserWarning` and falls back to the legacy count-based normalized direction — the reported MDV is then NOT in Roth's γ units. Users on string period IDs who need γ-unit MDV should re-fit with numeric labels.
28212821

28222822
*Standard errors:*
28232823
- Power calculations are exact (no sampling variability — power is computed against a hypothesized population trend, not estimated)

tests/test_methodology_pretrends.py

Lines changed: 73 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -526,15 +526,20 @@ def test_mpd_calendar_period_ids_derive_relative_times_from_reference(self):
526526
weights = pt._get_violation_weights(n_pre, relative_times=relative_times)
527527
np.testing.assert_allclose(weights, [4.0, 3.0, 2.0, 1.0])
528528

529-
def test_mpd_non_numeric_reference_falls_back_to_legacy_weights(self):
530-
"""MPD with non-numeric reference_period falls back to legacy direction.
529+
def test_mpd_non_numeric_reference_warns_and_falls_back_to_legacy_weights(self):
530+
"""MPD with non-numeric reference_period warns + falls back to legacy.
531531
532-
When ``reference_period`` is a string / categorical (e.g., "2019Q4"),
533-
the MPD branch returns ``relative_times=None`` so
532+
When ``reference_period`` is a genuinely non-numeric / non-datetime
533+
label (e.g., the string "REF_STRING"), the MPD branch emits an
534+
explicit ``UserWarning`` and returns ``relative_times=None`` so
534535
``_get_violation_weights('linear')`` uses the legacy count-based
535-
direction. Preserves backwards-compat for MPD callers that don't
536-
expose a numeric reference period.
536+
direction. The warning surfaces the contract that the reported
537+
MDV is NOT in Roth's γ units under this fallback (R8 CI codex
538+
fix: was previously a silent fallback, undocumented as a
539+
deviation in REGISTRY).
537540
"""
541+
import warnings as _warnings
542+
538543
from diff_diff.results import MultiPeriodDiDResults, PeriodEffect
539544

540545
period_ids = ["A", "B", "C"]
@@ -556,12 +561,72 @@ def test_mpd_non_numeric_reference_falls_back_to_legacy_weights(self):
556561
n_control=50,
557562
pre_periods=period_ids,
558563
post_periods=["D", "E"],
559-
reference_period="REF_STRING", # non-numeric
564+
reference_period="REF_STRING", # non-numeric, non-datetime
560565
)
561566

562567
pt = PreTrendsPower(pretest_form="nis", violation_type="linear")
563-
_, _, _, _, relative_times, _ = pt._extract_pre_period_params(mpd_results)
568+
with _warnings.catch_warnings(record=True) as caught:
569+
_warnings.simplefilter("always")
570+
_, _, _, _, relative_times, _ = pt._extract_pre_period_params(mpd_results)
571+
564572
assert relative_times is None, "Non-numeric reference should yield None"
573+
nis_warns = [
574+
w
575+
for w in caught
576+
if "reference_period" in str(w.message) and "γ units" in str(w.message)
577+
]
578+
assert len(nis_warns) >= 1, (
579+
"Non-numeric reference_period must emit an explicit UserWarning "
580+
f"noting the γ-unit contract is not held; got warnings: {[str(w.message) for w in caught]}"
581+
)
582+
583+
def test_mpd_pandas_period_reference_yields_numeric_relative_times(self):
584+
"""MPD with pandas.Period reference_period produces γ-unit weights.
585+
586+
Quarterly-Period labels ``[2019Q1, 2019Q2, 2019Q3]`` with
587+
``reference_period=2019Q4`` produce relative offsets in units of
588+
quarters: ``[-3, -2, -1]``. Validates the R8 CI codex fix that
589+
datetime-like labels are NOT silently fall-through cases — Period
590+
/ Timestamp arithmetic supplies the γ-unit relative times the
591+
legacy fallback would have lost.
592+
"""
593+
from diff_diff.results import MultiPeriodDiDResults, PeriodEffect
594+
595+
periods = [pd.Period(f"2019Q{q}", freq="Q") for q in (1, 2, 3)]
596+
reference_period = pd.Period("2019Q4", freq="Q")
597+
period_effects = {
598+
p: PeriodEffect(
599+
period=p, effect=0.1, se=0.2, t_stat=0.0, p_value=0.5, conf_int=(0.0, 0.0)
600+
)
601+
for p in periods
602+
}
603+
mpd_results = MultiPeriodDiDResults(
604+
period_effects=period_effects,
605+
avg_att=0.0,
606+
avg_se=0.2,
607+
avg_t_stat=0.0,
608+
avg_p_value=0.5,
609+
avg_conf_int=(0.0, 0.0),
610+
n_obs=100,
611+
n_treated=50,
612+
n_control=50,
613+
pre_periods=periods,
614+
post_periods=[pd.Period(f"2020Q{q}", freq="Q") for q in (1, 2)],
615+
reference_period=reference_period,
616+
)
617+
618+
pt = PreTrendsPower(pretest_form="nis", violation_type="linear")
619+
_, _, _, n_pre, relative_times, _ = pt._extract_pre_period_params(mpd_results)
620+
621+
# Period subtraction yields a Period offset whose `.n` is the
622+
# number-of-frequencies difference; signs matter and pre-periods
623+
# are NEGATIVE offsets from the reference.
624+
assert relative_times is not None
625+
np.testing.assert_allclose(relative_times, [-3.0, -2.0, -1.0])
626+
627+
# Plumbed through to linear weights: |t| = [3, 2, 1] in γ units.
628+
weights = pt._get_violation_weights(n_pre, relative_times=relative_times)
629+
np.testing.assert_allclose(weights, [3.0, 2.0, 1.0])
565630

566631
def test_backwards_compat_no_relative_times_uses_legacy_normalized(self):
567632
"""Without relative_times: legacy [n-1, ..., 0]/||·||_2 direction.

0 commit comments

Comments
 (0)