Skip to content

Commit f8e959c

Browse files
igerberclaude
andcommitted
Address PR #370 R4 review (1 P0 + 1 P1)
R4 P0 (Methodology) -- Yatchew test statistic was not invariant to uniform pweight rescaling. The formula `T_hr = sqrt(sum(w)) * (...)` makes T_hr scale as sqrt(c) under weights -> w * c, so weights=w and weights=100*w produced different p-values for the same design. Worse, SurveyDesign.resolve() normalizes pweights to mean=1 internally, so the survey= entry path and the weights= shortcut disagreed numerically. Fix: normalize per-unit pweights to mean=1 at every helper entry (stute_test, yatchew_hr_test, stute_joint_pretest) and at the workflow resolution helper. Matches SurveyDesign.resolve() convention; makes the Yatchew statistic scale-invariant; ensures weights=w and survey=SurveyDesign(weights="w") produce identical results for the same design. Stute is internally scale-invariant in functional form but normalization is required so the bootstrap helper sees the same weight vector under both entry paths (cross-path numerical agreement). R4 P1 (Code Quality) -- column-vector weights (e.g. `df[["w"]].to_numpy()` producing (G, 1)) silently broadcast through weighted moments / CvM sums instead of raising. Fix: validate via `_validate_1d_numeric` on all `weights=` arrays in stute_test, yatchew_hr_test, stute_joint_pretest; add explicit ndim check in `_resolve_pretest_unit_weights` with a hint about the common df[["w"]].to_numpy() mistake. 6 new regression tests in TestPhase45CR1Regressions: - test_yatchew_weights_scale_invariant (weights=w vs weights=100*w) - test_stute_weights_scale_invariant (mirror for Stute) - test_workflow_weights_eq_survey_at_overall_path (weights= shortcut and survey=SurveyDesign(...) produce identical Yatchew + Stute results, atol=1e-10) - test_stute_test_rejects_2d_weights / test_yatchew_hr_test_rejects_2d_weights / test_workflow_rejects_2d_weights (column-vector rejection at all three direct-helper / workflow entry points) 177 pretest tests pass (was 171 after R3). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 39b29a9 commit f8e959c

2 files changed

Lines changed: 154 additions & 12 deletions

File tree

diff_diff/had_pretests.py

