Skip to content

Commit 69d3df3

Browse files
igerberclaude
andcommitted
Delete compute_synthetic_weights shim; inline Frank-Wolfe in rank_control_units
Post-audit cleanup closing finding #22 from the Phase 2 silent-failures audit. Replaces the previous "fix both backends" approach with "delete the helper and inline correctly in its single caller." Why the scope change. The audit found that Rust and Python backends of `compute_synthetic_weights` solved different QPs (Python shrink-to-uniform, Rust shrink-to-zero) with different step sizes (Python adaptive, Rust broken constant 0.1 that diverges at large Y). Closer inspection revealed: - `compute_synthetic_weights` was private (not in __all__). - Its only non-test caller was `rank_control_units` in prep.py — a user-facing diagnostic that ranks control units by trend similarity. - The `synthetic_weight` column it feeds into the `rank_control_units` DataFrame is informational only — `quality_score` is computed from RMSE + covariate distance, NOT from `synthetic_weight`. So the bug did not affect donor-selection ranking at all. - `SyntheticDiD.fit()` does NOT use this helper. It uses the Frank-Wolfe solver (`compute_sdid_unit_weights` / `_sc_weight_fw_numpy`) directly on a completely independent code path that already matches R synthdid. Maintaining two broken PGD implementations of a one-caller helper was not worth the cost. Delete the shim and inline correctly. Changes: - Delete `compute_synthetic_weights` (utils.py:1134) and `_compute_synthetic_weights_numpy` (utils.py:1202). - Remove `_rust_synthetic_weights` from the `_backend.py` import and None-fallback, and from the `__init__.py` re-export. - Remove the import from `prep.py`. - Inline a Frank-Wolfe computation in `rank_control_units` (prep.py:990) using the shared `_sc_weight_fw` dispatcher — same solver SDID uses, matching R `synthdid::sc.weight.fw`. Threads `zeta=sqrt(lambda_reg/N)` to absorb FW's (1/N) objective scaling, noise-level-scaled `min_decrease` per R convention, and `max_iter=1000` to preserve the existing cost envelope. - Short-circuit `n_control in {0, 1}` to avoid FW loop overhead on trivially-solvable cases. - Suppress non-convergence warnings inside the inline block — the caller uses the output as a heuristic score, not a statistical estimate. Tests: - Delete `TestSyntheticWeightsBackendParity` (3 tests, 2 xfailed) and the 3 direct-Rust-import tests in `test_rust_backend.py`. - Delete `TestComputeSyntheticWeightsEdgeCases` (4 tests) in test_utils.py. - Delete `test_compute_synthetic_weights` in test_estimators.py. - Add `test_extreme_Y_scale_synthetic_weight_column` in test_prep.py — regression guard asserting the `synthetic_weight` column produces a valid non-degenerate simplex vector at `Y ~ 1e9` (the exact input that the old Rust PGD mishandled by collapsing onto a single vertex). SDID invariance. Verified via bit-identical pre/post numerical baseline on a fixed-seed `SyntheticDiD.fit()` (n_units=30, n_pre=8, n_post=2, n_treated=3, n_bootstrap=20, seed=42): before: att=1.1895009159752075 se=0.2576531609311485 weight_sum=1.000000000000002e+00 weight_max=6.820411397386437e-02 after: att=1.1895009159752075 se=0.2576531609311485 weight_sum=1.000000000000002e+00 weight_max=6.820411397386437e-02 Full SDID test suite: 40/40 pass before and after (test outcome diff empty). All 305 tests in the touched files pass. Rust-side `compute_synthetic_weights` PyO3 binding in `rust/src/weights.rs` is now dead code (no Python code calls it). Tracked as a low-priority cleanup follow-up in TODO.md; removal requires `maturin develop` rebuild. User-visible impact: - `rank_control_units` public signature and return schema unchanged. - `quality_score` values and ranking: unchanged. - `synthetic_weight` column: values agree with old code at ~1e-7 on default-parameter typical data; differ only in edge cases where the old code was mathematically wrong (extreme Y scale, lambda_reg > 0). - Private imports break: `from diff_diff.utils import compute_synthetic_weights`, `from diff_diff._rust_backend import compute_synthetic_weights`. These were not part of the documented stable API. Net diff: +110 / -367 lines. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent cbb8814 commit 69d3df3

9 files changed

Lines changed: 110 additions & 367 deletions

File tree

