Skip to content

Commit 03532fc

Browse files
igerberclaude
andcommitted
Address CI AI review round 2: mass-point guard, rank check, R parity
P1 #1 (methodology): mse_optimal_bandwidth now rejects Design 1 mass-point designs. When boundary > 0 and the modal fraction at d.min() exceeds the REGISTRY-specified 2% threshold, raise NotImplementedError pointing to the 2SLS sample-average estimator per de Chaisemartin et al. (2026) Section 3.2.4. Design 1' with untreated units at d=0 (boundary=0) is still accepted per Garrett et al. (2020) application precedent. P1 #2 (code quality): qrXXinv now catches np.linalg.LinAlgError from Cholesky and re-raises as ValueError with a targeted message naming the failing dimension and suggesting remediation. Duplicate-support windows or other rank-deficient designs now fail with a clear error instead of leaking LinAlgError out of the port. P3 (tests): Added TestStageDiagnosticsParity::test_R_parity covering all four stages. Previously only V/B1/B2 were pinned; R (BWreg) was only trivially checked for stage_d1 (scale=0 -> R=0). Now stage_b and stage_h R values are explicitly parity-tested at 1% against R nprobust. New behavioral tests: - test_mass_point_design_rejected: 10% mass at 0.1 -> NotImplementedError - test_continuous_near_d_lower_accepted: uniform(0.1, 1.0) passes - test_untreated_at_zero_accepted: 15% at d=0 with boundary=0 passes - test_rank_deficient_design_raises_valueerror: rank-1 X -> ValueError - R parity on all four stages across 3 DGPs (12 new parametrized cases) 169 tests pass (up from 153). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 5799d42 commit 03532fc

3 files changed

Lines changed: 116 additions & 2 deletions

File tree