Lines changed: 54 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1548,19 +1548,15 @@ def stute_test(
15481548
# Phase 4.5 C: resolve effective per-unit weights (None on the
15491549
# unweighted path, preserves bit-exact regression). When survey= is
15501550
# supplied, w is taken from the resolved design.
1551+
# R4 P1: validate 1D explicitly so column-vector inputs (e.g.
1552+
# df[["w"]].to_numpy()) raise instead of silently broadcasting.
15511553
if survey is not None:
1552-
w_arr = np.asarray(survey.weights, dtype=np.float64)
1554+
w_arr = _validate_1d_numeric(np.asarray(survey.weights), "stute_test: survey.weights")
15531555
if w_arr.shape[0] != G:
15541556
raise ValueError(
15551557
f"stute_test: survey.weights length {w_arr.shape[0]} does not "
15561558
f"match d/dy length {G}."
15571559
)
1558-
# R1 P0: strictly-positive weights at the per-unit level (mirrors
1559-
# workflow guard in _resolve_pretest_unit_weights). Zero-weight
1560-
# units would leak into the dose-variation check + CvM cusum +
1561-
# bootstrap refit, producing silent wrong pretest decisions on
1562-
# subpopulation-restricted designs (e.g. only zero-weight units
1563-
# carry dose variation -> spurious finite test statistic).
15641560
if (w_arr <= 0).any():
15651561
raise ValueError(
15661562
"stute_test: survey weights must be strictly positive. "
@@ -1570,7 +1566,7 @@ def stute_test(
15701566
"weight subpopulation before calling stute_test."
15711567
)
15721568
elif weights is not None:
1573-
w_arr = np.asarray(weights, dtype=np.float64)
1569+
w_arr = _validate_1d_numeric(np.asarray(weights), "stute_test: weights")
15741570
if w_arr.shape[0] != G:
15751571
raise ValueError(
15761572
f"stute_test: weights length {w_arr.shape[0]} does not match " f"d/dy length {G}."
@@ -1584,6 +1580,17 @@ def stute_test(
15841580
else:
15851581
w_arr = None
15861582

1583+
# R4 P0: normalize pweights to mean=1 (matches SurveyDesign.resolve()
1584+
# convention). Makes the test statistic scale-invariant under uniform
1585+
# rescaling of weights AND ensures weights= shortcut and
1586+
# survey=SurveyDesign(weights=...) produce identical results for the
1587+
# same design. Stute is internally scale-invariant in functional form,
1588+
# but the survey-aware bootstrap helper consumes weight values
1589+
# directly under non-trivial PSU/strata, so normalization is required
1590+
# for cross-path agreement.
1591+
if w_arr is not None:
1592+
w_arr = w_arr * (float(w_arr.shape[0]) / float(np.sum(w_arr)))
1593+
15871594
if w_arr is None:
15881595
a_hat, b_hat, eps = _fit_ols_intercept_slope(d_arr, dy_arr)
15891596
else:
@@ -1895,8 +1902,9 @@ def yatchew_hr_test(
18951902
# Phase 4.5 C: resolve effective per-unit weights. Strictly positive
18961903
# required (the adjacent-difference formula divides by sum(w_avg) which
18971904
# collapses to zero in any contiguous-zero block).
1905+
# R4 P1: validate 1D explicitly so column-vector inputs raise.
18981906
if survey is not None:
1899-
w_arr = np.asarray(survey.weights, dtype=np.float64)
1907+
w_arr = _validate_1d_numeric(np.asarray(survey.weights), "yatchew_hr_test: survey.weights")
19001908
if w_arr.shape[0] != G:
19011909
raise ValueError(
19021910
f"yatchew_hr_test: survey.weights length {w_arr.shape[0]} "
@@ -1909,7 +1917,7 @@ def yatchew_hr_test(
19091917
"zero-weight blocks)."
19101918
)
19111919
elif weights is not None:
1912-
w_arr = np.asarray(weights, dtype=np.float64)
1920+
w_arr = _validate_1d_numeric(np.asarray(weights), "yatchew_hr_test: weights")
19131921
if w_arr.shape[0] != G:
19141922
raise ValueError(
19151923
f"yatchew_hr_test: weights length {w_arr.shape[0]} does not "
@@ -1924,6 +1932,17 @@ def yatchew_hr_test(
19241932
else:
19251933
w_arr = None
19261934

1935+
# R4 P0: normalize pweights to mean=1 (matches SurveyDesign.resolve()
1936+
# convention). Yatchew uses sqrt(sum(w)) as the effective sample size,
1937+
# which without normalization would scale as sqrt(c) under uniform
1938+
# rescaling weights -> w * c, producing different p-values for
1939+
# weights=w vs weights=100*w. Normalization makes the statistic
1940+
# scale-invariant AND ensures weights= and survey=SurveyDesign(...)
1941+
# produce identical results (the latter resolve()s to mean=1
1942+
# internally, the former previously did not).
1943+
if w_arr is not None:
1944+
w_arr = w_arr * (float(w_arr.shape[0]) / float(np.sum(w_arr)))
1945+
19271946
if G < _MIN_G_YATCHEW:
19281947
warnings.warn(
19291948
f"yatchew_hr_test: G = {G} is below the minimum {_MIN_G_YATCHEW} "
@@ -2682,8 +2701,11 @@ def stute_joint_pretest(
26822701

26832702
# Phase 4.5 C: resolve effective per-unit weights (None → bit-exact
26842703
# unweighted path).
2704+
# R4 P1: validate 1D explicitly so column-vector inputs raise.
26852705
if survey is not None:
2686-
w_arr = np.asarray(survey.weights, dtype=np.float64)
2706+
w_arr = _validate_1d_numeric(
2707+
np.asarray(survey.weights), "stute_joint_pretest: survey.weights"
2708+
)
26872709
if w_arr.shape[0] != G:
26882710
raise ValueError(
26892711
f"stute_joint_pretest: survey.weights length {w_arr.shape[0]} "
@@ -2698,7 +2720,7 @@ def stute_joint_pretest(
26982720
"population mass."
26992721
)
27002722
elif weights is not None:
2701-
w_arr = np.asarray(weights, dtype=np.float64)
2723+
w_arr = _validate_1d_numeric(np.asarray(weights), "stute_joint_pretest: weights")
27022724
if w_arr.shape[0] != G:
27032725
raise ValueError(
27042726
f"stute_joint_pretest: weights length {w_arr.shape[0]} does "
@@ -2712,6 +2734,11 @@ def stute_joint_pretest(
27122734
else:
27132735
w_arr = None
27142736

2737+
# R4 P0: normalize pweights to mean=1 (matches SurveyDesign.resolve()
2738+
# convention; same fix as stute_test / yatchew_hr_test).
2739+
if w_arr is not None:
2740+
w_arr = w_arr * (float(w_arr.shape[0]) / float(np.sum(w_arr)))
2741+
27152742
idx = np.argsort(doses_arr, kind="stable")
27162743
d_sorted = doses_arr[idx]
27172744

@@ -2915,6 +2942,16 @@ def _resolve_pretest_unit_weights(
29152942
)
29162943
if weights is not None:
29172944
weights_arr = np.asarray(weights, dtype=np.float64)
2945+
# R4 P1: validate 1D explicitly (column-vector inputs would otherwise
2946+
# broadcast through downstream computations and silently corrupt
2947+
# results).
2948+
if weights_arr.ndim != 1:
2949+
raise ValueError(
2950+
f"{caller_name}: weights must be 1-dimensional, got shape "
2951+
f"{weights_arr.shape}. (A common mistake is passing "
2952+
"df[['w']].to_numpy() which produces (N, 1); use "
2953+
"df['w'].to_numpy() for (N,).)"
2954+
)
29182955
weights_unit = _aggregate_unit_weights(data, weights_arr, unit_col)
29192956
# R1 P0: strictly-positive weights required on the pweight shortcut
29202957
# (matches stute_test/yatchew_hr_test direct entry behavior; the CvM
@@ -2927,6 +2964,11 @@ def _resolve_pretest_unit_weights(
29272964
"mass; use survey= with explicit lonely-PSU handling for "
29282965
"principled subpopulation analysis."
29292966
)
2967+
# R4 P0: normalize per-unit weights to mean=1 (matches
2968+
# SurveyDesign.resolve() convention so weights= and survey= entry
2969+
# paths produce identical statistic values; ensures Yatchew is
2970+
# scale-invariant under uniform rescaling).
2971+
weights_unit = weights_unit * (float(weights_unit.shape[0]) / float(np.sum(weights_unit)))
29302972
return weights_unit, None
29312973
# survey is not None
29322974
if not hasattr(survey, "resolve"):

tests/test_had_pretests.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3724,3 +3724,103 @@ def test_workflow_single_psu_propagates_nan_through_stute(self):
37243724
assert report.yatchew is not None and np.isfinite(report.yatchew.p_value)
37253725
# Verdict carries the linearity-conditional suffix.
37263726
assert "linearity-conditional verdict" in report.verdict
3727+
3728+
# --- R4 P0: weight-scale invariance + cross-path agreement ------------
3729+
3730+
def test_yatchew_weights_scale_invariant(self):
3731+
"""R4 P0: Yatchew test statistic must be invariant under uniform
3732+
rescaling of weights. Pre-fix `T_hr = sqrt(sum(w)) * (...)` made
3733+
the stat scale as sqrt(c), so weights=w and weights=100*w gave
3734+
different p-values. Fix: helper normalizes pweights to mean=1
3735+
before any computation."""
3736+
d, dy = _linear_dgp(G=30, beta=2.0, sigma=0.3)
3737+
w = np.random.default_rng(7).uniform(0.5, 2.0, size=30)
3738+
r1 = yatchew_hr_test(d, dy, weights=w)
3739+
r2 = yatchew_hr_test(d, dy, weights=100.0 * w)
3740+
np.testing.assert_allclose(r1.t_stat_hr, r2.t_stat_hr, atol=1e-12, rtol=1e-12)
3741+
np.testing.assert_allclose(r1.p_value, r2.p_value, atol=1e-12, rtol=1e-12)
3742+
3743+
def test_stute_weights_scale_invariant(self):
3744+
"""R4 P0 mirror: Stute is internally scale-invariant in functional
3745+
form, but normalization is required so weights= and survey=
3746+
entry paths agree numerically."""
3747+
d, dy = _linear_dgp(G=30, beta=2.0, sigma=0.3)
3748+
w = np.random.default_rng(7).uniform(0.5, 2.0, size=30)
3749+
r1 = stute_test(d, dy, weights=w, n_bootstrap=199, seed=0)
3750+
r2 = stute_test(d, dy, weights=100.0 * w, n_bootstrap=199, seed=0)
3751+
np.testing.assert_allclose(r1.cvm_stat, r2.cvm_stat, atol=1e-12, rtol=1e-12)
3752+
np.testing.assert_allclose(r1.p_value, r2.p_value, atol=1e-12, rtol=1e-12)
3753+
3754+
def test_workflow_weights_eq_survey_at_overall_path(self):
3755+
"""R4 P0: workflow's weights= shortcut and survey=SurveyDesign(
3756+
weights="w") must produce identical Yatchew/Stute results for
3757+
the same design. SurveyDesign.resolve() normalizes pweights to
3758+
mean=1; the helper now applies the same normalization on the
3759+
weights= path so both paths agree numerically."""
3760+
from diff_diff import SurveyDesign
3761+
3762+
df = self._make_overall_panel(with_w_col=True)
3763+
# Build a per-row weights array matching df["w"] for the shortcut.
3764+
weights_per_row = df["w"].to_numpy()
3765+
with pytest.warns(UserWarning):
3766+
r_weights = did_had_pretest_workflow(
3767+
df,
3768+
"y",
3769+
"d",
3770+
"time",
3771+
"unit",
3772+
weights=weights_per_row,
3773+
n_bootstrap=199,
3774+
seed=0,
3775+
)
3776+
with pytest.warns(UserWarning):
3777+
r_survey = did_had_pretest_workflow(
3778+
df,
3779+
"y",
3780+
"d",
3781+
"time",
3782+
"unit",
3783+
survey=SurveyDesign(weights="w"),
3784+
n_bootstrap=199,
3785+
seed=0,
3786+
)
3787+
# Yatchew: closed-form, must match exactly under mean=1 normalization.
3788+
assert r_weights.yatchew is not None and r_survey.yatchew is not None
3789+
np.testing.assert_allclose(
3790+
r_weights.yatchew.t_stat_hr,
3791+
r_survey.yatchew.t_stat_hr,
3792+
atol=1e-10,
3793+
rtol=1e-10,
3794+
)
3795+
# Stute: bootstrap is seeded; same multiplier matrix shape under
3796+
# both paths means same RNG draws -> identical p-values.
3797+
assert r_weights.stute is not None and r_survey.stute is not None
3798+
np.testing.assert_allclose(
3799+
r_weights.stute.cvm_stat, r_survey.stute.cvm_stat, atol=1e-10, rtol=1e-10
3800+
)
3801+
np.testing.assert_allclose(
3802+
r_weights.stute.p_value, r_survey.stute.p_value, atol=1e-10, rtol=1e-10
3803+
)
3804+
3805+
# --- R4 P1: 1D weights validation ------------------------------------
3806+
3807+
def test_stute_test_rejects_2d_weights(self):
3808+
"""R4 P1: column-vector weights must raise, not silently broadcast."""
3809+
d, dy = _linear_dgp(G=30)
3810+
w_2d = np.ones((30, 1)) # common df[["w"]].to_numpy() pattern
3811+
with pytest.raises(ValueError, match="1-dimensional"):
3812+
stute_test(d, dy, weights=w_2d, n_bootstrap=199, seed=0)
3813+
3814+
def test_yatchew_hr_test_rejects_2d_weights(self):
3815+
d, dy = _linear_dgp(G=30)
3816+
w_2d = np.ones((30, 1))
3817+
with pytest.raises(ValueError, match="1-dimensional"):
3818+
yatchew_hr_test(d, dy, weights=w_2d)
3819+
3820+
def test_workflow_rejects_2d_weights(self):
3821+
df = self._make_overall_panel()
3822+
w_2d = np.ones((40, 1))
3823+
with pytest.raises(ValueError, match="1-dimensional"):
3824+
did_had_pretest_workflow(
3825+
df, "y", "d", "time", "unit", weights=w_2d, n_bootstrap=199, seed=0
3826+
)

0 commit comments

Comments
 (0)