Skip to content

Commit e59610b

Browse files
igerberclaude
andcommitted
Unify Rust TROP inner solver to SVD; close finding #23 grid-search divergence
Closes the grid-search half of silent-failures finding #23 (TODO row 87). The `xfail(strict=True)` regression `test_grid_search_rank_deficient_Y` baselined a ~6% ATT divergence between Rust and Python on two near-parallel control units. Root cause: Rust's `solve_joint_no_lowrank` used iterative block coordinate descent (50 iter, tol=1e-8) while Python used SVD-based minimum-norm least squares. On rank-deficient Y the two solvers converge to different stationary points of the same objective. Python is canonical (SVD / minimum-norm least squares per Golub & Van Loan). Rust's iterative solver was a speed optimization, not a methodology choice. Port the Rust inner TWFE step to SVD-based WLS that mirrors Python's `np.linalg.lstsq(rcond=None)` step-for-step, with numpy-compatible `rcond = eps * max(n, k)`. Changes - rust/src/linalg.rs: promote ndarray_to_faer to pub(crate) so trop.rs can reuse it. - rust/src/trop.rs: new module-private solve_wls_svd helper — thin-SVD + rcond truncation, matches numpy's minimum-norm semantics. Rewrite solve_joint_no_lowrank body to flatten y/weights row-major, build the [intercept | unit_dummies[1..] | time_dummies[1..]] design matrix, apply sqrt-weights, and solve via solve_wls_svd. Function signature unchanged — all 4 call sites (LOOCV, FISTA TWFE step x2, bootstrap) benefit transitively. - tests/test_rust_backend.py: remove @pytest.mark.xfail from test_grid_search_rank_deficient_Y; the gap is closed. Bootstrap-seed test retains its xfail (row 87 RNG mismatch, out of scope). - docs/methodology/REGISTRY.md: update the TROP Global Estimation bullet at the existing `np.linalg.lstsq` line to note Rust and Python now both use SVD-based minimum-norm WLS with numpy-compatible rcond. - TODO.md: delete row 87 (grid-search divergence entry). Verification - maturin develop --release --features accelerate: clean build, no warnings. - pytest tests/test_rust_backend.py::TestTROPRustEdgeCaseParity: grid-search test now passes; bootstrap-seed test correctly xfails. - pytest tests/test_rust_backend.py -k TROP -m '': 23 passed, 1 xfailed, no regressions. - pytest tests/test_trop.py: 83 passed, 37 deselected (slow). - TestTROPGlobalRustVsNumpy (incl. lambda_nn=0 low-rank FISTA path): 8 passed — FISTA TWFE step unchanged in behavior on well-conditioned data. - grep for other 'for _ in 0..50' coordinate-descent patterns in rust/src/*.rs: none found. Non-goals - No changes to row 87 (bootstrap RNG mismatch — Rust rand crate vs numpy default_rng ~28% SE gap on seed=42). Separate PR. - No changes to linalg.rs::solve_ols (rcond=1e-7 is load-bearing for MultiPeriodDiD / DiD / TWFE). - No public API changes. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 6b24c42 commit e59610b

5 files changed

Lines changed: 148 additions & 96 deletions

File tree

