|
5 | 5 | equations, known analytical cases, and expected mathematical properties. |
6 | 6 | """ |
7 | 7 |
|
8 | | -import os |
| 8 | + |
| 9 | +import warnings |
9 | 10 |
|
10 | 11 | import numpy as np |
11 | 12 | import pytest |
@@ -245,34 +246,37 @@ def test_optimal_flci_is_finite_and_valid(self): |
245 | 246 | assert ci_lb_opt <= lb, "CI lower should be <= identified set lower" |
246 | 247 | assert ci_ub_opt >= ub, "CI upper should be >= identified set upper" |
247 | 248 |
|
248 | | - @pytest.mark.skipif( |
249 | | - os.environ.get("CI") == "true", |
250 | | - reason="wall-clock timing is flaky on shared CI runners; short-circuit " |
251 | | - "correctness signal will be replaced with a mock/spy per TODO.md " |
252 | | - "(see PR #330 follow-up note)", |
253 | | - ) |
254 | 249 | def test_m0_short_circuit(self): |
255 | | - """M=0 should use standard CI without optimization. |
256 | | -
|
257 | | - Uses wall-clock elapsed time as a proxy for "short-circuit path |
258 | | - taken" — fast path is ``<0.5s``, slow optimization would be ``>> |
259 | | - 0.5s``. Skipped on CI because neighbor-VM contention on shared |
260 | | - runners can push even the short-circuit path past the threshold. |
261 | | - Run locally to validate the fast-path invariant; the TODO.md entry |
262 | | - added by PR #330 tracks replacing this with a mock/spy so the |
263 | | - correctness signal becomes CI-safe. |
| 250 | + """M=0 takes the bias=0 fast path and never invokes the LP solver. |
| 251 | +
|
| 252 | + ``_compute_worst_case_bias`` returns ``0.0`` immediately when ``M=0`` |
| 253 | + (diff_diff/honest_did.py:1650), so ``scipy.optimize.linprog`` is |
| 254 | + never reached. Patching the LP solver and asserting ``call_count |
| 255 | + == 0`` is a direct correctness signal — CI-safe (no wall-clock |
| 256 | + dependency) and faster than the prior timing-based proxy. |
264 | 257 | """ |
| 258 | + from unittest.mock import patch |
| 259 | + |
265 | 260 | beta_pre = np.array([0.3, 0.2, 0.1]) |
266 | 261 | beta_post = np.array([2.0]) |
267 | 262 | sigma = np.eye(4) * 0.01 |
268 | 263 | l_vec = np.array([1.0]) |
269 | 264 |
|
270 | | - import time |
271 | | - t0 = time.time() |
272 | | - _compute_optimal_flci(beta_pre, beta_post, sigma, l_vec, 3, 1, M=0.0) |
273 | | - elapsed = time.time() - t0 |
| 265 | + with patch("diff_diff.honest_did.optimize.linprog") as mock_linprog: |
| 266 | + ci_lb, ci_ub = _compute_optimal_flci( |
| 267 | + beta_pre, beta_post, sigma, l_vec, 3, 1, M=0.0 |
| 268 | + ) |
274 | 269 |
|
275 | | - assert elapsed < 0.5, f"M=0 should be fast, took {elapsed:.2f}s" |
| 270 | + assert mock_linprog.call_count == 0, ( |
| 271 | + f"M=0 must skip the LP solver (fast path at " |
| 272 | + f"_compute_worst_case_bias:1650); got " |
| 273 | + f"{mock_linprog.call_count} linprog call(s)." |
| 274 | + ) |
| 275 | + # End-to-end correctness: M=0 CI is still well-defined. |
| 276 | + assert np.isfinite(ci_lb) and np.isfinite(ci_ub), ( |
| 277 | + f"M=0 CI must be finite; got [{ci_lb}, {ci_ub}]" |
| 278 | + ) |
| 279 | + assert ci_lb <= ci_ub, f"M=0 CI must be ordered; got [{ci_lb}, {ci_ub}]" |
276 | 280 |
|
277 | 281 | def test_smoothness_flci_with_survey_df(self): |
278 | 282 | """Survey df should widen the smoothness FLCI (folded t vs folded normal).""" |
@@ -493,3 +497,64 @@ def test_breakdown_monotonicity(self): |
493 | 497 | # The optimal FLCI is efficient, so need large M for a weak effect. |
494 | 498 | r_large = honest.fit(results, M=20.0) |
495 | 499 | assert r_large.ci_lb <= 0 <= r_large.ci_ub, "Should lose significance at large M" |
| 500 | + |
| 501 | + |
| 502 | +class TestARPVertexEnumeration: |
| 503 | + """Diagnostic warnings on `_enumerate_vertices` vertex-search pathologies.""" |
| 504 | + |
| 505 | + def test_enumerate_vertices_warns_on_exhausted_search(self): |
| 506 | + """All-LinAlgError path: fully-zero nuisance column makes A_sys |
| 507 | + singular on every basis, so the enumeration exhausts without |
| 508 | + feasible vertices and the user should see a RuntimeWarning rather |
| 509 | + than a silent empty-list return.""" |
| 510 | + from diff_diff.honest_did import _enumerate_vertices |
| 511 | + |
| 512 | + # 4 moments, 1 nuisance column (all zeros) → A_sys singular on every basis |
| 513 | + X_tilde = np.zeros((4, 1)) |
| 514 | + sigma_tilde_diag = np.array([1.0, 1.0, 1.0, 1.0]) |
| 515 | + with pytest.warns(RuntimeWarning, match="exhausted"): |
| 516 | + vertices = _enumerate_vertices(X_tilde, sigma_tilde_diag, n_moments=4) |
| 517 | + assert vertices == [] |
| 518 | + |
| 519 | + def test_enumerate_vertices_warns_on_heavy_rejection(self): |
| 520 | + """Mixed-basis path: 5 moments, 1 nuisance column. C(5, 2) = 10 |
| 521 | + bases. By design, 6 bases hit LinAlgError (the singular pairs |
| 522 | + among indices {0,1,2,3} that share aligned nuisance/sigma values) |
| 523 | + and 4 bases produce feasible vertices (the (i, 4) pairs that pair |
| 524 | + a positive-X_tilde row with the unique negative-X_tilde row at |
| 525 | + index 4). 60% rejection rate trips the `heavily constrained` |
| 526 | + branch specifically, not the exhausted branch.""" |
| 527 | + from diff_diff.honest_did import _enumerate_vertices |
| 528 | + |
| 529 | + X_tilde = np.array([[1.0], [1.0], [1.0], [2.0], [-1.0]]) |
| 530 | + sigma_tilde_diag = np.array([1.0, 1.0, 1.0, 2.0, 1.0]) |
| 531 | + with pytest.warns(RuntimeWarning, match="heavily constrained"): |
| 532 | + vertices = _enumerate_vertices(X_tilde, sigma_tilde_diag, n_moments=5) |
| 533 | + assert len(vertices) >= 1, ( |
| 534 | + f"Heavy-rejection construction must still produce some feasible " |
| 535 | + f"vertices (otherwise the exhausted branch fires); got " |
| 536 | + f"{len(vertices)} vertices." |
| 537 | + ) |
| 538 | + |
| 539 | + def test_enumerate_vertices_quiet_on_healthy_enumeration(self): |
| 540 | + """Well-conditioned X_tilde: most bases solve cleanly and feasible |
| 541 | + vertices are recovered. No RuntimeWarning should fire.""" |
| 542 | + from diff_diff.honest_did import _enumerate_vertices |
| 543 | + |
| 544 | + rng = np.random.default_rng(0) |
| 545 | + # 4 moments, 1 nuisance — small and well-conditioned |
| 546 | + X_tilde = rng.normal(size=(4, 1)) |
| 547 | + sigma_tilde_diag = np.array([1.0, 1.0, 1.0, 1.0]) |
| 548 | + with warnings.catch_warnings(record=True) as caught: |
| 549 | + warnings.simplefilter("always", RuntimeWarning) |
| 550 | + vertices = _enumerate_vertices(X_tilde, sigma_tilde_diag, n_moments=4) |
| 551 | + diag_warnings = [ |
| 552 | + w for w in caught |
| 553 | + if "exhausted" in str(w.message) or "heavily constrained" in str(w.message) |
| 554 | + ] |
| 555 | + assert not diag_warnings, ( |
| 556 | + f"Healthy enumeration must not emit ARP diagnostics; got " |
| 557 | + f"{[str(w.message) for w in diag_warnings]}" |
| 558 | + ) |
| 559 | + # Sanity: we expect some feasible vertices on a well-conditioned input |
| 560 | + assert isinstance(vertices, list) |
0 commit comments