Skip to content

Commit 894955e

Browse files
authored
Merge pull request #74 from igerber/fix/tutorial-notebook-validation
Fix tutorial notebook validation errors and add pre_periods parameter
2 parents ae8ce01 + 021957a commit 894955e

16 files changed

Lines changed: 2814 additions & 346 deletions

TODO.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ Enhancements for `honest_did.py`:
8585
## CallawaySantAnna Bootstrap Improvements
8686

8787
- [ ] Consider aligning p-value computation with R `did` package (symmetric percentile method)
88+
- [ ] Investigate RuntimeWarnings in influence function aggregation (`staggered.py:1722`, `staggered.py:1999-2018`)
89+
- Warnings: "divide by zero", "overflow", "invalid value" in matmul operations
90+
- Occurs during bootstrap SE computation with small sample sizes or edge cases
91+
- Does not affect correctness (results are still valid), but should be suppressed or handled gracefully
8892

8993
---
9094

diff_diff/pretrends.py

Lines changed: 104 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,63 @@ def to_dataframe(self) -> pd.DataFrame:
202202
"""Convert results to DataFrame."""
203203
return pd.DataFrame([self.to_dict()])
204204

205+
def power_at(self, M: float) -> float:
206+
"""
207+
Compute power to detect a specific violation magnitude.
208+
209+
This method allows computing power at different M values without
210+
re-fitting the model, using the stored variance-covariance matrix.
211+
212+
Parameters
213+
----------
214+
M : float
215+
Violation magnitude to evaluate.
216+
217+
Returns
218+
-------
219+
float
220+
Power to detect violation of magnitude M.
221+
"""
222+
from scipy import stats
223+
224+
n_pre = self.n_pre_periods
225+
226+
# Reconstruct violation weights based on violation type
227+
# Must match PreTrendsPower._get_violation_weights() exactly
228+
if self.violation_type == "linear":
229+
# Linear trend: weights decrease toward treatment
230+
# [n-1, n-2, ..., 1, 0] for n pre-periods
231+
weights = np.arange(-n_pre + 1, 1, dtype=float)
232+
weights = -weights # Now [n-1, n-2, ..., 1, 0]
233+
elif self.violation_type == "constant":
234+
weights = np.ones(n_pre)
235+
elif self.violation_type == "last_period":
236+
weights = np.zeros(n_pre)
237+
weights[-1] = 1.0
238+
else:
239+
# For custom, we can't reconstruct - use equal weights as fallback
240+
weights = np.ones(n_pre)
241+
242+
# Normalize weights to unit L2 norm
243+
norm = np.linalg.norm(weights)
244+
if norm > 0:
245+
weights = weights / norm
246+
247+
# Compute non-centrality parameter
248+
try:
249+
vcov_inv = np.linalg.inv(self.vcov)
250+
except np.linalg.LinAlgError:
251+
vcov_inv = np.linalg.pinv(self.vcov)
252+
253+
# delta = M * weights
254+
# nc = delta' * V^{-1} * delta
255+
noncentrality = M**2 * (weights @ vcov_inv @ weights)
256+
257+
# Compute power using non-central chi-squared
258+
power = 1 - stats.ncx2.cdf(self.critical_value, df=n_pre, nc=noncentrality)
259+
260+
return float(power)
261+
205262

206263
@dataclass
207264
class PreTrendsPowerCurve:
@@ -471,10 +528,18 @@ def _get_violation_weights(self, n_pre: int) -> np.ndarray:
471528
def _extract_pre_period_params(
472529
self,
473530
results: Union[MultiPeriodDiDResults, Any],
531+
pre_periods: Optional[List[int]] = None,
474532
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, int]:
475533
"""
476534
Extract pre-period parameters from results.
477535
536+
Parameters
537+
----------
538+
results : MultiPeriodDiDResults or similar
539+
Results object from event study estimation.
540+
pre_periods : list of int, optional
541+
Explicit list of pre-treatment periods. If None, uses results.pre_periods.
542+
478543
Returns
479544
-------
480545
effects : np.ndarray
@@ -487,13 +552,18 @@ def _extract_pre_period_params(
487552
Number of pre-periods.
488553
"""
489554
if isinstance(results, MultiPeriodDiDResults):
490-
# Get pre-period information
491-
all_pre_periods = results.pre_periods
555+
# Get pre-period information - use explicit pre_periods if provided
556+
if pre_periods is not None:
557+
all_pre_periods = list(pre_periods)
558+
else:
559+
all_pre_periods = results.pre_periods
492560

