Skip to content

Commit 1998d75

Browse files
igerberclaude
andcommitted
Address PR #359 CI review round 6 (2 P1)
P1a — preserve full survey design for Binder-TSL variance. Round 5 filtered the ResolvedSurveyDesign to the positive-weight subset for BOTH design resolution AND survey variance composition. That's wrong under the standard subpopulation / domain-estimation convention (keep the sampling frame, zero the contributions) that ``SurveyDesign.subpopulation()`` + ``compute_survey_if_variance()`` both implement elsewhere in this repo. Silent consequences: ``n_psu`` / ``n_strata`` / ``df_survey`` / FPC application / lonely-PSU behavior would all have been computed on the in-domain subset instead of the full design. Refactored HAD ``fit()`` to preserve ``weights_unit_full`` / ``resolved_survey_unit_full`` / ``raw_weights_unit_full`` alongside the filtered copies used ONLY for design resolution (``_detect_design`` / ``d_lower`` / mass-point threshold / cohort counts). The fit itself (``_fit_continuous`` → ``bias_corrected_local_linear`` → ``lprobust``) now receives the FULL unfiltered arrays; ``bias_corrected_local_linear`` filters internally (see P1b below) and zero-pads the IF back to full ordering, so ``compute_survey_if_variance(if_full, resolved_full)`` preserves the sampling-frame structure. ``compute_survey_metadata`` and ``effective_dose_mean`` likewise consume the full arrays so ``n_psu`` / ``n_strata`` / ``sum_weights`` / ``weight_range`` reflect the sampling frame. Dropped the now-unused ``_filter_resolved_survey`` helper. P1b — ``bias_corrected_local_linear(weights=...)`` filters internally. The public wrapper previously ran ``_validate_had_inputs`` / mass-point threshold / Design 1' support heuristic + auto-bandwidth MSE-DPI selector on the full ``d`` array, so zero-weight units at the boundary or at ``d.min()`` could flip the detected regime or trigger spurious off-support rejections on valid weighted domain fits. Now computes ``positive_mask = weights > 0`` upfront, validates weights (finite, non-negative, positive sum — consistent with the port's ValueErrors), filters ``d`` / ``y`` / ``weights`` / ``cluster`` to the positive- weight subset BEFORE all downstream validation + selection + fit, and zero-pads the returned IF back to the original ordering when ``return_influence=True``. ``n_total`` on the returned ``BiasCorrectedFit`` is the original length (matches the caller's input shape). Four new regression tests: - ``test_zero_weight_survey_metadata_preserves_full_design``: survey= path with PSU/strata + 25% zero-weight units reports full-frame n_psu (not in-domain subset); ATT bit-parity with physically-dropped fit. - ``test_bias_corrected_local_linear_zero_weight_matches_filtered``: explicit-h/b wrapper with zero-weight unit at d=0 (would have flipped Design 1') matches dropped-sample fit at 1e-12; IF zero-padded to original length with IF[0]=0 at the zero-weight position. - ``test_bias_corrected_local_linear_zero_weight_auto_bandwidth``: auto-bandwidth path (DPI selector also runs on positive-weight support) produces same h + same bias-corrected estimate as the dropped-sample fit. - The existing Round 5 zero-weight tests (``n_obs`` reflects positive- weight subset; UserWarning; design-not-flipped) pass unchanged since the design-resolution filter is preserved. All 366 tests pass (across test_had, test_nprobust_port, test_bias_corrected_lprobust, test_np_npreg_weighted_parity, and the slow MC suite). Ruff clean. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 6bd07e7 commit 1998d75

3 files changed

Lines changed: 284 additions & 115 deletions

File tree

diff_diff/had.py

Lines changed: 64 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -1673,63 +1673,6 @@ def _collapse(arr: Optional[np.ndarray], name: str) -> Optional[np.ndarray]:
16731673
)
16741674

16751675

