Skip to content

Commit 9146f1e

Browse files
authored
Merge pull request igerber#344 from igerber/fix/delete-synthetic-weights-helper
Delete compute_synthetic_weights shim; inline Frank-Wolfe in rank_control_units
2 parents 9c787f0 + 5dc6ba0 commit 9146f1e

9 files changed

Lines changed: 221 additions & 370 deletions

File tree

TODO.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ Deferred items from PR reviews that were not addressed before merge.
8383
| Weighted CR2 Bell-McCaffrey cluster-robust (`vcov_type="hc2_bm"` + `cluster_ids` + `weights`) currently raises `NotImplementedError`. Weighted hat matrix and residual rebalancing need threading per clubSandwich WLS handling. | `linalg.py::_compute_cr2_bm` | Phase 1a | Medium |
8484
| Regenerate `benchmarks/data/clubsandwich_cr2_golden.json` from R (`Rscript benchmarks/R/generate_clubsandwich_golden.R`). Current JSON has `source: python_self_reference` as a stability anchor until an authoritative R run. | `benchmarks/R/generate_clubsandwich_golden.R` | Phase 1a | Medium |
8585
| `honest_did.py:1907` `np.linalg.solve(A_sys, b_sys) / except LinAlgError: continue` is a silent basis-rejection in the vertex-enumeration loop that is algorithmically intentional (try the next basis). Consider surfacing a count of rejected bases as a diagnostic when ARP enumeration exhausts, so users see when the vertex search was heavily constrained. Not a silent failure in the sense of the Phase 2 audit (the algorithm is supposed to skip), but the diagnostic would help debug borderline cases. | `honest_did.py` | #334 | Low |
86-
| `compute_synthetic_weights` backend algorithm mismatch: Rust path uses Frank-Wolfe (`_rust_synthetic_weights` in `utils.py:1184`); Python fallback uses projected gradient descent (`_compute_synthetic_weights_numpy` in `utils.py:1228`). Both solve the same constrained QP but converge to different simplex vertices on near-degenerate / extreme-scale inputs (e.g. `Y~1e9`, or near-singular `Y'Y`). Unified backend (one algorithm) would close the parity gap surfaced by audit finding #22. Two `@pytest.mark.xfail(strict=True)` tests in `tests/test_rust_backend.py::TestSyntheticWeightsBackendParity` baseline the divergence so we notice when/if the algorithms align. | `utils.py`, `rust/` | follow-up | Medium |
86+
| Rust `compute_synthetic_weights` + `compute_synthetic_weights_internal` (now dead code) can be removed from `rust/src/weights.rs:43-117` in a future Rust-cleanup PR. Python-side wrapper was deleted (post-audit cleanup for finding #22) and its sole caller now inlines Frank-Wolfe via `_sc_weight_fw`. The Rust symbol remains callable via `from diff_diff._rust_backend import compute_synthetic_weights` but no Python code calls it. Removal requires `maturin develop` rebuild. No functional impact of leaving it. | `rust/src/weights.rs` | follow-up | Low |
8787
| TROP Rust vs Python grid-search divergence on rank-deficient Y: on two near-parallel control units, LOOCV grid-search ATT diverges ~6% between Rust (`trop_global.py:688`) and Python fallback (`trop_global.py:753`). Either grid-winner ties are broken differently or the per-λ solver reaches different stationary points under rank deficiency. Audit finding #23 flagged this surface. `@pytest.mark.xfail(strict=True)` in `tests/test_rust_backend.py::TestTROPRustEdgeCaseParity::test_grid_search_rank_deficient_Y` baselines the gap. | `trop_global.py`, `rust/` | follow-up | Medium |
8888
| TROP Rust vs Python bootstrap SE divergence under fixed seed: `seed=42` on a tiny panel produces ~28% bootstrap-SE gap. Root cause: Rust bootstrap uses its own RNG (`rand` crate) while Python uses `numpy.random.default_rng`; same seed value maps to different bytestreams across backends. Audit axis-H (RNG/seed) adjacent. `@pytest.mark.xfail(strict=True)` in `tests/test_rust_backend.py::TestTROPRustEdgeCaseParity::test_bootstrap_seed_reproducibility` baselines the gap. Unifying RNG (threading a numpy-generated seed-sequence into Rust, or porting Python to ChaCha) would close it. | `trop_global.py`, `rust/` | follow-up | Medium |
8989
| `bias_corrected_local_linear`: extend golden parity to `kernel="triangular"` and `kernel="uniform"` (currently epa-only; all three kernels share `kernel_W` and the `lprobust` math, so parity is expected but not separately asserted). | `benchmarks/R/generate_nprobust_lprobust_golden.R`, `tests/test_bias_corrected_lprobust.py` | Phase 1c | Low |

diff_diff/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
_rust_compute_robust_vcov,
2222
_rust_project_simplex,
2323
_rust_solve_ols,
24-
_rust_synthetic_weights,
2524
)
2625

