Skip to content

Commit de8ff5e

Browse files
igerberclaude
andcommitted
Fix CI review R8: extend zero-weight contract to all validators + survey branch tests
- P1 #1: The R5 zero-weight filter only ran inside the cell aggregation step, after the NaN/coercion checks for group/time/treatment/outcome. Moved the filter to the very top of _validate_and_aggregate_to_cells so validation only sees the effective sample. fit()'s controls, trends_nonparam, and heterogeneity blocks now also scope their NaN/time-invariance checks to positive-weight rows when survey_weights is active. Legitimate SurveyDesign.subpopulation() inputs with NaN in excluded rows now fit cleanly. TSL variance path is unchanged (zero-weight obs still contribute zero psi). - P2: 5 new regression tests in test_survey_dcdh.py — TestZeroWeightSubpopulation now covers NaN outcome and NaN het columns in excluded rows; new TestSurveyTrendsLinear / TestSurveyTrendsNonparam / TestSurveyDesign2 classes exercise survey_design combined with those previously-untested branches. All 262 targeted tests pass. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 5b2939c commit de8ff5e

2 files changed

Lines changed: 194 additions & 23 deletions

File tree

diff_diff/chaisemartin_dhaultfoeuille.py

Lines changed: 47 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,16 @@ def _validate_and_aggregate_to_cells(
157157

158158
df = data.copy()
159159

160+
# 1a. SurveyDesign.subpopulation() contract: zero-weight rows are
161+
# out-of-sample. Pre-filter them *before* any NaN/coercion validation
162+
# so that invalid values in excluded rows do not abort the fit.
163+
if weights is not None:
164+
weights_arr = np.asarray(weights, dtype=np.float64)
165+
pos_mask = weights_arr > 0
166+
if not pos_mask.all():
167+
df = df.loc[pos_mask].reset_index(drop=True)
168+
weights = weights_arr[pos_mask]
169+
160170
# 1b. Group and time NaN checks (before groupby, which silently drops NaN keys)
161171
n_nan_group = int(df[group].isna().sum())
162172
if n_nan_group > 0:
@@ -210,19 +220,8 @@ def _validate_and_aggregate_to_cells(
210220

211221
# 5. Cell aggregation (compute min/max for within-cell check)
212222
if weights is not None:
213-
# Zero-weight rows are out-of-sample (e.g., via
214-
# SurveyDesign.subpopulation()). Pre-filter them before the
215-
# groupby so that d_min/d_max/n_gt reflect the effective sample
216-
# and a zero-weight obs cannot trip the within-cell treatment-
217-
# constancy check or inflate downstream n_gt counts.
218-
weights_arr = np.asarray(weights, dtype=np.float64)
219-
pos_mask = weights_arr > 0
220-
if not pos_mask.all():
221-
df = df.loc[pos_mask].reset_index(drop=True)
222-
weights_arr = weights_arr[pos_mask]
223-
weights = weights_arr
224-
225-
# Survey-weighted cell aggregation.
223+
# Survey-weighted cell aggregation (zero-weight rows already
224+
# filtered upstream at step 1a).
226225
# y_gt = sum(w_i * y_i) / sum(w_i) within each (g, t) cell.
227226
# Treatment is constant within cells (checked below), so weighted
228227
# and unweighted means are identical for d_gt.
@@ -730,8 +729,17 @@ def fit(
730729
f"Control column(s) {missing_controls!r} not found in "
731730
f"data. Available columns: {list(data.columns)}"
732731
)
733-
# Work on a copy to avoid mutating the caller's DataFrame
734-
data_controls = data[controls].copy()
732+
# SurveyDesign.subpopulation() contract: zero-weight rows are
733+
# out-of-sample. Scope NaN/Inf validation to positive-weight
734+
# rows so that excluded obs with missing covariates do not
735+
# abort the fit. The downstream weighted aggregation
736+
# (sum(w*x)/sum(w)) handles zero-weight rows correctly on
737+
# its own.
738+
if survey_weights is not None:
739+
pos_mask_ctrl = np.asarray(survey_weights) > 0
740+
data_controls = data.loc[pos_mask_ctrl, controls].copy()
741+
else:
742+
data_controls = data[controls].copy()
735743
for c in controls:
736744
try:
737745
data_controls[c] = pd.to_numeric(data_controls[c])
@@ -1196,16 +1204,24 @@ def fit(
11961204
f"trends_nonparam column {set_col!r} not found in "
11971205
f"data. Available columns: {list(data.columns)}"
11981206
)
1199-
# Reject NaN/missing set assignments
1200-
n_na_set = int(data[set_col].isna().sum())
1207+
# SurveyDesign.subpopulation() contract: scope NaN and
1208+
# time-invariance validation to positive-weight rows so
1209+
# excluded obs with missing set IDs do not abort the fit.
1210+
if survey_weights is not None:
1211+
pos_mask_tnp = np.asarray(survey_weights) > 0
1212+
data_tnp = data.loc[pos_mask_tnp]
1213+
else:
1214+
data_tnp = data
1215+
# Reject NaN/missing set assignments (effective sample only)
1216+
n_na_set = int(data_tnp[set_col].isna().sum())
12011217
if n_na_set > 0:
12021218
raise ValueError(
12031219
f"trends_nonparam column {set_col!r} contains "
12041220
f"{n_na_set} NaN/missing value(s). All groups must "
12051221
f"have a valid set assignment."
12061222
)
12071223
# Aggregate set membership per group (must be time-invariant)
1208-
set_per_group = data.groupby(group)[set_col].nunique()
1224+
set_per_group = data_tnp.groupby(group)[set_col].nunique()
12091225
time_varying = set_per_group[set_per_group > 1]
12101226
if len(time_varying) > 0:
12111227
raise ValueError(
@@ -1217,7 +1233,7 @@ def fit(
12171233
# Set partition must be coarser than group (multiple groups
12181234
# per set). A group-level partition creates singleton sets
12191235
# with no within-set controls available.
1220-
set_map_check = data.groupby(group)[set_col].first()
1236+
set_map_check = data_tnp.groupby(group)[set_col].first()
12211237
n_sets = set_map_check.nunique()
12221238
n_groups_total = len(set_map_check)
12231239
if n_sets >= n_groups_total:
@@ -1229,7 +1245,7 @@ def fit(
12291245
f"within-set controls."
12301246
)
12311247
# Extract set membership per group aligned with all_groups
1232-
set_map = data.groupby(group)[set_col].first()
1248+
set_map = data_tnp.groupby(group)[set_col].first()
12331249
set_ids_arr = np.array(
12341250
[set_map.loc[g] for g in all_groups], dtype=object
12351251
)
@@ -2376,16 +2392,24 @@ def fit(
23762392
"control-pool restrictions; the results would be "
23772393
"inconsistent with the fitted estimator."
23782394
)
2379-
# Extract per-group covariate (must be time-invariant)
2380-
het_per_group = data.groupby(group)[het_col].nunique()
2395+
# Extract per-group covariate (must be time-invariant).
2396+
# SurveyDesign.subpopulation() contract: scope time-invariance
2397+
# check to positive-weight rows so excluded obs with NaN/varying
2398+
# het values do not abort the fit.
2399+
if survey_weights is not None:
2400+
pos_mask_het = np.asarray(survey_weights) > 0
2401+
data_het = data.loc[pos_mask_het]
2402+
else:
2403+
data_het = data
2404+
het_per_group = data_het.groupby(group)[het_col].nunique()
23812405
het_varying = het_per_group[het_per_group > 1]
23822406
if len(het_varying) > 0:
23832407
raise ValueError(
23842408
f"heterogeneity column {het_col!r} must be "
23852409
f"time-invariant within each group. "
23862410
f"{len(het_varying)} group(s) have varying values."
23872411
)
2388-
het_map = data.groupby(group)[het_col].first()
2412+
het_map = data_het.groupby(group)[het_col].first()
23892413
X_het = np.array(
23902414
[float(het_map.loc[g]) for g in all_groups]
23912415
)

tests/test_survey_dcdh.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -841,3 +841,150 @@ def test_mixed_zero_weight_row_excluded_from_validation(self, base_data):
841841
survey_design=sd,
842842
)
843843
assert np.isfinite(result.overall_att)
844+
845+
def test_zero_weight_row_with_nan_outcome(self, base_data):
846+
"""A zero-weight row with NaN outcome must not trip the outcome
847+
NaN validator. SurveyDesign.subpopulation() contract."""
848+
df_ = base_data.copy()
849+
df_["pw"] = 1.0
850+
sample = df_.iloc[0].copy()
851+
sample["outcome"] = np.nan
852+
sample["pw"] = 0.0
853+
df_ = pd.concat([df_, pd.DataFrame([sample])], ignore_index=True)
854+
sd = SurveyDesign(weights="pw")
855+
# Must succeed — zero-weight row with NaN outcome is out-of-sample
856+
result = ChaisemartinDHaultfoeuille(seed=1).fit(
857+
df_,
858+
outcome="outcome", group="group",
859+
time="period", treatment="treatment",
860+
survey_design=sd,
861+
)
862+
assert np.isfinite(result.overall_att)
863+
864+
def test_zero_weight_row_with_nan_heterogeneity(self, base_data):
865+
"""A zero-weight row with NaN in the heterogeneity column must
866+
not trip the heterogeneity time-invariance validator."""
867+
rng = np.random.default_rng(0)
868+
df_ = base_data.copy()
869+
df_["pw"] = 1.0
870+
groups = sorted(df_["group"].unique())
871+
het_map = {g: rng.uniform(-1, 1) for g in groups}
872+
df_["x_het"] = df_["group"].map(het_map)
873+
# Inject a zero-weight row with NaN het value for an existing group
874+
sample = df_.iloc[0].copy()
875+
sample["x_het"] = np.nan
876+
sample["pw"] = 0.0
877+
df_ = pd.concat([df_, pd.DataFrame([sample])], ignore_index=True)
878+
sd = SurveyDesign(weights="pw")
879+
# Must succeed — zero-weight row with NaN het is out-of-sample
880+
result = ChaisemartinDHaultfoeuille(seed=1).fit(
881+
df_,
882+
outcome="outcome", group="group",
883+
time="period", treatment="treatment",
884+
L_max=1, heterogeneity="x_het", survey_design=sd,
885+
)
886+
assert result.heterogeneity_effects is not None
887+
888+
889+
# ── Test: Survey + trends_linear ────────────────────────────────────
890+
891+
892+
class TestSurveyTrendsLinear:
893+
"""Survey-backed trends_linear fit must populate linear_trends_effects."""
894+
895+
def test_survey_trends_linear_runs(self, data_with_survey):
896+
sd = SurveyDesign(weights="pw")
897+
r = ChaisemartinDHaultfoeuille(seed=1).fit(
898+
data_with_survey,
899+
outcome="outcome", group="group",
900+
time="period", treatment="treatment",
901+
L_max=2, trends_linear=True, survey_design=sd,
902+
)
903+
assert r.survey_metadata is not None
904+
# linear_trends_effects populated per REGISTRY line 614 contract
905+
assert r.linear_trends_effects is not None
906+
# At least one horizon should be estimable with finite value
907+
finite_horizons = [
908+
h for h, entry in r.linear_trends_effects.items()
909+
if np.isfinite(entry.get("effect", np.nan))
910+
]
911+
assert len(finite_horizons) > 0, (
912+
"expected at least one horizon with finite linear_trends_effect"
913+
)
914+
915+
916+
# ── Test: Survey + trends_nonparam ──────────────────────────────────
917+
918+
919+
class TestSurveyTrendsNonparam:
920+
"""Survey-backed trends_nonparam fit must thread set-restrictions."""
921+
922+
def test_survey_trends_nonparam_runs(self, data_with_survey):
923+
# Reuse stratum as set ID (time-invariant per group)
924+
sd = SurveyDesign(weights="pw")
925+
r = ChaisemartinDHaultfoeuille(seed=1).fit(
926+
data_with_survey,
927+
outcome="outcome", group="group",
928+
time="period", treatment="treatment",
929+
L_max=2, trends_nonparam="stratum", survey_design=sd,
930+
)
931+
assert r.survey_metadata is not None
932+
assert r.event_study_effects is not None
933+
# Support trimming may reduce counts but at least one finite-SE
934+
# horizon should remain on this fixture.
935+
finite_ses = [
936+
entry
937+
for entry in r.event_study_effects.values()
938+
if np.isfinite(entry.get("se", np.nan))
939+
]
940+
assert len(finite_ses) > 0, (
941+
"expected at least one event-study horizon with finite SE "
942+
"under trends_nonparam + survey"
943+
)
944+
945+
946+
# ── Test: Survey + design2 ──────────────────────────────────────────
947+
948+
949+
class TestSurveyDesign2:
950+
"""Survey-backed design2 fit must populate design2_effects."""
951+
952+
@staticmethod
953+
def _make_join_then_leave_panel(seed=42, n_groups=30, n_periods=8):
954+
"""Panel with join-then-leave (Design-2) groups, matching the
955+
existing design2 fixture in test_chaisemartin_dhaultfoeuille.py."""
956+
rng = np.random.RandomState(seed)
957+
rows = []
958+
for g in range(n_groups):
959+
group_fe = rng.normal(0, 2)
960+
for t in range(n_periods):
961+
if g < 10:
962+
d = 1 if 2 <= t < 5 else 0
963+
elif g < 20:
964+
d = 1 if t >= 3 else 0
965+
else:
966+
d = 0
967+
y = group_fe + 2.0 * t + 5.0 * d + rng.normal(0, 0.3)
968+
rows.append(
969+
{"group": g, "period": t, "treatment": d, "outcome": y, "pw": 1.0}
970+
)
971+
return pd.DataFrame(rows)
972+
973+
def test_survey_design2_runs(self):
974+
df_ = self._make_join_then_leave_panel()
975+
sd = SurveyDesign(weights="pw")
976+
# drop_larger_lower=False keeps the 2-switch groups
977+
r = ChaisemartinDHaultfoeuille(
978+
seed=1, drop_larger_lower=False
979+
).fit(
980+
df_,
981+
outcome="outcome", group="group",
982+
time="period", treatment="treatment",
983+
L_max=1, design2=True, survey_design=sd,
984+
)
985+
assert r.survey_metadata is not None
986+
assert r.design2_effects is not None
987+
assert r.design2_effects["n_design2_groups"] == 10
988+
# switch_in and switch_out mean effects should be finite
989+
assert np.isfinite(r.design2_effects["switch_in"]["mean_effect"])
990+
assert np.isfinite(r.design2_effects["switch_out"]["mean_effect"])

0 commit comments

Comments
 (0)