Skip to content

Commit 5ddc8b1

Browse files
igerberclaude
andcommitted
Fix underdetermined SVD in Rust OLS (n < k case)
For underdetermined systems (n < k), thin SVD returns: - U: (n, n) - S: (n,) - only n singular values - V: (k, n) - only n right singular vectors The bug was creating s_inv_uty with size k instead of s.len()=min(n,k), causing a dimension mismatch in the dot product vt.t().dot(&s_inv_uty). Fix: Use s.len() for s_inv_uty array size, which correctly handles both overdetermined (n >= k) and underdetermined (n < k) cases. Addresses P1 bug identified in PR #115 code review. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent f771663 commit 5ddc8b1

1 file changed

Lines changed: 31 additions & 2 deletions

File tree

rust/src/linalg.rs

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,10 @@ pub fn solve_ols<'py>(
119119
let uty = u.t().dot(&y_owned); // (min(n,k),)
120120

121121
// Build S^{-1} with truncation and count effective rank
122-
let mut s_inv_uty = Array1::<f64>::zeros(k);
122+
// Note: s.len() = min(n, k) from thin SVD, so this handles underdetermined (n < k) correctly
123+
let mut s_inv_uty = Array1::<f64>::zeros(s.len());
123124
let mut rank = 0usize;
124-
for i in 0..s.len().min(k) {
125+
for i in 0..s.len() {
125126
if s[i] > threshold {
126127
s_inv_uty[i] = uty[i] / s[i];
127128
rank += 1;
@@ -360,4 +361,32 @@ mod tests {
360361
assert_eq!(faer_mat[(1, 0)], 3.0);
361362
assert_eq!(faer_mat[(1, 1)], 4.0);
362363
}
364+
365+
#[test]
366+
fn test_svd_underdetermined_dimensions() {
367+
// Underdetermined system: n=2 observations, k=3 coefficients
368+
// X is (2, 3), y is (2,)
369+
// This test verifies that thin SVD returns the correct dimensions
370+
// for underdetermined systems and that our code handles them correctly
371+
let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
372+
let _y = array![7.0, 8.0];
373+
374+
// Convert to faer and compute thin SVD
375+
let x_faer = ndarray_to_faer(&x);
376+
let svd = x_faer.thin_svd().unwrap();
377+
378+
// For n=2 < k=3: U is (2, 2), S has 2 values, V is (3, 2)
379+
assert_eq!(svd.U().nrows(), 2, "U should have n=2 rows");
380+
assert_eq!(svd.U().ncols(), 2, "U should have min(n,k)=2 cols");
381+
assert_eq!(svd.S().column_vector().nrows(), 2, "S should have min(n,k)=2 singular values");
382+
assert_eq!(svd.V().nrows(), 3, "V should have k=3 rows");
383+
assert_eq!(svd.V().ncols(), 2, "V should have min(n,k)=2 cols");
384+
385+
// Verify s_inv_uty dimension calculation
386+
let s_len = svd.S().column_vector().nrows();
387+
assert_eq!(s_len, 2, "s.len() should be min(n,k)=2, not k=3");
388+
389+
// This is the key fix: s_inv_uty must have dimension s.len()=min(n,k),
390+
// not k, otherwise vt.t().dot(&s_inv_uty) will have mismatched dimensions
391+
}
363392
}

0 commit comments

Comments
 (0)