2726
from diff_diff.bacon import (

diff_diff/_backend.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
try:
2020
from diff_diff._rust_backend import (
2121
generate_bootstrap_weights_batch as _rust_bootstrap_weights,
22-
compute_synthetic_weights as _rust_synthetic_weights,
2322
project_simplex as _rust_project_simplex,
2423
solve_ols as _rust_solve_ols,
2524
compute_robust_vcov as _rust_compute_robust_vcov,
@@ -43,7 +42,6 @@
4342
except ImportError:
4443
_rust_available = False
4544
_rust_bootstrap_weights = None
46-
_rust_synthetic_weights = None
4745
_rust_project_simplex = None
4846
_rust_solve_ols = None
4947
_rust_compute_robust_vcov = None
@@ -66,7 +64,6 @@
6664
# Force pure Python mode - disable Rust even if available
6765
HAS_RUST_BACKEND = False
6866
_rust_bootstrap_weights = None
69-
_rust_synthetic_weights = None
7067
_rust_project_simplex = None
7168
_rust_solve_ols = None
7269
_rust_compute_robust_vcov = None
@@ -115,7 +112,6 @@ def rust_backend_info():
115112
"HAS_RUST_BACKEND",
116113
"rust_backend_info",
117114
"_rust_bootstrap_weights",
118-
"_rust_synthetic_weights",
119115
"_rust_project_simplex",
120116
"_rust_solve_ols",
121117
"_rust_compute_robust_vcov",

diff_diff/prep.py

Lines changed: 74 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
re-exported here for backward compatibility.
1010
"""
1111

12+
import warnings
1213
from typing import Any, Dict, List, Optional, Tuple, Union
1314

1415
import numpy as np
@@ -36,7 +37,7 @@
3637
compute_replicate_if_variance,
3738
compute_survey_if_variance,
3839
)
39-
from diff_diff.utils import compute_synthetic_weights
40+
from diff_diff.utils import _compute_noise_level, _sc_weight_fw
4041

4142
# Constants for rank_control_units
4243
_SIMILARITY_THRESHOLD_SD = 0.5 # Controls within this many SDs are "similar"
@@ -837,7 +838,10 @@ def rank_control_units(
837838
- quality_score: Combined quality score (0-1, higher is better)
838839
- outcome_trend_score: Pre-treatment outcome trend similarity
839840
- covariate_score: Covariate match score (NaN if no covariates)
840-
- synthetic_weight: Weight from synthetic control optimization
841+
- synthetic_weight: Informational heuristic weight from a single-pass
842+
uncentered Frank-Wolfe solve; does NOT factor into ``quality_score``
843+
(ranking) and is NOT the canonical SDID unit weight. For canonical
844+
SDID weights use ``SyntheticDiD.fit()``.
841845
- pre_trend_rmse: RMSE of pre-treatment outcome vs treated mean
842846
- is_required: Whether unit was in require_units
843847
@@ -989,8 +993,74 @@ def rank_control_units(
989993
# -------------------------------------------------------------------------
990994
# Compute outcome trend scores
991995
# -------------------------------------------------------------------------
992-
# Synthetic weights (higher = better match)
993-
synthetic_weights = compute_synthetic_weights(Y_control, Y_treated_mean, lambda_reg=lambda_reg)
996+
# Informational `synthetic_weight` column. This is a RANKING HEURISTIC,
997+
# not an estimator: it gives a rough "which controls would a synthetic
998+
# regression weight heavily" signal that's reported alongside RMSE and
999+
# covariate distance. The actual ranking (`quality_score`) is computed
1000+
# below from `outcome_trend_score` (RMSE-based) + `covariate_score`; the
1001+
# `synthetic_weight` column does NOT factor into the ranking decision.
1002+
#
1003+
# Solver choice. We use a single-pass uncentered Frank-Wolfe via the
1004+
# shared `_sc_weight_fw` dispatcher to solve:
1005+
#
1006+
# min_w ||Y_treated_mean - Y_control @ w||^2 + lambda_reg * ||w||^2
1007+
# s.t. w >= 0, sum(w) = 1
1008+
#
1009+
# Mapped to the FW objective `zeta^2 ||w||^2 + (1/N) ||Aw - b||^2` via
1010+
# `zeta = sqrt(lambda_reg / N)`. intercept=False because this QP does
1011+
# no column-centering, max_iter=1000 to bound ranking-loop cost,
1012+
# min_weight=1e-6 post-processing for interpretability.
1013+
#
1014+
# NOTE — this is INTENTIONALLY NOT the canonical SDID / R
1015+
# `synthdid::sc.weight.fw` two-pass unit-weight procedure (that uses
1016+
# intercept=TRUE, 100-iter -> sparsify -> 10000-iter). SDID estimation
1017+
# still uses that canonical path in `_sc_weight_fw_numpy` at
1018+
# `utils.py:_sc_weight_fw_numpy` via `compute_sdid_unit_weights`; this
1019+
# ranking heuristic uses a simpler single-pass call to the same solver
1020+
# for a cheap diagnostic score.
1021+
#
1022+
# Replaces the former `compute_synthetic_weights` wrapper whose Rust
1023+
# and Python backends had divergent PGD implementations (audit
1024+
# finding #22). Net effect: users on default `lambda_reg=0` with
1025+
# typical data see `synthetic_weight` values that agree with the old
1026+
# code to ~1e-7; extreme Y or `lambda_reg > 0` cases produce values
1027+
# that differ from the old code (which was mathematically wrong).
1028+
_Y_control = np.ascontiguousarray(Y_control, dtype=np.float64)
1029+
_Y_treated_mean = np.ascontiguousarray(Y_treated_mean, dtype=np.float64)
1030+
_n_pre, _n_control = _Y_control.shape
1031+
if _n_control == 0:
1032+
synthetic_weights = np.array([], dtype=np.float64)
1033+
elif _n_control == 1:
1034+
synthetic_weights = np.array([1.0])
1035+
else:
1036+
_zeta = float(np.sqrt(lambda_reg / _n_pre)) if lambda_reg > 0 else 0.0
1037+
# Scale stopping threshold by noise level so convergence stays
1038+
# meaningful at any data magnitude.
1039+
_sigma = _compute_noise_level(_Y_control)
1040+
_min_decrease = 1e-5 * max(_sigma, 1e-12)
1041+
_Y_fw = np.column_stack([_Y_control, _Y_treated_mean])
1042+
with warnings.catch_warnings():
1043+
warnings.filterwarnings(
1044+
"ignore",
1045+
message=r".*did not converge.*",
1046+
category=UserWarning,
1047+
)
1048+
synthetic_weights = _sc_weight_fw(
1049+
_Y_fw,
1050+
zeta=_zeta,
1051+
intercept=False,
1052+
min_decrease=_min_decrease,
1053+
max_iter=1000,
1054+
)
1055+
# Set small weights to zero for interpretability, then renormalize.
1056+
synthetic_weights = np.asarray(synthetic_weights, dtype=np.float64)
1057+
_min_weight = 1e-6
1058+
synthetic_weights[synthetic_weights < _min_weight] = 0.0
1059+
_total = float(np.sum(synthetic_weights))
1060+
if _total > 0:
1061+
synthetic_weights = synthetic_weights / _total
1062+
else:
1063+
synthetic_weights = np.ones(_n_control) / _n_control
9941064

9951065
# RMSE for each control vs treated mean (use nanmean to handle missing data)
9961066
rmse_scores = []

diff_diff/utils.py

Lines changed: 8 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from diff_diff._backend import (
1818
HAS_RUST_BACKEND,
1919
_rust_project_simplex,
20-
_rust_synthetic_weights,
2120
_rust_sdid_unit_weights,
2221
_rust_compute_time_weights,
2322
_rust_compute_noise_level,
@@ -1131,115 +1130,14 @@ def equivalence_test_trends(
11311130
}
11321131

11331132

1134-
def compute_synthetic_weights(
1135-
Y_control: np.ndarray, Y_treated: np.ndarray, lambda_reg: float = 0.0, min_weight: float = 1e-6
1136-
) -> np.ndarray:
1137-
"""
1138-
Compute synthetic control unit weights using constrained optimization.
1139-
1140-
Finds weights ω that minimize the squared difference between the
1141-
weighted average of control unit outcomes and the treated unit outcomes
1142-
during pre-treatment periods.
1143-
1144-
Parameters
1145-
----------
1146-
Y_control : np.ndarray
1147-
Control unit outcomes matrix of shape (n_pre_periods, n_control_units).
1148-
Each column is a control unit, each row is a pre-treatment period.
1149-
Y_treated : np.ndarray
1150-
Treated unit mean outcomes of shape (n_pre_periods,).
1151-
Average across treated units for each pre-treatment period.
1152-
lambda_reg : float, default=0.0
1153-
L2 regularization parameter. Larger values shrink weights toward
1154-
uniform (1/n_control). Helps prevent overfitting when n_pre < n_control.
1155-
min_weight : float, default=1e-6
1156-
Minimum weight threshold. Weights below this are set to zero.
1157-
1158-
Returns
1159-
-------
1160-
np.ndarray
1161-
Unit weights of shape (n_control_units,) that sum to 1.
1162-
1163-
Notes
1164-
-----
1165-
Solves the quadratic program:
1166-
1167-
min_ω ||Y_treated - Y_control @ ω||² + λ||ω - 1/n||²
1168-
s.t. ω >= 0, sum(ω) = 1
1169-
1170-
Uses a simplified coordinate descent approach with projection onto simplex.
1171-
"""
1172-
n_pre, n_control = Y_control.shape
1173-
1174-
if n_control == 0:
1175-
return np.asarray([])
1176-
1177-
if n_control == 1:
1178-
return np.asarray([1.0])
1179-
1180-
# Use Rust backend if available
1181-
if HAS_RUST_BACKEND:
1182-
Y_control = np.ascontiguousarray(Y_control, dtype=np.float64)
1183-
Y_treated = np.ascontiguousarray(Y_treated, dtype=np.float64)
1184-
weights = _rust_synthetic_weights(
1185-
Y_control, Y_treated, lambda_reg, _OPTIMIZATION_MAX_ITER, _OPTIMIZATION_TOL
1186-
)
1187-
else:
1188-
# Fallback to NumPy implementation
1189-
weights = _compute_synthetic_weights_numpy(Y_control, Y_treated, lambda_reg)
1190-
1191-
# Set small weights to zero for interpretability
1192-
weights[weights < min_weight] = 0
1193-
if np.sum(weights) > 0:
1194-
weights = weights / np.sum(weights)
1195-
else:
1196-
# Fallback to uniform if all weights are zeroed
1197-
weights = np.ones(n_control) / n_control
1198-
1199-
return np.asarray(weights)
1200-
1201-
1202-
def _compute_synthetic_weights_numpy(
1203-
Y_control: np.ndarray,
1204-
Y_treated: np.ndarray,
1205-
lambda_reg: float = 0.0,
1206-
) -> np.ndarray:
1207-
"""NumPy fallback implementation of compute_synthetic_weights."""
1208-
n_pre, n_control = Y_control.shape
1209-
1210-
# Initialize with uniform weights
1211-
weights = np.ones(n_control) / n_control
1212-
1213-
# Precompute matrices for optimization
1214-
# Objective: ||Y_treated - Y_control @ w||^2 + lambda * ||w - w_uniform||^2
1215-
# = w' @ (Y_control' @ Y_control + lambda * I) @ w - 2 * (Y_control' @ Y_treated + lambda * w_uniform)' @ w + const
1216-
YtY = Y_control.T @ Y_control
1217-
YtT = Y_control.T @ Y_treated
1218-
w_uniform = np.ones(n_control) / n_control
1219-
1220-
# Add regularization
1221-
H = YtY + lambda_reg * np.eye(n_control)
1222-
f = YtT + lambda_reg * w_uniform
1223-
1224-
# Solve with projected gradient descent
1225-
# Project onto probability simplex
1226-
step_size = 1.0 / (np.linalg.norm(H, 2) + _NUMERICAL_EPS)
1227-
1228-
for _ in range(_OPTIMIZATION_MAX_ITER):
1229-
weights_old = weights.copy()
1230-
1231-
# Gradient step: minimize ||Y - Y_control @ w||^2
1232-
grad = H @ weights - f
1233-
weights = weights - step_size * grad
1234-
1235-
# Project onto simplex (sum to 1, non-negative)
1236-
weights = _project_simplex(weights)
1237-
1238-
# Check convergence
1239-
if np.linalg.norm(weights - weights_old) < _OPTIMIZATION_TOL:
1240-
break
1241-
1242-
return weights
1133+
# compute_synthetic_weights and _compute_synthetic_weights_numpy removed in the
1134+
# silent-failures audit post-cleanup (finding #22). The one caller
1135+
# (`diff_diff.prep.rank_control_units`) inlines a single-pass, uncentered
1136+
# Frank-Wolfe via the shared `_sc_weight_fw` dispatcher — a ranking heuristic,
1137+
# NOT the canonical SDID/R `synthdid::sc.weight.fw` two-pass procedure
1138+
# (intercept=True, 100-iter -> sparsify -> 10000-iter). Canonical SDID unit
1139+
# weights go through `compute_sdid_unit_weights` (see `_sc_weight_fw_numpy`
1140+
# below and REGISTRY.md SDID section).
12431141

12441142

12451143
def _project_simplex(v: np.ndarray) -> np.ndarray:

tests/test_estimators.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3124,25 +3124,9 @@ def test_project_simplex(self):
31243124
assert abs(np.sum(projected) - 1.0) < 1e-6
31253125
assert np.all(projected >= 0)
31263126

3127-
def test_compute_synthetic_weights(self):
3128-
"""Test synthetic weight computation."""
3129-
from diff_diff.utils import compute_synthetic_weights
3130-
3131-
np.random.seed(42)
3132-
n_pre = 5
3133-
n_control = 10
3134-
3135-
Y_control = np.random.randn(n_pre, n_control)
3136-
Y_treated = np.random.randn(n_pre)
3137-
3138-
weights = compute_synthetic_weights(Y_control, Y_treated)
3139-
3140-
# Weights should sum to 1
3141-
assert abs(np.sum(weights) - 1.0) < 1e-6
3142-
# Weights should be non-negative
3143-
assert np.all(weights >= 0)
3144-
# Should have correct length
3145-
assert len(weights) == n_control
3127+
# test_compute_synthetic_weights removed in the silent-failures audit
3128+
# post-cleanup (finding #22). Helper deleted; behavior now covered via
3129+
# tests/test_prep.py::TestRankControlUnits (its sole caller).
31463130

31473131
def test_compute_time_weights(self):
31483132
"""Test time weight computation with Frank-Wolfe solver."""

0 commit comments

Comments
 (0)