Skip to content

Commit 1ed49bc

Browse files
igerberclaude
andcommitted
Fix CI review R1: SE divisor scaling, zero-weight cells, dead helper
- P1-A: Scale U_centered by 1/divisor before survey IF expansion. dCDH IFs are numerator-scale (U.sum() == N_S * DID_M), but compute_survey_if_variance() expects estimator-scale psi. - P1-B: Zero-weight cells (w_gt <= 0) now treated as absent by setting n_gt=0, preventing NaN propagation into estimates. - P2: Add SE-pinning test (uniform weights + PSU=group matches plug-in SE) and zero-weight cell exclusion test. - P3: Delete unused _validate_group_constant_survey() from survey.py that contradicted the supported within-group variation contract. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent d7ddb19 commit 1ed49bc

3 files changed

Lines changed: 78 additions & 42 deletions

File tree

diff_diff/chaisemartin_dhaultfoeuille.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,12 @@ def _validate_and_aggregate_to_cells(
227227
)
228228
cell["y_gt"] = cell["_wy_sum"] / cell["w_gt"]
229229
cell = cell.drop(columns=["_wy_sum"])
230+
# Zero-weight cells: treat as absent so downstream presence
231+
# logic (N_mat > 0) correctly excludes them.
232+
zero_w_mask = cell["w_gt"] <= 0
233+
if zero_w_mask.any():
234+
cell.loc[zero_w_mask, "n_gt"] = 0
235+
cell.loc[zero_w_mask, "y_gt"] = 0.0
230236
df.drop(columns=["_w_", "_wy_"], inplace=True)
231237
else:
232238
cell = df.groupby([group, time], as_index=False).agg(
@@ -4548,8 +4554,14 @@ def _compute_se(
45484554
return _plugin_se(U_centered=U_centered, divisor=divisor)
45494555
if eligible_groups is None:
45504556
return _plugin_se(U_centered=U_centered, divisor=divisor)
4557+
if divisor <= 0:
4558+
return float("nan")
4559+
# dCDH IFs are numerator-scale (U.sum() == N_S * DID_M).
4560+
# compute_survey_if_variance() expects estimator-scale psi.
4561+
# Scale by 1/divisor to normalize before survey expansion.
4562+
U_scaled = U_centered / divisor
45514563
return _survey_se_from_group_if(
4552-
U_centered=U_centered,
4564+
U_centered=U_scaled,
45534565
eligible_groups=eligible_groups,
45544566
obs_survey_info=obs_survey_info,
45554567
)

diff_diff/survey.py

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -911,47 +911,6 @@ def _validate_unit_constant_survey(data, unit_col, survey_design):
911911
)
912912

913913

914-
def _validate_group_constant_survey(data, group_col, survey_design):
915-
"""Validate that survey design columns are constant within groups.
916-
917-
The dCDH estimator aggregates to ``(group, time)`` cells and then
918-
works at the group level. Survey columns (weights, strata, PSU)
919-
must not vary within groups for the IF expansion and survey variance
920-
to be well-defined.
921-
922-
Parameters
923-
----------
924-
data : pd.DataFrame
925-
Input data (pre-aggregation).
926-
group_col : str
927-
Group identifier column name.
928-
survey_design : SurveyDesign
929-
Survey design specification (uses attribute names, not resolved arrays).
930-
931-
Raises
932-
------
933-
ValueError
934-
If any survey column varies within groups.
935-
"""
936-
cols_to_check = [
937-
survey_design.weights,
938-
survey_design.strata,
939-
survey_design.psu,
940-
survey_design.fpc,
941-
]
942-
for col in cols_to_check:
943-
if col is not None and col in data.columns:
944-
n_unique = data.groupby(group_col)[col].nunique()
945-
varying_groups = n_unique[n_unique > 1]
946-
if len(varying_groups) > 0:
947-
raise ValueError(
948-
f"Survey column '{col}' varies within groups "
949-
f"(found {len(varying_groups)} groups with multiple values). "
950-
f"dCDH survey support requires survey design columns to be "
951-
f"constant within groups."
952-
)
953-
954-
955914
def _resolve_pweight_only(resolved_survey, estimator_name):
956915
"""Guard: reject non-pweight and strata/PSU/FPC for pweight-only estimators.
957916

tests/test_survey_dcdh.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,3 +377,68 @@ def test_bootstrap_survey_emits_warning(self, data_with_survey):
377377
treatment="treatment",
378378
survey_design=sd,
379379
)
380+
381+
382+
# ── Test: SE scale pinning ──────────────────────────────────────────
383+
384+
385+
class TestSEScalePinning:
386+
"""Survey SE with uniform weights and no strata/PSU must match plug-in SE."""
387+
388+
def test_uniform_survey_se_matches_plugin(self, base_data):
389+
"""Pins the divisor normalization: uniform survey SE with group-level
390+
PSU clustering should be close to plug-in SE.
391+
392+
Without PSU clustering, survey treats each observation as independent
393+
(N_obs observations), while plug-in treats each group as independent
394+
(N_groups). Clustering at the group level aligns the two.
395+
"""
396+
df = base_data.copy()
397+
df["pw"] = 1.0
398+
sd = SurveyDesign(weights="pw", psu="group")
399+
400+
r_plain = ChaisemartinDHaultfoeuille(seed=1).fit(
401+
base_data, outcome="outcome", group="group",
402+
time="period", treatment="treatment",
403+
)
404+
r_survey = ChaisemartinDHaultfoeuille(seed=1).fit(
405+
df, outcome="outcome", group="group",
406+
time="period", treatment="treatment",
407+
survey_design=sd,
408+
)
409+
# With PSU=group and uniform weights, survey SE should be
410+
# close to plug-in SE (both assume group-level independence).
411+
# Small-sample corrections (n/(n-1)) cause minor differences.
412+
if np.isfinite(r_plain.overall_se) and np.isfinite(r_survey.overall_se):
413+
assert r_plain.overall_se == pytest.approx(
414+
r_survey.overall_se, rel=0.15
415+
), (
416+
f"Survey SE ({r_survey.overall_se:.6f}) should be close to "
417+
f"plug-in SE ({r_plain.overall_se:.6f}) with uniform weights "
418+
f"and PSU=group"
419+
)
420+
421+
422+
# ── Test: Zero-weight cells ─────────────────────────────────────────
423+
424+
425+
class TestZeroWeightCells:
426+
427+
def test_zero_weight_cell_excluded(self, base_data):
428+
"""A cell with zero survey weight is treated as absent."""
429+
df = base_data.copy()
430+
df["pw"] = 1.0
431+
# Zero out weight for one group at one period
432+
target_group = df["group"].unique()[0]
433+
target_period = df["period"].unique()[1]
434+
mask = (df["group"] == target_group) & (df["period"] == target_period)
435+
df.loc[mask, "pw"] = 0.0
436+
sd = SurveyDesign(weights="pw")
437+
438+
# Should not raise; the zero-weight cell is just absent
439+
result = ChaisemartinDHaultfoeuille(seed=1).fit(
440+
df, outcome="outcome", group="group",
441+
time="period", treatment="treatment",
442+
survey_design=sd,
443+
)
444+
assert np.isfinite(result.overall_att)

0 commit comments

Comments
 (0)