Skip to content

Commit f3b1566

Browse files
igerberclaude
andcommitted
Address PR #115 AI review: NaN propagation, fallback, and performance
P0: Return f64::NAN instead of 0.0 in TROP bootstrap when < 2 samples P1: Add Python fallback in _solve_ols_rust for numerical instability P2: Gate expensive O(n³) residual check behind LU pivot ratio detection Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 2341e4f commit f3b1566

3 files changed

Lines changed: 88 additions & 32 deletions

File tree

diff_diff/linalg.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -251,10 +251,10 @@ def _solve_ols_rust(
251251
cluster_ids: Optional[np.ndarray] = None,
252252
return_vcov: bool = True,
253253
return_fitted: bool = False,
254-
) -> Union[
254+
) -> Optional[Union[
255255
Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]],
256256
Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]],
257-
]:
257+
]]:
258258
"""
259259
Rust backend implementation of solve_ols for full-rank matrices.
260260
@@ -296,15 +296,30 @@ def _solve_ols_rust(
296296
Fitted values if return_fitted=True.
297297
vcov : np.ndarray, optional
298298
Variance-covariance matrix if return_vcov=True.
299+
None
300+
If Rust backend detects numerical instability and caller should
301+
fall back to Python backend.
299302
"""
300303
# Convert cluster_ids to int64 for Rust (handles string/categorical IDs)
301304
if cluster_ids is not None:
302305
cluster_ids = _factorize_cluster_ids(cluster_ids)
303306

304-
# Call Rust backend
305-
coefficients, residuals, vcov = _rust_solve_ols(
306-
X, y, cluster_ids=cluster_ids, return_vcov=return_vcov
307-
)
307+
# Call Rust backend with fallback on numerical instability
308+
try:
309+
coefficients, residuals, vcov = _rust_solve_ols(
310+
X, y, cluster_ids=cluster_ids, return_vcov=return_vcov
311+
)
312+
except ValueError as e:
313+
error_msg = str(e).lower()
314+
if "numerically unstable" in error_msg or "singular" in error_msg:
315+
warnings.warn(
316+
f"Rust backend detected numerical instability: {e}. "
317+
"Falling back to Python backend.",
318+
UserWarning,
319+
stacklevel=3,
320+
)
321+
return None # Signal caller to use Python fallback
322+
raise
308323

