Skip to content

Commit 31371fa

Browse files
igerberclaude
andcommitted
Implement R-style rank deficiency handling instead of silent SVD truncation
Replace silent SVD truncation with R's lm() approach for rank-deficient matrices: - Detect rank deficiency using pivoted QR decomposition - Warn users with clear message listing dropped columns - Set NaN for coefficients of linearly dependent columns - Compute valid SEs for identified coefficients only - Expand vcov matrix with NaN for dropped rows/columns Add rank_deficient_action parameter ("warn", "error", "silent") to control behavior. Hybrid Rust/Python routing: - Full-rank matrices use fast Rust backend (when available) - Rank-deficient matrices use Python backend for proper NA handling (ndarray-linalg doesn't support QR with pivoting) Also fixes tutorial notebook 02 to avoid rank deficiency by including both treated cohorts in the MultiPeriodDiD example. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 0e6524e commit 31371fa

7 files changed

Lines changed: 652 additions & 177 deletions

File tree

CLAUDE.md

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,13 +119,20 @@ pytest tests/test_rust_backend.py -v
119119
- Integrated with `TwoWayFixedEffects.decompose()` method
120120

121121
- **`diff_diff/linalg.py`** - Unified linear algebra backend (v1.4.0+):
122-
- `solve_ols()` - OLS solver using scipy's gelsd LAPACK driver (SVD-based, handles rank-deficient matrices)
122+
- `solve_ols()` - OLS solver with R-style rank deficiency handling
123+
- `_detect_rank_deficiency()` - Detect linearly dependent columns via pivoted QR
123124
- `compute_robust_vcov()` - Vectorized HC1 and cluster-robust variance-covariance estimation
124125
- `compute_r_squared()` - R-squared and adjusted R-squared computation
125126
- `LinearRegression` - High-level OLS helper class with unified coefficient extraction and inference
126127
- `InferenceResult` - Dataclass container for coefficient-level inference (SE, t-stat, p-value, CI)
127128
- Single optimization point for all estimators (reduces code duplication)
128129
- Cluster-robust SEs use pandas groupby instead of O(n × clusters) loop
130+
- **Rank deficiency handling** (R-style):
131+
- Detects rank-deficient matrices using pivoted QR decomposition
132+
- `rank_deficient_action` parameter: "warn" (default), "error", or "silent"
133+
- Dropped columns have NaN coefficients (like R's `lm()`)
134+
- VCoV matrix has NaN for rows/cols of dropped coefficients
135+
- Warnings include column names when provided
129136

130137
- **`diff_diff/_backend.py`** - Backend detection and configuration (v2.0.0):
131138
- Detects optional Rust backend availability
@@ -240,16 +247,20 @@ diff-diff achieved significant performance improvements in v1.4.0, now **faster
240247

241248
All estimators use a single optimized OLS/SE implementation:
242249

243-
- **scipy.linalg.lstsq with 'gelsd' driver**: SVD-based solving that properly handles rank-deficient matrices (critical for MultiPeriodDiD and other estimators with potentially redundant columns)
250+
- **R-style rank deficiency handling**: Uses pivoted QR to detect linearly dependent columns, drops them, sets NaN for their coefficients, and emits informative warnings (following R's `lm()` approach)
244251
- **Vectorized cluster-robust SE**: Uses pandas groupby aggregation instead of O(n × clusters) Python loop
245252
- **Single optimization point**: Changes to `linalg.py` benefit all estimators
246253

247254
```python
248255
# All estimators import from linalg.py
249256
from diff_diff.linalg import solve_ols, compute_robust_vcov
250257

251-
# Example usage
258+
# Example usage (warns on rank deficiency, sets NaN for dropped coefficients)
252259
coefficients, residuals, vcov = solve_ols(X, y, cluster_ids=cluster_ids)
260+
261+
# Suppress warning or raise error:
262+
coefficients, residuals, vcov = solve_ols(X, y, rank_deficient_action="silent") # no warning
263+
coefficients, residuals, vcov = solve_ols(X, y, rank_deficient_action="error") # raises ValueError
253264
```
254265

255266
#### CallawaySantAnna Optimizations (`staggered.py`)

0 commit comments

Comments
 (0)