TODO.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ Deferred items from PR reviews that were not addressed before merge.
8383
| Weighted CR2 Bell-McCaffrey cluster-robust (`vcov_type="hc2_bm"` + `cluster_ids` + `weights`) currently raises `NotImplementedError`. Weighted hat matrix and residual rebalancing need threading per clubSandwich WLS handling. | `linalg.py::_compute_cr2_bm` | Phase 1a | Medium |
8484
| Regenerate `benchmarks/data/clubsandwich_cr2_golden.json` from R (`Rscript benchmarks/R/generate_clubsandwich_golden.R`). Current JSON has `source: python_self_reference` as a stability anchor until an authoritative R run. | `benchmarks/R/generate_clubsandwich_golden.R` | Phase 1a | Medium |
8585
| `honest_did.py:1907` `np.linalg.solve(A_sys, b_sys) / except LinAlgError: continue` is a silent basis-rejection in the vertex-enumeration loop that is algorithmically intentional (try the next basis). Consider surfacing a count of rejected bases as a diagnostic when ARP enumeration exhausts, so users see when the vertex search was heavily constrained. Not a silent failure in the sense of the Phase 2 audit (the algorithm is supposed to skip), but the diagnostic would help debug borderline cases. | `honest_did.py` | #334 | Low |
86-
| `compute_synthetic_weights` backend algorithm mismatch: Rust path uses Frank-Wolfe (`_rust_synthetic_weights` in `utils.py:1184`); Python fallback uses projected gradient descent (`_compute_synthetic_weights_numpy` in `utils.py:1228`). Both solve the same constrained QP but converge to different simplex vertices on near-degenerate / extreme-scale inputs (e.g. `Y~1e9`, or near-singular `Y'Y`). Unified backend (one algorithm) would close the parity gap surfaced by audit finding #22. Two `@pytest.mark.xfail(strict=True)` tests in `tests/test_rust_backend.py::TestSyntheticWeightsBackendParity` baseline the divergence so we notice when/if the algorithms align. | `utils.py`, `rust/` | follow-up | Medium |
86+
| Rust `compute_synthetic_weights` + `compute_synthetic_weights_internal` (now dead code) can be removed from `rust/src/weights.rs:43-117` in a future Rust-cleanup PR. Python-side wrapper was deleted (post-audit cleanup for finding #22) and its sole caller now inlines Frank-Wolfe via `_sc_weight_fw`. The Rust symbol remains callable via `from diff_diff._rust_backend import compute_synthetic_weights` but no Python code calls it. Removal requires `maturin develop` rebuild. No functional impact of leaving it. | `rust/src/weights.rs` | follow-up | Low |
8787
| TROP Rust vs Python grid-search divergence on rank-deficient Y: on two near-parallel control units, LOOCV grid-search ATT diverges ~6% between Rust (`trop_global.py:688`) and Python fallback (`trop_global.py:753`). Either grid-winner ties are broken differently or the per-λ solver reaches different stationary points under rank deficiency. Audit finding #23 flagged this surface. `@pytest.mark.xfail(strict=True)` in `tests/test_rust_backend.py::TestTROPRustEdgeCaseParity::test_grid_search_rank_deficient_Y` baselines the gap. | `trop_global.py`, `rust/` | follow-up | Medium |
8888
| TROP Rust vs Python bootstrap SE divergence under fixed seed: `seed=42` on a tiny panel produces ~28% bootstrap-SE gap. Root cause: Rust bootstrap uses its own RNG (`rand` crate) while Python uses `numpy.random.default_rng`; same seed value maps to different bytestreams across backends. Audit axis-H (RNG/seed) adjacent. `@pytest.mark.xfail(strict=True)` in `tests/test_rust_backend.py::TestTROPRustEdgeCaseParity::test_bootstrap_seed_reproducibility` baselines the gap. Unifying RNG (threading a numpy-generated seed-sequence into Rust, or porting Python to ChaCha) would close it. | `trop_global.py`, `rust/` | follow-up | Medium |
8989

diff_diff/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
_rust_compute_robust_vcov,
2222
_rust_project_simplex,
2323
_rust_solve_ols,
24-
_rust_synthetic_weights,
2524
)
2625

