Skip to content

Commit 0e6524e

Browse files
igerberclaude
andcommitted
Fix rank-deficient matrix handling in OLS solver
MultiPeriodDiD was producing astronomically wrong estimates (~252 trillion instead of ~2-5) due to rank-deficient design matrices being solved incorrectly by the gelsy LAPACK driver. Changes: - Python: Switch from gelsy to gelsd driver (SVD-based with truncation) - Rust: Replace least_squares() with explicit SVD + truncated pseudoinverse - Add comprehensive tests for rank-deficient matrices in both backends - Add Rust vs NumPy equivalence tests for rank-deficient cases - Document NaN standard errors limitation in TODO.md The gelsd driver properly handles rank-deficient matrices by truncating small singular values below rcond * max(S), producing valid minimum-norm solutions instead of garbage coefficients. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 942feb7 commit 0e6524e

6 files changed

Lines changed: 367 additions & 30 deletions

File tree

CLAUDE.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ pytest tests/test_rust_backend.py -v
119119
- Integrated with `TwoWayFixedEffects.decompose()` method
120120

121121
- **`diff_diff/linalg.py`** - Unified linear algebra backend (v1.4.0+):
122-
- `solve_ols()` - OLS solver using scipy's gelsy LAPACK driver (QR-based, faster than SVD)
122+
- `solve_ols()` - OLS solver using scipy's gelsd LAPACK driver (SVD-based, handles rank-deficient matrices)
123123
- `compute_robust_vcov()` - Vectorized HC1 and cluster-robust variance-covariance estimation
124124
- `compute_r_squared()` - R-squared and adjusted R-squared computation
125125
- `LinearRegression` - High-level OLS helper class with unified coefficient extraction and inference
@@ -240,7 +240,7 @@ diff-diff achieved significant performance improvements in v1.4.0, now **faster
240240

241241
All estimators use a single optimized OLS/SE implementation:
242242

243-
- **scipy.linalg.lstsq with 'gelsy' driver**: QR-based solving, faster than NumPy's default SVD-based solver
243+
- **scipy.linalg.lstsq with 'gelsd' driver**: SVD-based solving that properly handles rank-deficient matrices (critical for MultiPeriodDiD and other estimators with potentially redundant columns)
244244
- **Vectorized cluster-robust SE**: Uses pandas groupby aggregation instead of O(n × clusters) Python loop
245245
- **Single optimization point**: Changes to `linalg.py` benefit all estimators
246246

TODO.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,24 @@ Current limitations that may affect users:
1212

1313
| Issue | Location | Priority | Notes |
1414
|-------|----------|----------|-------|
15+
| NaN standard errors for rank-deficient matrices | `linalg.py:330-345` | Medium | See details below |
1516
| MultiPeriodDiD wild bootstrap not supported | `estimators.py:1068-1074` | Low | Edge case |
1617
| `predict()` raises NotImplementedError | `estimators.py:532-554` | Low | Rarely needed |
1718

19+
### NaN Standard Errors for Rank-Deficient Matrices
20+
21+
**Problem**: When the design matrix is rank-deficient (e.g., MultiPeriodDiD with redundant period dummies + treatment interactions), the coefficients are now computed correctly via SVD truncation, but the variance-covariance matrix computation produces NaN values.
22+
23+
**Root cause**: The vcov computation in `compute_robust_vcov()` computes `(X'X)^{-1}` which doesn't exist for rank-deficient matrices. The current implementation uses Cholesky factorization which fails silently, producing NaN values.
24+
25+
**Affected estimators**:
26+
- `MultiPeriodDiD` - when design matrix has redundant columns
27+
- Any estimator using `solve_ols()` with rank-deficient X
28+
29+
**Potential fix**: Use the Moore-Penrose pseudoinverse `(X'X)^+` instead of `(X'X)^{-1}` for the bread matrix in the sandwich estimator. This would provide valid (though potentially conservative) standard errors for the identifiable parameters.
30+
31+
**Workaround**: Users can use bootstrap inference which doesn't rely on the analytical vcov.
32+
1833
---
1934

2035
## Code Quality

