Skip to content

Commit 099797e

Browse files
igerberclaude
andcommitted
Address PR #355 R1: weighted λ centering + weights=None survey designs
Fixes two P1 issues flagged by the CI reviewer on the initial submission of PR #352. P1 Methodology — `compute_time_weights_survey` was documented as solving the WLS-style weighted λ objective min Σ_u rw_u·(Σ_t λ_t·Y_u,pre[t] - Y_u,post_mean)² + ζ²·||λ||² but row-scaled Y by sqrt(rw) and then handed the scaled matrix to `_sc_weight_fw(intercept=True)`, whose column-centering uses an UNWEIGHTED mean across controls. That is not the weighted objective once rw varies, so non-uniform survey bootstrap draws were refitting λ on the wrong objective and could bias the reported SE. Fix: weighted-center `Y_time` BEFORE the sqrt(rw) row-scaling, using `col_weighted_mean = (Y_time * rw).sum(0) / rw.sum()`, and pass `intercept=False` to the kernel so no additional unweighted centering happens on the scaled matrix. Both two-pass calls updated. `compute_sdid_unit_weights_survey` is unchanged — its column-centering is PER-UNIT (time means within each control column), which is independent of rw. P1 Code Quality — `SurveyDesign(weights=None, strata=..., psu=...)` is a valid configuration (`SurveyDesign.resolve()` synthesizes ones when weights is None), but `_extract_unit_survey_weights` indexed `survey_design.weights` as if it were always a column name, so the groupby would fail with a KeyError before the bootstrap branch could run. Fix: `_extract_unit_survey_weights` now short-circuits to a vector of ones of length `len(unit_order)` when `survey_design.weights is None`, matching `SurveyDesign.resolve()`'s semantics. Regression tests: - `test_non_uniform_rw_beats_unweighted_centering_variant` (test_weighted_fw.py): reproduces the pre-fix buggy variant (row- scale Y by sqrt(rw), then call `_sc_weight_fw(intercept=True)`) and asserts the fixed path's weighted SSR is strictly ≤ the buggy variant's weighted SSR. If a future revert reintroduces intercept=True after the row-scaling, this test fails. - `test_bootstrap_full_design_without_explicit_weights` (test_methodology_sdid.py): `SurveyDesign(strata=..., psu=...)` with no explicit `weights` column now succeeds on the bootstrap path; survey_metadata populated with n_strata / n_psu. P3 Documentation: - `SyntheticDiD.fit()` docstring (survey_design parameter + Raises block): replace "bootstrap rejects all survey designs" language with the PR #352 support-matrix truth-table (bootstrap ✓ for both pweight- only and full design; placebo/jackknife ✓ pweight-only, ✗ full design). - `_placebo_variance_se` fallback-guidance messages (two sites): drop the "strata/PSU/FPC not yet supported by any SDID variance method" framing; recommend bootstrap for full-design survey fallback, jackknife for pweight-only, adding controls as the universal fallback. - `docs/survey-roadmap.md` Current Limitations table: collapse the two SDID bootstrap-rejection rows into a single row for placebo/ jackknife + full design (the bootstrap + full design row no longer applies). Verified: 75 targeted tests pass (test_weighted_fw + TestBootstrapSE + TestScaleEquivariance + TestCoverageMCArtifact + test_survey_phase5). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 6d64ec1 commit 099797e

6 files changed

Lines changed: 171 additions & 39 deletions

File tree

