@@ -811,3 +811,218 @@ def test_power_curve_has_plot_method(self, mock_multiperiod_results):
811811
812812 assert hasattr (curve , 'plot' )
813813 assert callable (curve .plot )
814+
815+
816+ # =============================================================================
817+ # Tests for PreTrendsPowerResults.power_at() method
818+ # =============================================================================
819+
820+
821+ class TestPreTrendsPowerResultsPowerAt :
822+ """Tests for the power_at method on PreTrendsPowerResults."""
823+
824+ def test_power_at_basic (self , mock_multiperiod_results ):
825+ """Test basic power_at functionality."""
826+ pt = PreTrendsPower ()
827+ results = pt .fit (mock_multiperiod_results )
828+
829+ # Compute power at different M values
830+ power_1 = results .power_at (1.0 )
831+ power_2 = results .power_at (2.0 )
832+ power_5 = results .power_at (5.0 )
833+
834+ # Power should increase with M
835+ assert power_1 < power_2 < power_5
836+
837+ # Power should be between 0 and 1
838+ assert 0 <= power_1 <= 1
839+ assert 0 <= power_2 <= 1
840+ assert 0 <= power_5 <= 1
841+
842+ def test_power_at_zero (self , mock_multiperiod_results ):
843+ """Test power_at with M=0 (should equal alpha)."""
844+ pt = PreTrendsPower (alpha = 0.05 )
845+ results = pt .fit (mock_multiperiod_results )
846+
847+ power_0 = results .power_at (0.0 )
848+
849+ # At M=0, power should equal size (alpha)
850+ assert np .isclose (power_0 , 0.05 , atol = 0.01 )
851+
852+ def test_power_at_matches_fit (self , mock_multiperiod_results ):
853+ """Test that power_at gives same result as fitting with that M."""
854+ pt = PreTrendsPower ()
855+
856+ # Get results from fit
857+ results1 = pt .fit (mock_multiperiod_results , M = 2.0 )
858+
859+ # Get power from power_at method
860+ results_base = pt .fit (mock_multiperiod_results )
861+ power_from_method = results_base .power_at (2.0 )
862+
863+ # Should be the same (or very close)
864+ assert np .isclose (results1 .power , power_from_method , rtol = 0.01 )
865+
866+ def test_power_at_linear_weights (self , mock_multiperiod_results ):
867+ """Test power_at uses correct linear weights."""
868+ pt = PreTrendsPower (violation_type = "linear" )
869+ results = pt .fit (mock_multiperiod_results )
870+
871+ # Power_at should work without error
872+ power = results .power_at (1.0 )
873+ assert 0 <= power <= 1
874+
875+ def test_power_at_constant_weights (self , mock_multiperiod_results ):
876+ """Test power_at uses correct constant weights."""
877+ pt = PreTrendsPower (violation_type = "constant" )
878+ results = pt .fit (mock_multiperiod_results )
879+
880+ power = results .power_at (1.0 )
881+ assert 0 <= power <= 1
882+
883+ def test_power_at_last_period_weights (self , mock_multiperiod_results ):
884+ """Test power_at uses correct last_period weights."""
885+ pt = PreTrendsPower (violation_type = "last_period" )
886+ results = pt .fit (mock_multiperiod_results )
887+
888+ power = results .power_at (1.0 )
889+ assert 0 <= power <= 1
890+
891+
892+ # =============================================================================
893+ # Tests for pre_periods parameter
894+ # =============================================================================
895+
896+
897+ class TestPrePeriodsParameter :
898+ """Tests for the pre_periods parameter in fit and related methods."""
899+
900+ @pytest .fixture
901+ def event_study_all_periods_results (self ):
902+ """Create results simulating all periods estimated as post_periods.
903+
904+ This mimics the event study workflow where we estimate coefficients
905+ for ALL periods (pre and post) to get pre-period placebo effects.
906+ """
907+ # Periods 0-3 are pre-treatment, 4-7 are post
908+ # But we estimate ALL periods as "post" to get coefficients
909+ period_effects = {}
910+ coefficients = {}
911+
912+ # Pre-periods (0, 1, 2) - period 3 would be reference
913+ for p in [0 , 1 , 2 ]:
914+ period_effects [p ] = PeriodEffect (
915+ period = p , effect = np .random .normal (0 , 0.1 ), se = 0.5 ,
916+ t_stat = 0.2 , p_value = 0.84 , conf_int = (- 0.88 , 1.08 )
917+ )
918+ coefficients [f'treated:period_{ p } ' ] = period_effects [p ].effect
919+
920+ # Post-periods (4, 5, 6, 7)
921+ for p in [4 , 5 , 6 , 7 ]:
922+ period_effects [p ] = PeriodEffect (
923+ period = p , effect = 5.0 + np .random .normal (0 , 0.1 ), se = 0.5 ,
924+ t_stat = 10.0 , p_value = 0.0001 , conf_int = (4.02 , 5.98 )
925+ )
926+ coefficients [f'treated:period_{ p } ' ] = period_effects [p ].effect
927+
928+ # In this scenario, pre_periods=[3] (only reference), post_periods=[0,1,2,4,5,6,7]
929+ vcov = np .diag ([0.25 ] * 7 )
930+
931+ return MultiPeriodDiDResults (
932+ period_effects = period_effects ,
933+ avg_att = 5.0 ,
934+ avg_se = 0.25 ,
935+ avg_t_stat = 20.0 ,
936+ avg_p_value = 0.0001 ,
937+ avg_conf_int = (4.51 , 5.49 ),
938+ n_obs = 800 ,
939+ n_treated = 400 ,
940+ n_control = 400 ,
941+ pre_periods = [3 ], # Only reference period
942+ post_periods = [0 , 1 , 2 , 4 , 5 , 6 , 7 ], # All estimated periods
943+ vcov = vcov ,
944+ coefficients = coefficients ,
945+ )
946+
947+ def test_fit_with_explicit_pre_periods (self , event_study_all_periods_results ):
948+ """Test fit() with explicit pre_periods parameter."""
949+ pt = PreTrendsPower ()
950+
951+ # Without pre_periods, would fail because results.pre_periods=[3]
952+ # and period 3 has no coefficient (it's the reference)
953+ # With explicit pre_periods=[0,1,2], should work
954+ results = pt .fit (
955+ event_study_all_periods_results ,
956+ pre_periods = [0 , 1 , 2 ]
957+ )
958+
959+ assert results .n_pre_periods == 3
960+ assert results .power >= 0
961+ assert results .mdv > 0
962+
963+ def test_pre_periods_overrides_results (self , event_study_all_periods_results ):
964+ """Test that pre_periods parameter overrides results.pre_periods."""
965+ pt = PreTrendsPower ()
966+
967+ # Explicitly set pre_periods to [0, 1]
968+ results = pt .fit (
969+ event_study_all_periods_results ,
970+ pre_periods = [0 , 1 ]
971+ )
972+
973+ # Should use 2 pre-periods, not what's in results
974+ assert results .n_pre_periods == 2
975+
976+ def test_power_at_with_pre_periods (self , event_study_all_periods_results ):
977+ """Test power_at() method with pre_periods parameter."""
978+ pt = PreTrendsPower ()
979+
980+ power = pt .power_at (
981+ event_study_all_periods_results ,
982+ M = 1.0 ,
983+ pre_periods = [0 , 1 , 2 ]
984+ )
985+
986+ assert 0 <= power <= 1
987+
988+ def test_power_curve_with_pre_periods (self , event_study_all_periods_results ):
989+ """Test power_curve() with pre_periods parameter."""
990+ pt = PreTrendsPower ()
991+
992+ curve = pt .power_curve (
993+ event_study_all_periods_results ,
994+ n_points = 10 ,
995+ pre_periods = [0 , 1 , 2 ]
996+ )
997+
998+ assert len (curve .M_values ) == 10
999+ assert len (curve .powers ) == 10
1000+
1001+ def test_sensitivity_to_honest_did_with_pre_periods (self , event_study_all_periods_results ):
1002+ """Test sensitivity_to_honest_did() with pre_periods parameter."""
1003+ pt = PreTrendsPower ()
1004+
1005+ sensitivity = pt .sensitivity_to_honest_did (
1006+ event_study_all_periods_results ,
1007+ pre_periods = [0 , 1 , 2 ]
1008+ )
1009+
1010+ assert 'mdv' in sensitivity
1011+ assert sensitivity ['mdv' ] > 0
1012+
1013+ def test_convenience_functions_with_pre_periods (self , event_study_all_periods_results ):
1014+ """Test convenience functions with pre_periods parameter."""
1015+ # compute_mdv
1016+ mdv = compute_mdv (
1017+ event_study_all_periods_results ,
1018+ pre_periods = [0 , 1 , 2 ]
1019+ )
1020+ assert mdv > 0
1021+
1022+ # compute_pretrends_power
1023+ results = compute_pretrends_power (
1024+ event_study_all_periods_results ,
1025+ M = 1.0 ,
1026+ pre_periods = [0 , 1 , 2 ]
1027+ )
1028+ assert results .n_pre_periods == 3
0 commit comments