Skip to content

Commit 3d21dd8

Browse files
igerberclaude
andcommitted
Extend TROP convergence tests to cover LOOCV and bootstrap aggregation
AI review on PR #317 flagged that my earlier fit()-level test only covered the per-treated-observation aggregation path, not the LOOCV or bootstrap wrapper paths. A regression in _nonconvergence_tracker plumbing for those paths could slip through. - test_local_fit_emits_single_aggregate_warning: expanded to assert per-obs, LOOCV, and bootstrap warnings each appear at most once per .fit(). - test_global_fit_emits_single_aggregate_warning: new test mirroring the local one for method="global" (LOOCV + bootstrap paths). Both use n_bootstrap=2, minimal lambda grid, and max_iter=1/tol=1e-15 to keep cost low: ~3.4s for all 6 TROP convergence tests combined. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 810d862 commit 3d21dd8

1 file changed

Lines changed: 43 additions & 14 deletions

File tree

tests/test_trop.py

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4047,11 +4047,9 @@ def test_local_alternating_min_no_warning_on_convergence(self, simple_panel_data
40474047
assert not any("did not converge" in str(x.message) for x in w)
40484048

40494049
def test_local_fit_emits_single_aggregate_warning(self, simple_panel_data):
4050-
"""Fit-level warning aggregation: per-treated-observation non-convergence must
4051-
surface as at most one aggregate warning per call, not one per observation.
4052-
4053-
Pins the P2 fan-out fix: warnings are accumulated via the
4054-
`_nonconvergence_tracker` kwarg and emitted once at the top-level fit."""
4050+
"""Fit-level warning aggregation: per-treated-observation, LOOCV, and
4051+
bootstrap non-convergence each surface as at most one aggregate warning
4052+
per wrapping call, not one per inner fit. Pins the P2 fan-out fix."""
40554053
trop_est = TROP(
40564054
method="local",
40574055
lambda_time_grid=[1.0],
@@ -4073,13 +4071,44 @@ def test_local_fit_emits_single_aggregate_warning(self, simple_panel_data):
40734071
time="period",
40744072
)
40754073

4076-
# The per-treated-observation fit loop must emit exactly one aggregate
4077-
# warning of the form "TROP local per-treated-observation fit: N of M fits
4078-
# did not converge", not N separate warnings.
4079-
per_obs_warnings = [
4080-
x for x in w if "per-treated-observation" in str(x.message)
4081-
]
4082-
assert len(per_obs_warnings) <= 1, (
4083-
f"Expected at most one aggregated per-treated-observation warning, "
4084-
f"got {len(per_obs_warnings)}: {[str(x.message) for x in per_obs_warnings]}"
4074+
def count_matching(needle: str) -> int:
4075+
return sum(1 for x in w if needle in str(x.message))
4076+
4077+
# Per-treated-observation aggregation (called once per .fit()).
4078+
assert count_matching("per-treated-observation") <= 1
4079+
# LOOCV aggregation (called once per (lambda_time, lambda_unit, lambda_nn) combo;
4080+
# grid has exactly 1 combo).
4081+
assert count_matching("local LOOCV") <= 1
4082+
# Bootstrap aggregation (called once per .fit()).
4083+
assert count_matching("local bootstrap") <= 1
4084+
4085+
def test_global_fit_emits_single_aggregate_warning(self, simple_panel_data):
4086+
"""Global-method fit-level warning aggregation: LOOCV and bootstrap
4087+
non-convergence each surface as at most one aggregate warning per
4088+
wrapping call, mirroring the local test above."""
4089+
trop_est = TROP(
4090+
method="global",
4091+
lambda_time_grid=[1.0],
4092+
lambda_unit_grid=[1.0],
4093+
lambda_nn_grid=[0.1],
4094+
max_iter=1,
4095+
tol=1e-15,
4096+
n_bootstrap=2,
4097+
seed=42,
40854098
)
4099+
4100+
with warnings.catch_warnings(record=True) as w:
4101+
warnings.simplefilter("always")
4102+
trop_est.fit(
4103+
simple_panel_data,
4104+
outcome="outcome",
4105+
treatment="treated",
4106+
unit="unit",
4107+
time="period",
4108+
)
4109+
4110+
def count_matching(needle: str) -> int:
4111+
return sum(1 for x in w if needle in str(x.message))
4112+
4113+
assert count_matching("global LOOCV") <= 1
4114+
assert count_matching("global bootstrap") <= 1

0 commit comments

Comments
 (0)