diff_diff/survey.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1058,7 +1058,12 @@ def _extract_unit_survey_weights(data, unit_col, survey_design, unit_order):
10581058
unit_col : str
10591059
Unit identifier column name.
10601060
survey_design : SurveyDesign
1061-
Survey design (uses ``weights`` column name).
1061+
Survey design. When ``survey_design.weights`` is a column name,
1062+
the weights are pulled from ``data``. When ``survey_design.weights
1063+
is None`` (a valid configuration — ``SurveyDesign.resolve()`` then
1064+
synthesizes ones), returns a vector of ones of length
1065+
``len(unit_order)`` so downstream estimators can treat all units
1066+
as having unit survey weight 1.
10621067
unit_order : array-like
10631068
Ordered sequence of unit identifiers to align weights to.
10641069
@@ -1067,6 +1072,14 @@ def _extract_unit_survey_weights(data, unit_col, survey_design, unit_order):
10671072
np.ndarray
10681073
Float64 array of unit-level weights, one per unit in ``unit_order``.
10691074
"""
1075+
if survey_design.weights is None:
1076+
# SurveyDesign(weights=None, strata=..., psu=...) is a valid
1077+
# configuration — the design element supplies clustering /
1078+
# stratification without explicit per-unit weights. Synthesize
1079+
# uniform unit weights of 1 to match SurveyDesign.resolve()'s
1080+
# behavior (which emits ones when weights is None). Without this
1081+
# branch the groupby below would raise a KeyError on ``None``.
1082+
return np.ones(len(unit_order), dtype=np.float64)
10701083
unit_w = data.groupby(unit_col)[survey_design.weights].first()
10711084
return np.array([unit_w[u] for u in unit_order], dtype=np.float64)
10721085

diff_diff/synthetic_did.py

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -246,15 +246,19 @@ def fit( # type: ignore[override]
246246
List of covariate column names. Covariates are residualized
247247
out before computing the SDID estimator.
248248
survey_design : SurveyDesign, optional
249-
Survey design specification. Only pweight weight_type is supported.
250-
``variance_method='placebo'`` and ``variance_method='jackknife'``
251-
accept pweight-only surveys (composed via ``w_control`` /
252-
``w_treated``). ``variance_method='bootstrap'`` rejects all
253-
survey designs (including pweight-only) and strata/PSU/FPC are
254-
not supported by any variance method on this release —
255-
composing Rao-Wu rescaled weights with paper-faithful
256-
Frank-Wolfe re-estimation requires a separate derivation
257-
(tracked in TODO.md, sketched in REGISTRY.md §SyntheticDiD).
249+
Survey design specification. Only pweight weight_type is
250+
supported. Support matrix (PR #352):
251+
252+
method pweight-only strata/PSU/FPC
253+
bootstrap ✓ weighted FW ✓ weighted FW + Rao-Wu
254+
placebo ✓ ✗ NotImplementedError
255+
jackknife ✓ ✗ NotImplementedError
256+
257+
The bootstrap path composes Rao-Wu rescaled weights per draw
258+
with the weighted-Frank-Wolfe kernel; see REGISTRY.md
259+
§SyntheticDiD ``Note (survey + bootstrap composition)``.
260+
``placebo`` and ``jackknife`` still reject strata/PSU/FPC
261+
(separate methodology gap tracked in TODO.md).
258262
259263
Returns
260264
-------
@@ -268,9 +272,10 @@ def fit( # type: ignore[override]
268272
If required parameters are missing, data validation fails,
269273
or a non-pweight survey design is provided.
270274
NotImplementedError
271-
If ``survey_design`` is provided with strata/PSU/FPC, or if
272-
``variance_method='bootstrap'`` is provided with any survey
273-
design (including pweight-only).
275+
If ``survey_design`` with strata/PSU/FPC is provided with
276+
``variance_method='placebo'`` or ``'jackknife'``. Bootstrap
277+
+ any survey design (pweight-only or full design) is
278+
supported via PR #352's weighted-FW + Rao-Wu composition.
274279
"""
275280
# Validate inputs
276281
if outcome is None or treatment is None or unit is None or time is None:
@@ -1249,14 +1254,13 @@ def _placebo_variance_se(
12491254
# Ensure we have enough controls for the split
12501255
n_pseudo_control = n_control - n_treated
12511256
if n_pseudo_control < 1:
1252-
# Bootstrap rejects every survey design in this release, so
1253-
# steer survey users to jackknife (pweight-only only) or
1254-
# adding controls. Non-survey users can still fall back to
1255-
# bootstrap or jackknife.
1257+
# Fallback guidance. Placebo and jackknife reject strata/PSU/FPC,
1258+
# but bootstrap (PR #352) supports both pweight-only and
1259+
# full-design surveys, so it's always a valid fallback.
12561260
fallback = (
1257-
"variance_method='jackknife' or adding more control units "
1258-
"(strata/PSU/FPC are not yet supported by any SDID variance "
1259-
"method)"
1261+
"variance_method='bootstrap' (supports pweight-only and "
1262+
"strata/PSU/FPC survey designs), variance_method='jackknife' "
1263+
"(pweight-only only), or adding more control units"
12601264
if w_control is not None
12611265
else "variance_method='bootstrap', variance_method='jackknife', "
12621266
"or adding more control units"
@@ -1353,13 +1357,14 @@ def _placebo_variance_se(
13531357
n_successful = len(placebo_estimates)
13541358

13551359
if n_successful < 2:
1356-
# Same survey-awareness branch as the pre-replication guard
1357-
# above — bootstrap rejects every survey design in this
1358-
# release, so suggest jackknife for pweight-only fits.
1360+
# Same fallback guidance as the pre-replication guard above.
1361+
# Bootstrap (PR #352) supports pweight-only + strata/PSU/FPC
1362+
# survey designs, so it's always a valid fallback for survey
1363+
# users even when placebo fails.
13591364
fallback = (
1360-
"variance_method='jackknife' or increasing the number of "
1361-
"control units (strata/PSU/FPC are not yet supported by any "
1362-
"SDID variance method)"
1365+
"variance_method='bootstrap' (supports pweight-only and "
1366+
"strata/PSU/FPC survey designs), variance_method='jackknife' "
1367+
"(pweight-only only), or increasing the number of control units"
13631368
if w_control is not None
13641369
else "variance_method='bootstrap' or variance_method='jackknife' "
13651370
"or increasing the number of control units"

diff_diff/utils.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1979,14 +1979,18 @@ def compute_time_weights_survey(
19791979
Solves the WLS-style time-weight objective (PR #352 §2.2)::
19801980
19811981
min_{λ on simplex}
1982-
Σ_u rw_control[u]·(Σ_t λ[t]·Y_pre_control[t,u] - Y_post_mean[u]
1982+
Σ_u rw_control[u]·(Σ_t λ[t]·Y_u,pre-centered[t] - Y_u,post_mean-centered
19831983
+ ζ²·||λ||²
19841984
19851985
Regularization stays uniform on λ (rw is per-control, λ is per-period —
1986-
no alignment for per-λ reg weighting). Loss term gets per-row weighting,
1987-
implemented as a √rw row-scale of the (transposed) Y_time matrix before
1988-
passing to the unweighted Rust kernel — equivalent to running the
1989-
standard FW on ``diag(√rw)·Y``.
1986+
no alignment for per-λ reg weighting). The loss term uses WLS-style
1987+
row weights; when ``intercept=True``, the column-centering step is
1988+
*also* survey-weighted (weighted mean across controls, weights
1989+
``rw_control``) so the centered loss minimizes
1990+
``Σ_u rw_u·(A_u·λ - b_u)²`` on the rw-centered matrix — equivalent
1991+
to the stated weighted objective. The Rust kernel then sees the
1992+
weighted-centered + sqrt(rw)-row-scaled matrix with
1993+
``intercept=False`` (no additional unweighted centering).
19901994
19911995
The returned λ is on the standard simplex.
19921996
@@ -2030,16 +2034,33 @@ def compute_time_weights_survey(
20302034
post_means = np.mean(Y_post_control, axis=0)
20312035
Y_time = np.column_stack([Y_pre_control.T, post_means]) # (N_co, T_pre+1)
20322036

2033-
# Row-scale by sqrt(rw): each control unit's contribution to the loss
2034-
# is weighted by rw_control[u]. Reg on λ stays uniform (no reg_weights).
2037+
# Column-center the (N_co, T_pre+1) matrix using the SURVEY-WEIGHTED
2038+
# mean across control units when ``intercept=True``. Plain
2039+
# ``intercept=True`` inside the FW kernel would use an unweighted
2040+
# column mean which does not correspond to the stated weighted-loss
2041+
# objective once ``rw_control`` varies. Perform the weighted
2042+
# centering here and pass ``intercept=False`` below so the kernel
2043+
# does not re-center on the row-scaled matrix.
2044+
rw_sum = float(np.sum(rw_control))
2045+
if intercept and rw_sum > 0:
2046+
col_weighted_means = (
2047+
(Y_time * rw_control[:, np.newaxis]).sum(axis=0) / rw_sum
2048+
)
2049+
Y_time = Y_time - col_weighted_means[np.newaxis, :]
2050+
2051+
# Row-scale by sqrt(rw): after weighted centering (if any), each
2052+
# control unit's contribution to the loss is weighted by
2053+
# ``rw_control[u]`` via the sqrt(rw) row scaling, which reproduces
2054+
# ``||diag(sqrt(rw))·(A·λ - b)||²`` = ``Σ_u rw_u·(A_u·λ - b_u)²``.
2055+
# Reg on λ stays uniform (no reg_weights).
20352056
sqrt_rw = np.sqrt(np.maximum(rw_control, 0.0))
20362057
Y_weighted = Y_time * sqrt_rw[:, np.newaxis]
20372058

20382059
if return_convergence:
20392060
lam, conv1 = _sc_weight_fw(
20402061
Y_weighted,
20412062
zeta=zeta_lambda,
2042-
intercept=intercept,
2063+
intercept=False, # weighted centering already applied above
20432064
init_weights=init_weights,
20442065
min_decrease=min_decrease,
20452066
max_iter=max_iter_pre_sparsify,
@@ -2049,7 +2070,7 @@ def compute_time_weights_survey(
20492070
lam = _sc_weight_fw(
20502071
Y_weighted,
20512072
zeta=zeta_lambda,
2052-
intercept=intercept,
2073+
intercept=False, # weighted centering already applied above
20532074
init_weights=init_weights,
20542075
min_decrease=min_decrease,
20552076
max_iter=max_iter_pre_sparsify,
@@ -2061,7 +2082,7 @@ def compute_time_weights_survey(
20612082
lam, conv2 = _sc_weight_fw(
20622083
Y_weighted,
20632084
zeta=zeta_lambda,
2064-
intercept=intercept,
2085+
intercept=False, # weighted centering already applied above
20652086
init_weights=lam,
20662087
min_decrease=min_decrease,
20672088
max_iter=max_iter,
@@ -2072,7 +2093,7 @@ def compute_time_weights_survey(
20722093
return _sc_weight_fw(
20732094
Y_weighted,
20742095
zeta=zeta_lambda,
2075-
intercept=intercept,
2096+
intercept=False, # weighted centering already applied above
20762097
init_weights=lam,
20772098
min_decrease=min_decrease,
20782099
max_iter=max_iter,

docs/survey-roadmap.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,8 +223,7 @@ the limitation and suggested alternative.
223223

224224
| Estimator | Limitation | Alternative |
225225
|-----------|-----------|-------------|
226-
| SyntheticDiD | Strata/PSU/FPC (any variance method) | No SDID variance option in this release. Pweight-only works with `variance_method='placebo'` or `'jackknife'`. Strata/PSU/FPC + SDID requires composing Rao-Wu rescaled weights with paper-faithful Frank-Wolfe re-estimation; sketch in REGISTRY.md §SyntheticDiD. |
227-
| SyntheticDiD | `variance_method='bootstrap'` + any survey design (including pweight-only) | Use `variance_method='placebo'` or `'jackknife'` for pweight-only surveys. Refit bootstrap composed with survey weights requires the same weighted-FW derivation noted above. |
226+
| SyntheticDiD | `variance_method='placebo'` or `'jackknife'` + strata/PSU/FPC | Use `variance_method='bootstrap'` for full-design surveys (PR #352 weighted-FW + Rao-Wu composition). Placebo's control-index permutation and jackknife's LOO allocator need their own weighted derivations on top of the weighted-FW kernel; tracked in TODO.md as a follow-up. |
228227
| SyntheticDiD | Replicate weights | Pre-existing limitation: no replicate-weight survey support on SDID. |
229228
| TROP | Replicate weights | Use strata/PSU/FPC design with Rao-Wu rescaled bootstrap |
230229
| BaconDecomposition | Replicate weights | Diagnostic only, no inference |

tests/test_methodology_sdid.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,35 @@ def test_bootstrap_summary_shows_replications(self, ci_params):
671671
assert "Bootstrap replications" in summary
672672
assert str(n_boot) in summary
673673

674+
def test_bootstrap_full_design_without_explicit_weights(self):
675+
"""SurveyDesign(strata=..., psu=..., weights=None) fits successfully.
676+
677+
Regression for PR #355 R1 code-quality finding: `SurveyDesign` allows
678+
`weights=None` (resolve() synthesizes unit weights of 1), but the
679+
SDID helper `_extract_unit_survey_weights` used to index
680+
`survey_design.weights` directly and would fail before bootstrap
681+
could run. The helper now returns ones for this configuration.
682+
"""
683+
from diff_diff.survey import SurveyDesign
684+
df = _make_panel(n_control=20, n_treated=3, seed=42)
685+
df["stratum"] = df["unit"] % 2
686+
df["psu"] = df["unit"]
687+
result = SyntheticDiD(
688+
variance_method="bootstrap", n_bootstrap=50, seed=1
689+
).fit(
690+
df, outcome="outcome", treatment="treated",
691+
unit="unit", time="period",
692+
post_periods=[5, 6, 7],
693+
survey_design=SurveyDesign(strata="stratum", psu="psu"), # weights=None
694+
)
695+
assert np.isfinite(result.att)
696+
assert np.isfinite(result.se)
697+
assert result.se > 0
698+
assert result.variance_method == "bootstrap"
699+
assert result.survey_metadata is not None
700+
assert result.survey_metadata.n_strata is not None
701+
assert result.survey_metadata.n_psu is not None
702+
674703
def test_bootstrap_single_psu_returns_nan(self):
675704
"""Unstratified single-PSU survey design returns NaN SE (PR #352).
676705

tests/test_weighted_fw.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,71 @@ def test_rw_shape_mismatch_raises(self, small_panel):
241241
Y_pre_c, Y_post_c, wrong_rw, zeta_lambda=0.05,
242242
)
243243

244+
def test_non_uniform_rw_beats_unweighted_centering_variant(self, small_panel):
245+
"""Non-uniform rw: the weighted-centering solution achieves strictly
246+
lower weighted SSR than the (buggy) unweighted-centering variant.
247+
248+
Verifies the PR #355 R1 fix — weighted centering + intercept=False
249+
— actually solves the stated weighted loss
250+
``Σ_u rw_u·(A_u·λ - b_u)²``. Reproduces the unweighted-centering
251+
pre-R1 path by hand (row-scale Y by sqrt(rw), then pass
252+
intercept=True to the kernel so it centers on unweighted column
253+
means) and asserts the correct path's weighted SSR is strictly
254+
better. If R1's fix regresses (someone reverts back to
255+
intercept=True after row-scaling), this test fails because the
256+
two solutions become identical.
257+
"""
258+
Y_pre_c = small_panel["Y_pre_control"]
259+
Y_post_c = small_panel["Y_post_control"]
260+
n_co = small_panel["n_control"]
261+
rng = np.random.default_rng(23)
262+
rw = np.where(rng.uniform(size=n_co) < 0.25, 5.0, 0.5)
263+
264+
# Correct path: what compute_time_weights_survey actually does.
265+
lam_correct = compute_time_weights_survey(
266+
Y_pre_c, Y_post_c, rw,
267+
zeta_lambda=0.05,
268+
min_decrease=1e-8,
269+
max_iter=10000,
270+
)
271+
272+
# Buggy variant: pre-R1 — row-scale by sqrt(rw) but let the kernel
273+
# do UNWEIGHTED centering (intercept=True on the row-scaled matrix).
274+
post_means = np.mean(Y_post_c, axis=0)
275+
Y_time_raw = np.column_stack([Y_pre_c.T, post_means])
276+
sqrt_rw = np.sqrt(np.maximum(rw, 0.0))
277+
Y_weighted_unweighted_center = Y_time_raw * sqrt_rw[:, None]
278+
lam_buggy = _sc_weight_fw(
279+
Y_weighted_unweighted_center, zeta=0.05, intercept=True,
280+
min_decrease=1e-8, max_iter=10000,
281+
)
282+
# Sparsify + refit second pass to match the two-pass shape.
283+
from diff_diff.utils import _sparsify
284+
lam_buggy = _sparsify(lam_buggy)
285+
lam_buggy = _sc_weight_fw(
286+
Y_weighted_unweighted_center, zeta=0.05, intercept=True,
287+
init_weights=lam_buggy, min_decrease=1e-8, max_iter=10000,
288+
)
289+
290+
# Compute the canonical (weighted-centered) objective on both.
291+
wc_mean_pre = (Y_pre_c.T * rw[:, None]).sum(axis=0) / rw.sum()
292+
wc_mean_post = (post_means * rw).sum() / rw.sum()
293+
A_wc = Y_pre_c.T - wc_mean_pre
294+
b_wc = post_means - wc_mean_post
295+
296+
def weighted_ssr(lam_val: np.ndarray) -> float:
297+
resid = A_wc @ lam_val - b_wc
298+
return float(np.sum(rw * resid ** 2))
299+
300+
ssr_correct = weighted_ssr(lam_correct)
301+
ssr_buggy = weighted_ssr(lam_buggy)
302+
assert ssr_correct <= ssr_buggy + 1e-6, (
303+
f"weighted-centering λ (SSR={ssr_correct:.4f}) must achieve at "
304+
f"least as low weighted SSR as the unweighted-centering variant "
305+
f"(SSR={ssr_buggy:.4f}). PR #355 R1 regression: weighted SSR is "
306+
"not being minimized by the survey λ helper."
307+
)
308+
244309
def test_zero_rw_subset_handled(self, small_panel):
245310
"""rw with some zeros (Rao-Wu draws units to zero weight) still yields
246311
a valid simplex λ — the FW just down-weights those rows in the loss.

0 commit comments

Comments
 (0)