TODO.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ Deferred items from PR reviews that were not addressed before merge.
8383
| Weighted CR2 Bell-McCaffrey cluster-robust (`vcov_type="hc2_bm"` + `cluster_ids` + `weights`) currently raises `NotImplementedError`. Weighted hat matrix and residual rebalancing need threading per clubSandwich WLS handling. | `linalg.py::_compute_cr2_bm` | Phase 1a | Medium |
8484
| Regenerate `benchmarks/data/clubsandwich_cr2_golden.json` from R (`Rscript benchmarks/R/generate_clubsandwich_golden.R`). Current JSON has `source: python_self_reference` as a stability anchor until an authoritative R run. | `benchmarks/R/generate_clubsandwich_golden.R` | Phase 1a | Medium |
8585
| `honest_did.py:1907` `np.linalg.solve(A_sys, b_sys) / except LinAlgError: continue` is a silent basis-rejection in the vertex-enumeration loop that is algorithmically intentional (try the next basis). Consider surfacing a count of rejected bases as a diagnostic when ARP enumeration exhausts, so users see when the vertex search was heavily constrained. Not a silent failure in the sense of the Phase 2 audit (the algorithm is supposed to skip), but the diagnostic would help debug borderline cases. | `honest_did.py` | #334 | Low |
86-
| TROP Rust vs Python grid-search divergence on rank-deficient Y: on two near-parallel control units, LOOCV grid-search ATT diverges ~6% between Rust (`trop_global.py:688`) and Python fallback (`trop_global.py:753`). Either grid-winner ties are broken differently or the per-λ solver reaches different stationary points under rank deficiency. Audit finding #23 flagged this surface. `@pytest.mark.xfail(strict=True)` in `tests/test_rust_backend.py::TestTROPRustEdgeCaseParity::test_grid_search_rank_deficient_Y` baselines the gap. | `trop_global.py`, `rust/` | follow-up | Medium |
8786
| TROP Rust vs Python bootstrap SE divergence under fixed seed: `seed=42` on a tiny panel produces ~28% bootstrap-SE gap. Root cause: Rust bootstrap uses its own RNG (`rand` crate) while Python uses `numpy.random.default_rng`; same seed value maps to different bytestreams across backends. Audit axis-H (RNG/seed) adjacent. `@pytest.mark.xfail(strict=True)` in `tests/test_rust_backend.py::TestTROPRustEdgeCaseParity::test_bootstrap_seed_reproducibility` baselines the gap. Unifying RNG (threading a numpy-generated seed-sequence into Rust, or porting Python to ChaCha) would close it. | `trop_global.py`, `rust/` | follow-up | Medium |
8887
| `bias_corrected_local_linear`: extend golden parity to `kernel="triangular"` and `kernel="uniform"` (currently epa-only; all three kernels share `kernel_W` and the `lprobust` math, so parity is expected but not separately asserted). | `benchmarks/R/generate_nprobust_lprobust_golden.R`, `tests/test_bias_corrected_lprobust.py` | Phase 1c | Low |
8988
| `bias_corrected_local_linear`: expose `vce in {"hc0", "hc1", "hc2", "hc3"}` on the public wrapper once R parity goldens exist (currently raises `NotImplementedError`). The port-level `lprobust` and `lprobust_res` already support all four; expanding the public surface requires a golden generator for each hc mode and a decision on hc2/hc3 q-fit leverage (R reuses p-fit `hii` for q-fit residuals; whether to match that or stage-match deserves a derivation before the wrapper advertises CCT-2014 conformance). | `diff_diff/local_linear.py::bias_corrected_local_linear`, `benchmarks/R/generate_nprobust_lprobust_golden.R`, `tests/test_bias_corrected_lprobust.py` | Phase 1c | Medium |

docs/methodology/REGISTRY.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2059,6 +2059,10 @@ Treatment effects are **heterogeneous** per-observation values. ATT is their mea
20592059
1. **Without low-rank (λ_nn = ∞)**: Standard weighted least squares
20602060
- Build design matrix with unit/time dummies (no treatment indicator)
20612061
- Solve via np.linalg.lstsq for (μ, α, β) using (1-W)-masked weights
2062+
- Both the Python fallback and the Rust acceleration path use SVD-based
2063+
minimum-norm least squares with numpy-compatible rcond = eps × max(n, k),
2064+
so they return the canonical minimum-norm solution on rank-deficient Y
2065+
(e.g., two near-parallel control units)
20622066

20632067
2. **With low-rank (finite λ_nn)**: Alternating minimization
20642068
- Alternate between:

rust/src/linalg.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ fn compute_robust_vcov_internal(
277277
}
278278