diff_diff/linalg.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
Rust backend for maximum performance.
66
77
The key optimizations are:
8-
1. scipy.linalg.lstsq with 'gelsy' driver (QR-based, faster than SVD)
8+
1. scipy.linalg.lstsq with 'gelsd' driver (SVD-based, handles rank-deficient matrices)
99
2. Vectorized cluster-robust SE via groupby (eliminates O(n*clusters) loop)
1010
3. Single interface for all estimators (reduces code duplication)
1111
4. Optional Rust backend for additional speedup (when available)
@@ -80,9 +80,9 @@ def solve_ols(
8080
8181
Notes
8282
-----
83-
This function uses scipy.linalg.lstsq with the 'gelsy' driver, which is
84-
QR-based and typically faster than NumPy's default SVD-based solver for
85-
well-conditioned matrices.
83+
This function uses scipy.linalg.lstsq with the 'gelsd' driver, which is
84+
SVD-based and handles rank-deficient matrices correctly by truncating
85+
small singular values.
8686
8787
The cluster-robust standard errors use the sandwich estimator with the
8888
standard small-sample adjustment: (G/(G-1)) * ((n-1)/(n-k)).
@@ -184,11 +184,11 @@ def _solve_ols_numpy(
184184
"""
185185
NumPy/SciPy fallback implementation of solve_ols.
186186
187-
Uses scipy.linalg.lstsq with 'gelsy' driver (QR with column pivoting)
188-
for numerically stable least squares solving. QR decomposition is preferred
189-
over normal equations because it doesn't square the condition number of X,
190-
making it more robust for ill-conditioned matrices common in DiD designs
191-
(e.g., many unit/time fixed effects).
187+
Uses scipy.linalg.lstsq with 'gelsd' driver (SVD-based with divide-and-conquer)
188+
for numerically stable least squares solving. SVD decomposition properly handles
189+
rank-deficient matrices by truncating small singular values, which is critical
190+
for DiD designs that may have redundant columns (e.g., period dummies + treatment
191+
interactions in MultiPeriodDiD).
192192
193193
Parameters
194194
----------
@@ -214,11 +214,11 @@ def _solve_ols_numpy(
214214
vcov : np.ndarray, optional
215215
Variance-covariance matrix if return_vcov=True.
216216
"""
217-
# Solve OLS using QR decomposition via scipy's optimized LAPACK routines
218-
# 'gelsy' uses QR with column pivoting, which is numerically stable even
219-
# for ill-conditioned matrices (doesn't square the condition number like
220-
# normal equations would)
221-
coefficients = scipy_lstsq(X, y, lapack_driver="gelsy", check_finite=False)[0]
217+
# Solve OLS using SVD via scipy's optimized LAPACK routines
218+
# 'gelsd' uses divide-and-conquer SVD, which properly handles rank-deficient
219+
# matrices by truncating small singular values (unlike 'gelsy' which can
220+
# produce garbage coefficients for nearly rank-deficient matrices)
221+
coefficients = scipy_lstsq(X, y, lapack_driver="gelsd", check_finite=False)[0]
222222

223223
# Compute residuals and fitted values
224224
fitted = X @ coefficients

rust/src/linalg.rs

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,22 @@
66
//! - Cluster-robust variance-covariance estimation
77
88
use ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
9-
use ndarray_linalg::{FactorizeC, LeastSquaresSvd, Solve, SolveC, UPLO};
9+
use ndarray_linalg::{FactorizeC, Solve, SolveC, SVD, UPLO};
1010
use numpy::{IntoPyArray, PyArray1, PyArray2, PyReadonlyArray1, PyReadonlyArray2};
1111
use pyo3::prelude::*;
1212
use std::collections::HashMap;
1313

1414
/// Solve OLS regression: β = (X'X)^{-1} X'y
1515
///
16+
/// Uses SVD with truncation for rank-deficient matrices:
17+
/// - Computes SVD: X = U * S * V^T
18+
/// - Truncates singular values below rcond * max(S)
19+
/// - Computes solution: β = V * S^{-1}_truncated * U^T * y
20+
///
21+
/// This matches scipy's 'gelsd' driver behavior for handling rank-deficient
22+
/// design matrices that can occur in DiD estimation (e.g., MultiPeriodDiD
23+
/// with redundant period dummies + treatment interactions).
24+
///
1625
/// # Arguments
1726
/// * `x` - Design matrix (n, k)
1827
/// * `y` - Response vector (n,)
@@ -37,15 +46,47 @@ pub fn solve_ols<'py>(
3746
let x_arr = x.as_array();
3847
let y_arr = y.as_array();
3948

40-
// Solve least squares using SVD (more stable than normal equations)
49+
let n = x_arr.nrows();
50+
let k = x_arr.ncols();
51+
52+
// Solve using SVD with truncation for rank-deficient matrices
53+
// This matches scipy's 'gelsd' behavior
4154
let x_owned = x_arr.to_owned();
4255
let y_owned = y_arr.to_owned();
4356

44-
let result = x_owned
45-
.least_squares(&y_owned)
46-
.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(format!("Least squares failed: {}", e)))?;
57+
// Compute SVD: X = U * S * V^T
58+
let (u_opt, s, vt_opt) = x_owned
59+
.svd(true, true)
60+
.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(format!("SVD failed: {}", e)))?;
61+
62+
let u = u_opt.ok_or_else(|| {
63+
PyErr::new::<pyo3::exceptions::PyValueError, _>("SVD did not return U matrix")
64+
})?;
65+
let vt = vt_opt.ok_or_else(|| {
66+
PyErr::new::<pyo3::exceptions::PyValueError, _>("SVD did not return V^T matrix")
67+
})?;
68+
69+
// Compute rcond threshold (matches numpy/scipy default)
70+
// rcond = max(n, k) * machine_epsilon
71+
let rcond = (n.max(k) as f64) * f64::EPSILON;
72+
let s_max = s.iter().cloned().fold(0.0_f64, f64::max);
73+
let threshold = s_max * rcond;
74+
75+
// Compute truncated pseudoinverse solution: β = V * S^{-1} * U^T * y
76+
// Singular values below threshold are treated as zero (truncated)
77+
let uty = u.t().dot(&y_owned); // (min(n,k),)
78+
79+
// Build S^{-1} with truncation
80+
let mut s_inv_uty = Array1::<f64>::zeros(k);
81+
for i in 0..s.len().min(k) {
82+
if s[i] > threshold {
83+
s_inv_uty[i] = uty[i] / s[i];
84+
}
85+
// else: leave as 0 (truncate this singular value)
86+
}
4787

