Skip to content

Commit e7eefa1

Browse files
igerberclaude
andcommitted
Normalize weights in SRS fallback for scale invariance
Both SRS fallback branches now normalize positive weights to mean=1 before computing variance, ensuring SEs are invariant to constant weight rescaling (important for replicate designs that preserve raw weight scale). Add scale-invariance regression test with 5x rescaling. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent bb5d660 commit e7eefa1

2 files changed

Lines changed: 56 additions & 4 deletions

File tree

diff_diff/prep.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1373,10 +1373,17 @@ def _cell_mean_variance(
13731373
y_bar = float(np.sum(w_valid * y_clean) / sum_w)
13741374

13751375
# SRS fallback if below min_n threshold
1376+
# Normalize positive weights to mean=1 so fallback is scale-invariant
1377+
# (replicate designs preserve raw weight scale per survey.py:L189-240)
13761378
used_srs = False
13771379
if n_valid < min_n:
1378-
resid_sq = w_valid * (y_clean - y_bar) ** 2
1379-
variance = float(np.sum(resid_sq) / (sum_w**2) * n_valid / (n_valid - 1))
1380+
w_norm = w_valid.copy()
1381+
w_pos = w_norm[w_norm > 0]
1382+
if len(w_pos) > 0:
1383+
w_norm[w_norm > 0] = w_pos / w_pos.mean()
1384+
sum_wn = float(np.sum(w_norm))
1385+
resid_sq = w_norm * (y_clean - y_bar) ** 2
1386+
variance = float(np.sum(resid_sq) / (sum_wn**2) * n_valid / (n_valid - 1))
13801387
return y_bar, max(variance, 0.0), n_valid, True
13811388

13821389
# Full-design domain estimation: construct full-length psi with zeros
@@ -1396,8 +1403,13 @@ def _cell_mean_variance(
13961403

13971404
# SRS fallback when design-based variance is unidentifiable
13981405
if np.isnan(variance):
1399-
resid_sq = w_valid * (y_clean - y_bar) ** 2
1400-
variance = float(np.sum(resid_sq) / (sum_w**2) * n_valid / (n_valid - 1))
1406+
w_norm = w_valid.copy()
1407+
w_pos = w_norm[w_norm > 0]
1408+
if len(w_pos) > 0:
1409+
w_norm[w_norm > 0] = w_pos / w_pos.mean()
1410+
sum_wn = float(np.sum(w_norm))
1411+
resid_sq = w_norm * (y_clean - y_bar) ** 2
1412+
variance = float(np.sum(resid_sq) / (sum_wn**2) * n_valid / (n_valid - 1))
14011413
used_srs = True
14021414

14031415
return y_bar, max(float(variance), 0.0), n_valid, used_srs

tests/test_prep.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2855,3 +2855,43 @@ def test_replicate_weight_min_n_fallback(self):
28552855
panel_rep["outcome_se"].values,
28562856
rtol=1e-6,
28572857
)
2858+
2859+
def test_srs_fallback_scale_invariant(self):
2860+
"""SRS fallback SEs are invariant to constant weight rescaling."""
2861+
rng = np.random.RandomState(55)
2862+
n = 60
2863+
data = pd.DataFrame(
2864+
{
2865+
"geo": np.repeat(["A", "B", "C"], n // 3),
2866+
"time": np.ones(n, dtype=int),
2867+
"wt": rng.uniform(0.5, 2.0, n),
2868+
"y": rng.normal(10, 2, n),
2869+
}
2870+
)
2871+
design1 = SurveyDesign(weights="wt")
2872+
2873+
# Force SRS fallback with high min_n
2874+
with pytest.warns(UserWarning, match="SRS fallback"):
2875+
panel1, _ = aggregate_survey(
2876+
data,
2877+
by=["geo", "time"],
2878+
outcomes="y",
2879+
survey_design=design1,
2880+
min_n=9999,
2881+
)
2882+
2883+
# Rescale weights by 5x → should give identical SEs
2884+
data2 = data.copy()
2885+
data2["wt"] = data2["wt"] * 5.0
2886+
design2 = SurveyDesign(weights="wt")
2887+
with pytest.warns(UserWarning, match="SRS fallback"):
2888+
panel2, _ = aggregate_survey(
2889+
data2,
2890+
by=["geo", "time"],
2891+
outcomes="y",
2892+
survey_design=design2,
2893+
min_n=9999,
2894+
)
2895+
2896+
np.testing.assert_allclose(panel1["y_se"].values, panel2["y_se"].values, rtol=1e-10)
2897+
np.testing.assert_allclose(panel1["y_mean"].values, panel2["y_mean"].values, rtol=1e-10)

0 commit comments

Comments
 (0)