493561
if len(all_pre_periods) == 0:
494562
raise ValueError(
495563
"No pre-treatment periods found in results. "
496-
"Pre-trends power analysis requires pre-period coefficients."
564+
"Pre-trends power analysis requires pre-period coefficients. "
565+
"If you estimated all periods as post_periods, use the pre_periods "
566+
"parameter to specify which are actually pre-treatment."
497567
)
498568

499569
# Only include periods with actual estimated coefficients
@@ -775,6 +845,7 @@ def fit(
775845
self,
776846
results: Union[MultiPeriodDiDResults, Any],
777847
M: Optional[float] = None,
848+
pre_periods: Optional[List[int]] = None,
778849
) -> PreTrendsPowerResults:
779850
"""
780851
Compute pre-trends power analysis.
@@ -786,14 +857,19 @@ def fit(
786857
M : float, optional
787858
Specific violation magnitude to evaluate. If None, evaluates at
788859
a default magnitude based on the data.
860+
pre_periods : list of int, optional
861+
Explicit list of pre-treatment periods to use for power analysis.
862+
If None, attempts to infer from results.pre_periods. Use this when
863+
you've estimated an event study with all periods in post_periods
864+
and need to specify which are actually pre-treatment.
789865
790866
Returns
791867
-------
792868
PreTrendsPowerResults
793869
Power analysis results including power and MDV.
794870
"""
795871
# Extract pre-period parameters
796-
effects, ses, vcov, n_pre = self._extract_pre_period_params(results)
872+
effects, ses, vcov, n_pre = self._extract_pre_period_params(results, pre_periods)
797873

798874
# Get violation weights
799875
weights = self._get_violation_weights(n_pre)
@@ -831,6 +907,7 @@ def power_at(
831907
self,
832908
results: Union[MultiPeriodDiDResults, Any],
833909
M: float,
910+
pre_periods: Optional[List[int]] = None,
834911
) -> float:
835912
"""
836913
Compute power to detect a specific violation magnitude.
@@ -841,20 +918,23 @@ def power_at(
841918
Event study results.
842919
M : float
843920
Violation magnitude.
921+
pre_periods : list of int, optional
922+
Explicit list of pre-treatment periods. See fit() for details.
844923
845924
Returns
846925
-------
847926
float
848927
Power to detect violation of magnitude M.
849928
"""
850-
result = self.fit(results, M=M)
929+
result = self.fit(results, M=M, pre_periods=pre_periods)
851930
return result.power
852931

853932
def power_curve(
854933
self,
855934
results: Union[MultiPeriodDiDResults, Any],
856935
M_grid: Optional[List[float]] = None,
857936
n_points: int = 50,
937+
pre_periods: Optional[List[int]] = None,
858938
) -> PreTrendsPowerCurve:
859939
"""
860940
Compute power across a range of violation magnitudes.
@@ -868,14 +948,16 @@ def power_curve(
868948
automatic grid from 0 to 2.5 * MDV.
869949
n_points : int, default=50
870950
Number of points in automatic grid.
951+
pre_periods : list of int, optional
952+
Explicit list of pre-treatment periods. See fit() for details.
871953
872954
Returns
873955
-------
874956
PreTrendsPowerCurve
875957
Power curve data with plot method.
876958
"""
877959
# Extract parameters
878-
effects, ses, vcov, n_pre = self._extract_pre_period_params(results)
960+
_, ses, vcov, n_pre = self._extract_pre_period_params(results, pre_periods)
879961
weights = self._get_violation_weights(n_pre)
880962

