@@ -3892,6 +3892,157 @@ def _fake_rust_boot(*args, **kwargs):
38923892 assert np .isfinite (se )
38933893 assert len (dist ) == 11
38943894
3895+ def test_global_rust_bootstrap_warns_above_5pct_failure (self ):
3896+ """Global Rust happy path: 3/20 Rust successes (85% fail) warns."""
3897+ import sys
3898+ from unittest .mock import patch
3899+
3900+ df = TestTROPNValidTreated ._make_panel ()
3901+
3902+ trop_est = TROP (
3903+ method = "global" ,
3904+ lambda_time_grid = [1.0 ],
3905+ lambda_unit_grid = [1.0 ],
3906+ lambda_nn_grid = [np .inf ],
3907+ n_bootstrap = 20 ,
3908+ seed = 42 ,
3909+ )
3910+
3911+ trop_global_module = sys .modules ["diff_diff.trop_global" ]
3912+ rng = np .random .default_rng (0 )
3913+ fake_boot = rng .normal (size = 3 )
3914+
3915+ def _fake_rust_boot_global (* args , ** kwargs ):
3916+ return fake_boot , float (np .std (fake_boot , ddof = 1 ))
3917+
3918+ with (
3919+ patch .object (trop_global_module , "HAS_RUST_BACKEND" , True ),
3920+ patch .object (
3921+ trop_global_module ,
3922+ "_rust_bootstrap_trop_variance_global" ,
3923+ side_effect = _fake_rust_boot_global ,
3924+ ),
3925+ ):
3926+ with pytest .warns (
3927+ UserWarning ,
3928+ match = r"3/20 bootstrap iterations succeeded in TROP global bootstrap \(Rust\)" ,
3929+ ):
3930+ se , dist = trop_est ._bootstrap_variance_global (
3931+ df , "outcome" , "treated" , "unit" , "time" , (1.0 , 1.0 , 1e10 ), 3
3932+ )
3933+
3934+ assert np .isfinite (se )
3935+ assert len (dist ) == 3
3936+
3937+ @staticmethod
3938+ def _make_survey_panel_and_design ():
3939+ """Build a panel with per-unit PSU + weight columns and the matching
3940+ SurveyDesign/ResolvedSurveyDesign needed to reach the Rao-Wu path."""
3941+ from diff_diff import SurveyDesign
3942+ from diff_diff .survey import ResolvedSurveyDesign
3943+
3944+ df = TestTROPNValidTreated ._make_panel ().copy ()
3945+ all_units = sorted (df ["unit" ].unique ())
3946+ unit_to_psu = {u : i for i , u in enumerate (all_units )}
3947+ df ["psu" ] = df ["unit" ].map (unit_to_psu ).astype (np .int64 )
3948+ df ["weight" ] = 1.0
3949+ n_obs = len (df )
3950+
3951+ survey_design = SurveyDesign (weights = "weight" , psu = "psu" )
3952+ resolved_survey = ResolvedSurveyDesign (
3953+ weights = np .ones (n_obs , dtype = np .float64 ),
3954+ weight_type = "pweight" ,
3955+ strata = None ,
3956+ psu = df ["psu" ].values .astype (np .int64 ),
3957+ fpc = None ,
3958+ n_strata = 0 ,
3959+ n_psu = len (all_units ),
3960+ lonely_psu = "remove" ,
3961+ )
3962+ return df , survey_design , resolved_survey
3963+
3964+ def test_local_rao_wu_bootstrap_warns_above_5pct_failure (self ):
3965+ """Local Rao-Wu survey bootstrap: forced failures → proportional warn."""
3966+ from unittest .mock import patch
3967+
3968+ df , survey_design , resolved_survey = self ._make_survey_panel_and_design ()
3969+
3970+ trop_est = TROP (
3971+ method = "local" ,
3972+ lambda_time_grid = [1.0 ],
3973+ lambda_unit_grid = [1.0 ],
3974+ lambda_nn_grid = [np .inf ],
3975+ n_bootstrap = 20 ,
3976+ seed = 42 ,
3977+ )
3978+
3979+ with patch .object (
3980+ TROP ,
3981+ "_fit_with_fixed_lambda" ,
3982+ side_effect = self ._make_failing_fit (20 , 4 ),
3983+ ):
3984+ with pytest .warns (
3985+ UserWarning ,
3986+ match = r"4/20 bootstrap iterations succeeded in TROP local Rao-Wu bootstrap" ,
3987+ ):
3988+ se , dist = trop_est ._bootstrap_rao_wu_local (
3989+ df ,
3990+ "outcome" ,
3991+ "treated" ,
3992+ "unit" ,
3993+ "time" ,
3994+ (1.0 , 1.0 , 1e10 ),
3995+ resolved_survey ,
3996+ survey_design ,
3997+ )
3998+
3999+ assert np .isfinite (se )
4000+ assert len (dist ) == 4
4001+
4002+ def test_global_rao_wu_bootstrap_warns_above_5pct_failure (self ):
4003+ """Global Rao-Wu survey bootstrap: forced failures → proportional warn."""
4004+ from unittest .mock import patch
4005+
4006+ df , survey_design , resolved_survey = self ._make_survey_panel_and_design ()
4007+
4008+ trop_est = TROP (
4009+ method = "global" ,
4010+ lambda_time_grid = [1.0 ],
4011+ lambda_unit_grid = [1.0 ],
4012+ lambda_nn_grid = [np .inf ],
4013+ n_bootstrap = 20 ,
4014+ seed = 42 ,
4015+ )
4016+
4017+ n_calls = {"count" : 0 }
4018+
4019+ def _flaky_solve (* args , ** kwargs ):
4020+ n_calls ["count" ] += 1
4021+ if n_calls ["count" ] <= 3 :
4022+ n_periods , n_units = args [0 ].shape
4023+ return 0.0 , np .zeros (n_units ), np .zeros (n_periods ), np .zeros ((n_periods , n_units ))
4024+ raise ValueError ("forced Rao-Wu failure" )
4025+
4026+ with patch .object (TROP , "_solve_global_model" , side_effect = _flaky_solve ):
4027+ with pytest .warns (
4028+ UserWarning ,
4029+ match = r"3/20 bootstrap iterations succeeded in TROP global Rao-Wu bootstrap" ,
4030+ ):
4031+ se , dist = trop_est ._bootstrap_rao_wu_global (
4032+ df ,
4033+ "outcome" ,
4034+ "treated" ,
4035+ "unit" ,
4036+ "time" ,
4037+ (1.0 , 1.0 , 1e10 ),
4038+ 3 ,
4039+ resolved_survey ,
4040+ survey_design ,
4041+ )
4042+
4043+ assert np .isfinite (se ) or np .isnan (se )
4044+ assert len (dist ) == 3
4045+
38954046
38964047class TestTROPModuleSplit :
38974048 """Regression tests for the trop.py -> trop_global.py / trop_local.py split."""
0 commit comments