Skip to content

Commit 622be06

Browse files
authored
Merge pull request #104 from igerber/feature/cs-event-study-reference-period
Include reference period (e=-1) in CS event study for universal base period
2 parents 6791431 + 98ba6db commit 622be06

10 files changed

Lines changed: 392 additions & 12 deletions

File tree

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,7 @@ target/
7676

7777
# Local scripts (not part of package)
7878
scripts/
79+
80+
# Launch directories (local only)
81+
launch/
82+
launch-video/

diff_diff/honest_did.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,12 @@ def _extract_event_study_params(
584584
)
585585

586586
# Extract event study effects by relative time
587-
event_effects = results.event_study_effects
587+
# Filter out normalization constraints (n_groups=0) and non-finite SEs
588+
event_effects = {
589+
t: data for t, data in results.event_study_effects.items()
590+
if data.get('n_groups', 1) > 0
591+
and np.isfinite(data.get('se', np.nan))
592+
}
588593
rel_times = sorted(event_effects.keys())
589594

590595
# Split into pre and post
@@ -1261,10 +1266,12 @@ def _estimate_max_pre_violation(
12611266
from diff_diff.staggered import CallawaySantAnnaResults
12621267
if isinstance(results, CallawaySantAnnaResults):
12631268
if results.event_study_effects:
1269+
# Filter out normalization constraints (n_groups=0, e.g. reference period)
12641270
pre_effects = [
12651271
abs(results.event_study_effects[t]['effect'])
12661272
for t in results.event_study_effects
12671273
if t < 0
1274+
and results.event_study_effects[t].get('n_groups', 1) > 0
12681275
]
12691276
if pre_effects:
12701277
return max(pre_effects)

diff_diff/pretrends.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -656,9 +656,12 @@ def _extract_pre_period_params(
656656
)
657657

658658
# Get pre-period effects (negative relative times)
659+
# Filter out normalization constraints (n_groups=0) and non-finite SEs
659660
pre_effects = {
660661
t: data for t, data in results.event_study_effects.items()
661662
if t < 0
663+
and data.get('n_groups', 1) > 0
664+
and np.isfinite(data.get('se', np.nan))
662665
}
663666

664667
if not pre_effects:
@@ -680,9 +683,12 @@ def _extract_pre_period_params(
680683
from diff_diff.sun_abraham import SunAbrahamResults
681684
if isinstance(results, SunAbrahamResults):
682685
# Get pre-period effects (negative relative times)
686+
# Filter out normalization constraints (n_groups=0) and non-finite SEs
683687
pre_effects = {
684688
t: data for t, data in results.event_study_effects.items()
685689
if t < 0
690+
and data.get('n_groups', 1) > 0
691+
and np.isfinite(data.get('se', np.nan))
686692
}
687693

688694
if not pre_effects:

diff_diff/staggered_aggregation.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ class CallawaySantAnnaAggregationMixin:
3434
# Type hint for anticipation attribute accessed from main class
3535
anticipation: int
3636

37+
# Type hint for base_period attribute accessed from main class
38+
base_period: str
39+
3740
def _aggregate_simple(
3841
self,
3942
group_time_effects: Dict,
@@ -414,6 +417,22 @@ def _aggregate_event_study(
414417
'n_groups': len(effect_list),
415418
}
416419

420+
# Add reference period for universal base period mode (matches R did package)
421+
# The reference period e = -1 - anticipation has effect = 0 by construction
422+
# Only add if there are actual computed effects (guard against empty data)
423+
if getattr(self, 'base_period', 'varying') == "universal":
424+
ref_period = -1 - self.anticipation
425+
# Only inject reference if we have at least one real effect
426+
if event_study_effects and ref_period not in event_study_effects:
427+
event_study_effects[ref_period] = {
428+
'effect': 0.0,
429+
'se': np.nan, # Undefined - no data, normalization constraint
430+
't_stat': np.nan, # Undefined - normalization constraint
431+
'p_value': np.nan,
432+
'conf_int': (np.nan, np.nan), # NaN propagation for undefined inference
433+
'n_groups': 0, # No groups contribute - fixed by construction
434+
}
435+
417436
return event_study_effects
418437

419438
def _aggregate_by_group(

diff_diff/visualization.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -197,11 +197,17 @@ def plot_event_study(
197197
effect = effects.get(period, np.nan)
198198
std_err = se.get(period, np.nan)
199199

200-
if np.isnan(effect) or np.isnan(std_err):
200+
# Skip entries with NaN effect, but allow NaN SE (will plot without error bars)
201+
if np.isnan(effect):
201202
continue
202203

203-
ci_lower = effect - critical_value * std_err
204-
ci_upper = effect + critical_value * std_err
204+
# Compute CI only if SE is finite
205+
if np.isfinite(std_err):
206+
ci_lower = effect - critical_value * std_err
207+
ci_upper = effect + critical_value * std_err
208+
else:
209+
ci_lower = np.nan
210+
ci_upper = np.nan
205211

206212
plot_data.append({
207213
'period': period,
@@ -244,13 +250,20 @@ def plot_event_study(
244250
ref_x = period_to_x[reference_period]
245251
ax.axvline(x=ref_x, color='gray', linestyle=':', linewidth=1, zorder=1)
246252

247-
# Plot error bars
248-
yerr = [df['effect'] - df['ci_lower'], df['ci_upper'] - df['effect']]
249-
ax.errorbar(
250-
x_vals, df['effect'], yerr=yerr,
251-
fmt='none', color=color, capsize=capsize, linewidth=linewidth,
252-
capthick=linewidth, zorder=2
253-
)
253+
# Plot error bars (only for entries with finite CI)
254+
has_ci = df['ci_lower'].notna() & df['ci_upper'].notna()
255+
if has_ci.any():
256+
df_with_ci = df[has_ci]
257+
x_with_ci = [period_to_x[p] for p in df_with_ci['period']]
258+
yerr = [
259+
df_with_ci['effect'] - df_with_ci['ci_lower'],
260+
df_with_ci['ci_upper'] - df_with_ci['effect']
261+
]
262+
ax.errorbar(
263+
x_with_ci, df_with_ci['effect'], yerr=yerr,
264+
fmt='none', color=color, capsize=capsize, linewidth=linewidth,
265+
capthick=linewidth, zorder=2
266+
)
254267

255268
# Plot point estimates
256269
for i, row in df.iterrows():
@@ -351,7 +364,15 @@ def _extract_plot_data(
351364

352365
# Reference period is typically -1 for event study
353366
if reference_period is None:
354-
reference_period = -1
367+
# Detect reference period from n_groups=0 marker (normalization constraint)
368+
# This handles anticipation > 0 where reference is at e = -1 - anticipation
369+
for period, effect_data in results.event_study_effects.items():
370+
if effect_data.get('n_groups', 1) == 0:
371+
reference_period = period
372+
break
373+
# Fallback to -1 if no marker found (backward compatibility)
374+
if reference_period is None:
375+
reference_period = -1
355376

356377
if pre_periods is None:
357378
pre_periods = [p for p in periods if p < 0]

docs/methodology/REGISTRY.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,9 @@ Aggregations:
229229
- "universal": All comparisons use g-anticipation-1 as base
230230
- Both produce identical post-treatment ATT(g,t); differ only pre-treatment
231231
- Matches R `did::att_gt()` base_period parameter
232+
- **Event study output**: With "universal", includes reference period (e=-1-anticipation)
233+
with effect=0, se=NaN, conf_int=(NaN, NaN). Inference fields are NaN since this is
234+
a normalization constraint, not an estimated effect. Only added when real effects exist.
232235
- Base period interaction with Sun-Abraham comparison:
233236
- CS with `base_period="varying"` produces different pre-treatment estimates than SA
234237
- This is expected: CS uses consecutive comparisons, SA uses fixed reference (e=-1-anticipation)

tests/test_honest_did.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,83 @@ def test_very_large_M(self, mock_multiperiod_results):
681681
assert isinstance(results, HonestDiDResults)
682682
assert results.ci_width > 0
683683

684+
def test_callaway_santanna_universal_base_period(self):
685+
"""Test that reference period (e=-1) is correctly filtered out with universal base period.
686+
687+
The reference period has n_groups=0 and se=NaN, so it should be excluded
688+
from HonestDiD analysis to avoid contaminating the vcov matrix.
689+
"""
690+
from diff_diff import CallawaySantAnna, generate_staggered_data
691+
692+
# Generate data and fit with universal base period
693+
data = generate_staggered_data(n_units=200, n_periods=10, seed=42)
694+
cs = CallawaySantAnna(base_period="universal")
695+
results = cs.fit(
696+
data,
697+
outcome='outcome',
698+
unit='unit',
699+
time='period',
700+
first_treat='first_treat',
701+
aggregate='event_study'
702+
)
703+
704+
# Verify reference period exists with NaN SE
705+
assert -1 in results.event_study_effects
706+
assert np.isnan(results.event_study_effects[-1]['se'])
707+
708+
# HonestDiD should work without errors (reference period filtered out)
709+
honest = HonestDiD(method='relative_magnitude', M=1.0)
710+
bounds = honest.fit(results)
711+
712+
# Should have valid (non-NaN) results
713+
assert isinstance(bounds, HonestDiDResults)
714+
assert np.isfinite(bounds.ci_lb)
715+
assert np.isfinite(bounds.ci_ub)
716+
717+
def test_max_pre_violation_excludes_reference_period(self):
718+
"""Test that reference period (effect=0, n_groups=0) is excluded from max pre-violation.
719+
720+
With universal base period, the reference period e=-1 is a normalization constraint
721+
with n_groups=0. It should not be used in _estimate_max_pre_violation because
722+
its effect is artificially set to 0, which would collapse RM bounds incorrectly.
723+
"""
724+
from diff_diff import CallawaySantAnna, generate_staggered_data
725+
726+
# Generate data with universal base period
727+
data = generate_staggered_data(n_units=200, n_periods=10, seed=42)
728+
cs = CallawaySantAnna(base_period="universal")
729+
results = cs.fit(
730+
data,
731+
outcome='outcome',
732+
unit='unit',
733+
time='period',
734+
first_treat='first_treat',
735+
aggregate='event_study'
736+
)
737+
738+
# Verify reference period exists with n_groups=0
739+
assert -1 in results.event_study_effects
740+
assert results.event_study_effects[-1]['n_groups'] == 0
741+
742+
# The max pre-violation calculation should exclude the reference period
743+
honest = HonestDiD(method='relative_magnitude', M=1.0)
744+
745+
# Get pre_periods excluding reference (n_groups=0)
746+
real_pre_periods = [
747+
t for t in results.event_study_effects
748+
if t < 0 and results.event_study_effects[t].get('n_groups', 1) > 0
749+
]
750+
751+
# If there are real pre-periods, max_violation should be > 0
752+
# (based on actual pre-period effects, not the reference period's effect=0)
753+
if real_pre_periods:
754+
max_violation = honest._estimate_max_pre_violation(results, real_pre_periods)
755+
# Max violation should reflect actual pre-period coefficients, not 0
756+
# The actual effects are non-zero due to sampling variation
757+
assert max_violation > 0, (
758+
"max_pre_violation should be > 0 when real pre-periods exist"
759+
)
760+
684761

685762
# =============================================================================
686763
# Tests for Visualization (without matplotlib)

tests/test_pretrends.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -795,6 +795,38 @@ def test_unsupported_results_type_raises(self):
795795
with pytest.raises(TypeError, match="Unsupported results type"):
796796
pt.fit("not a results object")
797797

798+
def test_callaway_santanna_universal_base_period(self):
799+
"""Test that reference period (e=-1) is correctly filtered out with universal base period.
800+
801+
The reference period has n_groups=0 and se=NaN, so it should be excluded
802+
from pre-trends power analysis to avoid contaminating the vcov matrix.
803+
"""
804+
from diff_diff import CallawaySantAnna, generate_staggered_data
805+
806+
# Generate data and fit with universal base period
807+
data = generate_staggered_data(n_units=200, n_periods=10, seed=42)
808+
cs = CallawaySantAnna(base_period="universal")
809+
results = cs.fit(
810+
data,
811+
outcome='outcome',
812+
unit='unit',
813+
time='period',
814+
first_treat='first_treat',
815+
aggregate='event_study'
816+
)
817+
818+
# Verify reference period exists with NaN SE
819+
assert -1 in results.event_study_effects
820+
assert np.isnan(results.event_study_effects[-1]['se'])
821+
822+
# PreTrendsPower should work without errors (reference period filtered out)
823+
pt = PreTrendsPower()
824+
power_results = pt.fit(results)
825+
826+
# Should have valid (non-NaN) results
827+
assert np.isfinite(power_results.power)
828+
assert power_results.n_pre_periods >= 1
829+
798830

799831
# =============================================================================
800832
# Tests for visualization (without rendering)

0 commit comments

Comments
 (0)