Skip to content

Commit e82faab

Browse files
igerberclaude
andcommitted
Fix CI review R13: preserve NaN-SE contract on degenerate-cohort survey fits
_plugin_se returns NaN when the cohort-recentered IF is empty or identically zero (documented degenerate-cohort contract — every variance-eligible group is its own (D_{g,1}, F_g, S_g) singleton). _survey_se_from_group_if only rejected negative variance, so on the same panel it computed sqrt(0) = 0.0, suppressed the degenerate-cohort warning (gated on np.isnan(overall_se)), and exposed a false zero SE. The bug affected every surface routed through _compute_se — top-level ATT, joiners/leavers, multi-horizon ATT, placebos, and derived normalized/cumulated SEs. Mirror the _plugin_se contract: short-circuit to NaN when U_centered is empty or sum(U_centered**2) <= 0, before delegating to compute_survey_if_variance. Added TestSurveyWithinGroupValidation.test_degenerate_cohort_survey_se_is_nan: 4 groups × 5 periods, each switching at a unique F_g so every cohort is a singleton; asserts overall_se is NaN (not 0.0) and that the degenerate-cohort warning fires under the survey path. All 272 targeted tests pass. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 657b62b commit e82faab

2 files changed

Lines changed: 58 additions & 0 deletions

File tree

diff_diff/chaisemartin_dhaultfoeuille.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4874,6 +4874,18 @@ def _survey_se_from_group_if(
48744874
"""
48754875
from diff_diff.survey import compute_survey_if_variance
48764876

4877+
# Degenerate-cohort contract (mirror _plugin_se): when the centered
4878+
# IF is empty or every cohort is a singleton (→ recentered IF is
4879+
# identically zero), the variance is unidentified. Return NaN
4880+
# rather than sqrt(0)=0 so the fit-time warning fires and
4881+
# inference stays NaN-consistent across every surface routed
4882+
# through _compute_se (overall, joiners/leavers, multi-horizon
4883+
# ATT, placebos, normalized/cumulated, heterogeneity).
4884+
if U_centered.size == 0:
4885+
return float("nan")
4886+
if float((U_centered ** 2).sum()) <= 0:
4887+
return float("nan")
4888+
48774889
group_ids = obs_survey_info["group_ids"]
48784890
weights = obs_survey_info["weights"]
48794891
resolved = obs_survey_info["resolved"]

tests/test_survey_dcdh.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1117,6 +1117,52 @@ def test_auto_inject_psu_matches_explicit_group_psu(self, base_data):
11171117
== r_explicit.survey_metadata.df_survey
11181118
)
11191119

1120+
def test_degenerate_cohort_survey_se_is_nan(self):
1121+
"""When every variance-eligible group is its own singleton
1122+
cohort (D_{g,1}, F_g, S_g), the cohort-recentered IF is
1123+
identically zero. The survey SE path must return NaN (not 0.0)
1124+
so the degenerate-cohort warning fires and inference stays
1125+
NaN-consistent — matching the _plugin_se contract documented
1126+
in REGISTRY.md."""
1127+
# 4 groups × 5 periods, each group switches at a unique F_g so
1128+
# the (D_{g,1}=0, F_g, S_g=+1) cohort key is unique per group.
1129+
rows = []
1130+
for g, f_switch in enumerate([1, 2, 3, 4]):
1131+
for t in range(5):
1132+
d = 1 if t >= f_switch else 0
1133+
y = float(g) + 0.5 * t + float(d)
1134+
rows.append({
1135+
"group": g,
1136+
"period": t,
1137+
"treatment": d,
1138+
"outcome": y,
1139+
"pw": 1.0,
1140+
})
1141+
df_ = pd.DataFrame(rows)
1142+
sd = SurveyDesign(weights="pw")
1143+
1144+
import warnings as _warnings
1145+
with _warnings.catch_warnings(record=True) as w:
1146+
_warnings.simplefilter("always")
1147+
result = ChaisemartinDHaultfoeuille(seed=1).fit(
1148+
df_,
1149+
outcome="outcome", group="group",
1150+
time="period", treatment="treatment",
1151+
survey_design=sd,
1152+
)
1153+
1154+
# overall_se must be NaN on degenerate cohorts (not 0.0)
1155+
assert np.isnan(result.overall_se), (
1156+
f"Degenerate-cohort survey overall_se must be NaN, "
1157+
f"got {result.overall_se}"
1158+
)
1159+
# Degenerate-cohort warning must fire
1160+
assert any(
1161+
"cohort" in str(wi.message).lower()
1162+
and "identically zero" in str(wi.message).lower()
1163+
for wi in w
1164+
), "Expected degenerate-cohort warning to fire under survey path"
1165+
11201166
def test_subpopulation_preserves_full_design_df_survey(self, base_data):
11211167
"""Under dCDH auto-inject, zero-weighting an entire group must not
11221168
shrink df_survey below what the full-design PSU count would give.

0 commit comments

Comments
 (0)