Skip to content

Commit 676d831

Browse files
igerberclaude
andcommitted
Address PR #355 R4 P0 + P3: resolve-normalize pweight-only weights + tighten boot_idx slice test
R4 P0 (scale-invariance): the pweight-only bootstrap branch was sourcing w_control / w_treated from raw panel-column weights via _extract_unit_survey_weights. The weighted-FW bootstrap objective is not scale-invariant in rw (loss scales as rw^2 via A·diag(rw), reg scales as rw), so two equivalent designs w and c*w could produce different bootstrap SE / p-value / CI with no warning. Fix: source w_control / w_treated from resolved_survey_unit.weights, which SurveyDesign.resolve() normalizes to mean=1 (survey.py L189-L203). Placebo / jackknife paths also consume the same w_control / w_treated but are scale-invariant, so their numerics are unchanged. R4 P3 (test tightening): the boot_idx × Rao-Wu regression test asserted captured rw values stayed within the known_rw[1, 15] range — too weak to catch permutation / deduplication regressions in the slice order. Tighten by reproducing the bootstrap RNG stream externally (fake_rao_wu doesn't consume rng) and asserting exact-equality between the captured rw_control vector and known_rw[:n_control][boot_idx[boot_is_control]]. New regression test: test_bootstrap_scale_invariance_under_pweight_rescaling fits the same panel under SurveyDesign("wt") vs SurveyDesign("wt_scaled") (10x rescale) and asserts SE, p-value, CI match to machine-epsilon tolerance. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 2bf3f93 commit 676d831

2 files changed

Lines changed: 124 additions & 26 deletions

File tree

diff_diff/synthetic_did.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,6 @@ def fit( # type: ignore[override]
292292

293293
# Resolve survey design
294294
from diff_diff.survey import (
295-
_extract_unit_survey_weights,
296295
_resolve_survey_for_fit,
297296
_validate_unit_constant_survey,
298297
)
@@ -426,22 +425,37 @@ def fit( # type: ignore[override]
426425
# consumes per draw.
427426
if resolved_survey is not None:
428427
_validate_unit_constant_survey(data, unit, survey_design)
429-
w_treated = _extract_unit_survey_weights(data, unit, survey_design, treated_units)
430-
w_control = _extract_unit_survey_weights(data, unit, survey_design, control_units)
431428
# Collapse to unit level for the bootstrap survey path. The
432429
# row order is [control_units..., treated_units...] so
433430
# boot_rw[:n_control] / boot_rw[n_control:] line up with the
434431
# bootstrap loop's column ordering. See
435432
# `collapse_survey_to_unit_level` in diff_diff/survey.py.
436-
from diff_diff.survey import collapse_survey_to_unit_level
437-
all_units_for_bootstrap = list(control_units) + list(treated_units)
438433
# Use `data` (not `working_data`) for the groupby — survey
439434
# design columns are unit-constant (validated above) and
440435
# covariate residualization doesn't shuffle row order, so the
441436
# collapse is invariant to which view we group on.
437+
from diff_diff.survey import collapse_survey_to_unit_level
438+
all_units_for_bootstrap = list(control_units) + list(treated_units)
442439
resolved_survey_unit = collapse_survey_to_unit_level(
443440
resolved_survey, data, unit, all_units_for_bootstrap,
444441
)
442+
# Source w_control / w_treated from resolved_survey_unit.weights
443+
# rather than re-extracting raw panel columns. resolved_survey.weights
444+
# is normalized to mean=1 by SurveyDesign.resolve() (survey.py L189-
445+
# L203), so the weighted-FW bootstrap objective — which is NOT
446+
# invariant to a global rescaling of rw — produces identical SE /
447+
# p-value / CI under SurveyDesign(weights="w") vs "c*w" (PR #355
448+
# R4 P0). Placebo / jackknife paths also consume w_control /
449+
# w_treated but are scale-invariant (np.average divides by sum;
450+
# ω_eff normalization likewise), so switching to resolved weights
451+
# doesn't change their numerics.
452+
n_control_for_split = len(control_units)
453+
w_control = resolved_survey_unit.weights[:n_control_for_split].astype(
454+
np.float64
455+
)
456+
w_treated = resolved_survey_unit.weights[n_control_for_split:].astype(
457+
np.float64
458+
)
445459
else:
446460
w_treated = None
447461
w_control = None

tests/test_methodology_sdid.py

Lines changed: 105 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -790,36 +790,120 @@ def capturing_helper(Y_pre_c, Y_pre_t_mean, rw, *args, **kwargs):
790790
sdid_mod, "compute_sdid_unit_weights_survey", capturing_helper
791791
)
792792

793-
SyntheticDiD(variance_method="bootstrap", n_bootstrap=10, seed=1).fit(
793+
bootstrap_seed = 1
794+
SyntheticDiD(
795+
variance_method="bootstrap", n_bootstrap=10, seed=bootstrap_seed,
796+
).fit(
794797
df, outcome="outcome", treatment="treated",
795798
unit="unit", time="period",
796799
post_periods=[5, 6, 7],
797800
survey_design=SurveyDesign(weights="wt", strata="stratum", psu="psu"),
798801
)
799802

800-
# For each captured rw vector: its values must all come from the
801-
# first n_control=15 positions of known_rw (never from the
802-
# treated slice [15:18]). Values may repeat across the vector
803-
# (bootstrap picks with replacement) but every element must be
804-
# ≤ n_control (positions 1..15, since we built known_rw as
805-
# arange(1, 19)). Catches either a slice-order bug (would mix in
806-
# treated-slice values 16..18) or a rw-drift bug (would produce
807-
# values outside [1, 15]).
808-
assert len(captured) >= 1, "no FW calls captured — survey dispatch broken"
803+
# Exact-equality check against a reproduced RNG stream (PR #355 R4
804+
# P3). The captured rw vectors must match known_rw[:n_control]
805+
# sliced by boot_idx[boot_is_control] value-for-value. Reproducing
806+
# the bootstrap's rng externally works because:
807+
# - fake_rao_wu does NOT consume rng (just returns known_rw),
808+
# so the only per-draw rng advance is ``rng.choice(n_total, ...)``
809+
# which yields boot_idx;
810+
# - known_rw is strictly positive, so the zero-mass retry branch
811+
# (synthetic_did.py ``_bootstrap_se``) never fires;
812+
# - a 15/3 split makes the no-control and all-control retries
813+
# vanishingly rare.
814+
# An exact-equality regression catches the sibling bugs the old
815+
# range check missed: permuted indices, deduplicated boot_idx, or
816+
# substituted ``resolved_survey_unit.weights`` lookup in place of
817+
# the known_rw slice — any of which would silently change
818+
# bootstrap SE.
809819
n_control = 15
810-
control_slice_max = float(known_rw[:n_control].max()) # = 15.0
811-
for i, rw_captured in enumerate(captured):
812-
assert rw_captured.shape[0] > 0, f"draw {i}: empty rw"
813-
assert rw_captured.max() <= control_slice_max, (
814-
f"draw {i}: captured rw max = {rw_captured.max()} exceeds "
815-
f"control-slice max ({control_slice_max}); slice order "
816-
"regressed — Rao-Wu weights mixed with treated slice."
817-
)
818-
assert rw_captured.min() >= 1.0, (
819-
f"draw {i}: captured rw min = {rw_captured.min()} below "
820-
"known_rw[0]=1; weights drifted outside the Rao-Wu output."
820+
rng_sim = np.random.default_rng(bootstrap_seed)
821+
expected_slices = []
822+
while len(expected_slices) < len(captured):
823+
boot_idx = rng_sim.choice(n_total, size=n_total, replace=True)
824+
boot_is_control = boot_idx < n_control
825+
n_co_b = int(boot_is_control.sum())
826+
if n_co_b == 0 or n_co_b == n_total:
827+
continue
828+
expected_slices.append(known_rw[:n_control][boot_idx[boot_is_control]])
829+
830+
assert len(captured) >= 1, "no FW calls captured — survey dispatch broken"
831+
for i, (rw_captured, rw_expected) in enumerate(
832+
zip(captured, expected_slices)
833+
):
834+
np.testing.assert_array_equal(
835+
rw_captured,
836+
rw_expected,
837+
err_msg=(
838+
f"draw {i}: captured rw_control differs from expected "
839+
f"known_rw[:n_control][boot_idx[boot_is_control]]. "
840+
"Regression in hybrid pairs-bootstrap + Rao-Wu "
841+
"slice ordering."
842+
),
821843
)
822844

845+
def test_bootstrap_scale_invariance_under_pweight_rescaling(self):
846+
"""Survey-bootstrap SE / p / CI are invariant to a global pweight rescaling.
847+
848+
``SurveyDesign.resolve()`` normalizes pweights/aweights to mean=1
849+
(survey.py L189-L203), which is the library's scale-invariance
850+
contract for survey-weighted fits. This test fits the same SDID
851+
panel under two SurveyDesigns — weights column ``"wt"`` vs a
852+
10x-rescaled copy ``"wt_scaled"`` — and asserts bootstrap SE,
853+
p-value, and CI agree to machine-epsilon tolerance.
854+
855+
Regression against PR #355 R4 P0: the initial PR #352 pweight-only
856+
bootstrap branch bypassed the resolved (normalized) unit-level
857+
weights and fed raw panel-column weights into the weighted-FW
858+
objective. That objective is NOT invariant to a global rescale
859+
of rw — the loss term scales as rw^2 (``A-tilde = A * diag(rw)``)
860+
while the reg term scales as rw (``zeta^2 * sum rw * omega^2``) —
861+
so any user who rescaled their pweight column (e.g. switched
862+
units) would see silently different SEs. The fix
863+
(synthetic_did.py ``fit()`` around the ``resolved_survey`` block)
864+
sources ``w_control`` and ``w_treated`` from
865+
``resolved_survey_unit.weights`` (post-normalization) rather
866+
than re-extracting raw weights via ``_extract_unit_survey_weights``.
867+
Tolerance is machine-epsilon tight because floating-point multiply-
868+
reduce ordering inside ``raw * (n / (c*raw_sum))`` vs
869+
``raw * (n / raw_sum)`` can drift by ~1 ULP; a raw-weight fallback
870+
would produce differences on the order of 1 or larger.
871+
"""
872+
from diff_diff.survey import SurveyDesign
873+
874+
df = _make_panel(n_control=12, n_treated=3, seed=42)
875+
unique_units = np.sort(df["unit"].unique())
876+
unit_weights = np.linspace(0.5, 2.5, len(unique_units))
877+
wt_map = dict(zip(unique_units, unit_weights))
878+
df["wt"] = df["unit"].map(wt_map)
879+
df["wt_scaled"] = df["wt"] * 10.0
880+
881+
kwargs = dict(
882+
outcome="outcome", treatment="treated",
883+
unit="unit", time="period",
884+
post_periods=[5, 6, 7],
885+
)
886+
result_base = SyntheticDiD(
887+
variance_method="bootstrap", n_bootstrap=50, seed=1,
888+
).fit(df, survey_design=SurveyDesign(weights="wt"), **kwargs)
889+
result_scaled = SyntheticDiD(
890+
variance_method="bootstrap", n_bootstrap=50, seed=1,
891+
).fit(df, survey_design=SurveyDesign(weights="wt_scaled"), **kwargs)
892+
893+
assert np.isfinite(result_base.se) and result_base.se > 0
894+
np.testing.assert_allclose(
895+
result_scaled.se, result_base.se, rtol=1e-13, atol=0,
896+
err_msg="bootstrap SE is not invariant to pweight global rescaling",
897+
)
898+
np.testing.assert_allclose(
899+
result_scaled.p_value, result_base.p_value, rtol=1e-12, atol=1e-14,
900+
err_msg="bootstrap p-value is not invariant to pweight global rescaling",
901+
)
902+
np.testing.assert_allclose(
903+
result_scaled.conf_int, result_base.conf_int, rtol=1e-13, atol=0,
904+
err_msg="bootstrap CI is not invariant to pweight global rescaling",
905+
)
906+
823907
def test_bootstrap_single_psu_returns_nan(self):
824908
"""Unstratified single-PSU survey design returns NaN SE (PR #352).
825909

0 commit comments

Comments
 (0)