Skip to content

Commit 336e246

Browse files
igerberclaude
andcommitted
Address PR review feedback: fix bugs and add tests
Fixes: - Fix weight computation in PreTrendsPowerResults.power_at() to match _get_violation_weights() logic (linear weights should be [n-1, n-2, ..., 0]) - Fix compute_mdv() parameter name from 'power' back to 'target_power' for consistency with compute_pretrends_power() - Update notebook cell-28 to use target_power instead of power Tests added: - TestPreTrendsPowerResultsPowerAt: 6 tests for power_at() method - TestPrePeriodsParameter: 6 tests for pre_periods parameter - TestCallawaySantAnnaNonStandardColumnNames: 10 tests for non-standard column names in CallawaySantAnna Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent c668a09 commit 336e246

4 files changed

Lines changed: 473 additions & 17 deletions

File tree

diff_diff/pretrends.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -224,18 +224,22 @@ def power_at(self, M: float) -> float:
224224
n_pre = self.n_pre_periods
225225

226226
# Reconstruct violation weights based on violation type
227+
# Must match PreTrendsPower._get_violation_weights() exactly
227228
if self.violation_type == "linear":
228-
weights = np.arange(1, n_pre + 1).astype(float)
229+
# Linear trend: weights decrease toward treatment
230+
# [n-1, n-2, ..., 1, 0] for n pre-periods
231+
weights = np.arange(-n_pre + 1, 1, dtype=float)
232+
weights = -weights # Now [n-1, n-2, ..., 1, 0]
229233
elif self.violation_type == "constant":
230234
weights = np.ones(n_pre)
231235
elif self.violation_type == "last_period":
232236
weights = np.zeros(n_pre)
233237
weights[-1] = 1.0
234238
else:
235-
# For custom, we can't reconstruct - use equal weights
239+
# For custom, we can't reconstruct - use equal weights as fallback
236240
weights = np.ones(n_pre)
237241

