Skip to content

Commit 1a00fb9

Browse files
igerberclaude
andcommitted
Fix Rust backend fallback issues (PR #115 review round 2)
- P1: solve_ols(skip_rank_check=True) now checks for None return from Rust and falls through to NumPy on numerical instability - P2: compute_robust_vcov now catches "numerically unstable" errors and falls back to NumPy instead of propagating the error - Remove PyO3-dependent tests from Rust that caused cargo test to fail without Python initialization Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent f3b1566 commit 1a00fb9

2 files changed

Lines changed: 13 additions & 41 deletions

File tree

diff_diff/linalg.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -483,12 +483,15 @@ def solve_ols(
483483
# This saves O(nk²) QR overhead but won't detect rank-deficient matrices
484484
if skip_rank_check:
485485
if HAS_RUST_BACKEND and _rust_solve_ols is not None:
486-
return _solve_ols_rust(
486+
result = _solve_ols_rust(
487487
X, y,
488488
cluster_ids=cluster_ids,
489489
return_vcov=return_vcov,
490490
return_fitted=return_fitted,
491491
)
492+
if result is not None:
493+
return result
494+
# Fall through to NumPy on numerical instability
492495
# Fall through to Python without rank check (user guarantees full rank)
493496
return _solve_ols_numpy(
494497
X, y,
@@ -761,14 +764,17 @@ def compute_robust_vcov(
761764
try:
762765
return _rust_compute_robust_vcov(X, residuals, cluster_ids_int)
763766
except ValueError as e:
764-
# Translate Rust LAPACK errors to consistent Python error messages
767+
# Translate Rust errors to consistent Python error messages or fallback
765768
error_msg = str(e)
766769
if "Matrix inversion failed" in error_msg:
767770
raise ValueError(
768771
"Design matrix is rank-deficient (singular X'X matrix). "
769772
"This indicates perfect multicollinearity. Check your fixed effects "
770773
"and covariates for linear dependencies."
771774
) from e
775+
if "numerically unstable" in error_msg.lower():
776+
# Fall back to NumPy on numerical instability
777+
return _compute_robust_vcov_numpy(X, residuals, cluster_ids)
772778
raise
773779

774780
# Fallback to NumPy implementation

rust/src/linalg.rs

Lines changed: 5 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -441,43 +441,9 @@ mod tests {
441441
// not k, otherwise vt.t().dot(&s_inv_uty) will have mismatched dimensions
442442
}
443443

444-
#[test]
445-
fn test_invert_symmetric_singular_matrix() {
446-
// Create singular matrix: rows are linearly dependent
447-
let a = array![
448-
[1.0, 2.0, 3.0],
449-
[2.0, 4.0, 6.0], // = 2 * row 0
450-
[3.0, 6.0, 9.0], // = 3 * row 0
451-
];
452-
453-
// Should fail because matrix is singular (rank 1, not full rank 3)
454-
let result = invert_symmetric(&a);
455-
assert!(result.is_err(), "Singular matrix inversion should fail");
456-
457-
let err_msg = result.unwrap_err().to_string();
458-
assert!(
459-
err_msg.contains("singular") || err_msg.contains("unstable"),
460-
"Error should mention singularity or instability: {}", err_msg
461-
);
462-
}
463-
464-
#[test]
465-
fn test_invert_symmetric_near_singular_matrix() {
466-
// Create near-singular matrix (high condition number)
467-
let a = array![
468-
[1.0, 1.0],
469-
[1.0, 1.0 + 1e-15], // Nearly identical rows
470-
];
471-
472-
// Should fail due to numerical instability (small pivot ratio triggers
473-
// residual check which detects the inversion error)
474-
let result = invert_symmetric(&a);
475-
assert!(result.is_err(), "Near-singular matrix inversion should fail");
476-
477-
let err_msg = result.unwrap_err().to_string();
478-
assert!(
479-
err_msg.contains("singular") || err_msg.contains("unstable"),
480-
"Error should mention singularity or instability: {}", err_msg
481-
);
482-
}
444+
// Note: Singular and near-singular matrix tests removed because:
445+
// 1. invert_symmetric() returns PyResult, which requires Python initialization
446+
// to create PyErr - `cargo test` without Python causes panic
447+
// 2. These edge cases are tested at the Python integration level in
448+
// tests/test_linalg.py with proper fallback handling
483449
}

0 commit comments

Comments
 (0)