|
9 | 9 | re-exported here for backward compatibility. |
10 | 10 | """ |
11 | 11 |
|
| 12 | +import warnings |
12 | 13 | from typing import Any, Dict, List, Optional, Tuple, Union |
13 | 14 |
|
14 | 15 | import numpy as np |
|
36 | 37 | compute_replicate_if_variance, |
37 | 38 | compute_survey_if_variance, |
38 | 39 | ) |
39 | | -from diff_diff.utils import compute_synthetic_weights |
| 40 | +from diff_diff.utils import _compute_noise_level, _sc_weight_fw |
40 | 41 |
|
41 | 42 | # Constants for rank_control_units |
42 | 43 | _SIMILARITY_THRESHOLD_SD = 0.5 # Controls within this many SDs are "similar" |
@@ -837,7 +838,10 @@ def rank_control_units( |
837 | 838 | - quality_score: Combined quality score (0-1, higher is better) |
838 | 839 | - outcome_trend_score: Pre-treatment outcome trend similarity |
839 | 840 | - covariate_score: Covariate match score (NaN if no covariates) |
840 | | - - synthetic_weight: Weight from synthetic control optimization |
| 841 | + - synthetic_weight: Informational heuristic weight from a single-pass |
| 842 | + uncentered Frank-Wolfe solve; does NOT factor into ``quality_score`` |
| 843 | + (ranking) and is NOT the canonical SDID unit weight. For canonical |
| 844 | + SDID weights use ``SyntheticDiD.fit()``. |
841 | 845 | - pre_trend_rmse: RMSE of pre-treatment outcome vs treated mean |
842 | 846 | - is_required: Whether unit was in require_units |
843 | 847 |
|
@@ -989,8 +993,74 @@ def rank_control_units( |
989 | 993 | # ------------------------------------------------------------------------- |
990 | 994 | # Compute outcome trend scores |
991 | 995 | # ------------------------------------------------------------------------- |
992 | | - # Synthetic weights (higher = better match) |
993 | | - synthetic_weights = compute_synthetic_weights(Y_control, Y_treated_mean, lambda_reg=lambda_reg) |
| 996 | + # Informational `synthetic_weight` column. This is a RANKING HEURISTIC, |
| 997 | + # not an estimator: it gives a rough "which controls would a synthetic |
| 998 | + # regression weight heavily" signal that's reported alongside RMSE and |
| 999 | + # covariate distance. The actual ranking (`quality_score`) is computed |
| 1000 | + # below from `outcome_trend_score` (RMSE-based) + `covariate_score`; the |
| 1001 | + # `synthetic_weight` column does NOT factor into the ranking decision. |
| 1002 | + # |
| 1003 | + # Solver choice. We use a single-pass uncentered Frank-Wolfe via the |
| 1004 | + # shared `_sc_weight_fw` dispatcher to solve: |
| 1005 | + # |
| 1006 | + # min_w ||Y_treated_mean - Y_control @ w||^2 + lambda_reg * ||w||^2 |
| 1007 | + # s.t. w >= 0, sum(w) = 1 |
| 1008 | + # |
| 1009 | + # Mapped to the FW objective `zeta^2 ||w||^2 + (1/N) ||Aw - b||^2` via |
| 1010 | + # `zeta = sqrt(lambda_reg / N)`. intercept=False because this QP does |
| 1011 | + # no column-centering, max_iter=1000 to bound ranking-loop cost, |
| 1012 | + # min_weight=1e-6 post-processing for interpretability. |
| 1013 | + # |
| 1014 | + # NOTE — this is INTENTIONALLY NOT the canonical SDID / R |
| 1015 | + # `synthdid::sc.weight.fw` two-pass unit-weight procedure (that uses |
| 1016 | + # intercept=TRUE, 100-iter -> sparsify -> 10000-iter). SDID estimation |
| 1017 | + # still uses that canonical path in `_sc_weight_fw_numpy` at |
| 1018 | + # `utils.py:_sc_weight_fw_numpy` via `compute_sdid_unit_weights`; this |
| 1019 | + # ranking heuristic uses a simpler single-pass call to the same solver |
| 1020 | + # for a cheap diagnostic score. |
| 1021 | + # |
| 1022 | + # Replaces the former `compute_synthetic_weights` wrapper whose Rust |
| 1023 | + # and Python backends had divergent PGD implementations (audit |
| 1024 | + # finding #22). Net effect: users on default `lambda_reg=0` with |
| 1025 | + # typical data see `synthetic_weight` values that agree with the old |
| 1026 | + # code to ~1e-7; extreme Y or `lambda_reg > 0` cases produce values |
| 1027 | + # that differ from the old code (which was mathematically wrong). |
| 1028 | + _Y_control = np.ascontiguousarray(Y_control, dtype=np.float64) |
| 1029 | + _Y_treated_mean = np.ascontiguousarray(Y_treated_mean, dtype=np.float64) |
| 1030 | + _n_pre, _n_control = _Y_control.shape |
| 1031 | + if _n_control == 0: |
| 1032 | + synthetic_weights = np.array([], dtype=np.float64) |
| 1033 | + elif _n_control == 1: |
| 1034 | + synthetic_weights = np.array([1.0]) |
| 1035 | + else: |
| 1036 | + _zeta = float(np.sqrt(lambda_reg / _n_pre)) if lambda_reg > 0 else 0.0 |
| 1037 | + # Scale stopping threshold by noise level so convergence stays |
| 1038 | + # meaningful at any data magnitude. |
| 1039 | + _sigma = _compute_noise_level(_Y_control) |
| 1040 | + _min_decrease = 1e-5 * max(_sigma, 1e-12) |
| 1041 | + _Y_fw = np.column_stack([_Y_control, _Y_treated_mean]) |
| 1042 | + with warnings.catch_warnings(): |
| 1043 | + warnings.filterwarnings( |
| 1044 | + "ignore", |
| 1045 | + message=r".*did not converge.*", |
| 1046 | + category=UserWarning, |
| 1047 | + ) |
| 1048 | + synthetic_weights = _sc_weight_fw( |
| 1049 | + _Y_fw, |
| 1050 | + zeta=_zeta, |
| 1051 | + intercept=False, |
| 1052 | + min_decrease=_min_decrease, |
| 1053 | + max_iter=1000, |
| 1054 | + ) |
| 1055 | + # Set small weights to zero for interpretability, then renormalize. |
| 1056 | + synthetic_weights = np.asarray(synthetic_weights, dtype=np.float64) |
| 1057 | + _min_weight = 1e-6 |
| 1058 | + synthetic_weights[synthetic_weights < _min_weight] = 0.0 |
| 1059 | + _total = float(np.sum(synthetic_weights)) |
| 1060 | + if _total > 0: |
| 1061 | + synthetic_weights = synthetic_weights / _total |
| 1062 | + else: |
| 1063 | + synthetic_weights = np.ones(_n_control) / _n_control |
994 | 1064 |
|
995 | 1065 | # RMSE for each control vs treated mean (use nanmean to handle missing data) |
996 | 1066 | rmse_scores = [] |
|
0 commit comments