Skip to content

Commit 98ba6db

Browse files
igerberclaude
andcommitted
Address code review round 4: fix reference period detection with anticipation
Fix _extract_plot_data() to detect reference period from n_groups=0 marker instead of hardcoding -1. This correctly handles anticipation > 0 where the reference period is at e = -1 - anticipation (e.g., e=-2 when anticipation=1). Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent d3da585 commit 98ba6db

2 files changed

Lines changed: 42 additions & 1 deletion

File tree

diff_diff/visualization.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,15 @@ def _extract_plot_data(
364364

365365
# Reference period is typically -1 for event study
366366
if reference_period is None:
367-
reference_period = -1
367+
# Detect reference period from n_groups=0 marker (normalization constraint)
368+
# This handles anticipation > 0 where reference is at e = -1 - anticipation
369+
for period, effect_data in results.event_study_effects.items():
370+
if effect_data.get('n_groups', 1) == 0:
371+
reference_period = period
372+
break
373+
# Fallback to -1 if no marker found (backward compatibility)
374+
if reference_period is None:
375+
reference_period = -1
368376

369377
if pre_periods is None:
370378
pre_periods = [p for p in periods if p < 0]

tests/test_visualization.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,39 @@ def test_plot_cs_universal_base_period(self):
285285

286286
plt.close()
287287

288+
def test_plot_cs_with_anticipation(self):
289+
"""Test plotting CallawaySantAnna results with anticipation > 0.
290+
291+
When anticipation=1, the reference period should be e=-2, not e=-1.
292+
"""
293+
pytest.importorskip("matplotlib")
294+
import matplotlib.pyplot as plt
295+
from diff_diff import generate_staggered_data
296+
297+
data = generate_staggered_data(n_units=200, n_periods=10, seed=42)
298+
cs = CallawaySantAnna(base_period="universal", anticipation=1)
299+
results = cs.fit(
300+
data,
301+
outcome='outcome',
302+
unit='unit',
303+
time='period',
304+
first_treat='first_treat',
305+
aggregate='event_study'
306+
)
307+
308+
# Reference period should be at e=-2 (not e=-1) with anticipation=1
309+
assert -2 in results.event_study_effects
310+
assert results.event_study_effects[-2]['n_groups'] == 0
311+
312+
ax = plot_event_study(results, show=False)
313+
assert ax is not None
314+
315+
# Verify -2 is in the plot (the true reference period)
316+
xtick_labels = [t.get_text() for t in ax.get_xticklabels()]
317+
assert '-2' in xtick_labels
318+
319+
plt.close()
320+
288321

289322
class TestPlotEventStudyIntegration:
290323
"""Integration tests for event study plotting."""

0 commit comments

Comments
 (0)