@@ -202,6 +202,63 @@ def to_dataframe(self) -> pd.DataFrame:
202202 """Convert results to DataFrame."""
203203 return pd .DataFrame ([self .to_dict ()])
204204
205+ def power_at (self , M : float ) -> float :
206+ """
207+ Compute power to detect a specific violation magnitude.
208+
209+ This method allows computing power at different M values without
210+ re-fitting the model, using the stored variance-covariance matrix.
211+
212+ Parameters
213+ ----------
214+ M : float
215+ Violation magnitude to evaluate.
216+
217+ Returns
218+ -------
219+ float
220+ Power to detect violation of magnitude M.
221+ """
222+ from scipy import stats
223+
224+ n_pre = self .n_pre_periods
225+
226+ # Reconstruct violation weights based on violation type
227+ # Must match PreTrendsPower._get_violation_weights() exactly
228+ if self .violation_type == "linear" :
229+ # Linear trend: weights decrease toward treatment
230+ # [n-1, n-2, ..., 1, 0] for n pre-periods
231+ weights = np .arange (- n_pre + 1 , 1 , dtype = float )
232+ weights = - weights # Now [n-1, n-2, ..., 1, 0]
233+ elif self .violation_type == "constant" :
234+ weights = np .ones (n_pre )
235+ elif self .violation_type == "last_period" :
236+ weights = np .zeros (n_pre )
237+ weights [- 1 ] = 1.0
238+ else :
239+ # For custom, we can't reconstruct - use equal weights as fallback
240+ weights = np .ones (n_pre )
241+
242+ # Normalize weights to unit L2 norm
243+ norm = np .linalg .norm (weights )
244+ if norm > 0 :
245+ weights = weights / norm
246+
247+ # Compute non-centrality parameter
248+ try :
249+ vcov_inv = np .linalg .inv (self .vcov )
250+ except np .linalg .LinAlgError :
251+ vcov_inv = np .linalg .pinv (self .vcov )
252+
253+ # delta = M * weights
254+ # nc = delta' * V^{-1} * delta
255+ noncentrality = M ** 2 * (weights @ vcov_inv @ weights )
256+
257+ # Compute power using non-central chi-squared
258+ power = 1 - stats .ncx2 .cdf (self .critical_value , df = n_pre , nc = noncentrality )
259+
260+ return float (power )
261+
205262
206263@dataclass
207264class PreTrendsPowerCurve :
@@ -471,10 +528,18 @@ def _get_violation_weights(self, n_pre: int) -> np.ndarray:
471528 def _extract_pre_period_params (
472529 self ,
473530 results : Union [MultiPeriodDiDResults , Any ],
531+ pre_periods : Optional [List [int ]] = None ,
474532 ) -> Tuple [np .ndarray , np .ndarray , np .ndarray , int ]:
475533 """
476534 Extract pre-period parameters from results.
477535
536+ Parameters
537+ ----------
538+ results : MultiPeriodDiDResults or similar
539+ Results object from event study estimation.
540+ pre_periods : list of int, optional
541+ Explicit list of pre-treatment periods. If None, uses results.pre_periods.
542+
478543 Returns
479544 -------
480545 effects : np.ndarray
@@ -487,13 +552,18 @@ def _extract_pre_period_params(
487552 Number of pre-periods.
488553 """
489554 if isinstance (results , MultiPeriodDiDResults ):
490- # Get pre-period information
491- all_pre_periods = results .pre_periods
555+ # Get pre-period information - use explicit pre_periods if provided
556+ if pre_periods is not None :
557+ all_pre_periods = list (pre_periods )
558+ else :
559+ all_pre_periods = results .pre_periods
492560
493561 if len (all_pre_periods ) == 0 :
494562 raise ValueError (
495563 "No pre-treatment periods found in results. "
496- "Pre-trends power analysis requires pre-period coefficients."
564+ "Pre-trends power analysis requires pre-period coefficients. "
565+ "If you estimated all periods as post_periods, use the pre_periods "
566+ "parameter to specify which are actually pre-treatment."
497567 )
498568
499569 # Only include periods with actual estimated coefficients
@@ -775,6 +845,7 @@ def fit(
775845 self ,
776846 results : Union [MultiPeriodDiDResults , Any ],
777847 M : Optional [float ] = None ,
848+ pre_periods : Optional [List [int ]] = None ,
778849 ) -> PreTrendsPowerResults :
779850 """
780851 Compute pre-trends power analysis.
@@ -786,14 +857,19 @@ def fit(
786857 M : float, optional
787858 Specific violation magnitude to evaluate. If None, evaluates at
788859 a default magnitude based on the data.
860+ pre_periods : list of int, optional
861+ Explicit list of pre-treatment periods to use for power analysis.
862+ If None, attempts to infer from results.pre_periods. Use this when
863+ you've estimated an event study with all periods in post_periods
864+ and need to specify which are actually pre-treatment.
789865
790866 Returns
791867 -------
792868 PreTrendsPowerResults
793869 Power analysis results including power and MDV.
794870 """
795871 # Extract pre-period parameters
796- effects , ses , vcov , n_pre = self ._extract_pre_period_params (results )
872+ effects , ses , vcov , n_pre = self ._extract_pre_period_params (results , pre_periods )
797873
798874 # Get violation weights
799875 weights = self ._get_violation_weights (n_pre )
@@ -831,6 +907,7 @@ def power_at(
831907 self ,
832908 results : Union [MultiPeriodDiDResults , Any ],
833909 M : float ,
910+ pre_periods : Optional [List [int ]] = None ,
834911 ) -> float :
835912 """
836913 Compute power to detect a specific violation magnitude.
@@ -841,20 +918,23 @@ def power_at(
841918 Event study results.
842919 M : float
843920 Violation magnitude.
921+ pre_periods : list of int, optional
922+ Explicit list of pre-treatment periods. See fit() for details.
844923
845924 Returns
846925 -------
847926 float
848927 Power to detect violation of magnitude M.
849928 """
850- result = self .fit (results , M = M )
929+ result = self .fit (results , M = M , pre_periods = pre_periods )
851930 return result .power
852931
853932 def power_curve (
854933 self ,
855934 results : Union [MultiPeriodDiDResults , Any ],
856935 M_grid : Optional [List [float ]] = None ,
857936 n_points : int = 50 ,
937+ pre_periods : Optional [List [int ]] = None ,
858938 ) -> PreTrendsPowerCurve :
859939 """
860940 Compute power across a range of violation magnitudes.
@@ -868,14 +948,16 @@ def power_curve(
868948 automatic grid from 0 to 2.5 * MDV.
869949 n_points : int, default=50
870950 Number of points in automatic grid.
951+ pre_periods : list of int, optional
952+ Explicit list of pre-treatment periods. See fit() for details.
871953
872954 Returns
873955 -------
874956 PreTrendsPowerCurve
875957 Power curve data with plot method.
876958 """
877959 # Extract parameters
878- effects , ses , vcov , n_pre = self ._extract_pre_period_params (results )
960+ _ , ses , vcov , n_pre = self ._extract_pre_period_params (results , pre_periods )
879961 weights = self ._get_violation_weights (n_pre )
880962
881963 # Compute MDV
@@ -906,6 +988,7 @@ def power_curve(
906988 def sensitivity_to_honest_did (
907989 self ,
908990 results : Union [MultiPeriodDiDResults , Any ],
991+ pre_periods : Optional [List [int ]] = None ,
909992 ) -> Dict [str , Any ]:
910993 """
911994 Compare pre-trends power analysis with HonestDiD sensitivity.
@@ -917,6 +1000,8 @@ def sensitivity_to_honest_did(
9171000 ----------
9181001 results : results object
9191002 Event study results.
1003+ pre_periods : list of int, optional
1004+ Explicit list of pre-treatment periods. See fit() for details.
9201005
9211006 Returns
9221007 -------
@@ -926,7 +1011,7 @@ def sensitivity_to_honest_did(
9261011 - honest_M_at_mdv: Corresponding M value for HonestDiD
9271012 - interpretation: Text explaining the relationship
9281013 """
929- pt_results = self .fit (results )
1014+ pt_results = self .fit (results , pre_periods = pre_periods )
9301015 mdv = pt_results .mdv
9311016
9321017 # The MDV represents the size of violation the test could detect
@@ -993,6 +1078,7 @@ def compute_pretrends_power(
9931078 alpha : float = 0.05 ,
9941079 target_power : float = 0.80 ,
9951080 violation_type : str = "linear" ,
1081+ pre_periods : Optional [List [int ]] = None ,
9961082) -> PreTrendsPowerResults :
9971083 """
9981084 Convenience function for pre-trends power analysis.
@@ -1009,6 +1095,9 @@ def compute_pretrends_power(
10091095 Target power for MDV calculation.
10101096 violation_type : str, default='linear'
10111097 Type of violation pattern.
1098+ pre_periods : list of int, optional
1099+ Explicit list of pre-treatment periods. If None, attempts to infer
1100+ from results. Use when you've estimated all periods as post_periods.
10121101
10131102 Returns
10141103 -------
@@ -1021,7 +1110,7 @@ def compute_pretrends_power(
10211110 >>> from diff_diff.pretrends import compute_pretrends_power
10221111 >>>
10231112 >>> results = MultiPeriodDiD().fit(data, ...)
1024- >>> power_results = compute_pretrends_power(results)
1113+ >>> power_results = compute_pretrends_power(results, pre_periods=[0, 1, 2, 3] )
10251114 >>> print(f"MDV: {power_results.mdv:.3f}")
10261115 >>> print(f"Power: {power_results.power:.1%}")
10271116 """
@@ -1030,14 +1119,15 @@ def compute_pretrends_power(
10301119 power = target_power ,
10311120 violation_type = violation_type ,
10321121 )
1033- return pt .fit (results , M = M )
1122+ return pt .fit (results , M = M , pre_periods = pre_periods )
10341123
10351124
10361125def compute_mdv (
10371126 results : Union [MultiPeriodDiDResults , Any ],
10381127 alpha : float = 0.05 ,
10391128 target_power : float = 0.80 ,
10401129 violation_type : str = "linear" ,
1130+ pre_periods : Optional [List [int ]] = None ,
10411131) -> float :
10421132 """
10431133 Compute minimum detectable violation.
@@ -1049,9 +1139,12 @@ def compute_mdv(
10491139 alpha : float, default=0.05
10501140 Significance level.
10511141 target_power : float, default=0.80
1052- Target power.
1142+ Target power for MDV calculation .
10531143 violation_type : str, default='linear'
10541144 Type of violation pattern.
1145+ pre_periods : list of int, optional
1146+ Explicit list of pre-treatment periods. If None, attempts to infer
1147+ from results. Use when you've estimated all periods as post_periods.
10551148
10561149 Returns
10571150 -------
@@ -1063,5 +1156,5 @@ def compute_mdv(
10631156 power = target_power ,
10641157 violation_type = violation_type ,
10651158 )
1066- result = pt .fit (results )
1159+ result = pt .fit (results , pre_periods = pre_periods )
10671160 return result .mdv
0 commit comments