Skip to content

Commit b285dc7

Browse files
igerberclaude
andcommitted
Fix replicate df_survey to use analysis weights and mse=False rscales centering
Use analysis-weight matrix (rep * full-sample weights when combined_weights=False) for rank-based df computation, matching R's survey::degf(). When mse=False and replicate_rscales has zero entries, exclude zero-scaled replicates from the centering mean, matching R's svrVar() convention. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 81ef2f5 commit b285dc7

3 files changed

Lines changed: 188 additions & 8 deletions

File tree

diff_diff/survey.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -574,9 +574,15 @@ def df_survey(self) -> Optional[int]:
574574
if self.uses_replicate_variance:
575575
if self.replicate_weights is None or self.n_replicates < 2:
576576
return None
577-
# Rank-based df from replicate weight matrix, matching
578-
# R's survey::degf() for svrepdesign objects
579-
rank = int(np.linalg.matrix_rank(self.replicate_weights))
577+
# Rank-based df from analysis-weight matrix, matching
578+
# R's survey::degf() which uses weights(design, "analysis").
579+
# For combined_weights=True, replicate cols ARE analysis weights.
580+
# For combined_weights=False, analysis weights = rep * full-sample.
581+
if self.combined_weights:
582+
analysis_weights = self.replicate_weights
583+
else:
584+
analysis_weights = self.replicate_weights * self.weights[:, np.newaxis]
585+
rank = int(np.linalg.matrix_rank(analysis_weights))
580586
return max(rank - 1, 1) if rank > 1 else None
581587
if self.psu is not None and self.n_psu > 0:
582588
if self.strata is not None and self.n_strata > 0:
@@ -1375,10 +1381,19 @@ def compute_replicate_vcov(
13751381

13761382
# Compute variance by method
13771383
# Support mse=False: center on replicate mean instead of full-sample estimate
1384+
# When rscales present and mse=False, center only over rscales > 0
1385+
# (R's svrVar convention — zero-scaled replicates should not shift center)
13781386
if resolved.mse:
13791387
center = c
13801388
else:
1381-
center = np.mean(coef_valid, axis=0)
1389+
if resolved.replicate_rscales is not None:
1390+
pos_scale = resolved.replicate_rscales[valid] > 0
1391+
if np.any(pos_scale):
1392+
center = np.mean(coef_valid[pos_scale], axis=0)
1393+
else:
1394+
center = np.mean(coef_valid, axis=0)
1395+
else:
1396+
center = np.mean(coef_valid, axis=0)
13821397
diffs = coef_valid - center[np.newaxis, :]
13831398

13841399
# Use custom scale/rscales if provided, else default method factor
@@ -1489,10 +1504,19 @@ def compute_replicate_if_variance(
14891504
return np.nan, n_valid
14901505

14911506
# Support mse=False: center on replicate mean
1507+
# When rscales present and mse=False, center only over rscales > 0
1508+
# (R's svrVar convention — zero-scaled replicates should not shift center)
14921509
if resolved.mse:
14931510
center = theta_full
14941511
else:
1495-
center = float(np.mean(theta_reps[valid]))
1512+
if resolved.replicate_rscales is not None:
1513+
pos_scale = resolved.replicate_rscales[valid] > 0
1514+
if np.any(pos_scale):
1515+
center = float(np.mean(theta_reps[valid][pos_scale]))
1516+
else:
1517+
center = float(np.mean(theta_reps[valid]))
1518+
else:
1519+
center = float(np.mean(theta_reps[valid]))
14961520
diffs = theta_reps[valid] - center
14971521

14981522
# Custom scale/rscales

docs/methodology/REGISTRY.md

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2011,8 +2011,11 @@ variance from the distribution of replicate estimates.
20112011
contrasts are formed via weight-ratio rescaling:
20122012
`theta_r = sum((w_r/w_full) * psi)` when `combined_weights=True`,
20132013
`theta_r = sum(w_r * psi)` when `combined_weights=False`.
2014-
- **Survey df**: Numerical rank of replicate weight matrix minus 1,
2015-
matching R's `survey::degf()`. Replaces `n_PSU - n_strata`.
2014+
- **Survey df**: Numerical rank of the analysis-weight matrix minus 1,
2015+
matching R's `survey::degf()`. For `combined_weights=True` (default),
2016+
analysis weights are the raw replicate columns. For `combined_weights=False`,
2017+
analysis weights are `replicate_weights * full_sample_weights`.
2018+
Replaces `n_PSU - n_strata`.
20162019
- **Mutual exclusion**: Replicate weights cannot be combined with
20172020
strata/psu/fpc (the replicates encode design structure implicitly)
20182021
- **Design parameters** (matching R `svrepdesign()`):
@@ -2023,7 +2026,9 @@ variance from the distribution of replicate estimates.
20232026
- `replicate_rscales`: per-replicate scaling factors (vector of length R)
20242027
- `mse` (default False, matching R's `survey::svrepdesign()`): if True,
20252028
center variance on full-sample estimate; if False, center on mean of
2026-
replicate estimates.
2029+
replicate estimates. When `replicate_rscales` contains zero entries
2030+
and `mse=False`, centering excludes zero-scaled replicates, matching
2031+
R's `survey::svrVar()` convention.
20272032
- **Note:** Replicate columns are NOT normalized — raw values are preserved
20282033
to maintain correct weight ratios in the IF path.
20292034
- **Note:** JKn requires explicit `replicate_strata` (per-replicate stratum

tests/test_survey_phase6.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -870,6 +870,157 @@ def test_replicate_if_no_divide_by_zero_warning(self):
870870
assert np.isfinite(v)
871871

872872

873+
class TestReplicateEdgeCases:
874+
"""Regression tests for analysis-weight df and rscales centering."""
875+
876+
def test_df_survey_combined_weights_false(self):
877+
"""df_survey uses analysis-weight rank when combined_weights=False."""
878+
from diff_diff.survey import ResolvedSurveyDesign
879+
880+
np.random.seed(42)
881+
n = 50
882+
R = 5
883+
weights = 1.0 + np.random.exponential(0.5, n)
884+
# Perturbation factors (not full weights)
885+
rep_factors = np.random.uniform(0.8, 1.2, (n, R))
886+
887+
resolved = ResolvedSurveyDesign(
888+
weights=weights, weight_type="pweight",
889+
strata=None, psu=None, fpc=None,
890+
n_strata=0, n_psu=0, lonely_psu="remove",
891+
replicate_weights=rep_factors,
892+
replicate_method="BRR", n_replicates=R,
893+
combined_weights=False,
894+
)
895+
# df should match rank of analysis weights (rep * full-sample)
896+
analysis_weights = rep_factors * weights[:, np.newaxis]
897+
expected_rank = int(np.linalg.matrix_rank(analysis_weights))
898+
expected_df = max(expected_rank - 1, 1)
899+
assert resolved.df_survey == expected_df
900+
901+
# Verify it differs from raw perturbation-factor rank when weights
902+
# cause a rank reduction (e.g., zero full-sample weights)
903+
weights_with_zeros = weights.copy()
904+
weights_with_zeros[:10] = 0.0 # subpopulation-zeroed
905+
resolved2 = ResolvedSurveyDesign(
906+
weights=weights_with_zeros, weight_type="pweight",
907+
strata=None, psu=None, fpc=None,
908+
n_strata=0, n_psu=0, lonely_psu="remove",
909+
replicate_weights=rep_factors,
910+
replicate_method="BRR", n_replicates=R,
911+
combined_weights=False,
912+
)
913+
raw_rank = int(np.linalg.matrix_rank(rep_factors))
914+
analysis_rank = int(np.linalg.matrix_rank(
915+
rep_factors * weights_with_zeros[:, np.newaxis]
916+
))
917+
# Analysis rank should be <= raw rank when zero weights present
918+
assert analysis_rank <= raw_rank
919+
assert resolved2.df_survey == max(analysis_rank - 1, 1)
920+
921+
def test_rscales_zero_centering_vcov(self):
922+
"""mse=False with zero rscales: center only on rscales > 0 replicates."""
923+
from diff_diff.survey import compute_replicate_vcov, ResolvedSurveyDesign
924+
from diff_diff.linalg import solve_ols
925+
926+
np.random.seed(42)
927+
n = 100
928+
R = 6
929+
x = np.random.randn(n)
930+
y = 1.0 + 2.0 * x + np.random.randn(n) * 0.5
931+
X = np.column_stack([np.ones(n), x])
932+
w = np.ones(n)
933+
934+
# Build JK1-style replicates
935+
cluster_size = n // R
936+
rep_arr = np.ones((n, R))
937+
for r in range(R):
938+
start = r * cluster_size
939+
end = min((r + 1) * cluster_size, n)
940+
rep_arr[start:end, :] = 0.0
941+
# Correct column r only
942+
rep_arr[:, r] = np.where(
943+
(np.arange(n) >= start) & (np.arange(n) < end), 0.0,
944+
R / (R - 1)
945+
)
946+
947+
# rscales with one zero entry
948+
rscales = np.array([1.0, 1.0, 0.0, 1.0, 1.0, 1.0])
949+
950+
coef, _, _ = solve_ols(X, y, weights=w)
951+
952+
resolved = ResolvedSurveyDesign(
953+
weights=w, weight_type="pweight",
954+
strata=None, psu=None, fpc=None,
955+
n_strata=0, n_psu=0, lonely_psu="remove",
956+
replicate_weights=rep_arr,
957+
replicate_method="BRR", n_replicates=R,
958+
replicate_rscales=rscales, mse=False,
959+
)
960+
vcov, _nv = compute_replicate_vcov(X, y, coef, resolved)
961+
962+
# Manual computation: center only on replicates with rscales > 0
963+
coef_reps = []
964+
for r in range(R):
965+
c_r, _, _ = solve_ols(X, y, weights=rep_arr[:, r])
966+
coef_reps.append(c_r)
967+
coef_reps = np.array(coef_reps)
968+
pos_mask = rscales > 0
969+
center = np.mean(coef_reps[pos_mask], axis=0)
970+
diffs = coef_reps - center[np.newaxis, :]
971+
V_manual = np.zeros((2, 2))
972+
for r in range(R):
973+
V_manual += rscales[r] * np.outer(diffs[r], diffs[r])
974+
975+
assert np.allclose(np.diag(vcov), np.diag(V_manual), rtol=1e-10)
976+
977+
def test_rscales_zero_centering_if(self):
978+
"""mse=False with zero rscales: IF path centers only on rscales > 0."""
979+
from diff_diff.survey import compute_replicate_if_variance, ResolvedSurveyDesign
980+
981+
np.random.seed(42)
982+
n = 50
983+
R = 5
984+
psi = np.random.randn(n) * 0.1
985+
w = np.ones(n)
986+
987+
# Build simple replicates
988+
rep_arr = np.ones((n, R))
989+
for r in range(R):
990+
start = r * (n // R)
991+
end = min((r + 1) * (n // R), n)
992+
rep_arr[start:end, r] = 0.0
993+
rep_arr[:, r] = np.where(
994+
(np.arange(n) >= start) & (np.arange(n) < end), 0.0,
995+
R / (R - 1)
996+
)
997+
998+
rscales = np.array([1.0, 0.0, 1.0, 1.0, 1.0])
999+
1000+
resolved = ResolvedSurveyDesign(
1001+
weights=w, weight_type="pweight",
1002+
strata=None, psu=None, fpc=None,
1003+
n_strata=0, n_psu=0, lonely_psu="remove",
1004+
replicate_weights=rep_arr,
1005+
replicate_method="BRR", n_replicates=R,
1006+
replicate_rscales=rscales, mse=False,
1007+
)
1008+
var, _nv = compute_replicate_if_variance(psi, resolved)
1009+
1010+
# Manual: theta_r = sum((w_r/w) * psi), center on rscales > 0 only
1011+
theta_full = float(np.sum(psi))
1012+
theta_reps = np.array([
1013+
float(np.sum(np.divide(rep_arr[:, r], w, out=np.zeros(n), where=w > 0) * psi))
1014+
for r in range(R)
1015+
])
1016+
pos_mask = rscales > 0
1017+
center = float(np.mean(theta_reps[pos_mask]))
1018+
diffs = theta_reps - center
1019+
var_manual = float(np.sum(rscales * diffs**2))
1020+
1021+
assert var == pytest.approx(var_manual, rel=1e-10)
1022+
1023+
8731024
# =============================================================================
8741025
# Estimator-Level Replicate Weight Tests
8751026
# =============================================================================

0 commit comments

Comments
 (0)