Skip to content

Commit b552cb8

Browse files
igerberclaude
andcommitted
Fix replicate-weight scale invariance and BRR test fixtures
Normalize survey_weights_arr to sum=n in _precompute_structures() so size_gt/size_gt_ctrl denominators are scale-invariant for replicate designs. Fix BRR test fixtures to build combined replicate weights (rep_r = weight * factor) honoring combined_weights=True semantics. Add replicate scale-invariance tests for simple/event_study/group aggregation. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 35e1ff2 commit b552cb8

2 files changed

Lines changed: 95 additions & 6 deletions

File tree

diff_diff/staggered_triple_diff.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -839,6 +839,12 @@ def _precompute_structures(
839839
.reindex(all_units)
840840
.values.astype(np.float64)
841841
)
842+
# Normalize to sum=n for aggregation/rescaling (matches pweight
843+
# convention). Raw weights preserved in resolved_survey_unit for
844+
# replicate w_r/w_full ratios — those are inherently scale-invariant.
845+
sw_sum = np.sum(survey_weights_arr)
846+
if sw_sum > 0:
847+
survey_weights_arr = survey_weights_arr * (len(survey_weights_arr) / sw_sum)
842848
resolved_survey_unit = collapse_survey_to_unit_level(
843849
resolved_survey, df, unit, all_units
844850
)

tests/test_survey_staggered_ddd.py

Lines changed: 89 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -378,12 +378,13 @@ def test_brr_replicate(self, sddd_data):
378378
rng = np.random.default_rng(99)
379379
n_units = data["unit"].nunique()
380380
R = 20
381-
# Generate unit-level replicate weights
381+
# Generate combined replicate weights: rep_r = weight * factor_r
382+
# (combined_weights=True means each column includes full-sample weight)
382383
unit_ids = sorted(data["unit"].unique())
383-
rep_matrix = 1.0 + rng.standard_normal((n_units, R)) * 0.1
384-
rep_matrix = np.abs(rep_matrix) # Ensure positive
384+
base_w = data.groupby("unit")["weight"].first().reindex(unit_ids).values
385385
for r in range(R):
386-
unit_w = dict(zip(unit_ids, rep_matrix[:, r]))
386+
factor = np.abs(1.0 + rng.standard_normal(n_units) * 0.1)
387+
unit_w = dict(zip(unit_ids, base_w * factor))
387388
data[f"rep_{r}"] = data["unit"].map(unit_w)
388389

389390
rep_cols = [f"rep_{r}" for r in range(R)]
@@ -422,9 +423,10 @@ def test_brr_with_bootstrap_rejected(self, sddd_data):
422423
n_units = data["unit"].nunique()
423424
unit_ids = sorted(data["unit"].unique())
424425
R = 10
425-
rep_matrix = np.abs(1.0 + rng.standard_normal((n_units, R)) * 0.1)
426+
base_w = data.groupby("unit")["weight"].first().reindex(unit_ids).values
426427
for r in range(R):
427-
unit_w = dict(zip(unit_ids, rep_matrix[:, r]))
428+
factor = np.abs(1.0 + rng.standard_normal(n_units) * 0.1)
429+
unit_w = dict(zip(unit_ids, base_w * factor))
428430
data[f"rep_{r}"] = data["unit"].map(unit_w)
429431

430432
rep_cols = [f"rep_{r}" for r in range(R)]
@@ -449,6 +451,87 @@ def test_brr_with_bootstrap_rejected(self, sddd_data):
449451
)
450452

451453

454+
# ---------------------------------------------------------------------------
455+
# Replicate-weight scale invariance
456+
# ---------------------------------------------------------------------------
457+
458+
459+
def _make_brr_data(sddd_data, rng_seed=99, R=20):
460+
"""Helper: build combined BRR replicate weights for sddd_data."""
461+
data = sddd_data.copy()
462+
rng = np.random.default_rng(rng_seed)
463+
unit_ids = sorted(data["unit"].unique())
464+
n_units = len(unit_ids)
465+
base_w = data.groupby("unit")["weight"].first().reindex(unit_ids).values
466+
for r in range(R):
467+
factor = np.abs(1.0 + rng.standard_normal(n_units) * 0.1)
468+
unit_w = dict(zip(unit_ids, base_w * factor))
469+
data[f"rep_{r}"] = data["unit"].map(unit_w)
470+
rep_cols = [f"rep_{r}" for r in range(R)]
471+
return data, rep_cols
472+
473+
474+
class TestReplicateScaleInvariance:
475+
"""Rescaling all weights + replicates by constant k must not change results."""
476+
477+
@pytest.mark.parametrize("agg", ["simple", "event_study", "group"])
478+
def test_scale_invariance(self, sddd_data, agg):
479+
data, rep_cols = _make_brr_data(sddd_data)
480+
k = 5.0
481+
482+
sd1 = SurveyDesign(
483+
weights="weight",
484+
replicate_weights=rep_cols,
485+
replicate_method="BRR",
486+
)
487+
est = StaggeredTripleDifference(estimation_method="reg")
488+
res1 = est.fit(
489+
data,
490+
"outcome",
491+
"unit",
492+
"period",
493+
"first_treat",
494+
"eligibility",
495+
aggregate=agg,
496+
survey_design=sd1,
497+
)
498+
499+
# Scale all weights and replicate columns by k
500+
data_k = data.copy()
501+
data_k["weight"] = data_k["weight"] * k
502+
for col in rep_cols:
503+
data_k[col] = data_k[col] * k
504+
505+
sd2 = SurveyDesign(
506+
weights="weight",
507+
replicate_weights=rep_cols,
508+
replicate_method="BRR",
509+
)
510+
res2 = est.fit(
511+
data_k,
512+
"outcome",
513+
"unit",
514+
"period",
515+
"first_treat",
516+
"eligibility",
517+
aggregate=agg,
518+
survey_design=sd2,
519+
)
520+
521+
np.testing.assert_allclose(
522+
res2.overall_att,
523+
res1.overall_att,
524+
atol=1e-12,
525+
err_msg=f"ATT changed with weight rescaling (agg={agg})",
526+
)
527+
np.testing.assert_allclose(
528+
res2.overall_se,
529+
res1.overall_se,
530+
rtol=1e-6,
531+
err_msg=f"SE changed with weight rescaling (agg={agg})",
532+
)
533+
534+
452535
# ---------------------------------------------------------------------------
453536
# Survey-weighted aggregation point estimates
454537
# ---------------------------------------------------------------------------

0 commit comments

Comments
 (0)