Skip to content

Commit 356e3ec

Browse files
igerberclaude
andcommitted
Close Rao-Wu and global-Rust warning-path coverage gap
CI AI review on PR #324 flagged (P3) that the delta changed the Rao-Wu local, Rao-Wu global, and global-Rust warning branches without direct assertions on those paths. Add three targeted tests mirroring the pattern of the existing four (mocked inner-fit side effects, direct method invocation, pytest.warns with the context string in the match regex): - `_bootstrap_variance_global` Rust happy path - `_bootstrap_rao_wu_local` (survey design with per-unit PSU) - `_bootstrap_rao_wu_global` (same survey setup) All six changed warning sites now have direct regression coverage. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 5755873 commit 356e3ec

1 file changed

Lines changed: 151 additions & 0 deletions

File tree

tests/test_trop.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3892,6 +3892,157 @@ def _fake_rust_boot(*args, **kwargs):
38923892
assert np.isfinite(se)
38933893
assert len(dist) == 11
38943894

3895+
def test_global_rust_bootstrap_warns_above_5pct_failure(self):
3896+
"""Global Rust happy path: 3/20 Rust successes (85% fail) warns."""
3897+
import sys
3898+
from unittest.mock import patch
3899+
3900+
df = TestTROPNValidTreated._make_panel()
3901+
3902+
trop_est = TROP(
3903+
method="global",
3904+
lambda_time_grid=[1.0],
3905+
lambda_unit_grid=[1.0],
3906+
lambda_nn_grid=[np.inf],
3907+
n_bootstrap=20,
3908+
seed=42,
3909+
)
3910+
3911+
trop_global_module = sys.modules["diff_diff.trop_global"]
3912+
rng = np.random.default_rng(0)
3913+
fake_boot = rng.normal(size=3)
3914+
3915+
def _fake_rust_boot_global(*args, **kwargs):
3916+
return fake_boot, float(np.std(fake_boot, ddof=1))
3917+
3918+
with (
3919+
patch.object(trop_global_module, "HAS_RUST_BACKEND", True),
3920+
patch.object(
3921+
trop_global_module,
3922+
"_rust_bootstrap_trop_variance_global",
3923+
side_effect=_fake_rust_boot_global,
3924+
),
3925+
):
3926+
with pytest.warns(
3927+
UserWarning,
3928+
match=r"3/20 bootstrap iterations succeeded in TROP global bootstrap \(Rust\)",
3929+
):
3930+
se, dist = trop_est._bootstrap_variance_global(
3931+
df, "outcome", "treated", "unit", "time", (1.0, 1.0, 1e10), 3
3932+
)
3933+
3934+
assert np.isfinite(se)
3935+
assert len(dist) == 3
3936+
3937+
@staticmethod
3938+
def _make_survey_panel_and_design():
3939+
"""Build a panel with per-unit PSU + weight columns and the matching
3940+
SurveyDesign/ResolvedSurveyDesign needed to reach the Rao-Wu path."""
3941+
from diff_diff import SurveyDesign
3942+
from diff_diff.survey import ResolvedSurveyDesign
3943+
3944+
df = TestTROPNValidTreated._make_panel().copy()
3945+
all_units = sorted(df["unit"].unique())
3946+
unit_to_psu = {u: i for i, u in enumerate(all_units)}
3947+
df["psu"] = df["unit"].map(unit_to_psu).astype(np.int64)
3948+
df["weight"] = 1.0
3949+
n_obs = len(df)
3950+
3951+
survey_design = SurveyDesign(weights="weight", psu="psu")
3952+
resolved_survey = ResolvedSurveyDesign(
3953+
weights=np.ones(n_obs, dtype=np.float64),
3954+
weight_type="pweight",
3955+
strata=None,
3956+
psu=df["psu"].values.astype(np.int64),
3957+
fpc=None,
3958+
n_strata=0,
3959+
n_psu=len(all_units),
3960+
lonely_psu="remove",
3961+
)
3962+
return df, survey_design, resolved_survey
3963+
3964+
def test_local_rao_wu_bootstrap_warns_above_5pct_failure(self):
3965+
"""Local Rao-Wu survey bootstrap: forced failures → proportional warn."""
3966+
from unittest.mock import patch
3967+
3968+
df, survey_design, resolved_survey = self._make_survey_panel_and_design()
3969+
3970+
trop_est = TROP(
3971+
method="local",
3972+
lambda_time_grid=[1.0],
3973+
lambda_unit_grid=[1.0],
3974+
lambda_nn_grid=[np.inf],
3975+
n_bootstrap=20,
3976+
seed=42,
3977+
)
3978+
3979+
with patch.object(
3980+
TROP,
3981+
"_fit_with_fixed_lambda",
3982+
side_effect=self._make_failing_fit(20, 4),
3983+
):
3984+
with pytest.warns(
3985+
UserWarning,
3986+
match=r"4/20 bootstrap iterations succeeded in TROP local Rao-Wu bootstrap",
3987+
):
3988+
se, dist = trop_est._bootstrap_rao_wu_local(
3989+
df,
3990+
"outcome",
3991+
"treated",
3992+
"unit",
3993+
"time",
3994+
(1.0, 1.0, 1e10),
3995+
resolved_survey,
3996+
survey_design,
3997+
)
3998+
3999+
assert np.isfinite(se)
4000+
assert len(dist) == 4
4001+
4002+
def test_global_rao_wu_bootstrap_warns_above_5pct_failure(self):
4003+
"""Global Rao-Wu survey bootstrap: forced failures → proportional warn."""
4004+
from unittest.mock import patch
4005+
4006+
df, survey_design, resolved_survey = self._make_survey_panel_and_design()
4007+
4008+
trop_est = TROP(
4009+
method="global",
4010+
lambda_time_grid=[1.0],
4011+
lambda_unit_grid=[1.0],
4012+
lambda_nn_grid=[np.inf],
4013+
n_bootstrap=20,
4014+
seed=42,
4015+
)
4016+
4017+
n_calls = {"count": 0}
4018+
4019+
def _flaky_solve(*args, **kwargs):
4020+
n_calls["count"] += 1
4021+
if n_calls["count"] <= 3:
4022+
n_periods, n_units = args[0].shape
4023+
return 0.0, np.zeros(n_units), np.zeros(n_periods), np.zeros((n_periods, n_units))
4024+
raise ValueError("forced Rao-Wu failure")
4025+
4026+
with patch.object(TROP, "_solve_global_model", side_effect=_flaky_solve):
4027+
with pytest.warns(
4028+
UserWarning,
4029+
match=r"3/20 bootstrap iterations succeeded in TROP global Rao-Wu bootstrap",
4030+
):
4031+
se, dist = trop_est._bootstrap_rao_wu_global(
4032+
df,
4033+
"outcome",
4034+
"treated",
4035+
"unit",
4036+
"time",
4037+
(1.0, 1.0, 1e10),
4038+
3,
4039+
resolved_survey,
4040+
survey_design,
4041+
)
4042+
4043+
assert np.isfinite(se) or np.isnan(se)
4044+
assert len(dist) == 3
4045+
38954046

38964047
class TestTROPModuleSplit:
38974048
"""Regression tests for the trop.py -> trop_global.py / trop_local.py split."""

0 commit comments

Comments
 (0)