Skip to content

Commit bee6805

Browse files
igerberclaude
andcommitted
Reject object-dtype missing cluster IDs across all three surfaces
CI review follow-up: the floating-dtype-only missing-ID guard in `bias_corrected_local_linear` and `lprobust` let object-dtype arrays with `None` / object `np.nan` sentinels bypass validation. The downstream `lprobust_vce` cluster loop would then group on `np.unique`, treating the sentinel as a real cluster and silently misstating clustered SE. Extract the dtype-agnostic `_cluster_has_missing` helper already used inside `lpbwselect_mse_dpi` and apply it at all three entry points: the selector, the port-level `lprobust`, and the public `bias_corrected_local_linear` wrapper. Regression tests: object-dtype cluster arrays with None and with np.nan sentinels raise a targeted ValueError at both the wrapper (test_bias_corrected_lprobust) and the port (test_nprobust_port) entry points. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 237a800 commit bee6805

4 files changed

Lines changed: 109 additions & 32 deletions

File tree

diff_diff/_nprobust_port.py

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,33 @@
8484
_VALID_VCE = ("nn", "hc0", "hc1", "hc2", "hc3")
8585

8686

87+
def _cluster_has_missing(cluster: np.ndarray) -> bool:
88+
"""Detect missing cluster IDs across float / object / string dtypes.
89+
90+
nprobust::lpbwselect complete-case-filters (x, y, cluster) before
91+
dispatch. This port deliberately rejects missingness instead so
92+
callers see it rather than silently losing rows. Used by
93+
``lpbwselect_mse_dpi`` and ``lprobust`` (and the public
94+
``bias_corrected_local_linear`` wrapper) so all three surfaces
95+
honor the same contract.
96+
"""
97+
if cluster.dtype.kind in ("f", "c"):
98+
return bool(np.any(~np.isfinite(cluster)))
99+
# Object / string / None-containing arrays: treat None and NaN-like
100+
# sentinels as missing.
101+
try:
102+
if bool(np.any([v is None for v in cluster])):
103+
return True
104+
except TypeError:
105+
pass
106+
try:
107+
# np.nan comparisons are False; cast to float and check finiteness.
108+
cluster_f = cluster.astype(np.float64, copy=False)
109+
return bool(np.any(~np.isfinite(cluster_f)))
110+
except (TypeError, ValueError):
111+
return False
112+
113+
87114
# =============================================================================
88115
# Kernel (W.fun, npfunctions.R:1-7)
89116
# =============================================================================
@@ -684,25 +711,9 @@ def lpbwselect_mse_dpi(
684711
# before dispatch; this port deliberately rejects instead so
685712
# callers see the missingness rather than lose rows silently.
686713
# The "reject" vs "filter" choice is documented in the module
687-
# docstring deviations list.
688-
has_missing = False
689-
if cluster.dtype.kind in ("f", "c"):
690-
has_missing = bool(np.any(~np.isfinite(cluster)))
691-
else:
692-
# object / string / None-containing arrays: treat None and
693-
# NaN-like sentinels as missing.
694-
try:
695-
has_missing = bool(np.any([x is None for x in cluster]))
696-
except TypeError:
697-
has_missing = False
698-
if not has_missing:
699-
try:
700-
# np.nan comparisons are False; use pd-style check.
701-
cluster_f = cluster.astype(np.float64, copy=False)
702-
has_missing = bool(np.any(~np.isfinite(cluster_f)))
703-
except (TypeError, ValueError):
704-
pass
705-
if has_missing:
714+
# docstring deviations list. Dtype-agnostic via
715+
# `_cluster_has_missing`.
716+
if _cluster_has_missing(cluster):
706717
raise ValueError(
707718
"cluster contains missing values (NaN / None). Unlike "
708719
"nprobust::lpbwselect which complete-case-filters "
@@ -1130,13 +1141,17 @@ def lprobust(
11301141
raise ValueError(
11311142
f"cluster length ({cluster.shape[0]}) does not match x/y ({N})."
11321143
)
1133-
# Reject NaN cluster IDs (Phase 1b convention: surface missingness
1134-
# rather than silently drop rows).
1135-
cluster_float = np.asarray(cluster, dtype=np.float64).ravel() if np.issubdtype(
1136-
cluster.dtype, np.floating
1137-
) else None
1138-
if cluster_float is not None and np.any(~np.isfinite(cluster_float)):
1139-
raise ValueError("cluster contains non-finite values (NaN or Inf).")
1144+
# Dtype-agnostic missingness check. Float NaN/Inf, object None,
1145+
# and object np.nan all get rejected here (shared with
1146+
# `lpbwselect_mse_dpi` via `_cluster_has_missing`) so the
1147+
# downstream `lprobust_vce` cluster grouping on `np.unique`
1148+
# cannot silently treat a missing sentinel as a real cluster.
1149+
if _cluster_has_missing(cluster):
1150+
raise ValueError(
1151+
"cluster contains missing values (NaN / None). "
1152+
"Filter your data before the call or drop missing "
1153+
"observations explicitly."
1154+
)
11401155

11411156
# --- vce="nn" setup: sort ascending, precompute dups ---
11421157
dups: Optional[np.ndarray] = None

diff_diff/local_linear.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,19 +1050,27 @@ def bias_corrected_local_linear(
10501050
d, y = _validate_had_inputs(d, y, boundary)
10511051
n_total = int(d.shape[0])
10521052

1053-
# Reject missing cluster IDs up front (Phase 1b convention).
1053+
# Reject missing cluster IDs up front (Phase 1b convention). Delegates
1054+
# to the dtype-agnostic `_cluster_has_missing` helper in the port so
1055+
# wrapper, port-level `lprobust`, and `lpbwselect_mse_dpi` all enforce
1056+
# the same missing-sentinel contract across float / object / string
1057+
# dtypes (CI review PR #340 P1 follow-up).
10541058
cluster_arr: Optional[np.ndarray] = None
10551059
if cluster is not None:
1060+
from diff_diff._nprobust_port import _cluster_has_missing
1061+
10561062
cluster_arr = np.asarray(cluster).ravel()
10571063
if cluster_arr.shape[0] != n_total:
10581064
raise ValueError(
10591065
f"cluster length ({cluster_arr.shape[0]}) does not match "
10601066
f"d/y ({n_total})."
10611067
)
1062-
if np.issubdtype(cluster_arr.dtype, np.floating) and not np.all(
1063-
np.isfinite(cluster_arr)
1064-
):
1065-
raise ValueError("cluster contains non-finite values (NaN or Inf).")
1068+
if _cluster_has_missing(cluster_arr):
1069+
raise ValueError(
1070+
"cluster contains missing values (NaN / None). Filter "
1071+
"your data before the call or drop missing observations "
1072+
"explicitly."
1073+
)
10661074

10671075
# --- Resolve (h, b) ---
10681076
# nprobust's lprobust() with the default rho=1 sets b = h / rho = h

tests/test_bias_corrected_lprobust.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,11 +311,34 @@ def test_weights_raises_not_implemented(self):
311311
bias_corrected_local_linear(d, y, h=0.3, weights=np.ones(100))
312312

313313
def test_cluster_nan_raises(self):
314+
"""Float NaN in cluster IDs is rejected."""
314315
d = np.linspace(0.0, 1.0, 100)
315316
y = d.copy()
316317
cluster = np.repeat(np.arange(10), 10).astype(np.float64)
317318
cluster[5] = np.nan
318-
with pytest.raises(ValueError, match="cluster contains non-finite"):
319+
with pytest.raises(ValueError, match="cluster contains missing"):
320+
bias_corrected_local_linear(d, y, h=0.3, cluster=cluster)
321+
322+
def test_cluster_object_none_raises(self):
323+
"""Object-dtype cluster with a ``None`` sentinel is rejected.
324+
Float-only checks let this through; the dtype-agnostic helper
325+
catches it (CI review PR #340 follow-up P1)."""
326+
d = np.linspace(0.0, 1.0, 100)
327+
y = d.copy()
328+
cluster = np.array(
329+
[i // 10 if i != 5 else None for i in range(100)], dtype=object
330+
)
331+
with pytest.raises(ValueError, match="cluster contains missing"):
332+
bias_corrected_local_linear(d, y, h=0.3, cluster=cluster)
333+
334+
def test_cluster_object_nan_raises(self):
335+
"""Object-dtype cluster with np.nan is rejected."""
336+
d = np.linspace(0.0, 1.0, 100)
337+
y = d.copy()
338+
cluster = np.array(
339+
[i // 10 if i != 5 else np.nan for i in range(100)], dtype=object
340+
)
341+
with pytest.raises(ValueError, match="cluster contains missing"):
319342
bias_corrected_local_linear(d, y, h=0.3, cluster=cluster)
320343

321344
def test_unknown_kernel_raises(self):

tests/test_nprobust_port.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,37 @@ def test_lprobust_shifted_boundary_dgp5(self, lprobust_golden):
588588
np.testing.assert_allclose(res.se_cl, g["se_cl"], atol=1e-12, rtol=1e-12)
589589
np.testing.assert_allclose(res.se_rb, g["se_rb"], atol=1e-12, rtol=1e-12)
590590

591+
def test_lprobust_cluster_object_none_raises(self):
592+
"""Port-level: object-dtype cluster with None sentinel is rejected.
593+
Mirror of the wrapper test; pins that `_cluster_has_missing` fires
594+
inside the direct port entry point too (CI review PR #340
595+
follow-up P1)."""
596+
from diff_diff._nprobust_port import lprobust
597+
598+
rng = np.random.default_rng(0)
599+
G = 200
600+
d = rng.uniform(0.0, 1.0, G)
601+
y = d + rng.normal(0, 0.3, G)
602+
cluster = np.array(
603+
[i // 10 if i != 5 else None for i in range(G)], dtype=object
604+
)
605+
with pytest.raises(ValueError, match="cluster contains missing"):
606+
lprobust(y, d, eval_point=0.0, h=0.3, b=0.3, cluster=cluster)
607+
608+
def test_lprobust_cluster_object_nan_raises(self):
609+
"""Port-level: object-dtype cluster with np.nan is rejected."""
610+
from diff_diff._nprobust_port import lprobust
611+
612+
rng = np.random.default_rng(0)
613+
G = 200
614+
d = rng.uniform(0.0, 1.0, G)
615+
y = d + rng.normal(0, 0.3, G)
616+
cluster = np.array(
617+
[i // 10 if i != 5 else np.nan for i in range(G)], dtype=object
618+
)
619+
with pytest.raises(ValueError, match="cluster contains missing"):
620+
lprobust(y, d, eval_point=0.0, h=0.3, b=0.3, cluster=cluster)
621+
591622
def test_lprobust_h_gt_b_selects_h_window(self):
592623
"""When h > b, the active window is ind.h (lprobust.R:182 conditional
593624
replacement), not a union of ind.h and ind.b."""

0 commit comments

Comments
 (0)