Skip to content

Commit 657b62b

Browse files
igerberclaude
andcommitted
Fix CI review R12: preserve full survey design under auto-inject
R11's auto-inject of psu=group filtered `data` to positive-weight rows before re-resolving the effective SurveyDesign. That silently shrank `df_survey` on SurveyDesign.subpopulation() inputs without an explicit PSU — violating the documented subpopulation contract that keeps the full design intact so t critical values, p-values, CIs, and HonestDiD bounds match full-design expectations. Replace the pre-filter with a synthesized PSU column built on a private copy of `data`: - Valid group values flow through unchanged as the per-row PSU label. - NaN / invalid group values on zero-weight rows (the edge case that motivated the R11 filter) are replaced with a single shared dummy label so the PSU resolver accepts them. - Zero-weight rows contribute psi_i = 0 to the variance, but remain in the resolved design so n_psu / n_strata / df_survey reflect the full sample — matching the library's subpopulation contract. Added TestSurveyWithinGroupValidation.test_subpopulation_preserves_full_design_df_survey: zero-weights an entire group (mimicking SurveyDesign.subpopulation) and asserts that auto-inject df_survey equals the explicit psu='group' df_survey — the full-design reference. All 271 targeted tests pass. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent bfee956 commit 657b62b

2 files changed

Lines changed: 63 additions & 9 deletions

File tree

diff_diff/chaisemartin_dhaultfoeuille.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -672,25 +672,39 @@ def fit(
672672
):
673673
from diff_diff.survey import SurveyDesign as _SurveyDesign
674674

675-
# Pre-filter zero-weight rows so NaN / invalid group IDs on
676-
# excluded subpopulation rows don't block PSU resolution
677-
# (group becomes the PSU column after auto-inject). Updates
678-
# local bindings only; caller's DataFrame is untouched.
679-
pos_mask_sv = np.asarray(survey_weights) > 0
680-
if not pos_mask_sv.all():
681-
data = data.loc[pos_mask_sv].reset_index(drop=True)
675+
# Build a synthesized PSU column on a private copy of data
676+
# so the caller's DataFrame is untouched. Valid group values
677+
# flow through as their own PSU label; NaN/invalid group
678+
# values on zero-weight rows (SurveyDesign.subpopulation()
679+
# excluded rows) are replaced with a single shared dummy
680+
# label so the PSU resolver accepts them. Zero-weight rows
681+
# contribute psi_i = 0 to the variance; keeping them in the
682+
# resolved design preserves the full-design df_survey
683+
# contract (n_psu / n_strata reflect the full sample, not
684+
# the positive-weight subset).
685+
psu_col_name = "__dcdh_eff_psu__"
686+
synth_data = data.copy()
687+
synth_psu = synth_data[group].copy()
688+
try:
689+
invalid_mask = synth_psu.isna().to_numpy()
690+
except (AttributeError, TypeError):
691+
invalid_mask = np.zeros(len(synth_psu), dtype=bool)
692+
if invalid_mask.any():
693+
synth_psu = synth_psu.astype(object)
694+
synth_psu.loc[invalid_mask] = "__dcdh_excluded_null_psu__"
695+
synth_data[psu_col_name] = synth_psu
682696

683697
eff_design = _SurveyDesign(
684698
weights=survey_design.weights,
685699
strata=survey_design.strata,
686-
psu=group,
700+
psu=psu_col_name,
687701
fpc=getattr(survey_design, "fpc", None),
688702
weight_type=getattr(survey_design, "weight_type", "pweight"),
689703
nest=getattr(survey_design, "nest", False),
690704
lonely_psu=getattr(survey_design, "lonely_psu", "remove"),
691705
)
692706
resolved_survey, survey_weights, _, survey_metadata = (
693-
_resolve_survey_for_fit(eff_design, data, "analytical")
707+
_resolve_survey_for_fit(eff_design, synth_data, "analytical")
694708
)
695709

696710
if resolved_survey is not None:

tests/test_survey_dcdh.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1117,6 +1117,46 @@ def test_auto_inject_psu_matches_explicit_group_psu(self, base_data):
11171117
== r_explicit.survey_metadata.df_survey
11181118
)
11191119

1120+
def test_subpopulation_preserves_full_design_df_survey(self, base_data):
1121+
"""Under dCDH auto-inject, zero-weighting an entire group must not
1122+
shrink df_survey below what the full-design PSU count would give.
1123+
1124+
Mirrors SurveyDesign.subpopulation() semantics where excluded
1125+
rows keep their weights at zero but remain in the design so
1126+
that t critical values, p-values, CIs, and HonestDiD bounds
1127+
reflect the full sampling structure."""
1128+
df_ = base_data.copy()
1129+
df_["pw"] = 1.0
1130+
# Mimic subpopulation() by zero-weighting one entire group
1131+
excluded_group = df_["group"].unique()[0]
1132+
df_.loc[df_["group"] == excluded_group, "pw"] = 0.0
1133+
1134+
sd = SurveyDesign(weights="pw")
1135+
r_subpop = ChaisemartinDHaultfoeuille(seed=1).fit(
1136+
df_, outcome="outcome", group="group",
1137+
time="period", treatment="treatment",
1138+
survey_design=sd,
1139+
)
1140+
# Reference: explicit psu='group' preserves the full-design
1141+
# PSU count because the resolver sees all groups (even those
1142+
# entirely zero-weighted). The auto-inject path must match this.
1143+
r_explicit = ChaisemartinDHaultfoeuille(seed=1).fit(
1144+
df_, outcome="outcome", group="group",
1145+
time="period", treatment="treatment",
1146+
survey_design=SurveyDesign(weights="pw", psu="group"),
1147+
)
1148+
assert r_subpop.survey_metadata is not None
1149+
assert r_explicit.survey_metadata is not None
1150+
assert (
1151+
r_subpop.survey_metadata.df_survey
1152+
== r_explicit.survey_metadata.df_survey
1153+
), (
1154+
f"Auto-inject df_survey={r_subpop.survey_metadata.df_survey} "
1155+
f"must match explicit psu='group' df_survey="
1156+
f"{r_explicit.survey_metadata.df_survey} "
1157+
f"(full-design subpopulation contract)."
1158+
)
1159+
11201160
def test_off_horizon_row_duplication_does_not_change_se(self, base_data):
11211161
"""Under auto-injected psu=group, duplicating an observation
11221162
within a group (cell mean unchanged because the duplicate matches

0 commit comments

Comments
 (0)