48-
let coefficients = result.solution;
88+
// Compute coefficients: β = V * (S^{-1} * U^T * y)
89+
let coefficients = vt.t().dot(&s_inv_uty);
4990

5091
// Compute fitted values and residuals
5192
let fitted = x_arr.dot(&coefficients);

tests/test_linalg.py

Lines changed: 92 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -186,21 +186,32 @@ def test_inf_in_x_raises_error(self):
186186
solve_ols(X, y)
187187

188188
def test_check_finite_false_skips_validation(self):
189-
"""Test that check_finite=False skips NaN/Inf validation."""
189+
"""Test that check_finite=False skips the upfront NaN/Inf validation.
190+
191+
Note: With the 'gelsd' driver, LAPACK may still error on NaN values
192+
during computation, which is actually safer than producing garbage.
193+
"""
190194
X = np.random.randn(100, 2)
191195
X[50, 0] = np.nan
192196
y = np.random.randn(100)
193197

194-
# Should not raise, but will return garbage results
195-
coef, resid, vcov = solve_ols(X, y, check_finite=False)
196-
# Coefficients will contain NaN due to bad input
197-
assert np.isnan(coef).any() or np.isinf(coef).any()
198+
# The gelsd driver may raise an error when encountering NaN during
199+
# computation, or produce garbage results. Either is acceptable
200+
# (the key is that we don't raise the "X contains NaN" user-friendly error)
201+
try:
202+
coef, resid, vcov = solve_ols(X, y, check_finite=False)
203+
# If it completed, coefficients should contain NaN/Inf due to bad input
204+
assert np.isnan(coef).any() or np.isinf(coef).any()
205+
except ValueError as e:
206+
# LAPACK may raise an error on NaN values (gelsd behavior)
207+
# This is acceptable - the key is we skipped our own validation
208+
assert "X contains NaN" not in str(e) and "y contains NaN" not in str(e)
198209

