Skip to content

Commit 42d3dff

Browse files
igerberclaude
andcommitted
Restore replicate-weight dispatch in aggregation, raise on zero-weight never-treated
- Add uses_replicate_variance branching in _aggregate_overall() and _aggregate_event_study() to route replicate designs to compute_replicate_if_variance() instead of compute_survey_if_variance() - Change zero-weight never-treated guard from warning to ValueError for covariates path — DR nuisance estimation requires positive-weight controls - Add test_zero_weight_never_treated_raises for the new error path Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent aca9f8b commit 42d3dff

2 files changed

Lines changed: 49 additions & 16 deletions

File tree

diff_diff/efficient_did.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -608,11 +608,10 @@ def fit(
608608

609609
# Guard: never-treated with zero survey weight → no valid comparisons
610610
if cohort_fractions.get(np.inf, 0.0) <= 0 and use_covariates:
611-
warnings.warn(
612-
"Never-treated group has zero survey weight; no valid "
613-
"comparisons possible for covariates path.",
614-
UserWarning,
615-
stacklevel=2,
611+
raise ValueError(
612+
"Never-treated group has zero survey weight. The doubly "
613+
"robust covariates path requires a never-treated control "
614+
"group with positive survey weight for nuisance estimation."
616615
)
617616

618617
# ----- Covariate preparation (if provided) -----
@@ -1274,17 +1273,27 @@ def _aggregate_overall(
12741273
keepers, effects, unit_cohorts, cohort_fractions, n_units,
12751274
unit_weights=self._unit_level_weights,
12761275
)
1277-
# Compute SE: survey path uses score-level psi + compute_survey_if_variance
1278-
# to avoid double-weighting (compute_survey_vcov applies w_i internally,
1279-
# which would double-weight the survey-weighted WIF term).
1276+
# Compute SE: survey path uses score-level psi to avoid double-weighting
1277+
# (compute_survey_vcov applies w_i internally, which would double-weight
1278+
# the survey-weighted WIF term). Dispatch replicate vs TSL.
12801279
if self._unit_resolved_survey is not None:
12811280
uw = self._unit_level_weights
12821281
total_w = float(np.sum(uw))
1283-
# Score-level: standard = w*eif/sum(w), wif already has w_i in indicator
12841282
psi_total = uw * agg_eif / total_w + wif / total_w
1285-
from diff_diff.survey import compute_survey_if_variance
12861283

1287-
variance = compute_survey_if_variance(psi_total, self._unit_resolved_survey)
1284+
if (hasattr(self._unit_resolved_survey, 'uses_replicate_variance')
1285+
and self._unit_resolved_survey.uses_replicate_variance):
1286+
from diff_diff.survey import compute_replicate_if_variance
1287+
1288+
variance, _ = compute_replicate_if_variance(
1289+
psi_total, self._unit_resolved_survey
1290+
)
1291+
else:
1292+
from diff_diff.survey import compute_survey_if_variance
1293+
1294+
variance = compute_survey_if_variance(
1295+
psi_total, self._unit_resolved_survey
1296+
)
12881297
se = float(np.sqrt(max(variance, 0.0))) if np.isfinite(variance) else np.nan
12891298
else:
12901299
agg_eif_total = agg_eif + wif
@@ -1387,9 +1396,20 @@ def _aggregate_event_study(
13871396
uw = self._unit_level_weights
13881397
total_w = float(np.sum(uw))
13891398
psi_total = uw * agg_eif / total_w + wif_e / total_w
1390-
from diff_diff.survey import compute_survey_if_variance
13911399

1392-
variance = compute_survey_if_variance(psi_total, self._unit_resolved_survey)
1400+
if (hasattr(self._unit_resolved_survey, 'uses_replicate_variance')
1401+
and self._unit_resolved_survey.uses_replicate_variance):
1402+
from diff_diff.survey import compute_replicate_if_variance
1403+
1404+
variance, _ = compute_replicate_if_variance(
1405+
psi_total, self._unit_resolved_survey
1406+
)
1407+
else:
1408+
from diff_diff.survey import compute_survey_if_variance
1409+
1410+
variance = compute_survey_if_variance(
1411+
psi_total, self._unit_resolved_survey
1412+
)
13931413
agg_se = float(np.sqrt(max(variance, 0.0))) if np.isfinite(variance) else np.nan
13941414
else:
13951415
agg_eif = agg_eif + wif_e

tests/test_survey_phase3.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,14 +1088,12 @@ def test_bootstrap_se_in_ballpark_of_analytical(self, cov_survey_data):
10881088
)
10891089

10901090
def test_zero_weight_cohort_skipped(self, cov_survey_data):
1091-
"""Zero-weight cohort should be skipped with a warning."""
1091+
"""Zero-weight treated cohort should be skipped with a warning."""
10921092
from diff_diff import EfficientDiD
10931093

10941094
# Set early cohort (first_treat=4) weights to exactly zero
10951095
cov_survey_data = cov_survey_data.copy()
10961096
cov_survey_data.loc[cov_survey_data["first_treat"] == 4, "weight"] = 0.0
1097-
# Need small positive weight for pweight validation (can't be all zero)
1098-
# Keep remaining cohorts with positive weights
10991097
sd = SurveyDesign(weights="weight")
11001098
with pytest.warns(UserWarning, match="zero survey weight"):
11011099
result = EfficientDiD(n_bootstrap=0).fit(
@@ -1107,6 +1105,21 @@ def test_zero_weight_cohort_skipped(self, cov_survey_data):
11071105
assert np.isfinite(result.overall_att)
11081106
assert np.isfinite(result.overall_se)
11091107

1108+
def test_zero_weight_never_treated_raises(self, cov_survey_data):
1109+
"""Zero-weight never-treated group should raise ValueError."""
1110+
from diff_diff import EfficientDiD
1111+
1112+
cov_survey_data = cov_survey_data.copy()
1113+
cov_survey_data.loc[cov_survey_data["first_treat"] == 0, "weight"] = 0.0
1114+
sd = SurveyDesign(weights="weight")
1115+
with pytest.raises(ValueError, match="zero survey weight"):
1116+
EfficientDiD(n_bootstrap=0).fit(
1117+
cov_survey_data,
1118+
"outcome", "unit", "time", "first_treat",
1119+
covariates=["x1"],
1120+
survey_design=sd,
1121+
)
1122+
11101123

11111124
# =============================================================================
11121125
# Scale Invariance (applies to all estimators)

0 commit comments

Comments
 (0)