Skip to content

Commit 35e1ff2

Browse files
igerberclaude
andcommitted
Add replicate+bootstrap guard, effective_df propagation, and test coverage
Block unsupported replicate-weight + n_bootstrap>0 combination matching CallawaySantAnna guard. Propagate _effective_df from _aggregate_simple() to df_survey for correct replicate-weight inference. Add tests for replicate+bootstrap rejection and survey-weighted aggregation point estimates. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 5d62910 commit 35e1ff2

2 files changed

Lines changed: 125 additions & 1 deletion

File tree

diff_diff/staggered_triple_diff.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -496,9 +496,15 @@ def fit(
496496
df_agg.loc[df_agg[eligibility] == 0, "first_treat"] = 0
497497

498498
# Overall ATT via aggregation mixin
499-
overall_att, overall_se, _effective_df = self._aggregate_simple(
499+
overall_att, overall_se, overall_effective_df = self._aggregate_simple(
500500
group_time_effects, influence_func_info, df_agg, unit, precomputed_agg
501501
)
502+
# Use per-statistic effective df from replicate aggregation if available;
503+
# otherwise fall back to the original df from the survey design.
504+
if overall_effective_df is not None:
505+
df_survey = overall_effective_df
506+
if survey_metadata is not None:
507+
survey_metadata.df_survey = df_survey
502508
overall_t_stat, overall_p_value, overall_conf_int = safe_inference(
503509
overall_att, overall_se, alpha=self.alpha, df=df_survey
504510
)
@@ -527,6 +533,20 @@ def fit(
527533
unit,
528534
)
529535

536+
# Reject replicate-weight designs for bootstrap — replicate variance
537+
# is an analytical alternative, not compatible with bootstrap
538+
if (
539+
self.n_bootstrap > 0
540+
and resolved_survey is not None
541+
and hasattr(resolved_survey, "uses_replicate_variance")
542+
and resolved_survey.uses_replicate_variance
543+
):
544+
raise NotImplementedError(
545+
"StaggeredTripleDifference bootstrap (n_bootstrap > 0) is not "
546+
"supported with replicate-weight survey designs. Replicate "
547+
"weights provide analytical variance; use n_bootstrap=0 instead."
548+
)
549+
530550
# Bootstrap
531551
bootstrap_results = None
532552
cband_crit_value = None

tests/test_survey_staggered_ddd.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,110 @@ def test_brr_replicate(self, sddd_data):
408408
assert res.overall_se > 0
409409

410410

411+
# ---------------------------------------------------------------------------
412+
# Replicate + bootstrap rejection
413+
# ---------------------------------------------------------------------------
414+
415+
416+
class TestReplicateBootstrapRejection:
417+
"""Replicate weights + n_bootstrap>0 raises NotImplementedError."""
418+
419+
def test_brr_with_bootstrap_rejected(self, sddd_data):
420+
data = sddd_data.copy()
421+
rng = np.random.default_rng(99)
422+
n_units = data["unit"].nunique()
423+
unit_ids = sorted(data["unit"].unique())
424+
R = 10
425+
rep_matrix = np.abs(1.0 + rng.standard_normal((n_units, R)) * 0.1)
426+
for r in range(R):
427+
unit_w = dict(zip(unit_ids, rep_matrix[:, r]))
428+
data[f"rep_{r}"] = data["unit"].map(unit_w)
429+
430+
rep_cols = [f"rep_{r}" for r in range(R)]
431+
sd = SurveyDesign(
432+
weights="weight",
433+
replicate_weights=rep_cols,
434+
replicate_method="BRR",
435+
)
436+
est = StaggeredTripleDifference(
437+
estimation_method="reg",
438+
n_bootstrap=49,
439+
)
440+
with pytest.raises(NotImplementedError, match="replicate"):
441+
est.fit(
442+
data,
443+
"outcome",
444+
"unit",
445+
"period",
446+
"first_treat",
447+
"eligibility",
448+
survey_design=sd,
449+
)
450+
451+
452+
# ---------------------------------------------------------------------------
453+
# Survey-weighted aggregation point estimates
454+
# ---------------------------------------------------------------------------
455+
456+
457+
class TestSurveyWeightedAggregation:
458+
"""Survey weights change aggregation point estimates (not just SEs)."""
459+
460+
def test_unequal_cohort_weights_change_aggregate(self):
461+
"""Cohorts with very different survey weights produce different
462+
aggregated ATT from unweighted."""
463+
# Create data where cohort g=3 units have weight=10 and g=4 have weight=1
464+
data = _make_staggered_ddd_data(n_units=200, seed=123)
465+
rng = np.random.default_rng(123)
466+
unit_df = data.groupby("unit")["first_treat"].first()
467+
# Assign extreme weights: g=3 units 10x heavier than g=4
468+
w_map = {}
469+
for uid, g in unit_df.items():
470+
if g == 3:
471+
w_map[uid] = 10.0 + rng.uniform(0, 1)
472+
elif g == 4:
473+
w_map[uid] = 1.0 + rng.uniform(0, 0.1)
474+
else:
475+
w_map[uid] = 3.0 + rng.uniform(0, 0.5)
476+
data["skewed_w"] = data["unit"].map(w_map)
477+
478+
est = StaggeredTripleDifference(estimation_method="reg")
479+
480+
# Unweighted
481+
res_uw = est.fit(
482+
data,
483+
"outcome",
484+
"unit",
485+
"period",
486+
"first_treat",
487+
"eligibility",
488+
aggregate="simple",
489+
)
490+
491+
# Skewed survey weights
492+
sd = SurveyDesign(weights="skewed_w")
493+
res_w = est.fit(
494+
data,
495+
"outcome",
496+
"unit",
497+
"period",
498+
"first_treat",
499+
"eligibility",
500+
aggregate="simple",
501+
survey_design=sd,
502+
)
503+
504+
# Both should be finite
505+
assert np.isfinite(res_uw.overall_att)
506+
assert np.isfinite(res_w.overall_att)
507+
508+
# Aggregated ATT should differ due to different cohort weighting
509+
assert abs(res_w.overall_att - res_uw.overall_att) > 1e-6, (
510+
f"Expected aggregate ATTs to differ with skewed weights: "
511+
f"weighted={res_w.overall_att}, unweighted={res_uw.overall_att}"
512+
)
513+
514+
411515
# ---------------------------------------------------------------------------
412516
# pweight-only validation
413517
# ---------------------------------------------------------------------------

0 commit comments

Comments
 (0)