2726
from diff_diff.bacon import (

diff_diff/_backend.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
try:
2020
from diff_diff._rust_backend import (
2121
generate_bootstrap_weights_batch as _rust_bootstrap_weights,
22-
compute_synthetic_weights as _rust_synthetic_weights,
2322
project_simplex as _rust_project_simplex,
2423
solve_ols as _rust_solve_ols,
2524
compute_robust_vcov as _rust_compute_robust_vcov,
@@ -43,7 +42,6 @@
4342
except ImportError:
4443
_rust_available = False
4544
_rust_bootstrap_weights = None
46-
_rust_synthetic_weights = None
4745
_rust_project_simplex = None
4846
_rust_solve_ols = None
4947
_rust_compute_robust_vcov = None

diff_diff/prep.py

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
re-exported here for backward compatibility.
1010
"""
1111

12+
import warnings
1213
from typing import Any, Dict, List, Optional, Tuple, Union
1314

1415
import numpy as np
@@ -33,7 +34,7 @@
3334
compute_replicate_if_variance,
3435
compute_survey_if_variance,
3536
)
36-
from diff_diff.utils import compute_synthetic_weights
37+
from diff_diff.utils import _compute_noise_level, _sc_weight_fw
3738

3839
# Constants for rank_control_units
3940
_SIMILARITY_THRESHOLD_SD = 0.5 # Controls within this many SDs are "similar"
@@ -986,8 +987,54 @@ def rank_control_units(
986987
# -------------------------------------------------------------------------
987988
# Compute outcome trend scores
988989
# -------------------------------------------------------------------------
989-
# Synthetic weights (higher = better match)
990-
synthetic_weights = compute_synthetic_weights(Y_control, Y_treated_mean, lambda_reg=lambda_reg)
990+
# Synthetic weights (higher = better match). Inlined Frank-Wolfe:
991+
# routes through the shared `_sc_weight_fw` dispatcher (same solver
992+
# SyntheticDiD.fit uses, matching R synthdid::sc.weight.fw). Replaces
993+
# the former `compute_synthetic_weights` wrapper, which had divergent
994+
# PGD implementations across Rust/Python backends (audit finding #22).
995+
_Y_control = np.ascontiguousarray(Y_control, dtype=np.float64)
996+
_Y_treated_mean = np.ascontiguousarray(Y_treated_mean, dtype=np.float64)
997+
_n_pre, _n_control = _Y_control.shape
998+
if _n_control == 0:
999+
synthetic_weights = np.array([], dtype=np.float64)
1000+
elif _n_control == 1:
1001+
synthetic_weights = np.array([1.0])
1002+
else:
1003+
# FW objective `zeta² ||w||² + (1/N) ||Aw - b||²` matches the
1004+
# intended `||Aw - b||² + lambda_reg ||w||²` on the simplex under
1005+
# `zeta = sqrt(lambda_reg / N)`. intercept=False because this QP
1006+
# does no column-centering.
1007+
_zeta = float(np.sqrt(lambda_reg / _n_pre)) if lambda_reg > 0 else 0.0
1008+
# Scale stopping threshold by noise level (R synthdid convention)
1009+
# so the relative convergence criterion stays meaningful at any
1010+
# data magnitude.
1011+
_sigma = _compute_noise_level(_Y_control)
1012+
_min_decrease = 1e-5 * max(_sigma, 1e-12)
1013+
_Y_fw = np.column_stack([_Y_control, _Y_treated_mean])
1014+
# Suppress non-convergence warnings: this is a heuristic ranking
1015+
# score, not a statistical estimate.
1016+
with warnings.catch_warnings():
1017+
warnings.filterwarnings(
1018+
"ignore",
1019+
message=r".*did not converge.*",
1020+
category=UserWarning,
1021+
)
1022+
synthetic_weights = _sc_weight_fw(
1023+
_Y_fw,
1024+
zeta=_zeta,
1025+
intercept=False,
1026+
min_decrease=_min_decrease,
1027+
max_iter=1000,
1028+
)
1029+
# Set small weights to zero for interpretability, then renormalize.
1030+
synthetic_weights = np.asarray(synthetic_weights, dtype=np.float64)
1031+
_min_weight = 1e-6
1032+
synthetic_weights[synthetic_weights < _min_weight] = 0.0
1033+
_total = float(np.sum(synthetic_weights))
1034+
if _total > 0:
1035+
synthetic_weights = synthetic_weights / _total
1036+
else:
1037+
synthetic_weights = np.ones(_n_control) / _n_control
9911038

9921039
# RMSE for each control vs treated mean (use nanmean to handle missing data)
9931040
rmse_scores = []

diff_diff/utils.py

Lines changed: 4 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from diff_diff._backend import (
1818
HAS_RUST_BACKEND,
1919
_rust_project_simplex,
20-
_rust_synthetic_weights,
2120
_rust_sdid_unit_weights,
2221
_rust_compute_time_weights,
2322
_rust_compute_noise_level,
@@ -1131,115 +1130,10 @@ def equivalence_test_trends(
11311130
}
11321131

11331132

1134-
def compute_synthetic_weights(
1135-
Y_control: np.ndarray, Y_treated: np.ndarray, lambda_reg: float = 0.0, min_weight: float = 1e-6
1136-
) -> np.ndarray:
1137-
"""
1138-
Compute synthetic control unit weights using constrained optimization.
1139-
1140-
Finds weights ω that minimize the squared difference between the
1141-
weighted average of control unit outcomes and the treated unit outcomes
1142-
during pre-treatment periods.
1143-
1144-
Parameters
1145-
----------
1146-
Y_control : np.ndarray
1147-
Control unit outcomes matrix of shape (n_pre_periods, n_control_units).
1148-
Each column is a control unit, each row is a pre-treatment period.
1149-
Y_treated : np.ndarray
1150-
Treated unit mean outcomes of shape (n_pre_periods,).
1151-
Average across treated units for each pre-treatment period.
1152-
lambda_reg : float, default=0.0
1153-
L2 regularization parameter. Larger values shrink weights toward
1154-
uniform (1/n_control). Helps prevent overfitting when n_pre < n_control.
1155-
min_weight : float, default=1e-6
1156-
Minimum weight threshold. Weights below this are set to zero.
1157-
1158-
Returns
1159-
-------
1160-
np.ndarray
1161-
Unit weights of shape (n_control_units,) that sum to 1.
1162-
1163-
Notes
1164-
-----
1165-
Solves the quadratic program:
1166-
1167-
min_ω ||Y_treated - Y_control @ ω||² + λ||ω - 1/n||²
1168-
s.t. ω >= 0, sum(ω) = 1
1169-
1170-
Uses a simplified coordinate descent approach with projection onto simplex.
1171-
"""
1172-
n_pre, n_control = Y_control.shape
1173-
1174-
if n_control == 0:
1175-
return np.asarray([])
1176-
1177-
if n_control == 1:
1178-
return np.asarray([1.0])
1179-
1180-
# Use Rust backend if available
1181-
if HAS_RUST_BACKEND:
1182-
Y_control = np.ascontiguousarray(Y_control, dtype=np.float64)
1183-
Y_treated = np.ascontiguousarray(Y_treated, dtype=np.float64)
1184-
weights = _rust_synthetic_weights(
1185-
Y_control, Y_treated, lambda_reg, _OPTIMIZATION_MAX_ITER, _OPTIMIZATION_TOL
1186-
)
1187-
else:
1188-
# Fallback to NumPy implementation
1189-
weights = _compute_synthetic_weights_numpy(Y_control, Y_treated, lambda_reg)
1190-
1191-
# Set small weights to zero for interpretability
1192-
weights[weights < min_weight] = 0
1193-
if np.sum(weights) > 0:
1194-
weights = weights / np.sum(weights)
1195-
else:
1196-
# Fallback to uniform if all weights are zeroed
1197-
weights = np.ones(n_control) / n_control
1198-
1199-
return np.asarray(weights)
1200-
1201-
1202-
def _compute_synthetic_weights_numpy(
1203-
Y_control: np.ndarray,
1204-
Y_treated: np.ndarray,
1205-
lambda_reg: float = 0.0,
1206-
) -> np.ndarray:
1207-
"""NumPy fallback implementation of compute_synthetic_weights."""
1208-
n_pre, n_control = Y_control.shape
1209-
1210-
# Initialize with uniform weights
1211-
weights = np.ones(n_control) / n_control
1212-
1213-
# Precompute matrices for optimization
1214-
# Objective: ||Y_treated - Y_control @ w||^2 + lambda * ||w - w_uniform||^2
1215-
# = w' @ (Y_control' @ Y_control + lambda * I) @ w - 2 * (Y_control' @ Y_treated + lambda * w_uniform)' @ w + const
1216-
YtY = Y_control.T @ Y_control
1217-
YtT = Y_control.T @ Y_treated
1218-
w_uniform = np.ones(n_control) / n_control
1219-
1220-
# Add regularization
1221-
H = YtY + lambda_reg * np.eye(n_control)
1222-
f = YtT + lambda_reg * w_uniform
1223-
1224-
# Solve with projected gradient descent
1225-
# Project onto probability simplex
1226-
step_size = 1.0 / (np.linalg.norm(H, 2) + _NUMERICAL_EPS)
1227-
1228-
for _ in range(_OPTIMIZATION_MAX_ITER):
1229-
weights_old = weights.copy()
1230-
1231-
# Gradient step: minimize ||Y - Y_control @ w||^2
1232-
grad = H @ weights - f
1233-
weights = weights - step_size * grad
1234-
1235-
# Project onto simplex (sum to 1, non-negative)
1236-
weights = _project_simplex(weights)
1237-
1238-
# Check convergence
1239-
if np.linalg.norm(weights - weights_old) < _OPTIMIZATION_TOL:
1240-
break
1241-
1242-
return weights
1133+
# compute_synthetic_weights and _compute_synthetic_weights_numpy removed in the
1134+
# silent-failures audit post-cleanup (finding #22). The one caller
1135+
# (`diff_diff.prep.rank_control_units`) inlines Frank-Wolfe directly via
1136+
# `_sc_weight_fw`, matching R `synthdid::sc.weight.fw`. See `prep.py:990`.
12431137

12441138

12451139
def _project_simplex(v: np.ndarray) -> np.ndarray:

tests/test_estimators.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3124,25 +3124,9 @@ def test_project_simplex(self):
31243124
assert abs(np.sum(projected) - 1.0) < 1e-6
31253125
assert np.all(projected >= 0)
31263126

3127-
def test_compute_synthetic_weights(self):
3128-
"""Test synthetic weight computation."""
3129-
from diff_diff.utils import compute_synthetic_weights
3130-
3131-
np.random.seed(42)
3132-
n_pre = 5
3133-
n_control = 10
3134-
3135-
Y_control = np.random.randn(n_pre, n_control)
3136-
Y_treated = np.random.randn(n_pre)
3137-
3138-
weights = compute_synthetic_weights(Y_control, Y_treated)
3139-
3140-
# Weights should sum to 1
3141-
assert abs(np.sum(weights) - 1.0) < 1e-6
3142-
# Weights should be non-negative
3143-
assert np.all(weights >= 0)
3144-
# Should have correct length
3145-
assert len(weights) == n_control
3127+
# test_compute_synthetic_weights removed in the silent-failures audit
3128+
# post-cleanup (finding #22). Helper deleted; behavior now covered via
3129+
# tests/test_prep.py::TestRankControlUnits (its sole caller).
31463130

31473131
def test_compute_time_weights(self):
31483132
"""Test time weight computation with Frank-Wolfe solver."""

tests/test_prep.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -766,6 +766,43 @@ def test_single_control_unit(self):
766766
# Single control should get score of 1.0 (best possible)
767767
assert result["quality_score"].iloc[0] == 1.0
768768

769+
def test_extreme_Y_scale_synthetic_weight_column(self):
770+
"""Finding #22 (post-audit cleanup): `synthetic_weight` column must
771+
remain a valid non-degenerate simplex vector even at extreme Y
772+
scale (Y ~ 1e9). The previous `compute_synthetic_weights` wrapper
773+
had two bugs here: Rust PGD collapsed to a single vertex, Python
774+
PGD stalled at uniform. The inlined Frank-Wolfe solver in
775+
``rank_control_units`` handles both cases correctly."""
776+
from diff_diff.prep import rank_control_units
777+
778+
data = generate_did_data(n_units=12, n_periods=8, seed=42)
779+
# Shift outcomes to extreme scale — the exact condition the deleted
780+
# wrapper mishandled.
781+
data = data.copy()
782+
data["outcome"] = data["outcome"] + 1e9
783+
784+
result = rank_control_units(
785+
data,
786+
unit_column="unit",
787+
time_column="period",
788+
outcome_column="outcome",
789+
treatment_column="treated",
790+
)
791+
792+
weights = result["synthetic_weight"].to_numpy()
793+
# Valid simplex: non-negative, sums to 1.
794+
assert np.all(weights >= 0), "synthetic_weight must be non-negative"
795+
assert abs(weights.sum() - 1.0) < 1e-10, (
796+
f"synthetic_weight should sum to 1.0, got {weights.sum()}"
797+
)
798+
# Non-degenerate: at least 2 controls receive non-trivial weight.
799+
# This guards the Rust-PGD collapse-to-one-vertex bug that
800+
# previously fired at Y ~ 1e9 under the deleted wrapper.
801+
assert int(np.sum(weights > 1e-6)) >= 2, (
802+
f"synthetic_weight collapsed to a single vertex at extreme Y "
803+
f"scale; n_nonzero={int(np.sum(weights > 1e-6))}. weights={weights}"
804+
)
805+
769806

770807
class TestGenerateStaggeredData:
771808
"""Tests for generate_staggered_data function."""

0 commit comments

Comments
 (0)