Skip to content

Commit 84809cd

Browse files
igerberclaude
andcommitted
Return n_valid from replicate variance functions, fix df properly
Stop mutating resolved.n_replicates in place — instead return (result, n_valid) tuples from compute_replicate_vcov() and compute_replicate_if_variance(). Callers unpack the tuple and LinearRegression.fit() uses n_valid-1 for survey_df. This eliminates the shared-object mutation that the CI reviewer flagged as P0 (order-dependent bugs on reused resolved designs) while properly threading the effective df through inference. Updated all 7 callers across 5 files + 4 test call sites. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 91d778f commit 84809cd

7 files changed

Lines changed: 29 additions & 30 deletions

File tree

diff_diff/continuous_did.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -643,7 +643,7 @@ def fit(
643643

644644
# Score-scale: psi = w * if_es (matches TSL bread)
645645
psi_es = unit_resolved_es.weights * if_es
646-
variance = compute_replicate_if_variance(psi_es, unit_resolved_es)
646+
variance, _nv = compute_replicate_if_variance(psi_es, unit_resolved_es)
647647
es_se = float(np.sqrt(max(variance, 0.0))) if np.isfinite(variance) else np.nan
648648
else:
649649
X_ones_es = np.ones((n_units, 1))
@@ -1240,7 +1240,7 @@ def _compute_analytical_se(
12401240

12411241
def _rep_se(if_vals):
12421242
psi_scaled = _w_rep * if_vals
1243-
v = compute_replicate_if_variance(psi_scaled, unit_resolved)
1243+
v, _nv = compute_replicate_if_variance(psi_scaled, unit_resolved)
12441244
return float(np.sqrt(max(v, 0.0))) if np.isfinite(v) else np.nan
12451245

12461246
overall_att_se = _rep_se(if_att_glob)

diff_diff/efficient_did.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1088,7 +1088,7 @@ def _compute_survey_eif_se(self, eif_vals: np.ndarray) -> float:
10881088
# Score-scale IFs to match TSL bread: psi = w * eif / sum(w)
10891089
w = self._unit_resolved_survey.weights
10901090
psi_scaled = w * eif_vals / w.sum()
1091-
variance = compute_replicate_if_variance(psi_scaled, self._unit_resolved_survey)
1091+
variance, _n_valid = compute_replicate_if_variance(psi_scaled, self._unit_resolved_survey)
10921092
return float(np.sqrt(max(variance, 0.0))) if np.isfinite(variance) else np.nan
10931093

10941094
from diff_diff.survey import compute_survey_vcov

diff_diff/linalg.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1802,19 +1802,22 @@ def fit(
18021802
if np.any(nan_mask):
18031803
kept_cols = np.where(~nan_mask)[0]
18041804
if len(kept_cols) > 0:
1805-
vcov_reduced = compute_replicate_vcov(
1805+
vcov_reduced, _n_valid_rep = compute_replicate_vcov(
18061806
X[:, kept_cols], y, coefficients[kept_cols],
18071807
_effective_survey_design,
18081808
weight_type=self.weight_type,
18091809
)
18101810
vcov = _expand_vcov_with_nan(vcov_reduced, X.shape[1], kept_cols)
18111811
else:
18121812
vcov = np.full((X.shape[1], X.shape[1]), np.nan)
1813+
_n_valid_rep = 0
18131814
else:
1814-
vcov = compute_replicate_vcov(
1815+
vcov, _n_valid_rep = compute_replicate_vcov(
18151816
X, y, coefficients, _effective_survey_design,
18161817
weight_type=self.weight_type,
18171818
)
1819+
# Store effective replicate df (n_valid - 1) for later use
1820+
self._replicate_df = _n_valid_rep - 1 if _n_valid_rep > 1 else None
18181821
else:
18191822
from diff_diff.survey import compute_survey_vcov
18201823

@@ -1858,6 +1861,9 @@ def fit(
18581861

18591862
if isinstance(_effective_survey_design, ResolvedSurveyDesign):
18601863
self.survey_df_ = _effective_survey_design.df_survey
1864+
# Override with effective replicate df if available
1865+
if hasattr(self, '_replicate_df') and self._replicate_df is not None:
1866+
self.survey_df_ = self._replicate_df
18611867

18621868
return self
18631869

diff_diff/staggered_aggregation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,7 @@ def _compute_aggregated_se_with_wif(
476476
if resolved_survey is not None and hasattr(resolved_survey, "uses_replicate_variance") and resolved_survey.uses_replicate_variance:
477477
from diff_diff.survey import compute_replicate_if_variance
478478

479-
variance = compute_replicate_if_variance(psi_total, resolved_survey)
479+
variance, _n_valid = compute_replicate_if_variance(psi_total, resolved_survey)
480480
if np.isnan(variance):
481481
return np.nan
482482
return np.sqrt(max(variance, 0.0))

diff_diff/survey.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1308,9 +1308,6 @@ def compute_replicate_vcov(
13081308
stacklevel=2,
13091309
)
13101310
n_valid = int(np.sum(valid))
1311-
# Update effective replicate count so df_survey reflects valid replicates
1312-
if n_valid < R:
1313-
resolved.n_replicates = n_valid
13141311
if n_valid < 2:
13151312
if n_valid == 0:
13161313
warnings.warn(
@@ -1323,7 +1320,7 @@ def compute_replicate_vcov(
13231320
f"with fewer than 2. Returning NaN.",
13241321
UserWarning, stacklevel=2,
13251322
)
1326-
return np.full((k, k), np.nan)
1323+
return np.full((k, k), np.nan), n_valid
13271324
coef_valid = coef_reps[valid]
13281325
c = full_sample_coef
13291326

@@ -1333,7 +1330,7 @@ def compute_replicate_vcov(
13331330

13341331
if method in ("BRR", "Fay", "JK1"):
13351332
factor = _replicate_variance_factor(method, int(np.sum(valid)), resolved.fay_rho)
1336-
return factor * outer_sum
1333+
return factor * outer_sum, n_valid
13371334
elif method == "JKn":
13381335
# JKn: V = sum_h ((n_h-1)/n_h) * sum_{r in h} (c_r - c)(c_r - c)^T
13391336
rep_strata = resolved.replicate_strata
@@ -1348,15 +1345,15 @@ def compute_replicate_vcov(
13481345
continue
13491346
diffs_h = diffs[mask_h]
13501347
V += ((n_h - 1.0) / n_h) * (diffs_h.T @ diffs_h)
1351-
return V
1348+
return V, n_valid
13521349
else:
13531350
raise ValueError(f"Unknown replicate method: {method}")
13541351

13551352

13561353
def compute_replicate_if_variance(
13571354
psi: np.ndarray,
13581355
resolved: "ResolvedSurveyDesign",
1359-
) -> float:
1356+
) -> Tuple[float, int]:
13601357
"""Compute replicate-based variance for influence-function estimators.
13611358
13621359
Instead of re-running the full estimator, reweights the influence
@@ -1401,22 +1398,18 @@ def compute_replicate_if_variance(
14011398

14021399
valid = np.isfinite(theta_reps)
14031400
n_valid = int(np.sum(valid))
1404-
# Update effective replicate count so df_survey reflects valid replicates
1405-
if n_valid < R:
1406-
resolved.n_replicates = n_valid
14071401
if n_valid < 2:
1408-
return np.nan
1402+
return np.nan, n_valid
14091403
diffs = theta_reps[valid] - theta_full
14101404
ss = float(np.sum(diffs**2))
14111405

14121406
if method in ("BRR", "Fay", "JK1"):
1413-
factor = _replicate_variance_factor(method, int(np.sum(valid)), resolved.fay_rho)
1414-
return factor * ss
1407+
factor = _replicate_variance_factor(method, n_valid, resolved.fay_rho)
1408+
return factor * ss, n_valid
14151409
elif method == "JKn":
14161410
rep_strata = resolved.replicate_strata
14171411
if rep_strata is None:
14181412
raise ValueError("JKn requires replicate_strata")
1419-
# Filter to valid replicates
14201413
valid_strata = rep_strata[valid]
14211414
valid_diffs = diffs
14221415
result = 0.0
@@ -1426,7 +1419,7 @@ def compute_replicate_if_variance(
14261419
if n_h < 1:
14271420
continue
14281421
result += ((n_h - 1.0) / n_h) * float(np.sum(valid_diffs[mask_h] ** 2))
1429-
return result
1422+
return result, n_valid
14301423
else:
14311424
raise ValueError(f"Unknown replicate method: {method}")
14321425

diff_diff/triple_diff.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1096,7 +1096,7 @@ def _estimate_ddd_decomposition(
10961096
psi_rep = inf_func / w_sum
10971097
else:
10981098
psi_rep = resolved_survey.weights * inf_func / w_sum
1099-
variance = compute_replicate_if_variance(psi_rep, resolved_survey)
1099+
variance, _nv = compute_replicate_if_variance(psi_rep, resolved_survey)
11001100
se = float(np.sqrt(max(variance, 0.0))) if np.isfinite(variance) else np.nan
11011101
else:
11021102
from diff_diff.survey import compute_survey_vcov

tests/test_survey_phase6.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ def test_brr_vcov(self, replicate_data):
453453
X, y, weights=resolved.weights, weight_type="pweight",
454454
)
455455

456-
vcov = compute_replicate_vcov(X, y, coef, resolved)
456+
vcov, _nv = compute_replicate_vcov(X, y, coef, resolved)
457457
assert np.all(np.isfinite(np.diag(vcov)))
458458

459459
def test_fay_inflates_over_brr(self, replicate_data):
@@ -471,14 +471,14 @@ def test_fay_inflates_over_brr(self, replicate_data):
471471
)
472472
resolved_brr = sd_brr.resolve(data)
473473
coef, _, _ = solve_ols(X, y, weights=resolved_brr.weights)
474-
vcov_brr = compute_replicate_vcov(X, y, coef, resolved_brr)
474+
vcov_brr, _nv = compute_replicate_vcov(X, y, coef, resolved_brr)
475475

476476
sd_fay = SurveyDesign(
477477
weights="weight", replicate_weights=rep_cols,
478478
replicate_method="Fay", fay_rho=0.5,
479479
)
480480
resolved_fay = sd_fay.resolve(data)
481-
vcov_fay = compute_replicate_vcov(X, y, coef, resolved_fay)
481+
vcov_fay, _nv = compute_replicate_vcov(X, y, coef, resolved_fay)
482482

483483
# Fay variance = BRR variance / (1-rho)^2 > BRR variance
484484
assert np.all(np.diag(vcov_fay) > np.diag(vcov_brr))
@@ -534,7 +534,7 @@ def test_replicate_if_variance(self, replicate_data):
534534

535535
# Synthetic influence function
536536
psi = np.random.randn(len(data)) * 0.1
537-
var = compute_replicate_if_variance(psi, resolved)
537+
var, _nv = compute_replicate_if_variance(psi, resolved)
538538
assert np.isfinite(var)
539539
assert var >= 0
540540

@@ -583,7 +583,7 @@ def test_jkn_variance(self, replicate_data):
583583
X = np.column_stack([np.ones(len(data)), data["x"].values])
584584

585585
coef, _, _ = solve_ols(X, y, weights=resolved.weights)
586-
vcov = compute_replicate_vcov(X, y, coef, resolved)
586+
vcov, _nv = compute_replicate_vcov(X, y, coef, resolved)
587587
assert np.all(np.isfinite(np.diag(vcov)))
588588
assert np.all(np.diag(vcov) > 0)
589589

@@ -611,7 +611,7 @@ def test_replicate_if_scale_matches_analytical(self):
611611
)
612612
resolved = sd.resolve(data)
613613

614-
v_rep = compute_replicate_if_variance(psi, resolved)
614+
v_rep, _nv = compute_replicate_if_variance(psi, resolved)
615615
v_analytical = float(np.sum(psi**2))
616616

617617
# JK1 gives (n-1)/n * sum(...) which should approximate sum(psi^2)
@@ -665,7 +665,7 @@ def test_replicate_if_matches_survey_if_variance(self):
665665
replicate_method="JK1",
666666
n_replicates=n_psu,
667667
)
668-
v_rep = compute_replicate_if_variance(psi, resolved_rep)
668+
v_rep, _nv = compute_replicate_if_variance(psi, resolved_rep)
669669

670670
# Should be in the same ballpark (within 50% — different estimators
671671
# of the same quantity)
@@ -795,7 +795,7 @@ def test_replicate_if_no_divide_by_zero_warning(self):
795795
with warnings.catch_warnings():
796796
warnings.simplefilter("error", RuntimeWarning)
797797
# Should NOT raise RuntimeWarning for divide by zero
798-
v = compute_replicate_if_variance(psi, resolved)
798+
v, _nv = compute_replicate_if_variance(psi, resolved)
799799
assert np.isfinite(v)
800800

801801

0 commit comments

Comments
 (0)