Skip to content

Commit 350e6db

Browse files
igerberclaude
andcommitted
Add front-door validation to lpbwselect_mse_dpi (P1)
CI AI review P1: the port's lpbwselect_mse_dpi is advertised as the advanced-use entry point for callers outside the HAD Phase 1b wrapper surface. It accepted x, y, cluster without validating shapes, and the default vce="nn" branch reindexed y/cluster by argsort(x) -- so a longer y would be silently truncated to match x's length, producing a bandwidth on misaligned data with no warning. Added front-door validation at the top of lpbwselect_mse_dpi: - x and y ravel()ed and required to have the same shape - x must be non-empty - x, y, and eval_point must be finite - cluster (if supplied) must match x.shape Seven new tests in TestLpbwselectMseDpiValidation: - test_mismatched_shapes_raise - test_longer_y_silent_truncation_rejected (regression for the specific nn-reindex bug) - test_cluster_wrong_length_rejected - test_empty_direct_port_input_rejected - test_non_finite_x_rejected - test_non_finite_y_rejected - test_non_finite_eval_point_rejected 185 tests pass (up from 178). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent a313e05 commit 350e6db

2 files changed

Lines changed: 99 additions & 3 deletions

File tree

diff_diff/_nprobust_port.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -650,10 +650,31 @@ def lpbwselect_mse_dpi(
650650
if vce not in _VALID_VCE:
651651
raise ValueError(f"Unknown vce {vce!r}. Expected one of {_VALID_VCE}.")
652652

653-
x = np.asarray(x, dtype=np.float64)
654-
y = np.asarray(y, dtype=np.float64)
653+
x = np.asarray(x, dtype=np.float64).ravel()
654+
y = np.asarray(y, dtype=np.float64).ravel()
655+
if x.shape != y.shape:
656+
raise ValueError(
657+
f"x and y must have the same 1-D shape; got "
658+
f"{x.shape} and {y.shape}"
659+
)
660+
if x.size == 0:
661+
raise ValueError(
662+
"x and y must be non-empty; lpbwselect_mse_dpi cannot "
663+
"estimate a bandwidth from zero observations."
664+
)
665+
if not np.all(np.isfinite(x)):
666+
raise ValueError("x contains non-finite values (NaN or Inf)")
667+
if not np.all(np.isfinite(y)):
668+
raise ValueError("y contains non-finite values (NaN or Inf)")
669+
if not np.isfinite(eval_point):
670+
raise ValueError(f"eval_point must be finite; got {eval_point}")
655671
if cluster is not None:
656-
cluster = np.asarray(cluster)
672+
cluster = np.asarray(cluster).ravel()
673+
if cluster.shape != x.shape:
674+
raise ValueError(
675+
f"cluster must have the same shape as x; got "
676+
f"{cluster.shape} and {x.shape}"
677+
)
657678
if q is None:
658679
q = p + 1
659680

tests/test_nprobust_port.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,3 +241,78 @@ def test_R_is_zero_when_scale_zero(self):
241241
)
242242
# With scale=0, BWreg is never computed -> R stays 0.
243243
assert C_d1.R == 0.0
244+
245+
246+
# =============================================================================
247+
# lpbwselect_mse_dpi: input validation on the advanced-use entry point
248+
# =============================================================================
249+
250+
251+
class TestLpbwselectMseDpiValidation:
252+
"""The public wrapper is restricted to the HAD surface; the port is
253+
the advertised advanced-use entry point. It must enforce its own
254+
shape / emptiness / finiteness contract -- silently truncating a
255+
longer y or cluster through sort-index reindexing would be a real
256+
bug."""
257+
258+
def test_mismatched_shapes_raise(self):
259+
from diff_diff._nprobust_port import lpbwselect_mse_dpi
260+
261+
x = np.array([0.1, 0.2, 0.3])
262+
y = np.array([1.0, 2.0, 3.0, 4.0]) # length 4 != 3
263+
with pytest.raises(ValueError, match="same 1-D shape"):
264+
lpbwselect_mse_dpi(y, x, eval_point=0.0)
265+
266+
def test_longer_y_silent_truncation_rejected(self):
267+
"""Regression: len(y) > len(x) previously got truncated via
268+
the sort-indexer under vce='nn'. Must now raise."""
269+
from diff_diff._nprobust_port import lpbwselect_mse_dpi
270+
271+
rng = np.random.default_rng(0)
272+
x = rng.uniform(0.0, 1.0, size=100)
273+
y = rng.normal(size=200) # twice the length
274+
with pytest.raises(ValueError, match="same 1-D shape"):
275+
lpbwselect_mse_dpi(y, x, eval_point=0.0, vce="nn")
276+
277+
def test_cluster_wrong_length_rejected(self):
278+
from diff_diff._nprobust_port import lpbwselect_mse_dpi
279+
280+
rng = np.random.default_rng(0)
281+
x = rng.uniform(0.0, 1.0, size=100)
282+
y = rng.normal(size=100)
283+
cluster = np.arange(50) # wrong length
284+
with pytest.raises(ValueError, match="cluster must have"):
285+
lpbwselect_mse_dpi(y, x, cluster=cluster, eval_point=0.0)
286+
287+
def test_empty_direct_port_input_rejected(self):
288+
from diff_diff._nprobust_port import lpbwselect_mse_dpi
289+
290+
x = np.array([], dtype=np.float64)
291+
y = np.array([], dtype=np.float64)
292+
with pytest.raises(ValueError, match="non-empty"):
293+
lpbwselect_mse_dpi(y, x, eval_point=0.0)
294+
295+
def test_non_finite_x_rejected(self):
296+
from diff_diff._nprobust_port import lpbwselect_mse_dpi
297+
298+
x = np.array([0.1, np.nan, 0.3])
299+
y = np.array([1.0, 2.0, 3.0])
300+
with pytest.raises(ValueError, match="x contains non-finite"):
301+
lpbwselect_mse_dpi(y, x, eval_point=0.0)
302+
303+
def test_non_finite_y_rejected(self):
304+
from diff_diff._nprobust_port import lpbwselect_mse_dpi
305+
306+
x = np.array([0.1, 0.2, 0.3])
307+
y = np.array([1.0, np.inf, 3.0])
308+
with pytest.raises(ValueError, match="y contains non-finite"):
309+
lpbwselect_mse_dpi(y, x, eval_point=0.0)
310+
311+
def test_non_finite_eval_point_rejected(self):
312+
from diff_diff._nprobust_port import lpbwselect_mse_dpi
313+
314+
rng = np.random.default_rng(0)
315+
x = rng.uniform(0.0, 1.0, size=100)
316+
y = rng.normal(size=100)
317+
with pytest.raises(ValueError, match="eval_point"):
318+
lpbwselect_mse_dpi(y, x, eval_point=np.nan)

0 commit comments

Comments
 (0)