279279
/// Convert ndarray Array2 to faer Mat
280-
fn ndarray_to_faer(arr: &Array2<f64>) -> faer::Mat<f64> {
280+
pub(crate) fn ndarray_to_faer(arr: &Array2<f64>) -> faer::Mat<f64> {
281281
let nrows = arr.nrows();
282282
let ncols = arr.ncols();
283283
faer::Mat::from_fn(nrows, ncols, |i, j| arr[[i, j]])

rust/src/trop.rs

Lines changed: 135 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ use numpy::{PyArray1, PyArray2, PyReadonlyArray1, PyReadonlyArray2, ToPyArray};
1414
use pyo3::prelude::*;
1515
use rayon::prelude::*;
1616

17+
use crate::linalg::ndarray_to_faer;
18+
1719
/// Minimum chunk size for parallel distance computation.
1820
/// Reduces scheduling overhead for small matrices.
1921
const MIN_CHUNK_SIZE: usize = 16;
@@ -1233,111 +1235,164 @@ fn solve_joint_no_lowrank(
12331235
y: &ArrayView2<f64>,
12341236
delta: &ArrayView2<f64>,
12351237
) -> Option<(f64, Array1<f64>, Array1<f64>)> {
1238+
// SVD-based minimum-norm weighted least-squares fit — mirrors Python's
1239+
// `_solve_global_no_lowrank` at `diff_diff/trop_global.py:340-412`
1240+
// step-for-step so Rust and Python paths produce the same canonical
1241+
// solution on rank-deficient Y (silent-failures finding #23).
1242+
//
1243+
// Model: Y_it = μ + α_i + β_t + ε_it, with α_0 = β_0 = 0 for
1244+
// identification. Weights: δ_it. Flatten row-major with
1245+
// idx = t * n_units + i (matches Python's Y.flatten() C-order).
12361246
let n_periods = y.nrows();
12371247
let n_units = y.ncols();
1248+
let n_obs = n_periods * n_units;
1249+
let n_params = 1 + (n_units - 1) + (n_periods - 1);
12381250

1239-
// We solve using normal equations with the design matrix structure
1240-
// Rather than build full X matrix, use block structure for efficiency
1241-
//
1242-
// The model: Y_it = μ + α_i + β_t + ε_it
1243-
// With identification: α_0 = β_0 = 0
1244-
1245-
// Compute weighted sums needed for normal equations
1251+
// Flatten y + weights with NaN masking — matches trop_global.py:354-360.
1252+
let mut y_flat = Array1::<f64>::zeros(n_obs);
1253+
let mut w_flat = Array1::<f64>::zeros(n_obs);
12461254
let mut sum_w = 0.0;
1247-
let mut sum_wy = 0.0;
1248-
1249-
// Per-unit and per-period weighted sums
1250-
let mut sum_w_by_unit = Array1::<f64>::zeros(n_units);
1251-
let mut sum_wy_by_unit = Array1::<f64>::zeros(n_units);
1252-
let mut sum_w_by_period = Array1::<f64>::zeros(n_periods);
1253-
let mut sum_wy_by_period = Array1::<f64>::zeros(n_periods);
1254-
12551255
for t in 0..n_periods {
12561256
for i in 0..n_units {
1257-
// NaN outcomes get zero weight (not imputed to 0.0 with active weight)
1258-
let w = if y[[t, i]].is_finite() { delta[[t, i]] } else { 0.0 };
1259-
let y_ti = if y[[t, i]].is_finite() { y[[t, i]] } else { 0.0 };
1260-
1257+
let idx = t * n_units + i;
1258+
let y_ti = y[[t, i]];
1259+
let w_ti = delta[[t, i]];
1260+
let valid = y_ti.is_finite() && w_ti.is_finite();
1261+
let w = if valid { w_ti.max(0.0) } else { 0.0 };
1262+
y_flat[idx] = if valid { y_ti } else { 0.0 };
1263+
w_flat[idx] = w;
12611264
sum_w += w;
1262-
sum_wy += w * y_ti;
1263-
1264-
sum_w_by_unit[i] += w;
1265-
sum_wy_by_unit[i] += w * y_ti;
1266-
sum_w_by_period[t] += w;
1267-
sum_wy_by_period[t] += w * y_ti;
12681265
}
12691266
}
12701267

1268+
// All-zero weights short-circuit — matches trop_global.py:366.
12711269
if sum_w < 1e-10 {
12721270
return None;
12731271
}
12741272

1275-
// Use iterative approach: alternate between (alpha, beta) and mu
1276-
// until convergence (simpler than full normal equations)
1277-
let mut mu = sum_wy / sum_w;
1273+
// Build design matrix X = [intercept | unit_dummies[1..] | time_dummies[1..]]
1274+
// — matches trop_global.py:374-385. Explicit nested loops so the
1275+
// index correspondence with Python is unambiguous.
1276+
let mut x = Array2::<f64>::zeros((n_obs, n_params));
1277+
for t in 0..n_periods {
1278+
for i in 0..n_units {
1279+
let idx = t * n_units + i;
1280+
x[[idx, 0]] = 1.0; // intercept
1281+
if i >= 1 {
1282+
x[[idx, i]] = 1.0; // unit dummy (unit 0 dropped)
1283+
}
1284+
if t >= 1 {
1285+
x[[idx, (n_units - 1) + t]] = 1.0; // time dummy (period 0 dropped)
1286+
}
1287+
}
1288+
}
1289+
1290+
// Apply sqrt-weights: X_w = X * sqrt(w)[:, None], y_w = y * sqrt(w).
1291+
// Matches trop_global.py:387-389.
1292+
let sqrt_w: Array1<f64> = w_flat.mapv(|w| w.sqrt());
1293+
for r in 0..n_obs {
1294+
let s = sqrt_w[r];
1295+
for c in 0..n_params {
1296+
x[[r, c]] *= s;
1297+
}
1298+
y_flat[r] *= s;
1299+
}
1300+
1301+
// Solve via SVD with numpy-compatible rcond truncation.
1302+
let coeffs = solve_wls_svd(&x.view(), &y_flat.view())?;
1303+
1304+
// Unpack: mu = coeffs[0], alpha[1..] = coeffs[1..n_units],
1305+
// beta[1..] = coeffs[n_units..]. Matches trop_global.py:406-410.
1306+
let mu = coeffs[0];
12781307
let mut alpha = Array1::<f64>::zeros(n_units);
1308+
for i in 1..n_units {
1309+
alpha[i] = coeffs[i];
1310+
}
12791311
let mut beta = Array1::<f64>::zeros(n_periods);
1312+
for t in 1..n_periods {
1313+
beta[t] = coeffs[(n_units - 1) + t];
1314+
}
12801315

1281-
for _ in 0..50 {
1282-
let mu_old = mu;
1283-
let alpha_old = alpha.clone();
1284-
let beta_old = beta.clone();
1316+
Some((mu, alpha, beta))
1317+
}
12851318

1286-
// Update alpha (fixing beta, mu)
1287-
for i in 1..n_units { // α_0 = 0 for identification
1288-
if sum_w_by_unit[i] > 1e-10 {
1289-
let mut num = 0.0;
1290-
for t in 0..n_periods {
1291-
// NaN outcomes get zero weight
1292-
let w = if y[[t, i]].is_finite() { delta[[t, i]] } else { 0.0 };
1293-
let y_ti = if y[[t, i]].is_finite() { y[[t, i]] } else { 0.0 };
1294-
num += w * (y_ti - mu - beta[t]);
1295-
}
1296-
alpha[i] = num / sum_w_by_unit[i];
1297-
}
1298-
}
1319+
/// Minimum-norm least-squares solution via faer thin SVD with rcond truncation.
1320+
///
1321+
/// Mirrors `np.linalg.lstsq(A, b, rcond=None)` from numpy: singular values
1322+
/// below `rcond * max(S)` with `rcond = eps * max(n_rows, n_cols)` are
1323+
/// treated as zero. On rank-deficient A this returns the unique
1324+
/// minimum-norm least-squares solution.
1325+
///
1326+
/// This helper intentionally does NOT reuse `rust/src/linalg.rs::solve_ols`
1327+
/// because `solve_ols` hard-codes `rcond = 1e-7` (R's `lm()` default) which
1328+
/// would truncate singular values that numpy's default keeps. TROP's
1329+
/// canonical Python path is numpy-compatible; Rust must match.
1330+
///
1331+
/// Returns `None` only when the SVD itself fails (rare on finite inputs);
1332+
/// the caller (LOOCV / FISTA / bootstrap) interprets `None` as an
1333+
/// unsuccessful fit.
1334+
fn solve_wls_svd(a: &ArrayView2<f64>, b: &ArrayView1<f64>) -> Option<Array1<f64>> {
1335+
let n_rows = a.nrows();
1336+
let n_cols = a.ncols();
1337+
let a_owned = a.to_owned();
1338+
let b_owned = b.to_owned();
12991339

1300-
// Update beta (fixing alpha, mu)
1301-
for t in 1..n_periods { // β_0 = 0 for identification
1302-
if sum_w_by_period[t] > 1e-10 {
1303-
let mut num = 0.0;
1304-
for i in 0..n_units {
1305-
// NaN outcomes get zero weight
1306-
let w = if y[[t, i]].is_finite() { delta[[t, i]] } else { 0.0 };
1307-
let y_ti = if y[[t, i]].is_finite() { y[[t, i]] } else { 0.0 };
1308-
num += w * (y_ti - mu - alpha[i]);
1309-
}
1310-
beta[t] = num / sum_w_by_period[t];
1311-
}
1340+
// Convert ndarray -> faer for SVD.
1341+
let a_faer = ndarray_to_faer(&a_owned);
1342+
1343+
let svd = a_faer.thin_svd().ok()?;
1344+
1345+
let u_faer = svd.U();
1346+
let s_diag = svd.S();
1347+
let s_col = s_diag.column_vector();
1348+
let v_faer = svd.V();
1349+
1350+
// Extract U (n_rows x min(n,k)) back to ndarray.
1351+
let u_rows = u_faer.nrows();
1352+
let u_cols = u_faer.ncols();
1353+
let mut u = Array2::<f64>::zeros((u_rows, u_cols));
1354+
for i in 0..u_rows {
1355+
for j in 0..u_cols {
1356+
u[[i, j]] = u_faer[(i, j)];
13121357
}
1358+
}
13131359

1314-
// Update mu (fixing alpha, beta)
1315-
let mut num_mu = 0.0;
1316-
for t in 0..n_periods {
1317-
for i in 0..n_units {
1318-
// NaN outcomes get zero weight
1319-
let w = if y[[t, i]].is_finite() { delta[[t, i]] } else { 0.0 };
1320-
let y_ti = if y[[t, i]].is_finite() { y[[t, i]] } else { 0.0 };
1321-
num_mu += w * (y_ti - alpha[i] - beta[t]);
1322-
}
1360+
// Extract singular values.
1361+
let s_len = s_col.nrows();
1362+
let mut s = Array1::<f64>::zeros(s_len);
1363+
for i in 0..s_len {
1364+
s[i] = s_col[i];
1365+
}
1366+
1367+
// Extract V (k x min(n,k)) back to ndarray. faer's V is not V^T.
1368+
let v_rows = v_faer.nrows();
1369+
let v_cols = v_faer.ncols();
1370+
let mut v = Array2::<f64>::zeros((v_rows, v_cols));
1371+
for i in 0..v_rows {
1372+
for j in 0..v_cols {
1373+
v[[i, j]] = v_faer[(i, j)];
13231374
}
1324-
mu = num_mu / sum_w;
1325-
1326-
// Check convergence across ALL parameters (not just mu)
1327-
let mu_diff = (mu - mu_old).abs();
1328-
let alpha_diff = alpha.iter().zip(alpha_old.iter())
1329-
.map(|(a, b)| (a - b).abs())
1330-
.fold(0.0_f64, f64::max);
1331-
let beta_diff = beta.iter().zip(beta_old.iter())
1332-
.map(|(a, b)| (a - b).abs())
1333-
.fold(0.0_f64, f64::max);
1334-
let max_diff = mu_diff.max(alpha_diff).max(beta_diff);
1335-
if max_diff < 1e-8 {
1336-
break;
1375+
}
1376+
1377+
// numpy rcond = eps * max(n_rows, n_cols); truncate s[i] <= rcond * max(s).
1378+
let rcond = f64::EPSILON * (n_rows.max(n_cols) as f64);
1379+
let s_max = s.iter().cloned().fold(0.0_f64, f64::max);
1380+
let threshold = s_max * rcond;
1381+
1382+
// Compute β = V * S^{-1}_truncated * U^T * y.
1383+
let uty = u.t().dot(&b_owned); // (min(n,k),)
1384+
let mut s_inv_uty = Array1::<f64>::zeros(s_len);
1385+
for i in 0..s_len {
1386+
if s[i] > threshold {
1387+
s_inv_uty[i] = uty[i] / s[i];
13371388
}
1389+
// else: leave 0 — this is the pseudo-inverse / minimum-norm step
1390+
// that also covers Python's `except LinAlgError: pinv(...)` fallback
1391+
// tier, since faer thin_svd is numerically robust on finite inputs.
13381392
}
1393+
let coeffs = v.dot(&s_inv_uty);
13391394

1340-
Some((mu, alpha, beta))
1395+
Some(coeffs)
13411396
}
13421397

13431398
/// Solve global TWFE + low-rank via alternating minimization (no tau).

tests/test_rust_backend.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2204,23 +2204,17 @@ def _make_correlated_panel(n_units=6, n_periods=5, n_treated=2):
22042204
data.append({"unit": i, "time": t, "outcome": y, "treated": treated})
22052205
return pd.DataFrame(data)
22062206

2207-
@pytest.mark.xfail(
2208-
strict=True,
2209-
reason="TROP Rust grid-search at trop_global.py:688 and Python fallback "
2210-
"at trop_global.py:753 diverge on rank-deficient Y: empirical ATT "
2211-
"gap ~6% on two near-parallel control units. Either (a) grid-winner "
2212-
"ties break differently between backends, or (b) the per-λ solver "
2213-
"itself reaches different stationary points under rank deficiency. "
2214-
"Finding #23 flagged this exact surface as the Phase-2 gap. Rust "
2215-
"vs Python unification is a P1 follow-up (TODO.md). This xfail "
2216-
"baselines the divergence so we notice when/if the backends align.",
2217-
)
22182207
def test_grid_search_rank_deficient_Y(self):
22192208
"""Grid-search ATT parity on rank-deficient Y.
22202209
2221-
Known to fail: two near-parallel control units produce ~6% ATT
2222-
divergence between Rust and Python LOOCV grid-search. See xfail
2223-
reason for follow-up plan.
2210+
Finding #23 / TODO row 87 regression guard. Previously a ~6%
2211+
ATT divergence on two near-parallel control units because the
2212+
Rust inner solver used iterative block coordinate descent while
2213+
the Python fallback used SVD-based minimum-norm least squares.
2214+
Fixed by porting the Rust inner solver to an SVD-based WLS path
2215+
(numpy-compatible rcond = eps*max(n,k)) that mirrors Python's
2216+
`np.linalg.lstsq(rcond=None)` step-for-step. This test asserts
2217+
the backends now agree at atol=1e-6 on rank-deficient Y.
22242218
"""
22252219
import sys
22262220
from unittest.mock import patch

0 commit comments

Comments
 (0)