diff_diff/_nprobust_port.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,11 +122,29 @@ def qrXXinv(x: np.ndarray) -> np.ndarray:
122122
-------
123123
np.ndarray, shape (k, k)
124124
Inverse of ``x.T @ x``.
125+
126+
Raises
127+
------
128+
ValueError
129+
If ``x.T @ x`` is rank-deficient (Cholesky fails). Converts
130+
the raw ``np.linalg.LinAlgError`` into a targeted message so
131+
callers (``lprobust_bw``) can surface a clear failure reason
132+
instead of an opaque linear-algebra error.
125133
"""
126134
xtx = x.T @ x
127-
# Cholesky solve for the inverse. Matches R's chol2inv(chol(.)).
128-
L = np.linalg.cholesky(xtx)
129135
k = xtx.shape[0]
136+
# Cholesky solve for the inverse. Matches R's chol2inv(chol(.)).
137+
try:
138+
L = np.linalg.cholesky(xtx)
139+
except np.linalg.LinAlgError as exc:
140+
raise ValueError(
141+
f"qrXXinv: Cholesky decomposition of X'X ({k}x{k}) failed. "
142+
f"The weighted design matrix is rank-deficient, likely "
143+
f"because the in-window support has fewer than {k} distinct "
144+
f"points. Increase sample size, widen the bandwidth, or pick "
145+
f"a boundary with more distinct values nearby. "
146+
f"(LinAlgError: {exc})"
147+
) from exc
130148
Linv = np.linalg.solve(L, np.eye(k))
131149
return Linv.T @ Linv
132150

diff_diff/local_linear.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,36 @@ def mse_optimal_bandwidth(
633633
f"separately parity-tested against nprobust."
634634
)
635635

636+
# Mass-point design check (REGISTRY.md plans `>2%` modal-min rule).
637+
# If d_lower > 0 and the distribution bunches at d_lower, the paper
638+
# (de Chaisemartin et al. 2026 Section 3.2.4) prescribes the 2SLS
639+
# sample-average path, NOT the nonparametric CCF local-polynomial
640+
# path. Detect bunching and redirect the caller.
641+
#
642+
# We only flag when boundary > 0 (Design 1 continuous-near-d_lower
643+
# vs Design 1 mass-point). For boundary = 0 (Design 1' or "untreated
644+
# units present" subcase), the paper accepts nonparametric even with
645+
# mass at 0 (Garrett et al. 2020 application with 12/2954 at 0).
646+
if boundary > _boundary_tol:
647+
eps_eq = 1e-12 * max(1.0, abs(d_min))
648+
at_boundary_mask = np.abs(d - d_min) <= eps_eq
649+
modal_fraction = float(np.mean(at_boundary_mask))
650+
_MASS_POINT_THRESHOLD = 0.02 # REGISTRY rule: > 2% modal-min
651+
if modal_fraction > _MASS_POINT_THRESHOLD:
652+
raise NotImplementedError(
653+
f"Detected mass-point design: the lower boundary "
654+
f"d_lower={d_min!r} has modal fraction "
655+
f"{modal_fraction:.4f} > {_MASS_POINT_THRESHOLD:.2f}. "
656+
f"Per de Chaisemartin et al. (2026) Section 3.2.4 and "
657+
f"the methodology registry, this case requires the 2SLS "
658+
f"sample-average estimator with instrument 1{{D_2 > "
659+
f"d_lower}}, not the nonparametric CCF local-polynomial "
660+
f"bandwidth selector. That estimator is queued for "
661+
f"Phase 2 (HeterogeneousAdoptionDiD). For continuous "
662+
f"near-d_lower designs (modal fraction <= "
663+
f"{_MASS_POINT_THRESHOLD:.2f}), this wrapper is applicable."
664+
)
665+
636666
# Defer heavy import to call time to avoid import-cycle risk.
637667
from diff_diff._nprobust_port import lpbwselect_mse_dpi
638668

tests/test_bandwidth_selector.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,25 @@ def test_B2_parity(self, dgp_case, stage):
116116
else:
117117
assert actual == pytest.approx(expected, rel=_PARITY_TOL), f"{name} {stage}"
118118

119+
@pytest.mark.parametrize(
120+
"stage",
121+
["stage_d1", "stage_d2", "stage_b", "stage_h"],
122+
)
123+
def test_R_parity(self, dgp_case, stage):
124+
"""R (BWreg) parity. stage_d1 / stage_d2 use scale=0 so R=0;
125+
stage_b / stage_h use scale=bwregul=1 so R is non-trivial and
126+
must match nprobust."""
127+
name, d, y, g = dgp_case
128+
br = mse_optimal_bandwidth(d, y, return_diagnostics=True)
129+
actual = getattr(br, f"{stage}_R")
130+
expected = g[stage]["R"]
131+
if expected == 0:
132+
assert actual == pytest.approx(0, abs=1e-10), f"{name} {stage}"
133+
else:
134+
assert actual == pytest.approx(expected, rel=_PARITY_TOL), (
135+
f"{name} {stage}: py={actual!r} R={expected!r}"
136+
)
137+
119138

120139
# =============================================================================
121140
# Behavioral tests
@@ -303,6 +322,53 @@ def test_boundary_below_min_d_accepted(self):
303322
assert np.isfinite(h)
304323
assert h > 0.0
305324

325+
def test_mass_point_design_rejected(self):
326+
"""Design 1 mass-point case (boundary > 0, modal fraction > 2%)
327+
must be rejected with NotImplementedError pointing to 2SLS."""
328+
rng = np.random.default_rng(2026)
329+
n_mass = 200 # 10% mass at d_lower
330+
n_cont = 1800
331+
d_mass = np.full(n_mass, 0.1)
332+
d_cont = rng.uniform(0.1, 1.0, size=n_cont)
333+
d = np.concatenate([d_mass, d_cont])
334+
y = d + rng.normal(0, 0.5, size=d.size)
335+
with pytest.raises(NotImplementedError, match="mass-point"):
336+
mse_optimal_bandwidth(d, y, boundary=float(d.min()))
337+
338+
def test_continuous_near_d_lower_accepted(self):
339+
"""Design 1 continuous-near-d_lower (boundary > 0, modal
340+
fraction <= 2%) must pass through to nonparametric."""
341+
rng = np.random.default_rng(20260419)
342+
d = rng.uniform(0.1, 1.0, size=1500) # no mass point
343+
y = d + rng.normal(0, 0.3, size=1500)
344+
h = mse_optimal_bandwidth(d, y, boundary=float(d.min()))
345+
assert np.isfinite(h)
346+
assert h > 0.0
347+
348+
def test_untreated_at_zero_accepted(self):
349+
"""Paper Section 3.1.5 / Garrett et al. application: untreated
350+
units at d=0 are OK for Design 1'. boundary=0 with mass at 0
351+
must NOT trigger the mass-point rejection."""
352+
rng = np.random.default_rng(2026)
353+
# ~15% at d=0 (genuinely untreated), rest continuous on (0, 1).
354+
d_zero = np.zeros(300)
355+
d_pos = rng.uniform(0.01, 1.0, size=1700)
356+
d = np.concatenate([d_zero, d_pos])
357+
y = d + rng.normal(0, 0.5, size=d.size)
358+
h = mse_optimal_bandwidth(d, y, boundary=0.0)
359+
assert np.isfinite(h)
360+
assert h > 0.0
361+
362+
def test_rank_deficient_design_raises_valueerror(self):
363+
"""Duplicate-support windows must fail with a clear ValueError
364+
from qrXXinv's Cholesky guard, not an opaque LinAlgError."""
365+
from diff_diff._nprobust_port import qrXXinv
366+
367+
# Rank-1 X: all rows identical -> X.T @ X is rank-1.
368+
X = np.tile([[1.0, 2.0, 3.0]], (10, 1))
369+
with pytest.raises(ValueError, match="qrXXinv"):
370+
qrXXinv(X)
371+
306372

307373
class TestKernelDispatch:
308374
"""Different kernels produce different bandwidths."""

0 commit comments

Comments
 (0)