Skip to content

Commit 5de3634

Browse files
igerberclaude
andcommitted
Fix validation order in lpbwselect_mse_dpi
Prior commit moved shape / emptiness / finiteness checks into the port but left bwcheck validation above them. As a result, empty or non-finite inputs got "bwcheck exceeds sample size" errors instead of the targeted contract messages the tests expect. Reorder so input-shape validation runs first, then bwcheck, then kernel/vce. Drop duplicate N assignment. 185 tests pass (unchanged). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 350e6db commit 5de3634

1 file changed

Lines changed: 22 additions & 18 deletions

File tree

diff_diff/_nprobust_port.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -633,23 +633,10 @@ def lpbwselect_mse_dpi(
633633
If ``bwcheck`` is supplied and falls outside the valid range
634634
``[1, len(x)]``.
635635
"""
636-
N = x.shape[0] if hasattr(x, "shape") else len(x)
637-
if bwcheck is not None:
638-
if bwcheck < 1:
639-
raise ValueError(
640-
f"bwcheck must be a positive integer (>= 1); got {bwcheck}"
641-
)
642-
if bwcheck > N:
643-
raise ValueError(
644-
f"bwcheck={bwcheck} exceeds sample size N={N}. Either "
645-
f"reduce bwcheck or increase sample size; pass "
646-
f"bwcheck=None to skip the nearest-neighbor floor."
647-
)
648-
if kernel not in _VALID_KERNELS:
649-
raise ValueError(f"Unknown kernel {kernel!r}. Expected one of {_VALID_KERNELS}.")
650-
if vce not in _VALID_VCE:
651-
raise ValueError(f"Unknown vce {vce!r}. Expected one of {_VALID_VCE}.")
652-
636+
# Front-door input contract (shape / emptiness / finiteness).
637+
# Must run BEFORE the bwcheck range check so empty-array or
638+
# non-finite inputs get targeted messages instead of "bwcheck
639+
# exceeds sample size".
653640
x = np.asarray(x, dtype=np.float64).ravel()
654641
y = np.asarray(y, dtype=np.float64).ravel()
655642
if x.shape != y.shape:
@@ -675,11 +662,28 @@ def lpbwselect_mse_dpi(
675662
f"cluster must have the same shape as x; got "
676663
f"{cluster.shape} and {x.shape}"
677664
)
665+
666+
N = x.shape[0]
667+
if bwcheck is not None:
668+
if bwcheck < 1:
669+
raise ValueError(
670+
f"bwcheck must be a positive integer (>= 1); got {bwcheck}"
671+
)
672+
if bwcheck > N:
673+
raise ValueError(
674+
f"bwcheck={bwcheck} exceeds sample size N={N}. Either "
675+
f"reduce bwcheck or increase sample size; pass "
676+
f"bwcheck=None to skip the nearest-neighbor floor."
677+
)
678+
if kernel not in _VALID_KERNELS:
679+
raise ValueError(f"Unknown kernel {kernel!r}. Expected one of {_VALID_KERNELS}.")
680+
if vce not in _VALID_VCE:
681+
raise ValueError(f"Unknown vce {vce!r}. Expected one of {_VALID_VCE}.")
682+
678683
if q is None:
679684
q = p + 1
680685

681686
even = (p - deriv) % 2 == 0
682-
N = x.shape[0]
683687
x_min = float(x.min())
684688
x_max = float(x.max())
685689
range_ = x_max - x_min

0 commit comments

Comments
 (0)