881963
# Compute MDV
@@ -906,6 +988,7 @@ def power_curve(
906988
def sensitivity_to_honest_did(
907989
self,
908990
results: Union[MultiPeriodDiDResults, Any],
991+
pre_periods: Optional[List[int]] = None,
909992
) -> Dict[str, Any]:
910993
"""
911994
Compare pre-trends power analysis with HonestDiD sensitivity.
@@ -917,6 +1000,8 @@ def sensitivity_to_honest_did(
9171000
----------
9181001
results : results object
9191002
Event study results.
1003+
pre_periods : list of int, optional
1004+
Explicit list of pre-treatment periods. See fit() for details.
9201005
9211006
Returns
9221007
-------
@@ -926,7 +1011,7 @@ def sensitivity_to_honest_did(
9261011
- honest_M_at_mdv: Corresponding M value for HonestDiD
9271012
- interpretation: Text explaining the relationship
9281013
"""
929-
pt_results = self.fit(results)
1014+
pt_results = self.fit(results, pre_periods=pre_periods)
9301015
mdv = pt_results.mdv
9311016

9321017
# The MDV represents the size of violation the test could detect
@@ -993,6 +1078,7 @@ def compute_pretrends_power(
9931078
alpha: float = 0.05,
9941079
target_power: float = 0.80,
9951080
violation_type: str = "linear",
1081+
pre_periods: Optional[List[int]] = None,
9961082
) -> PreTrendsPowerResults:
9971083
"""
9981084
Convenience function for pre-trends power analysis.
@@ -1009,6 +1095,9 @@ def compute_pretrends_power(
10091095
Target power for MDV calculation.
10101096
violation_type : str, default='linear'
10111097
Type of violation pattern.
1098+
pre_periods : list of int, optional
1099+
Explicit list of pre-treatment periods. If None, attempts to infer
1100+
from results. Use when you've estimated all periods as post_periods.
10121101
10131102
Returns
10141103
-------
@@ -1021,7 +1110,7 @@ def compute_pretrends_power(
10211110
>>> from diff_diff.pretrends import compute_pretrends_power
10221111
>>>
10231112
>>> results = MultiPeriodDiD().fit(data, ...)
1024-
>>> power_results = compute_pretrends_power(results)
1113+
>>> power_results = compute_pretrends_power(results, pre_periods=[0, 1, 2, 3])
10251114
>>> print(f"MDV: {power_results.mdv:.3f}")
10261115
>>> print(f"Power: {power_results.power:.1%}")
10271116
"""
@@ -1030,14 +1119,15 @@ def compute_pretrends_power(
10301119
power=target_power,
10311120
violation_type=violation_type,
10321121
)
1033-
return pt.fit(results, M=M)
1122+
return pt.fit(results, M=M, pre_periods=pre_periods)
10341123

10351124

10361125
def compute_mdv(
10371126
results: Union[MultiPeriodDiDResults, Any],
10381127
alpha: float = 0.05,
10391128
target_power: float = 0.80,
10401129
violation_type: str = "linear",
1130+
pre_periods: Optional[List[int]] = None,
10411131
) -> float:
10421132
"""
10431133
Compute minimum detectable violation.
@@ -1049,9 +1139,12 @@ def compute_mdv(
10491139
alpha : float, default=0.05
10501140
Significance level.
10511141
target_power : float, default=0.80
1052-
Target power.
1142+
Target power for MDV calculation.
10531143
violation_type : str, default='linear'
10541144
Type of violation pattern.
1145+
pre_periods : list of int, optional
1146+
Explicit list of pre-treatment periods. If None, attempts to infer
1147+
from results. Use when you've estimated all periods as post_periods.
10551148
10561149
Returns
10571150
-------
@@ -1063,5 +1156,5 @@ def compute_mdv(
10631156
power=target_power,
10641157
violation_type=violation_type,
10651158
)
1066-
result = pt.fit(results)
1159+
result = pt.fit(results, pre_periods=pre_periods)
10671160
return result.mdv

diff_diff/staggered.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1053,6 +1053,10 @@ def fit(
10531053
df[time] = pd.to_numeric(df[time])
10541054
df[first_treat] = pd.to_numeric(df[first_treat])
10551055

1056+
# Standardize the first_treat column name for internal use
1057+
# This avoids hardcoding column names in internal methods
1058+
df['first_treat'] = df[first_treat]
1059+
10561060
# Identify groups and time periods
10571061
time_periods = sorted(df[time].unique())
10581062
treatment_groups = sorted([g for g in df[first_treat].unique() if g > 0])

0 commit comments

Comments
 (0)