Skip to content

Commit 6a66328

Browse files
igerberclaude
andcommitted
Tighten drift test: round-based endpoint pins + exact warning-set check
CI review surfaced two refinements: 1. Endpoint bands like `11.0 <= ci_low <= 11.6` would still pass values rounding to several different one-decimal displays (11.0, 11.1, ..., 11.6) while the notebook prose stays at "11.3", "12.8", "11.4", "13.3", "11.5", "13.6". Replace those with `round(ci_low, 1) == 11.3` etc. - directly pins the displayed rounding so any drift past the tenth fails the test. 2. The warning tests didn't pin the notebook's full warning contract. `event_study_results` suppressed A7 for fixture cleanliness while the docstring claimed "A7 visible". Two changes: - Fix the fixture docstring to acknowledge A7 is muted there for value-checking tests, with the notebook's actual warning-policy contract validated separately - Add `test_event_study_warning_policy_matches_notebook` that mirrors the notebook's exact filter (only matmul-pattern RuntimeWarnings silenced) and asserts the resulting warning set: exactly one UserWarning (A7 leavers-present, the one the markdown explains) and zero RuntimeWarnings. If a future library change emits an unexpected warning on this code path, the test fails. 12 tests pass in ~0.07s (was 11). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent a7f2b0e commit 6a66328

1 file changed

Lines changed: 57 additions & 13 deletions

File tree

tests/test_t19_marketing_pulse_drift.py

Lines changed: 57 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,11 @@ def phase1_results(panel):
7676

7777
@pytest.fixture(scope="module")
7878
def event_study_results(panel):
79-
"""Event-study fit: L_max=2 + multiplier bootstrap. Same warning
80-
treatment as the notebook (Accelerate matmul filter; A7 visible)."""
79+
"""Event-study fit: L_max=2 + multiplier bootstrap. The A7
80+
UserWarning is intentionally muted here so the fixture is quiet
81+
for the value-checking tests below; the notebook's actual
82+
warning-policy contract (A7 visible, only matmul filtered) is
83+
validated separately by `test_event_study_warning_policy_matches_notebook`."""
8184
with warnings.catch_warnings():
8285
warnings.filterwarnings(
8386
"ignore",
@@ -123,13 +126,12 @@ def test_overall_ci_covers_truth(phase1_results):
123126

124127

125128
def test_overall_ci_endpoints_match_quoted(phase1_results):
126-
"""Section 3 narrative quotes '95% CI: 11.3 to 12.8'. Lock the
127-
rounded endpoints so prose drift fails this test."""
129+
"""Section 3 narrative quotes '95% CI: 11.3 to 12.8'. Pin the
130+
one-decimal display exactly so any drift past the displayed
131+
rounding fails this test."""
128132
ci_low, ci_high = phase1_results.overall_conf_int
129-
# CI lower endpoint rounds to 11.3 -> band covers 11.0..11.6
130-
assert 11.0 <= ci_low <= 11.6, ci_low
131-
# CI upper endpoint rounds to 12.8 -> band covers 12.5..13.1
132-
assert 12.5 <= ci_high <= 13.1, ci_high
133+
assert round(ci_low, 1) == 11.3, ci_low
134+
assert round(ci_high, 1) == 12.8, ci_high
133135

134136

135137
def test_joiners_leavers_consistent(phase1_results):
@@ -154,14 +156,14 @@ def test_event_study_horizons_cover_truth(event_study_results):
154156

155157
def test_event_study_ci_endpoints_match_quoted(event_study_results):
156158
"""Section 4 narrative quotes l=1 CI [11.4, 13.3] and l=2 CI
157-
[11.5, 13.6]. Lock the rounded endpoints so prose drift fails."""
159+
[11.5, 13.6]. Pin the one-decimal display exactly."""
158160
es = event_study_results.event_study_effects
159161
# l=1 CI [11.4, 13.3]
160-
assert 11.1 <= es[1]["conf_int"][0] <= 11.7, es[1]["conf_int"]
161-
assert 13.0 <= es[1]["conf_int"][1] <= 13.6, es[1]["conf_int"]
162+
assert round(es[1]["conf_int"][0], 1) == 11.4, es[1]["conf_int"]
163+
assert round(es[1]["conf_int"][1], 1) == 13.3, es[1]["conf_int"]
162164
# l=2 CI [11.5, 13.6]
163-
assert 11.2 <= es[2]["conf_int"][0] <= 11.8, es[2]["conf_int"]
164-
assert 13.3 <= es[2]["conf_int"][1] <= 13.9, es[2]["conf_int"]
165+
assert round(es[2]["conf_int"][0], 1) == 11.5, es[2]["conf_int"]
166+
assert round(es[2]["conf_int"][1], 1) == 13.6, es[2]["conf_int"]
165167

166168

167169
def test_event_study_significance(event_study_results):
@@ -211,6 +213,48 @@ def test_assumption7_warning_fires_as_expected(panel):
211213
assert len(a7_warnings) >= 1, [str(w.message)[:80] for w in ws]
212214

213215

216+
def test_event_study_warning_policy_matches_notebook(panel):
217+
"""Mirror the notebook's exact warning policy on the visible
218+
event-study fit and assert the resulting warning set matches the
219+
documented contract: exactly one UserWarning (the A7 leavers-present
220+
warning that the notebook's markdown explains), and zero
221+
RuntimeWarnings (matmul-pattern ones filtered; everything else
222+
surfaces). If the library starts emitting an unexpected warning on
223+
this code path, this test fails and the notebook prose may need to
224+
be updated."""
225+
with warnings.catch_warnings(record=True) as ws:
226+
warnings.simplefilter("always")
227+
# MIRROR the notebook's narrow filter exactly (no np.errstate, no
228+
# blanket A7 suppression).
229+
warnings.filterwarnings(
230+
"ignore",
231+
message=r".*encountered in matmul",
232+
category=RuntimeWarning,
233+
)
234+
model = DCDH(
235+
twfe_diagnostic=False, placebo=True, n_bootstrap=199, seed=42
236+
)
237+
model.fit(
238+
panel,
239+
outcome="sessions",
240+
group="market_id",
241+
time="week",
242+
treatment="promo_on",
243+
L_max=2,
244+
)
245+
user_warnings = [w for w in ws if w.category is UserWarning]
246+
runtime_warnings = [w for w in ws if w.category is RuntimeWarning]
247+
# Exactly one UserWarning, and it's the documented A7 warning.
248+
assert len(user_warnings) == 1, [str(w.message)[:120] for w in user_warnings]
249+
msg = str(user_warnings[0].message)
250+
assert "Assumption 7" in msg, msg
251+
assert "leavers present" in msg, msg
252+
# All RuntimeWarnings should be the matmul pattern (filtered) - so
253+
# zero remaining. If a new RuntimeWarning fires from somewhere else,
254+
# this fails.
255+
assert len(runtime_warnings) == 0, [str(w.message)[:120] for w in runtime_warnings]
256+
257+
214258
def test_a11_warning_does_not_fire():
215259
"""The notebook claims this seed/DGP is in the A11-clean regime
216260
(no warning fires). If a library change starts triggering A11 on

0 commit comments

Comments
 (0)