@@ -202,6 +202,59 @@ def to_dataframe(self) -> pd.DataFrame:
202202 """Convert results to DataFrame."""
203203 return pd .DataFrame ([self .to_dict ()])
204204
205+ def power_at (self , M : float ) -> float :
206+ """
207+ Compute power to detect a specific violation magnitude.
208+
209+ This method allows computing power at different M values without
210+ re-fitting the model, using the stored variance-covariance matrix.
211+
212+ Parameters
213+ ----------
214+ M : float
215+ Violation magnitude to evaluate.
216+
217+ Returns
218+ -------
219+ float
220+ Power to detect violation of magnitude M.
221+ """
222+ from scipy import stats
223+
224+ n_pre = self .n_pre_periods
225+
226+ # Reconstruct violation weights based on violation type
227+ if self .violation_type == "linear" :
228+ weights = np .arange (1 , n_pre + 1 ).astype (float )
229+ elif self .violation_type == "constant" :
230+ weights = np .ones (n_pre )
231+ elif self .violation_type == "last_period" :
232+ weights = np .zeros (n_pre )
233+ weights [- 1 ] = 1.0
234+ else :
235+ # For custom, we can't reconstruct - use equal weights
236+ weights = np .ones (n_pre )
237+
238+ # Normalize weights
239+ norm = np .linalg .norm (weights )
240+ if norm > 0 :
241+ weights = weights / norm
242+
243+ # Compute non-centrality parameter
244+ try :
245+ vcov_inv = np .linalg .inv (self .vcov )
246+ except np .linalg .LinAlgError :
247+ vcov_inv = np .linalg .pinv (self .vcov )
248+
249+ # delta = M * weights
250+ # nc = delta' * V^{-1} * delta
251+ noncentrality = M ** 2 * (weights @ vcov_inv @ weights )
252+
253+ # Compute power using non-central chi-squared
254+ power = 1 - stats .ncx2 .cdf (self .critical_value , df = n_pre , nc = noncentrality )
255+
256+ return float (power )
257+
205258
206259@dataclass
207260class PreTrendsPowerCurve :
@@ -471,10 +524,18 @@ def _get_violation_weights(self, n_pre: int) -> np.ndarray:
471524 def _extract_pre_period_params (
472525 self ,
473526 results : Union [MultiPeriodDiDResults , Any ],
527+ pre_periods : Optional [List [int ]] = None ,
474528 ) -> Tuple [np .ndarray , np .ndarray , np .ndarray , int ]:
475529 """
476530 Extract pre-period parameters from results.
477531
532+ Parameters
533+ ----------
534+ results : MultiPeriodDiDResults or similar
535+ Results object from event study estimation.
536+ pre_periods : list of int, optional
537+ Explicit list of pre-treatment periods. If None, uses results.pre_periods.
538+
478539 Returns
479540 -------
480541 effects : np.ndarray
@@ -487,13 +548,18 @@ def _extract_pre_period_params(
487548 Number of pre-periods.
488549 """
489550 if isinstance (results , MultiPeriodDiDResults ):
490- # Get pre-period information
491- all_pre_periods = results .pre_periods
551+ # Get pre-period information - use explicit pre_periods if provided
552+ if pre_periods is not None :
553+ all_pre_periods = list (pre_periods )
554+ else :
555+ all_pre_periods = results .pre_periods
492556
493557 if len (all_pre_periods ) == 0 :
494558 raise ValueError (
495559 "No pre-treatment periods found in results. "
496- "Pre-trends power analysis requires pre-period coefficients."
560+ "Pre-trends power analysis requires pre-period coefficients. "
561+ "If you estimated all periods as post_periods, use the pre_periods "
562+ "parameter to specify which are actually pre-treatment."
497563 )
498564
499565 # Only include periods with actual estimated coefficients
@@ -775,6 +841,7 @@ def fit(
775841 self ,
776842 results : Union [MultiPeriodDiDResults , Any ],
777843 M : Optional [float ] = None ,
844+ pre_periods : Optional [List [int ]] = None ,
778845 ) -> PreTrendsPowerResults :
779846 """
780847 Compute pre-trends power analysis.
@@ -786,14 +853,19 @@ def fit(
786853 M : float, optional
787854 Specific violation magnitude to evaluate. If None, evaluates at
788855 a default magnitude based on the data.
856+ pre_periods : list of int, optional
857+ Explicit list of pre-treatment periods to use for power analysis.
858+ If None, attempts to infer from results.pre_periods. Use this when
859+ you've estimated an event study with all periods in post_periods
860+ and need to specify which are actually pre-treatment.
789861
790862 Returns
791863 -------
792864 PreTrendsPowerResults
793865 Power analysis results including power and MDV.
794866 """
795867 # Extract pre-period parameters
796- effects , ses , vcov , n_pre = self ._extract_pre_period_params (results )
868+ effects , ses , vcov , n_pre = self ._extract_pre_period_params (results , pre_periods )
797869
798870 # Get violation weights
799871 weights = self ._get_violation_weights (n_pre )
@@ -831,6 +903,7 @@ def power_at(
831903 self ,
832904 results : Union [MultiPeriodDiDResults , Any ],
833905 M : float ,
906+ pre_periods : Optional [List [int ]] = None ,
834907 ) -> float :
835908 """
836909 Compute power to detect a specific violation magnitude.
@@ -841,20 +914,23 @@ def power_at(
841914 Event study results.
842915 M : float
843916 Violation magnitude.
917+ pre_periods : list of int, optional
918+ Explicit list of pre-treatment periods. See fit() for details.
844919
845920 Returns
846921 -------
847922 float
848923 Power to detect violation of magnitude M.
849924 """
850- result = self .fit (results , M = M )
925+ result = self .fit (results , M = M , pre_periods = pre_periods )
851926 return result .power
852927
853928 def power_curve (
854929 self ,
855930 results : Union [MultiPeriodDiDResults , Any ],
856931 M_grid : Optional [List [float ]] = None ,
857932 n_points : int = 50 ,
933+ pre_periods : Optional [List [int ]] = None ,
858934 ) -> PreTrendsPowerCurve :
859935 """
860936 Compute power across a range of violation magnitudes.
@@ -868,14 +944,16 @@ def power_curve(
868944 automatic grid from 0 to 2.5 * MDV.
869945 n_points : int, default=50
870946 Number of points in automatic grid.
947+ pre_periods : list of int, optional
948+ Explicit list of pre-treatment periods. See fit() for details.
871949
872950 Returns
873951 -------
874952 PreTrendsPowerCurve
875953 Power curve data with plot method.
876954 """
877955 # Extract parameters
878- effects , ses , vcov , n_pre = self ._extract_pre_period_params (results )
956+ _ , ses , vcov , n_pre = self ._extract_pre_period_params (results , pre_periods )
879957 weights = self ._get_violation_weights (n_pre )
880958
881959 # Compute MDV
@@ -906,6 +984,7 @@ def power_curve(
906984 def sensitivity_to_honest_did (
907985 self ,
908986 results : Union [MultiPeriodDiDResults , Any ],
987+ pre_periods : Optional [List [int ]] = None ,
909988 ) -> Dict [str , Any ]:
910989 """
911990 Compare pre-trends power analysis with HonestDiD sensitivity.
@@ -917,6 +996,8 @@ def sensitivity_to_honest_did(
917996 ----------
918997 results : results object
919998 Event study results.
999+ pre_periods : list of int, optional
1000+ Explicit list of pre-treatment periods. See fit() for details.
9201001
9211002 Returns
9221003 -------
@@ -926,7 +1007,7 @@ def sensitivity_to_honest_did(
9261007 - honest_M_at_mdv: Corresponding M value for HonestDiD
9271008 - interpretation: Text explaining the relationship
9281009 """
929- pt_results = self .fit (results )
1010+ pt_results = self .fit (results , pre_periods = pre_periods )
9301011 mdv = pt_results .mdv
9311012
9321013 # The MDV represents the size of violation the test could detect
@@ -993,6 +1074,7 @@ def compute_pretrends_power(
9931074 alpha : float = 0.05 ,
9941075 target_power : float = 0.80 ,
9951076 violation_type : str = "linear" ,
1077+ pre_periods : Optional [List [int ]] = None ,
9961078) -> PreTrendsPowerResults :
9971079 """
9981080 Convenience function for pre-trends power analysis.
@@ -1009,6 +1091,9 @@ def compute_pretrends_power(
10091091 Target power for MDV calculation.
10101092 violation_type : str, default='linear'
10111093 Type of violation pattern.
1094+ pre_periods : list of int, optional
1095+ Explicit list of pre-treatment periods. If None, attempts to infer
1096+ from results. Use when you've estimated all periods as post_periods.
10121097
10131098 Returns
10141099 -------
@@ -1021,7 +1106,7 @@ def compute_pretrends_power(
10211106 >>> from diff_diff.pretrends import compute_pretrends_power
10221107 >>>
10231108 >>> results = MultiPeriodDiD().fit(data, ...)
1024- >>> power_results = compute_pretrends_power(results)
1109+ >>> power_results = compute_pretrends_power(results, pre_periods=[0, 1, 2, 3] )
10251110 >>> print(f"MDV: {power_results.mdv:.3f}")
10261111 >>> print(f"Power: {power_results.power:.1%}")
10271112 """
@@ -1030,14 +1115,15 @@ def compute_pretrends_power(
10301115 power = target_power ,
10311116 violation_type = violation_type ,
10321117 )
1033- return pt .fit (results , M = M )
1118+ return pt .fit (results , M = M , pre_periods = pre_periods )
10341119
10351120
10361121def compute_mdv (
10371122 results : Union [MultiPeriodDiDResults , Any ],
10381123 alpha : float = 0.05 ,
1039- target_power : float = 0.80 ,
1124+ power : float = 0.80 ,
10401125 violation_type : str = "linear" ,
1126+ pre_periods : Optional [List [int ]] = None ,
10411127) -> float :
10421128 """
10431129 Compute minimum detectable violation.
@@ -1048,10 +1134,13 @@ def compute_mdv(
10481134 Event study results.
10491135 alpha : float, default=0.05
10501136 Significance level.
1051- target_power : float, default=0.80
1137+ power : float, default=0.80
10521138 Target power.
10531139 violation_type : str, default='linear'
10541140 Type of violation pattern.
1141+ pre_periods : list of int, optional
1142+ Explicit list of pre-treatment periods. If None, attempts to infer
1143+ from results. Use when you've estimated all periods as post_periods.
10551144
10561145 Returns
10571146 -------
@@ -1060,8 +1149,8 @@ def compute_mdv(
10601149 """
10611150 pt = PreTrendsPower (
10621151 alpha = alpha ,
1063- power = target_power ,
1152+ power = power ,
10641153 violation_type = violation_type ,
10651154 )
1066- result = pt .fit (results )
1155+ result = pt .fit (results , pre_periods = pre_periods )
10671156 return result .mdv
0 commit comments