1676-
def _filter_resolved_survey(resolved: Any, keep_mask: np.ndarray) -> Any:
1677-
"""Filter a ResolvedSurveyDesign to a boolean unit-level subset.
1678-
1679-
Used by HAD's continuous path to drop zero-weight units from the
1680-
design-resolution sub-population while preserving the attribute shape
1681-
that ``compute_survey_if_variance`` expects. PSU/strata counts are
1682-
recomputed on the positive-weight subset so degenerate singleton
1683-
strata (after filtering) are counted correctly.
1684-
1685-
Parameters
1686-
----------
1687-
resolved : ResolvedSurveyDesign
1688-
Unit-level resolved design (typically from
1689-
``_aggregate_unit_resolved_survey``).
1690-
keep_mask : np.ndarray, shape (G,), bool
1691-
True for units to keep (e.g., ``weights > 0``).
1692-
1693-
Returns
1694-
-------
1695-
ResolvedSurveyDesign with all (G,) arrays filtered by ``keep_mask``.
1696-
"""
1697-
from diff_diff.survey import ResolvedSurveyDesign
1698-
1699-
def _f(arr: Optional[np.ndarray]) -> Optional[np.ndarray]:
1700-
return arr[keep_mask] if arr is not None else None
1701-
1702-
strata_f = _f(resolved.strata)
1703-
psu_f = _f(resolved.psu)
1704-
n_strata_f = (
1705-
int(np.unique(strata_f).shape[0]) if strata_f is not None else 1
1706-
)
1707-
n_psu_f = (
1708-
int(np.unique(psu_f).shape[0])
1709-
if psu_f is not None
1710-
else int(keep_mask.sum())
1711-
)
1712-
return ResolvedSurveyDesign(
1713-
weights=resolved.weights[keep_mask],
1714-
weight_type=resolved.weight_type,
1715-
strata=strata_f,
1716-
psu=psu_f,
1717-
fpc=_f(resolved.fpc),
1718-
n_strata=n_strata_f,
1719-
n_psu=n_psu_f,
1720-
lonely_psu=resolved.lonely_psu,
1721-
replicate_weights=None,
1722-
replicate_method=None,
1723-
fay_rho=0.0,
1724-
n_replicates=0,
1725-
replicate_strata=None,
1726-
combined_weights=resolved.combined_weights,
1727-
replicate_scale=None,
1728-
replicate_rscales=None,
1729-
mse=resolved.mse,
1730-
)
1731-
1732-
17331676
def _aggregate_multi_period_first_differences(
17341677
data: pd.DataFrame,
17351678
outcome_col: str,
@@ -2563,16 +2506,26 @@ def fit(
25632506
resolved_survey_unit.weights, dtype=np.float64
25642507
)
25652508

2566-
# Zero-weight units (e.g., from SurveyDesign.subpopulation(), or
2509+
# Zero-weight units (e.g. SurveyDesign.subpopulation() output, or
25672510
# a user-supplied pweight column with excluded observations) must
2568-
# not drive design resolution. Filter d_arr, dy_arr, weights_unit,
2569-
# raw_weights_unit, and resolved_survey_unit to the positive-
2570-
# weight subset BEFORE _detect_design / d_lower / mass-point
2571-
# threshold / treated+control counts / bandwidth selection run.
2572-
# The weighted kernel already drops zero-weight observations via
2573-
# the ``w > 0`` selector in lprobust, so the FIT is unchanged;
2574-
# only the design-decision logic was previously contaminated
2575-
# (CI review PR #359 round 5, P0).
2511+
# not drive design resolution — ``_detect_design`` / ``d_lower``
2512+
# / mass-point threshold / cohort counts run on the POSITIVE-
2513+
# weight subset. But the survey VARIANCE and ``SurveyMetadata``
2514+
# preserve the FULL ResolvedSurveyDesign (zero-weight PSUs /
2515+
# strata kept in the design with zero in-domain mass) — that is
2516+
# the standard subpopulation / domain-estimation convention in
2517+
# ``diff_diff.survey``: keep the sampling frame, zero the
2518+
# contributions. The weighted kernel in ``lprobust`` drops
2519+
# zero-weight observations via its ``w > 0`` selector, and
2520+
# ``bias_corrected_local_linear`` zero-pads the returned IF back
2521+
# to the full unit ordering so the survey composition at the
2522+
# HAD level sees IF=0 for zero-weight units on the FULL design.
2523+
# (CI review PR #359 round 5 P0 + round 6 P1 cascade.)
2524+
d_arr_full = d_arr # unfiltered (G units); passed to _fit_continuous
2525+
dy_arr_full = dy_arr
2526+
weights_unit_full = weights_unit # may contain zeros; used for FIT
2527+
resolved_survey_unit_full = resolved_survey_unit # full design for VARIANCE
2528+
raw_weights_unit_full = raw_weights_unit # full for SurveyMetadata
25762529
if weights_unit is not None:
25772530
positive_mask = weights_unit > 0.0
25782531
if not bool(positive_mask.all()):
@@ -2581,22 +2534,18 @@ def fit(
25812534
f"HAD continuous path: {n_dropped} unit(s) have "
25822535
f"weight == 0 and are excluded from design resolution "
25832536
f"(auto-detect design, d_lower, mass-point threshold, "
2584-
f"cohort counts) + the weighted fit. Standard survey "
2585-
f"subpopulation designs (SurveyDesign.subpopulation) "
2586-
f"zero-out excluded units by design; the estimator "
2587-
f"treats them as absent from the analysis sample.",
2537+
f"cohort counts). They are RETAINED in the survey "
2538+
f"design for variance + SurveyMetadata (subpopulation "
2539+
f"convention: zero-weight contributions but full "
2540+
f"sampling frame), and their IF is 0 on the full "
2541+
f"design.",
25882542
UserWarning,
25892543
stacklevel=2,
25902544
)
2545+
# Filter arrays used for DESIGN-RESOLUTION ONLY.
25912546
d_arr = d_arr[positive_mask]
25922547
dy_arr = dy_arr[positive_mask]
25932548
weights_unit = weights_unit[positive_mask]
2594-
if raw_weights_unit is not None:
2595-
raw_weights_unit = raw_weights_unit[positive_mask]
2596-
if resolved_survey_unit is not None:
2597-
resolved_survey_unit = _filter_resolved_survey(
2598-
resolved_survey_unit, positive_mask
2599-
)
26002549

26012550
n_obs = int(d_arr.shape[0])
26022551
if n_obs < 3:
@@ -2838,13 +2787,19 @@ def fit(
28382787
UserWarning,
28392788
stacklevel=2,
28402789
)
2790+
# Fit on FULL (unfiltered) arrays so the IF aligns with the
2791+
# full survey design. bias_corrected_local_linear drops
2792+
# zero-weight rows internally for its validation + selector +
2793+
# fit, then zero-pads the IF back to full length. Survey
2794+
# composition below runs on the full design, preserving
2795+
# domain-estimation semantics.
28412796
att, se, bc_fit, bw_diag = self._fit_continuous(
2842-
d_arr,
2843-
dy_arr,
2797+
d_arr_full,
2798+
dy_arr_full,
28442799
resolved_design,
28452800
d_lower_val,
2846-
weights_arr=weights_unit,
2847-
resolved_survey_unit=resolved_survey_unit,
2801+
weights_arr=weights_unit_full,
2802+
resolved_survey_unit=resolved_survey_unit_full,
28482803
)
28492804
inference_method = "analytical_nonparametric"
28502805
vcov_label: Optional[str] = None
@@ -2907,46 +2862,47 @@ def fit(
29072862
survey_metadata: Optional[SurveyMetadata] = None
29082863
variance_formula_label: Optional[str] = None
29092864
effective_dose_mean_value: Optional[float] = None
2910-
if weights_unit is not None:
2911-
if resolved_survey_unit is not None:
2912-
# survey= path: build metadata from the ResolvedSurveyDesign
2913-
# already aggregated to unit-level by
2914-
# _aggregate_unit_resolved_survey. Pass the RAW
2915-
# pre-normalization per-unit weights (captured above before
2916-
# survey.resolve() rescaled pweights/aweights to mean=1)
2917-
# so ``sum_weights`` and ``weight_range`` reflect the
2918-
# user-supplied scale — matching both the ``weights=``
2919-
# shortcut and ``compute_survey_metadata``'s contract.
2920-
assert raw_weights_unit is not None # set in survey= branch
2865+
if weights_unit_full is not None:
2866+
if resolved_survey_unit_full is not None:
2867+
# survey= path: build metadata from the FULL
2868+
# ResolvedSurveyDesign (pre-zero-weight-filter), so
2869+
# ``n_strata`` / ``n_psu`` / ``df_survey`` / weight sums
2870+
# reflect the sampling frame, not the in-domain subset.
2871+
# Pass the RAW pre-normalization per-unit weights
2872+
# (captured before survey.resolve() rescaled pweights/
2873+
# aweights to mean=1) so ``sum_weights`` / ``weight_range``
2874+
# reflect the user-supplied scale — matching both the
2875+
# ``weights=`` shortcut and ``compute_survey_metadata``'s
2876+
# contract.
2877+
assert raw_weights_unit_full is not None # set in survey= branch
29212878
survey_metadata = compute_survey_metadata(
2922-
resolved_survey_unit, raw_weights_unit
2879+
resolved_survey_unit_full, raw_weights_unit_full
29232880
)
29242881
variance_formula_label = "survey_binder_tsl"
29252882
else:
29262883
# weights=<array> shortcut: construct a minimal resolved
2927-
# SurveyDesign with just the unit-level weights (no strata /
2928-
# PSU / FPC) so compute_survey_metadata returns a
2929-
# SurveyMetadata with the same schema as the survey= path.
2930-
# This keeps shared reporting consumers on a single code
2931-
# path — they read attributes regardless of entry point.
2884+
# SurveyDesign with the FULL user-supplied weights
2885+
# (including zero-weight units) so SurveyMetadata
2886+
# summarizes the full sample. No strata / PSU / FPC
2887+
# structure — the shortcut is pweight-only by contract.
29322888
from diff_diff.survey import ResolvedSurveyDesign
29332889

29342890
minimal_resolved = ResolvedSurveyDesign(
2935-
weights=weights_unit,
2891+
weights=weights_unit_full,
29362892
weight_type="pweight",
29372893
strata=None,
29382894
psu=None,
29392895
fpc=None,
29402896
n_strata=1,
2941-
n_psu=int(weights_unit.shape[0]),
2897+
n_psu=int(weights_unit_full.shape[0]),
29422898
lonely_psu="remove",
29432899
combined_weights=True,
29442900
mse=False,
29452901
)
2946-
# weights_unit is already the raw user-supplied array
2947-
# (no SurveyDesign.resolve() normalization on this path).
2902+
# weights_unit_full is already the raw user-supplied
2903+
# array (no SurveyDesign.resolve() normalization here).
29482904
survey_metadata = compute_survey_metadata(
2949-
minimal_resolved, weights_unit
2905+
minimal_resolved, weights_unit_full
29502906
)
29512907
# On the ``weights=`` shortcut, inference stays Normal
29522908
# (df=None in safe_inference) — no PSU / strata / FPC
@@ -2964,16 +2920,17 @@ def fit(
29642920
survey_metadata.df_survey = None
29652921
variance_formula_label = "pweight"
29662922
# Expose the effective weighted denominator used by the
2967-
# beta-scale rescaling (bc_fit carries it via its internal
2968-
# weighted means, but users inspecting the result directly
2969-
# need the value alongside the raw ``dose_mean``).
2923+
# beta-scale rescaling. Use FULL arrays (same numerical
2924+
# result — zero-weight units contribute 0 to both numerator
2925+
# and denominator — but preserves symmetry with the FULL-
2926+
# array fit path above).
29702927
if resolved_design == "continuous_at_zero":
29712928
effective_dose_mean_value = float(
2972-
np.average(d_arr, weights=weights_unit)
2929+
np.average(d_arr_full, weights=weights_unit_full)
29732930
)
29742931
elif resolved_design == "continuous_near_d_lower":
29752932
effective_dose_mean_value = float(
2976-
np.average(d_arr - d_lower_val, weights=weights_unit)
2933+
np.average(d_arr_full - d_lower_val, weights=weights_unit_full)
29772934
)
29782935
# else (mass_point): unreachable here because mass_point with
29792936
# weights raises NotImplementedError upstream.

diff_diff/local_linear.py

Lines changed: 72 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1060,14 +1060,59 @@ def bias_corrected_local_linear(
10601060
IDs are in first-appearance order; otherwise BLAS reduction
10611061
ordering can drift to ``atol=1e-10``.
10621062
"""
1063+
# Zero-weight unit handling (Phase 4.5 survey support). Filter
1064+
# d/y/weights to the positive-weight support BEFORE
1065+
# ``_validate_had_inputs`` runs. Otherwise zero-weight observations
1066+
# at the boundary or at ``d.min()`` taint the Design 1' support
1067+
# heuristic and the mass-point threshold, causing spurious
1068+
# off-support rejections or design misidentification on valid
1069+
# weighted domain fits (e.g. ``SurveyDesign.subpopulation()``).
1070+
# The auto-bandwidth MSE-DPI selector below also sees only
1071+
# positive-weight observations — kernel density + variance estimates
1072+
# at the boundary are incorrect if zero-weight units contaminate
1073+
# the sample. Callers that want the IF aligned with the ORIGINAL
1074+
# ``d`` ordering get zero-padded positions back via
1075+
# ``return_influence=True``: positive-weight positions carry the
1076+
# active IF, zero-weight positions are 0 (consistent with their
1077+
# zero contribution to the fit).
1078+
_positive_mask_full: Optional[np.ndarray] = None
1079+
_n_full_for_if: Optional[int] = None
10631080
if weights is not None:
10641081
weights = np.asarray(weights, dtype=np.float64).ravel()
1065-
# NOTE: bandwidth selection (auto mode) remains unweighted; the
1066-
# plug-in MSE-optimal DPI is not yet weight-aware. Weights only
1067-
# enter the final lprobust fit + its variance propagation. Users
1068-
# who want a weight-aware bandwidth should pass ``h``/``b`` that
1069-
# reflect the weighted DGP. See REGISTRY "Weighted extension"
1070-
# subsection for the documented methodology gap.
1082+
d_np = np.asarray(d).ravel()
1083+
# Weight validation fires BEFORE the zero-weight filter below,
1084+
# so invalid weights (NaN, Inf, negative, zero-sum) raise
1085+
# consistent ValueErrors regardless of how many zero-weight
1086+
# units the caller also passed. Duplicates the lprobust-level
1087+
# validation locally so errors surface at the public-wrapper
1088+
# boundary with the same messages the port raises downstream.
1089+
if weights.shape[0] != d_np.shape[0]:
1090+
raise ValueError(
1091+
f"weights length ({weights.shape[0]}) does not match "
1092+
f"d/y ({d_np.shape[0]})."
1093+
)
1094+
if not np.all(np.isfinite(weights)):
1095+
raise ValueError("weights contains non-finite values (NaN or Inf).")
1096+
if np.any(weights < 0):
1097+
raise ValueError("weights must be non-negative.")
1098+
if np.sum(weights) <= 0:
1099+
raise ValueError(
1100+
"weights sum to zero — no observations have positive weight."
1101+
)
1102+
if weights.shape[0] == d_np.shape[0]:
1103+
_positive_mask_full = weights > 0.0
1104+
if not bool(_positive_mask_full.all()):
1105+
_n_full_for_if = int(d_np.shape[0])
1106+
d = d_np[_positive_mask_full]
1107+
y = np.asarray(y).ravel()[_positive_mask_full]
1108+
weights = weights[_positive_mask_full]
1109+
if cluster is not None:
1110+
cluster = np.asarray(cluster).ravel()[_positive_mask_full]
1111+
# NOTE: bandwidth selection (auto mode) on the POSITIVE-weight
1112+
# subset remains unweighted — the plug-in MSE-optimal DPI is
1113+
# not yet weight-aware. Users who want a weight-aware bandwidth
1114+
# should pass ``h``/``b`` that reflect the weighted DGP. See
1115+
# REGISTRY "Weighted extension" subsection.
10711116

10721117
if kernel not in _KERNEL_NAME_TO_NPROBUST:
10731118
raise ValueError(
@@ -1248,6 +1293,25 @@ def bias_corrected_local_linear(
12481293

12491294
_, _, (ci_low, ci_high) = safe_inference(result.tau_bc, result.se_rb, alpha=float(alpha))
12501295

1296+
# Zero-pad the IF back to the ORIGINAL (pre-positive-weight-filter)
1297+
# sample ordering when the caller passed a zero-weight vector, so
1298+
# downstream estimator-level Binder-TSL composition lines up with
1299+
# the full survey design rather than with the positive-weight
1300+
# subset. Zero-weight positions get IF=0 (consistent with their
1301+
# zero contribution to the fit); the survey-design variance
1302+
# (``compute_survey_if_variance``) can then preserve PSU / stratum
1303+
# / df_survey structure of the full design — the correct
1304+
# subpopulation / domain-analysis convention.
1305+
if_out = result.influence_function
1306+
if (
1307+
if_out is not None
1308+
and _positive_mask_full is not None
1309+
and _n_full_for_if is not None
1310+
):
1311+
if_full = np.zeros(_n_full_for_if, dtype=np.float64)
1312+
if_full[_positive_mask_full] = if_out
1313+
if_out = if_full
1314+
12511315
return BiasCorrectedFit(
12521316
estimate_classical=result.tau_cl,
12531317
estimate_bias_corrected=result.tau_bc,
@@ -1261,8 +1325,8 @@ def bias_corrected_local_linear(
12611325
bandwidth_source=bw_source,
12621326
bandwidth_diagnostics=bw_diag,
12631327
n_used=result.n_used,
1264-
n_total=n_total,
1328+
n_total=int(_n_full_for_if) if _n_full_for_if is not None else n_total,
12651329
kernel=kernel,
12661330
boundary=float(boundary),
1267-
influence_function=result.influence_function,
1331+
influence_function=if_out,
12681332
)

0 commit comments

Comments
 (0)