199210
def test_rank_deficient_still_solves(self):
200-
"""Test that rank-deficient matrix still returns a solution.
211+
"""Test that rank-deficient matrix returns a valid solution.
201212
202-
Note: The gelsy driver doesn't always detect rank deficiency,
203-
but it still returns a valid least-squares solution.
213+
The 'gelsd' driver uses SVD with truncation to properly handle
214+
rank-deficient matrices, producing finite and reasonable coefficients.
204215
"""
205216
np.random.seed(42)
206217
X = np.random.randn(100, 3)
@@ -212,9 +223,82 @@ def test_rank_deficient_still_solves(self):
212223

213224
assert coef.shape == (3,)
214225
assert resid.shape == (100,)
226+
227+
# Coefficients must be finite (not NaN or Inf)
228+
assert np.all(np.isfinite(coef)), f"Coefficients contain non-finite values: {coef}"
229+
230+
# Coefficients should be reasonable (not astronomically large)
231+
# For a rank-deficient system, coefficients should still be bounded
232+
assert np.all(np.abs(coef) < 1e6), f"Coefficients are unreasonably large: {coef}"
233+
215234
# Residuals should still be valid (y - X @ coef)
216235
np.testing.assert_allclose(resid, y - X @ coef, rtol=1e-10)
217236

237+
def test_multiperiod_like_rank_deficiency(self):
238+
"""Test that MultiPeriodDiD-like design matrices are handled correctly.
239+
240+
MultiPeriodDiD creates design matrices with intercept + period dummies +
241+
treatment × post interactions, which can have redundant columns and be
242+
rank-deficient. This test mimics that structure.
243+
"""
244+
np.random.seed(42)
245+
n = 200
246+
n_periods = 5
247+
248+
# Create a design matrix similar to MultiPeriodDiD:
249+
# [intercept, period_1, period_2, ..., period_k, treated*post_1, ...]
250+
251+
# Intercept
252+
intercept = np.ones(n)
253+
254+
# Period dummies (one-hot encoding for periods 1 to n_periods-1)
255+
# Period 0 is the reference
256+
period_assignment = np.random.randint(0, n_periods, n)
257+
period_dummies = np.zeros((n, n_periods - 1))
258+
for i in range(1, n_periods):
259+
period_dummies[:, i - 1] = (period_assignment == i).astype(float)
260+
261+
# Treatment indicator
262+
treated = np.random.binomial(1, 0.5, n)
263+
264+
# Post indicator (periods >= 3 are post)
265+
post = (period_assignment >= 3).astype(float)
266+
267+
# Treatment × post interaction
268+
treat_post = treated * post
269+
270+
# Build design matrix
271+
# Note: This creates a rank-deficient matrix because the period dummies
272+
# and treat_post are not all linearly independent when combined
273+
X = np.column_stack([intercept, period_dummies, treat_post])
274+
275+
# True effect
276+
true_effect = 2.5
277+
y = (
278+
1.0 # intercept effect
279+
+ 0.5 * period_dummies[:, 0] # period 1 effect
280+
+ 0.3 * period_dummies[:, 1] # period 2 effect
281+
+ 0.7 * period_dummies[:, 2] # period 3 effect
282+
+ 0.9 * period_dummies[:, 3] # period 4 effect
283+
+ true_effect * treat_post # treatment effect
284+
+ np.random.randn(n) * 0.5 # noise
285+
)
286+
287+
# Fit with solve_ols
288+
coef, resid, vcov = solve_ols(X, y)
289+
290+
# Coefficients must be finite
291+
assert np.all(np.isfinite(coef)), f"Coefficients contain non-finite values: {coef}"
292+
293+
# Coefficients should be reasonable (not trillions)
294+
assert np.all(np.abs(coef) < 1e6), f"Coefficients are unreasonably large: {coef}"
295+
296+
# The treatment effect coefficient (last one) should be close to true effect
297+
# Allow for sampling variation and potential multicollinearity effects
298+
assert abs(coef[-1] - true_effect) < 2.0, (
299+
f"Treatment effect coefficient {coef[-1]} is too far from true effect {true_effect}"
300+
)
301+
218302
def test_single_cluster_error(self):
219303
"""Test that single cluster raises error."""
220304
X = np.random.randn(100, 2)

0 commit comments

Comments
 (0)