309324
# Convert to numpy arrays
310325
coefficients = np.asarray(coefficients)
@@ -499,6 +514,7 @@ def solve_ols(
499514
# Routing strategy:
500515
# - Full-rank + Rust available → fast Rust backend (SVD-based solve)
501516
# - Rank-deficient → Python backend (proper NA handling, valid SEs)
517+
# - Rust numerical instability → Python fallback (via None return)
502518
# - No Rust → Python backend (works for all cases)
503519
if HAS_RUST_BACKEND and _rust_solve_ols is not None and not is_rank_deficient:
504520
result = _solve_ols_rust(
@@ -508,6 +524,19 @@ def solve_ols(
508524
return_fitted=return_fitted,
509525
)
510526

527+
# Check for None: Rust backend detected numerical instability and
528+
# signaled us to fall back to Python backend
529+
if result is None:
530+
return _solve_ols_numpy(
531+
X, y,
532+
cluster_ids=cluster_ids,
533+
return_vcov=return_vcov,
534+
return_fitted=return_fitted,
535+
rank_deficient_action=rank_deficient_action,
536+
column_names=column_names,
537+
_precomputed_rank_info=None, # Force fresh rank detection
538+
)
539+
511540
# Check for NaN vcov: Rust SVD may detect rank-deficiency that QR missed
512541
# for ill-conditioned matrices (QR and SVD have different numerical properties).
513542
# When this happens, fall back to Python's R-style handling.

rust/src/linalg.rs

Lines changed: 49 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,12 @@ fn ndarray_to_faer(arr: &Array2<f64>) -> faer::Mat<f64> {
286286
/// Invert a symmetric positive-definite matrix.
287287
///
288288
/// Uses LU decomposition with partial pivoting. Includes both NaN/Inf check
289-
/// and residual-based verification to catch near-singular matrices that
290-
/// produce finite but numerically inaccurate results.
289+
/// and conditional residual-based verification to catch near-singular matrices
290+
/// that produce finite but numerically inaccurate results.
291+
///
292+
/// Performance optimization: The expensive O(n³) residual check (A * A⁻¹ - I)
293+
/// is only performed when LU pivot ratios suggest potential instability. For
294+
/// well-conditioned matrices (the common case), this check is skipped.
291295
fn invert_symmetric(a: &Array2<f64>) -> PyResult<Array2<f64>> {
292296
let n = a.nrows();
293297

@@ -323,31 +327,51 @@ fn invert_symmetric(a: &Array2<f64>) -> PyResult<Array2<f64>> {
323327
));
324328
}
325329

326-
// Verify inversion accuracy by checking ||A * A^{-1} - I||_max
327-
// For near-singular matrices, this residual will be large even if
328-
// the result contains no NaN/Inf values
329-
let a_times_inv = a_faer.as_ref() * &x_faer;
330-
let mut max_residual = 0.0_f64;
330+
// Check pivot ratio to detect potential instability.
331+
// The diagonal of U contains the pivots from LU factorization.
332+
// A small pivot ratio (min/max) indicates potential numerical instability.
333+
let u_factor = lu.U();
334+
let mut max_pivot = 0.0_f64;
335+
let mut min_pivot = f64::INFINITY;
331336
for i in 0..n {
332-
for j in 0..n {
333-
let expected = if i == j { 1.0 } else { 0.0 };
334-
let residual = (a_times_inv[(i, j)] - expected).abs();
335-
max_residual = max_residual.max(residual);
337+
let pivot = u_factor[(i, i)].abs();
338+
if pivot > 0.0 {
339+
max_pivot = max_pivot.max(pivot);
340+
min_pivot = min_pivot.min(pivot);
336341
}
337342
}
343+
let pivot_ratio = if max_pivot > 0.0 { min_pivot / max_pivot } else { 0.0 };
344+
345+
// Only perform expensive residual check if pivots suggest potential instability.
346+
// Threshold of 1e-10 catches truly problematic matrices while avoiding
347+
// unnecessary O(n³) computation for well-conditioned cases.
348+
if pivot_ratio < 1e-10 {
349+
// Verify inversion accuracy by checking ||A * A^{-1} - I||_max
350+
// For near-singular matrices, this residual will be large even if
351+
// the result contains no NaN/Inf values
352+
let a_times_inv = a_faer.as_ref() * &x_faer;
353+
let mut max_residual = 0.0_f64;
354+
for i in 0..n {
355+
for j in 0..n {
356+
let expected = if i == j { 1.0 } else { 0.0 };
357+
let residual = (a_times_inv[(i, j)] - expected).abs();
358+
max_residual = max_residual.max(residual);
359+
}
360+
}
338361

339-
// Threshold: detect truly singular matrices while allowing ill-conditioned ones
340-
// Ill-conditioned matrices (high condition number) can have residuals up to ~1e-4
341-
// while still producing usable results. Use 1e-4 * n as threshold.
342-
let threshold = 1e-4 * (n as f64);
343-
if max_residual > threshold {
344-
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
345-
format!(
346-
"Matrix inversion numerically unstable (residual={:.2e} > threshold={:.2e}). \
347-
Design matrix may be near-singular.",
348-
max_residual, threshold
349-
)
350-
));
362+
// Threshold: detect truly singular matrices while allowing ill-conditioned ones
363+
// Ill-conditioned matrices (high condition number) can have residuals up to ~1e-4
364+
// while still producing usable results. Use 1e-4 * n as threshold.
365+
let threshold = 1e-4 * (n as f64);
366+
if max_residual > threshold {
367+
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
368+
format!(
369+
"Matrix inversion numerically unstable (residual={:.2e} > threshold={:.2e}). \
370+
Design matrix may be near-singular.",
371+
max_residual, threshold
372+
)
373+
));
374+
}
351375
}
352376

353377
// Convert back to ndarray
@@ -445,7 +469,8 @@ mod tests {
445469
[1.0, 1.0 + 1e-15], // Nearly identical rows
446470
];
447471

448-
// Should fail due to numerical instability
472+
// Should fail due to numerical instability (small pivot ratio triggers
473+
// residual check which detects the inversion error)
449474
let result = invert_symmetric(&a);
450475
assert!(result.is_err(), "Near-singular matrix inversion should fail");
451476

rust/src/trop.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,8 +1036,9 @@ pub fn bootstrap_trop_variance<'py>(
10361036
.collect();
10371037

10381038
// Compute standard error
1039+
// Return NaN when < 2 samples to properly propagate undefined inference
10391040
let se = if bootstrap_estimates.len() < 2 {
1040-
0.0
1041+
f64::NAN
10411042
} else {
10421043
let n = bootstrap_estimates.len() as f64;
10431044
let mean = bootstrap_estimates.iter().sum::<f64>() / n;
@@ -1701,8 +1702,9 @@ pub fn bootstrap_trop_variance_joint<'py>(
17011702
.collect();
17021703

17031704
// Compute standard error
1705+
// Return NaN when < 2 samples to properly propagate undefined inference
17041706
let se = if bootstrap_estimates.len() < 2 {
1705-
0.0
1707+
f64::NAN
17061708
} else {
17071709
let n = bootstrap_estimates.len() as f64;
17081710
let mean = bootstrap_estimates.iter().sum::<f64>() / n;

0 commit comments

Comments
 (0)