238-
# Normalize weights
242+
# Normalize weights to unit L2 norm
239243
norm = np.linalg.norm(weights)
240244
if norm > 0:
241245
weights = weights / norm
@@ -1121,7 +1125,7 @@ def compute_pretrends_power(
11211125
def compute_mdv(
11221126
results: Union[MultiPeriodDiDResults, Any],
11231127
alpha: float = 0.05,
1124-
power: float = 0.80,
1128+
target_power: float = 0.80,
11251129
violation_type: str = "linear",
11261130
pre_periods: Optional[List[int]] = None,
11271131
) -> float:
@@ -1134,8 +1138,8 @@ def compute_mdv(
11341138
Event study results.
11351139
alpha : float, default=0.05
11361140
Significance level.
1137-
power : float, default=0.80
1138-
Target power.
1141+
target_power : float, default=0.80
1142+
Target power for MDV calculation.
11391143
violation_type : str, default='linear'
11401144
Type of violation pattern.
11411145
pre_periods : list of int, optional
@@ -1149,7 +1153,7 @@ def compute_mdv(
11491153
"""
11501154
pt = PreTrendsPower(
11511155
alpha=alpha,
1152-
power=power,
1156+
power=target_power,
11531157
violation_type=violation_type,
11541158
)
11551159
result = pt.fit(results, pre_periods=pre_periods)

docs/tutorials/07_pretrends_power.ipynb

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -659,15 +659,7 @@
659659
}
660660
},
661661
"outputs": [],
662-
"source": [
663-
"# Quick MDV calculation\n",
664-
"mdv = compute_mdv(event_results, power=0.80, violation_type='linear', pre_periods=pre_treatment_periods)\n",
665-
"print(f\"MDV: {mdv:.3f}\")\n",
666-
"\n",
667-
"# Quick power calculation at a specific violation\n",
668-
"power_result = compute_pretrends_power(event_results, M=2.0, pre_periods=pre_treatment_periods)\n",
669-
"print(f\"Power at violation=2.0: {power_result.power:.1%}\")"
670-
]
662+
"source": "# Quick MDV calculation\nmdv = compute_mdv(event_results, target_power=0.80, violation_type='linear', pre_periods=pre_treatment_periods)\nprint(f\"MDV: {mdv:.3f}\")\n\n# Quick power calculation at a specific violation\npower_result = compute_pretrends_power(event_results, M=2.0, pre_periods=pre_treatment_periods)\nprint(f\"Power at violation=2.0: {power_result.power:.1%}\")"
671663
},
672664
{
673665
"cell_type": "markdown",
@@ -855,4 +847,4 @@
855847
},
856848
"nbformat": 4,
857849
"nbformat_minor": 5
858-
}
850+
}

tests/test_pretrends.py

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -811,3 +811,218 @@ def test_power_curve_has_plot_method(self, mock_multiperiod_results):
811811

812812
assert hasattr(curve, 'plot')
813813
assert callable(curve.plot)
814+
815+
816+
# =============================================================================
817+
# Tests for PreTrendsPowerResults.power_at() method
818+
# =============================================================================
819+
820+
821+
class TestPreTrendsPowerResultsPowerAt:
822+
"""Tests for the power_at method on PreTrendsPowerResults."""
823+
824+
def test_power_at_basic(self, mock_multiperiod_results):
825+
"""Test basic power_at functionality."""
826+
pt = PreTrendsPower()
827+
results = pt.fit(mock_multiperiod_results)
828+
829+
# Compute power at different M values
830+
power_1 = results.power_at(1.0)
831+
power_2 = results.power_at(2.0)
832+
power_5 = results.power_at(5.0)
833+
834+
# Power should increase with M
835+
assert power_1 < power_2 < power_5
836+
837+
# Power should be between 0 and 1
838+
assert 0 <= power_1 <= 1
839+
assert 0 <= power_2 <= 1
840+
assert 0 <= power_5 <= 1
841+
842+
def test_power_at_zero(self, mock_multiperiod_results):
843+
"""Test power_at with M=0 (should equal alpha)."""
844+
pt = PreTrendsPower(alpha=0.05)
845+
results = pt.fit(mock_multiperiod_results)
846+
847+
power_0 = results.power_at(0.0)
848+
849+
# At M=0, power should equal size (alpha)
850+
assert np.isclose(power_0, 0.05, atol=0.01)
851+
852+
def test_power_at_matches_fit(self, mock_multiperiod_results):
853+
"""Test that power_at gives same result as fitting with that M."""
854+
pt = PreTrendsPower()
855+
856+
# Get results from fit
857+
results1 = pt.fit(mock_multiperiod_results, M=2.0)
858+
859+
# Get power from power_at method
860+
results_base = pt.fit(mock_multiperiod_results)
861+
power_from_method = results_base.power_at(2.0)
862+
863+
# Should be the same (or very close)
864+
assert np.isclose(results1.power, power_from_method, rtol=0.01)
865+
866+
def test_power_at_linear_weights(self, mock_multiperiod_results):
867+
"""Test power_at uses correct linear weights."""
868+
pt = PreTrendsPower(violation_type="linear")
869+
results = pt.fit(mock_multiperiod_results)
870+
871+
# Power_at should work without error
872+
power = results.power_at(1.0)
873+
assert 0 <= power <= 1
874+
875+
def test_power_at_constant_weights(self, mock_multiperiod_results):
876+
"""Test power_at uses correct constant weights."""
877+
pt = PreTrendsPower(violation_type="constant")
878+
results = pt.fit(mock_multiperiod_results)
879+
880+
power = results.power_at(1.0)
881+
assert 0 <= power <= 1
882+
883+
def test_power_at_last_period_weights(self, mock_multiperiod_results):
884+
"""Test power_at uses correct last_period weights."""
885+
pt = PreTrendsPower(violation_type="last_period")
886+
results = pt.fit(mock_multiperiod_results)
887+
888+
power = results.power_at(1.0)
889+
assert 0 <= power <= 1
890+
891+
892+
# =============================================================================
893+
# Tests for pre_periods parameter
894+
# =============================================================================
895+
896+
897+
class TestPrePeriodsParameter:
898+
"""Tests for the pre_periods parameter in fit and related methods."""
899+
900+
@pytest.fixture
901+
def event_study_all_periods_results(self):
902+
"""Create results simulating all periods estimated as post_periods.
903+
904+
This mimics the event study workflow where we estimate coefficients
905+
for ALL periods (pre and post) to get pre-period placebo effects.
906+
"""
907+
# Periods 0-3 are pre-treatment, 4-7 are post
908+
# But we estimate ALL periods as "post" to get coefficients
909+
period_effects = {}
910+
coefficients = {}
911+
912+
# Pre-periods (0, 1, 2) - period 3 would be reference
913+
for p in [0, 1, 2]:
914+
period_effects[p] = PeriodEffect(
915+
period=p, effect=np.random.normal(0, 0.1), se=0.5,
916+
t_stat=0.2, p_value=0.84, conf_int=(-0.88, 1.08)
917+
)
918+
coefficients[f'treated:period_{p}'] = period_effects[p].effect
919+
920+
# Post-periods (4, 5, 6, 7)
921+
for p in [4, 5, 6, 7]:
922+
period_effects[p] = PeriodEffect(
923+
period=p, effect=5.0 + np.random.normal(0, 0.1), se=0.5,
924+
t_stat=10.0, p_value=0.0001, conf_int=(4.02, 5.98)
925+
)
926+
coefficients[f'treated:period_{p}'] = period_effects[p].effect
927+
928+
# In this scenario, pre_periods=[3] (only reference), post_periods=[0,1,2,4,5,6,7]
929+
vcov = np.diag([0.25] * 7)
930+
931+
return MultiPeriodDiDResults(
932+
period_effects=period_effects,
933+
avg_att=5.0,
934+
avg_se=0.25,
935+
avg_t_stat=20.0,
936+
avg_p_value=0.0001,
937+
avg_conf_int=(4.51, 5.49),
938+
n_obs=800,
939+
n_treated=400,
940+
n_control=400,
941+
pre_periods=[3], # Only reference period
942+
post_periods=[0, 1, 2, 4, 5, 6, 7], # All estimated periods
943+
vcov=vcov,
944+
coefficients=coefficients,
945+
)
946+
947+
def test_fit_with_explicit_pre_periods(self, event_study_all_periods_results):
948+
"""Test fit() with explicit pre_periods parameter."""
949+
pt = PreTrendsPower()
950+
951+
# Without pre_periods, would fail because results.pre_periods=[3]
952+
# and period 3 has no coefficient (it's the reference)
953+
# With explicit pre_periods=[0,1,2], should work
954+
results = pt.fit(
955+
event_study_all_periods_results,
956+
pre_periods=[0, 1, 2]
957+
)
958+
959+
assert results.n_pre_periods == 3
960+
assert results.power >= 0
961+
assert results.mdv > 0
962+
963+
def test_pre_periods_overrides_results(self, event_study_all_periods_results):
964+
"""Test that pre_periods parameter overrides results.pre_periods."""
965+
pt = PreTrendsPower()
966+
967+
# Explicitly set pre_periods to [0, 1]
968+
results = pt.fit(
969+
event_study_all_periods_results,
970+
pre_periods=[0, 1]
971+
)
972+
973+
# Should use 2 pre-periods, not what's in results
974+
assert results.n_pre_periods == 2
975+
976+
def test_power_at_with_pre_periods(self, event_study_all_periods_results):
977+
"""Test power_at() method with pre_periods parameter."""
978+
pt = PreTrendsPower()
979+
980+
power = pt.power_at(
981+
event_study_all_periods_results,
982+
M=1.0,
983+
pre_periods=[0, 1, 2]
984+
)
985+
986+
assert 0 <= power <= 1
987+
988+
def test_power_curve_with_pre_periods(self, event_study_all_periods_results):
989+
"""Test power_curve() with pre_periods parameter."""
990+
pt = PreTrendsPower()
991+
992+
curve = pt.power_curve(
993+
event_study_all_periods_results,
994+
n_points=10,
995+
pre_periods=[0, 1, 2]
996+
)
997+
998+
assert len(curve.M_values) == 10
999+
assert len(curve.powers) == 10
1000+
1001+
def test_sensitivity_to_honest_did_with_pre_periods(self, event_study_all_periods_results):
1002+
"""Test sensitivity_to_honest_did() with pre_periods parameter."""
1003+
pt = PreTrendsPower()
1004+
1005+
sensitivity = pt.sensitivity_to_honest_did(
1006+
event_study_all_periods_results,
1007+
pre_periods=[0, 1, 2]
1008+
)
1009+
1010+
assert 'mdv' in sensitivity
1011+
assert sensitivity['mdv'] > 0
1012+
1013+
def test_convenience_functions_with_pre_periods(self, event_study_all_periods_results):
1014+
"""Test convenience functions with pre_periods parameter."""
1015+
# compute_mdv
1016+
mdv = compute_mdv(
1017+
event_study_all_periods_results,
1018+
pre_periods=[0, 1, 2]
1019+
)
1020+
assert mdv > 0
1021+
1022+
# compute_pretrends_power
1023+
results = compute_pretrends_power(
1024+
event_study_all_periods_results,
1025+
M=1.0,
1026+
pre_periods=[0, 1, 2]
1027+
)
1028+
assert results.n_pre_periods == 3

0 commit comments

Comments
 (0)