diff --git a/tests/test_eig_utils.py b/tests/test_eig_utils.py index 00834133..93df0285 100644 --- a/tests/test_eig_utils.py +++ b/tests/test_eig_utils.py @@ -180,6 +180,9 @@ def test_eigh_with_fallback_reconstruction_close_to_original( x = torch.randn(shape, device=self.device) x = x @ x.T # symmetric positive semi-definite + # Remove 0 to prevent test failure because reconstruction of 0 can be further off than other values. + x[x.abs() < 1e-8] = 0.25 + L, Q = eig_utils.eigh_with_fallback( x, force_double=force_double,