Skip to content

Commit 66ab529

Browse files
igerberclaude
andcommitted
Address P1/P2 review findings: survey-weighted WIF, bootstrap denominator, zero-weight guard
- Make _compute_wif_contribution() survey-aware: use w_i * 1{G_i=g} - pg_k formula when unit_weights present, matching staggered_aggregation.py - Use explicit sum(unit_level_weights) denominator in bootstrap perturbation when survey design is active - Guard zero-weight cohorts: skip in fit loop, early return in compute_generated_outcomes_cov when pi_g <= 0 - Add regression tests: analytical SE differs from unweighted, bootstrap SE in ballpark of analytical, zero-weight cohort handled gracefully - Update tutorial notebook: remove stale note about covariates+survey Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent d495e35 commit 66ab529

5 files changed

Lines changed: 140 additions & 6 deletions

File tree

diff_diff/efficient_did.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,7 @@ def fit(
571571
# Use the resolved survey's weights (already normalized per weight_type)
572572
# subset to unit level via _unit_first_panel_row (aligned to all_units)
573573
unit_level_weights = self._unit_resolved_survey.weights
574+
self._unit_level_weights = unit_level_weights
574575

575576
cohort_fractions: Dict[float, float] = {}
576577
if unit_level_weights is not None:
@@ -674,6 +675,15 @@ def fit(
674675
else:
675676
effective_p1_col = period_1_col
676677

678+
# Guard: skip cohorts with zero survey weight (all units zero-weighted)
679+
if cohort_fractions[g] <= 0:
680+
warnings.warn(
681+
f"Cohort {g} has zero survey weight; skipping.",
682+
UserWarning,
683+
stacklevel=2,
684+
)
685+
continue
686+
677687
# Estimate all (g, t) cells including pre-treatment. Under PT-Post,
678688
# pre-treatment cells serve as placebo/pre-trend diagnostics, matching
679689
# the CallawaySantAnna implementation. Users filter to t >= g for
@@ -976,6 +986,7 @@ def fit(
976986
cluster_indices=unit_cluster_indices,
977987
n_clusters=n_clusters,
978988
resolved_survey=self._unit_resolved_survey,
989+
unit_level_weights=self._unit_level_weights,
979990
)
980991
# Update estimates with bootstrap inference
981992
overall_se = bootstrap_results.overall_att_se
@@ -1137,6 +1148,7 @@ def _compute_wif_contribution(
11371148
unit_cohorts: np.ndarray,
11381149
cohort_fractions: Dict[float, float],
11391150
n_units: int,
1151+
unit_weights: Optional[np.ndarray] = None,
11401152
) -> np.ndarray:
11411153
"""Compute weight influence function correction (O(1) scale, matching EIF).
11421154
@@ -1156,6 +1168,9 @@ def _compute_wif_contribution(
11561168
``{cohort: n_cohort / n}`` for each cohort.
11571169
n_units : int
11581170
Total number of units.
1171+
unit_weights : ndarray, shape (n_units,), optional
1172+
Survey weights at the unit level. When provided, uses the
1173+
survey-weighted WIF formula: IF_i(p_g) = (w_i * 1{G_i=g} - pg_k).
11591174
11601175
Returns
11611176
-------
@@ -1169,10 +1184,19 @@ def _compute_wif_contribution(
11691184
return np.zeros(n_units)
11701185

11711186
indicator = (unit_cohorts[:, None] == groups_for_keepers[None, :]).astype(float)
1172-
indicator_sum = np.sum(indicator - pg_keepers, axis=1)
1187+
1188+
if unit_weights is not None:
1189+
# Survey-weighted WIF (matches staggered_aggregation.py:392-401):
1190+
# IF_i(p_g) = (w_i * 1{G_i=g} - pg_k), NOT (1{G_i=g} - pg_k)
1191+
weighted_indicator = indicator * unit_weights[:, None]
1192+
indicator_diff = weighted_indicator - pg_keepers
1193+
indicator_sum = np.sum(indicator_diff, axis=1)
1194+
else:
1195+
indicator_diff = indicator - pg_keepers
1196+
indicator_sum = np.sum(indicator_diff, axis=1)
11731197

11741198
with np.errstate(divide="ignore", invalid="ignore", over="ignore"):
1175-
if1 = (indicator - pg_keepers) / sum_pg
1199+
if1 = indicator_diff / sum_pg
11761200
if2 = np.outer(indicator_sum, pg_keepers) / sum_pg**2
11771201
wif_matrix = if1 - if2
11781202
wif_contrib = wif_matrix @ effects
@@ -1229,7 +1253,8 @@ def _aggregate_overall(
12291253

12301254
# WIF correction: accounts for uncertainty in cohort-size weights
12311255
wif = self._compute_wif_contribution(
1232-
keepers, effects, unit_cohorts, cohort_fractions, n_units
1256+
keepers, effects, unit_cohorts, cohort_fractions, n_units,
1257+
unit_weights=self._unit_level_weights,
12331258
)
12341259
agg_eif_total = agg_eif + wif # both O(1) scale
12351260

@@ -1325,7 +1350,8 @@ def _aggregate_event_study(
13251350
es_keepers = [(g, t) for (g, t) in gt_pairs]
13261351
es_effects = effs
13271352
wif = self._compute_wif_contribution(
1328-
es_keepers, es_effects, unit_cohorts, cohort_fractions, n_units
1353+
es_keepers, es_effects, unit_cohorts, cohort_fractions, n_units,
1354+
unit_weights=self._unit_level_weights,
13291355
)
13301356
agg_eif = agg_eif + wif
13311357

diff_diff/efficient_did_bootstrap.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def _run_multiplier_bootstrap(
6363
cluster_indices: Optional[np.ndarray] = None,
6464
n_clusters: Optional[int] = None,
6565
resolved_survey: object = None,
66+
unit_level_weights: Optional[np.ndarray] = None,
6667
) -> EDiDBootstrapResults:
6768
"""Run multiplier bootstrap on stored EIF values.
6869
@@ -136,11 +137,19 @@ def _run_multiplier_bootstrap(
136137
original_atts = np.array([group_time_effects[gt]["effect"] for gt in gt_pairs])
137138

138139
# Perturbed ATTs: (n_bootstrap, n_gt)
140+
# Under survey design, normalize by sum(survey_weights) instead of n_units
141+
# (pweights are normalized to mean=1, so numerically equivalent, but explicit
142+
# for robustness against future weight types)
143+
denom = (
144+
float(np.sum(unit_level_weights))
145+
if unit_level_weights is not None
146+
else float(n_units)
147+
)
139148
bootstrap_atts = np.zeros((self.n_bootstrap, n_gt))
140149
for j, gt in enumerate(gt_pairs):
141150
eif_gt = eif_by_gt[gt] # shape (n_units,)
142151
with np.errstate(divide="ignore", invalid="ignore", over="ignore"):
143-
perturbation = (all_weights @ eif_gt) / n_units
152+
perturbation = (all_weights @ eif_gt) / denom
144153
bootstrap_atts[:, j] = original_atts[j] + perturbation
145154

146155
# Post-treatment mask — also exclude NaN effects

diff_diff/efficient_did_covariates.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,10 @@ def compute_generated_outcomes_cov(
488488
g_mask = cohort_masks[target_g]
489489
pi_g = cohort_fractions[target_g]
490490

491+
# Guard: zero survey weight for the target cohort → no DR estimation possible
492+
if pi_g <= 0:
493+
return np.zeros((n_units, H))
494+
491495
gen_out = np.zeros((n_units, H))
492496

493497
for j, (gp, tpre) in enumerate(valid_pairs):

docs/tutorials/16_survey_did.ipynb

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,42 @@
527527
{
528528
"cell_type": "markdown",
529529
"metadata": {},
530-
"source": "## 9. Which Estimators Support Survey Design?\n\n`diff-diff` supports survey design across all estimators, though the level of support varies:\n\n| Estimator | Weights | Strata/PSU/FPC (TSL) | Replicate Weights | Survey-Aware Bootstrap |\n|-----------|---------|---------------------|-------------------|------------------------|\n| **DifferenceInDifferences** | Full | Full | -- | -- |\n| **TwoWayFixedEffects** | Full | Full | -- | -- |\n| **MultiPeriodDiD** | Full | Full | -- | -- |\n| **CallawaySantAnna** | pweight only | Full | Full | Multiplier at PSU |\n| **TripleDifference** | pweight only | Full | Full (analytical) | -- |\n| **StaggeredTripleDifference** | pweight only | Full | Full | Multiplier at PSU |\n| **SunAbraham** | Full | Full | -- | Rao-Wu rescaled |\n| **StackedDiD** | pweight only | Full (pweight only) | -- | -- |\n| **ImputationDiD** | pweight only | Partial (no FPC) | -- | Multiplier at PSU |\n| **TwoStageDiD** | pweight only | Partial (no FPC) | -- | Multiplier at PSU |\n| **ContinuousDiD** | Full | Full | Full (analytical) | Multiplier at PSU |\n| **EfficientDiD** | Full | Full | Full (analytical) | Multiplier at PSU |\n| **SyntheticDiD** | pweight only | -- | -- | Rao-Wu rescaled |\n| **TROP** | pweight only | -- | -- | Rao-Wu rescaled |\n| **BaconDecomposition** | Diagnostic | Diagnostic | -- | -- |\n\n**Legend:**\n- **Full**: All weight types (pweight/fweight/aweight) + strata/PSU/FPC + Taylor Series Linearization variance\n- **Full (pweight only)**: Full TSL support with strata/PSU/FPC, but only accepts `pweight` weight type (`fweight`/`aweight` rejected because Q-weight composition changes their semantics)\n- **Partial (no FPC)**: Weights + strata (for df) + PSU (for clustering); FPC raises `NotImplementedError`\n- **pweight only** (Weights column): Only `pweight` accepted; `fweight`/`aweight` raise an error\n- **pweight only** (TSL column): Sampling weights for point estimates; no strata/PSU/FPC design elements\n- **Diagnostic**: Weighted descriptive statistics only (no inference)\n- **--**: Not supported\n\n**Note:** `EfficientDiD` does not support `covariates` and `survey_design` simultaneously (the DR nuisance path does not yet thread survey weights). Use `covariates=None` with survey designs.\n\nFor full details, see `docs/survey-roadmap.md`."
530+
"source": [
531+
"## 9. Which Estimators Support Survey Design?\n",
532+
"\n",
533+
"`diff-diff` supports survey design across all estimators, though the level of support varies:\n",
534+
"\n",
535+
"| Estimator | Weights | Strata/PSU/FPC (TSL) | Replicate Weights | Survey-Aware Bootstrap |\n",
536+
"|-----------|---------|---------------------|-------------------|------------------------|\n",
537+
"| **DifferenceInDifferences** | Full | Full | -- | -- |\n",
538+
"| **TwoWayFixedEffects** | Full | Full | -- | -- |\n",
539+
"| **MultiPeriodDiD** | Full | Full | -- | -- |\n",
540+
"| **CallawaySantAnna** | pweight only | Full | Full | Multiplier at PSU |\n",
541+
"| **TripleDifference** | pweight only | Full | Full (analytical) | -- |\n",
542+
"| **StaggeredTripleDifference** | pweight only | Full | Full | Multiplier at PSU |\n",
543+
"| **SunAbraham** | Full | Full | -- | Rao-Wu rescaled |\n",
544+
"| **StackedDiD** | pweight only | Full (pweight only) | -- | -- |\n",
545+
"| **ImputationDiD** | pweight only | Partial (no FPC) | -- | Multiplier at PSU |\n",
546+
"| **TwoStageDiD** | pweight only | Partial (no FPC) | -- | Multiplier at PSU |\n",
547+
"| **ContinuousDiD** | Full | Full | Full (analytical) | Multiplier at PSU |\n",
548+
"| **EfficientDiD** | Full | Full | Full (analytical) | Multiplier at PSU |\n",
549+
"| **SyntheticDiD** | pweight only | -- | -- | Rao-Wu rescaled |\n",
550+
"| **TROP** | pweight only | -- | -- | Rao-Wu rescaled |\n",
551+
"| **BaconDecomposition** | Diagnostic | Diagnostic | -- | -- |\n",
552+
"\n",
553+
"**Legend:**\n",
554+
"- **Full**: All weight types (pweight/fweight/aweight) + strata/PSU/FPC + Taylor Series Linearization variance\n",
555+
"- **Full (pweight only)**: Full TSL support with strata/PSU/FPC, but only accepts `pweight` weight type (`fweight`/`aweight` rejected because Q-weight composition changes their semantics)\n",
556+
"- **Partial (no FPC)**: Weights + strata (for df) + PSU (for clustering); FPC raises `NotImplementedError`\n",
557+
"- **pweight only** (Weights column): Only `pweight` accepted; `fweight`/`aweight` raise an error\n",
558+
"- **pweight only** (TSL column): Sampling weights for point estimates; no strata/PSU/FPC design elements\n",
559+
"- **Diagnostic**: Weighted descriptive statistics only (no inference)\n",
560+
"- **--**: Not supported\n",
561+
"\n",
562+
"**Note:** `EfficientDiD` supports `covariates` and `survey_design` simultaneously. The doubly-robust (DR) path threads survey weights through WLS outcome regression, weighted sieve propensity ratios, and survey-weighted kernel smoothing.\n",
563+
"\n",
564+
"For full details, see `docs/survey-roadmap.md`."
565+
]
531566
},
532567
{
533568
"cell_type": "markdown",

tests/test_survey_phase3.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1044,6 +1044,66 @@ def test_bootstrap_covariates_survey(self, cov_survey_data):
10441044
assert np.isfinite(result.overall_se)
10451045
assert result.overall_se > 0
10461046

1047+
def test_analytical_se_differs_from_unweighted(self, cov_survey_data):
1048+
"""Survey analytical SE should differ from unweighted SE."""
1049+
from diff_diff import EfficientDiD
1050+
1051+
sd = SurveyDesign(weights="weight")
1052+
result_survey = EfficientDiD(n_bootstrap=0).fit(
1053+
cov_survey_data,
1054+
"outcome", "unit", "time", "first_treat",
1055+
covariates=["x1"],
1056+
survey_design=sd,
1057+
)
1058+
result_nosurv = EfficientDiD(n_bootstrap=0).fit(
1059+
cov_survey_data,
1060+
"outcome", "unit", "time", "first_treat",
1061+
covariates=["x1"],
1062+
)
1063+
# Non-uniform weights (1.0 + 0.3*stratum) should produce different SEs
1064+
assert result_survey.overall_se != result_nosurv.overall_se
1065+
assert np.isfinite(result_survey.overall_se)
1066+
assert result_survey.overall_se > 0
1067+
1068+
def test_bootstrap_se_in_ballpark_of_analytical(self, cov_survey_data):
1069+
"""Bootstrap SE should be in same ballpark as analytical SE."""
1070+
from diff_diff import EfficientDiD
1071+
1072+
sd = SurveyDesign(weights="weight")
1073+
result_analytical = EfficientDiD(n_bootstrap=0).fit(
1074+
cov_survey_data,
1075+
"outcome", "unit", "time", "first_treat",
1076+
covariates=["x1"],
1077+
survey_design=sd,
1078+
)
1079+
result_boot = EfficientDiD(n_bootstrap=199, seed=42).fit(
1080+
cov_survey_data,
1081+
"outcome", "unit", "time", "first_treat",
1082+
covariates=["x1"],
1083+
survey_design=sd,
1084+
)
1085+
ratio = result_boot.overall_se / result_analytical.overall_se
1086+
assert 0.3 < ratio < 3.0, (
1087+
f"Bootstrap/analytical SE ratio {ratio:.2f} outside [0.3, 3.0]"
1088+
)
1089+
1090+
def test_zero_weight_cohort_skipped(self, cov_survey_data):
1091+
"""Zero-weight cohort should be skipped with a warning."""
1092+
from diff_diff import EfficientDiD
1093+
1094+
# Set early cohort (first_treat=4) weights to near-zero
1095+
cov_survey_data = cov_survey_data.copy()
1096+
cov_survey_data.loc[cov_survey_data["first_treat"] == 4, "weight"] = 1e-15
1097+
sd = SurveyDesign(weights="weight")
1098+
result = EfficientDiD(n_bootstrap=0).fit(
1099+
cov_survey_data,
1100+
"outcome", "unit", "time", "first_treat",
1101+
covariates=["x1"],
1102+
survey_design=sd,
1103+
)
1104+
assert np.isfinite(result.overall_att)
1105+
assert np.isfinite(result.overall_se)
1106+
10471107

10481108
# =============================================================================
10491109
# Scale Invariance (applies to all estimators)

0 commit comments

Comments
 (0)