@@ -286,8 +286,12 @@ fn ndarray_to_faer(arr: &Array2<f64>) -> faer::Mat<f64> {
286286/// Invert a symmetric positive-definite matrix.
287287///
288288/// Uses LU decomposition with partial pivoting. Includes both NaN/Inf check
289- /// and residual-based verification to catch near-singular matrices that
290- /// produce finite but numerically inaccurate results.
289+ /// and conditional residual-based verification to catch near-singular matrices
290+ /// that produce finite but numerically inaccurate results.
291+ ///
292+ /// Performance optimization: The expensive O(n³) residual check (A * A⁻¹ - I)
293+ /// is only performed when LU pivot ratios suggest potential instability. For
294+ /// well-conditioned matrices (the common case), this check is skipped.
291295fn invert_symmetric ( a : & Array2 < f64 > ) -> PyResult < Array2 < f64 > > {
292296 let n = a. nrows ( ) ;
293297
@@ -323,31 +327,51 @@ fn invert_symmetric(a: &Array2<f64>) -> PyResult<Array2<f64>> {
323327 ) ) ;
324328 }
325329
326- // Verify inversion accuracy by checking ||A * A^{-1} - I||_max
327- // For near-singular matrices, this residual will be large even if
328- // the result contains no NaN/Inf values
329- let a_times_inv = a_faer. as_ref ( ) * & x_faer;
330- let mut max_residual = 0.0_f64 ;
330+ // Check pivot ratio to detect potential instability.
331+ // The diagonal of U contains the pivots from LU factorization.
332+ // A small pivot ratio (min/max) indicates potential numerical instability.
333+ let u_factor = lu. U ( ) ;
334+ let mut max_pivot = 0.0_f64 ;
335+ let mut min_pivot = f64:: INFINITY ;
331336 for i in 0 ..n {
332- for j in 0 ..n {
333- let expected = if i == j { 1.0 } else { 0 .0 } ;
334- let residual = ( a_times_inv [ ( i , j ) ] - expected ) . abs ( ) ;
335- max_residual = max_residual . max ( residual ) ;
337+ let pivot = u_factor [ ( i , i ) ] . abs ( ) ;
338+ if pivot > 0 .0 {
339+ max_pivot = max_pivot . max ( pivot ) ;
340+ min_pivot = min_pivot . min ( pivot ) ;
336341 }
337342 }
343+ let pivot_ratio = if max_pivot > 0.0 { min_pivot / max_pivot } else { 0.0 } ;
344+
345+ // Only perform expensive residual check if pivots suggest potential instability.
346+ // Threshold of 1e-10 catches truly problematic matrices while avoiding
347+ // unnecessary O(n³) computation for well-conditioned cases.
348+ if pivot_ratio < 1e-10 {
349+ // Verify inversion accuracy by checking ||A * A^{-1} - I||_max
350+ // For near-singular matrices, this residual will be large even if
351+ // the result contains no NaN/Inf values
352+ let a_times_inv = a_faer. as_ref ( ) * & x_faer;
353+ let mut max_residual = 0.0_f64 ;
354+ for i in 0 ..n {
355+ for j in 0 ..n {
356+ let expected = if i == j { 1.0 } else { 0.0 } ;
357+ let residual = ( a_times_inv[ ( i, j) ] - expected) . abs ( ) ;
358+ max_residual = max_residual. max ( residual) ;
359+ }
360+ }
338361
339- // Threshold: detect truly singular matrices while allowing ill-conditioned ones
340- // Ill-conditioned matrices (high condition number) can have residuals up to ~1e-4
341- // while still producing usable results. Use 1e-4 * n as threshold.
342- let threshold = 1e-4 * ( n as f64 ) ;
343- if max_residual > threshold {
344- return Err ( PyErr :: new :: < pyo3:: exceptions:: PyValueError , _ > (
345- format ! (
346- "Matrix inversion numerically unstable (residual={:.2e} > threshold={:.2e}). \
347- Design matrix may be near-singular.",
348- max_residual, threshold
349- )
350- ) ) ;
362+ // Threshold: detect truly singular matrices while allowing ill-conditioned ones
363+ // Ill-conditioned matrices (high condition number) can have residuals up to ~1e-4
364+ // while still producing usable results. Use 1e-4 * n as threshold.
365+ let threshold = 1e-4 * ( n as f64 ) ;
366+ if max_residual > threshold {
367+ return Err ( PyErr :: new :: < pyo3:: exceptions:: PyValueError , _ > (
368+ format ! (
369+ "Matrix inversion numerically unstable (residual={:.2e} > threshold={:.2e}). \
370+ Design matrix may be near-singular.",
371+ max_residual, threshold
372+ )
373+ ) ) ;
374+ }
351375 }
352376
353377 // Convert back to ndarray
@@ -445,7 +469,8 @@ mod tests {
445469 [ 1.0 , 1.0 + 1e-15 ] , // Nearly identical rows
446470 ] ;
447471
448- // Should fail due to numerical instability
472+ // Should fail due to numerical instability (small pivot ratio triggers
473+ // residual check which detects the inversion error)
449474 let result = invert_symmetric ( & a) ;
450475 assert ! ( result. is_err( ) , "Near-singular matrix inversion should fail" ) ;
451476
0 commit comments