Skip to content

Commit baeca11

Browse files
lezcanofacebook-github-bot
authored andcommitted
Remove random_fullrank_matrix_distinc_singular_value (pytorch#68183)
Summary: Pull Request resolved: pytorch#68183 We do so in favour of `make_fullrank_matrices_with_distinct_singular_values` as this latter one not only has an even longer name, but also generates inputs correctly for them to work with the PR that tests noncontig inputs latter in this stack. We also heavily simplified the generation of samples for the SVD, as it was fairly convoluted and it was not generating the inputs correclty for the noncontiguous test. To do the transition, we also needed to fix the following issue, as it was popping up in the tests: Fixes pytorch#66856 cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D32684853 Pulled By: mruberry fbshipit-source-id: e88189c8b67dbf592eccdabaf2aa6d2e2f7b95a4
1 parent 08ef4ae commit baeca11

File tree

3 files changed

+134
-190
lines changed

3 files changed

+134
-190
lines changed

test/test_linalg.py

Lines changed: 61 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
from torch.testing._internal.common_utils import \
1818
(TestCase, run_tests, TEST_SCIPY, IS_MACOS, IS_WINDOWS, slowTest,
1919
TEST_WITH_ASAN, TEST_WITH_ROCM, IS_FBCODE, IS_REMOTE_GPU,
20-
iter_indices, gradcheck, gradgradcheck)
20+
iter_indices, gradcheck, gradgradcheck,
21+
make_fullrank_matrices_with_distinct_singular_values)
2122
from torch.testing._internal.common_device_type import \
2223
(instantiate_device_type_tests, dtypes,
2324
onlyCPU, skipCUDAIf, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride,
@@ -3213,7 +3214,8 @@ def test_cholesky_solve_out_errors_and_warnings(self, device, dtype):
32133214
@precisionOverride({torch.float32: 2e-3, torch.complex64: 2e-3,
32143215
torch.float64: 1e-8, torch.complex128: 1e-8})
32153216
def test_inverse(self, device, dtype):
3216-
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value
3217+
make_fullrank = make_fullrank_matrices_with_distinct_singular_values
3218+
make_arg = partial(make_fullrank, device=device, dtype=dtype)
32173219

32183220
def run_test(torch_inverse, matrix, batches, n):
32193221
matrix_inverse = torch_inverse(matrix)
@@ -3265,15 +3267,15 @@ def test_inv_ex(input, out=None):
32653267
[[], [0], [2], [2, 1]],
32663268
[0, 5]
32673269
):
3268-
matrices = random_fullrank_matrix_distinct_singular_value(n, *batches, dtype=dtype, device=device)
3270+
matrices = make_arg(*batches, n, n)
32693271
run_test(torch_inverse, matrices, batches, n)
32703272

32713273
# test non-contiguous input
32723274
run_test(torch_inverse, matrices.mT, batches, n)
32733275
if n > 0:
32743276
run_test(
32753277
torch_inverse,
3276-
random_fullrank_matrix_distinct_singular_value(n * 2, *batches, dtype=dtype, device=device)
3278+
make_arg(*batches, 2 * n, 2 * n)
32773279
.view(-1, n * 2, n * 2)[:, ::2, ::2].view(*batches, n, n),
32783280
batches, n
32793281
)
@@ -3321,10 +3323,11 @@ def test_inv_ex_singular(self, device, dtype):
33213323
@precisionOverride({torch.float32: 2e-3, torch.complex64: 2e-3,
33223324
torch.float64: 1e-5, torch.complex128: 1e-5})
33233325
def test_inverse_many_batches(self, device, dtype):
3324-
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value
3326+
make_fullrank = make_fullrank_matrices_with_distinct_singular_values
3327+
make_arg = partial(make_fullrank, device=device, dtype=dtype)
33253328

33263329
def test_inverse_many_batches_helper(torch_inverse, b, n):
3327-
matrices = random_fullrank_matrix_distinct_singular_value(b, n, n, dtype=dtype, device=device)
3330+
matrices = make_arg(b, n, n)
33283331
matrices_inverse = torch_inverse(matrices)
33293332

33303333
# Compare against NumPy output
@@ -3542,10 +3545,11 @@ def run_test_singular_input(batch_dim, n):
35423545
self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
35433546

35443547
def solve_test_helper(self, A_dims, b_dims, device, dtype):
3545-
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value
3548+
make_fullrank = make_fullrank_matrices_with_distinct_singular_values
3549+
make_A = partial(make_fullrank, device=device, dtype=dtype)
35463550

35473551
b = torch.randn(*b_dims, dtype=dtype, device=device)
3548-
A = random_fullrank_matrix_distinct_singular_value(*A_dims, dtype=dtype, device=device)
3552+
A = make_A(*A_dims)
35493553
return b, A
35503554

35513555
@skipCUDAIfNoMagma
@@ -3554,7 +3558,7 @@ def solve_test_helper(self, A_dims, b_dims, device, dtype):
35543558
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3})
35553559
def test_solve(self, device, dtype):
35563560
def run_test(n, batch, rhs):
3557-
A_dims = (n, *batch)
3561+
A_dims = (*batch, n, n)
35583562
b_dims = (*batch, n, *rhs)
35593563
b, A = self.solve_test_helper(A_dims, b_dims, device, dtype)
35603564

@@ -3600,8 +3604,10 @@ def run_test(n, batch, rhs):
36003604
@dtypes(*floating_and_complex_types())
36013605
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3})
36023606
def test_solve_batched_non_contiguous(self, device, dtype):
3603-
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value
3604-
A = random_fullrank_matrix_distinct_singular_value(2, 2, dtype=dtype, device=device).permute(1, 0, 2)
3607+
make_fullrank = make_fullrank_matrices_with_distinct_singular_values
3608+
make_A = partial(make_fullrank, device=device, dtype=dtype)
3609+
3610+
A = make_A(2, 2, 2).permute(1, 0, 2)
36053611
b = torch.randn(2, 2, 2, dtype=dtype, device=device).permute(2, 1, 0)
36063612
self.assertFalse(A.is_contiguous())
36073613
self.assertFalse(b.is_contiguous())
@@ -3680,7 +3686,7 @@ def run_test_singular_input(batch_dim, n):
36803686
@dtypes(*floating_and_complex_types())
36813687
def test_old_solve(self, device, dtype):
36823688
for (k, n) in zip([2, 3, 5], [3, 5, 7]):
3683-
b, A = self.solve_test_helper((n,), (n, k), device, dtype)
3689+
b, A = self.solve_test_helper((n, n), (n, k), device, dtype)
36843690
x = torch.solve(b, A)[0]
36853691
self.assertEqual(b, np.matmul(A.cpu(), x.cpu()))
36863692

@@ -3700,15 +3706,18 @@ def solve_batch_helper(A_dims, b_dims):
37003706
self.assertEqual(b, Ax)
37013707

37023708
for batchsize in [1, 3, 4]:
3703-
solve_batch_helper((5, batchsize), (batchsize, 5, 10))
3709+
solve_batch_helper((batchsize, 5, 5), (batchsize, 5, 10))
37043710

37053711
@skipCUDAIfNoMagma
37063712
@skipCPUIfNoLapack
37073713
@dtypes(*floating_and_complex_types())
37083714
def test_old_solve_batched_non_contiguous(self, device, dtype):
37093715
from numpy.linalg import solve
3710-
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value
3711-
A = random_fullrank_matrix_distinct_singular_value(2, 2, dtype=dtype, device=device).permute(1, 0, 2)
3716+
3717+
make_fullrank = make_fullrank_matrices_with_distinct_singular_values
3718+
make_A = partial(make_fullrank, device=device, dtype=dtype)
3719+
3720+
A = make_A(2, 2, 2).permute(1, 0, 2)
37123721
b = torch.randn(2, 2, 2, dtype=dtype, device=device).permute(2, 1, 0)
37133722
x, _ = torch.solve(b, A)
37143723
x_exp = solve(A.cpu().numpy(), b.cpu().numpy())
@@ -3719,7 +3728,7 @@ def test_old_solve_batched_non_contiguous(self, device, dtype):
37193728
@skipCPUIfNoLapack
37203729
@dtypes(*floating_and_complex_types())
37213730
def test_old_solve_batched_many_batches(self, device, dtype):
3722-
for A_dims, b_dims in zip([(5, 256, 256), (3, )], [(5, 1), (512, 512, 3, 1)]):
3731+
for A_dims, b_dims in zip([(256, 256, 5, 5), (3, 3)], [(5, 1), (512, 512, 3, 1)]):
37233732
b, A = self.solve_test_helper(A_dims, b_dims, device, dtype)
37243733
x, _ = torch.solve(b, A)
37253734
Ax = torch.matmul(A, x)
@@ -3734,7 +3743,7 @@ def test_old_solve_batched_broadcasting(self, device, dtype):
37343743
def run_test(A_dims, b_dims):
37353744
A_matrix_size = A_dims[-1]
37363745
A_batch_dims = A_dims[:-2]
3737-
b, A = self.solve_test_helper((A_matrix_size,) + A_batch_dims, b_dims, device, dtype)
3746+
b, A = self.solve_test_helper(A_batch_dims + (A_matrix_size, A_matrix_size), b_dims, device, dtype)
37383747
x, _ = torch.solve(b, A)
37393748
x_exp = solve(A.cpu().numpy(), b.cpu().numpy())
37403749
self.assertEqual(x, x_exp)
@@ -4196,26 +4205,27 @@ def run_test_atol(shape0, shape1, batch):
41964205
@skipCPUIfNoLapack
41974206
@dtypes(torch.float64)
41984207
def test_matrix_rank_atol_rtol(self, device, dtype):
4199-
from torch.testing._internal.common_utils import make_fullrank_matrices_with_distinct_singular_values
4208+
make_fullrank = make_fullrank_matrices_with_distinct_singular_values
4209+
make_arg = partial(make_fullrank, device=device, dtype=dtype)
42004210

4201-
# creates a matrix with singular values arange(1/(n+1), 1, 1/(n+1)) and rank=n
4211+
# creates a matrix with singular values rank=n and singular values in range [2/3, 3/2]
4212+
# the singular values are 1 + 1/2, 1 - 1/3, 1 + 1/4, 1 - 1/5, ...
42024213
n = 9
4203-
a = make_fullrank_matrices_with_distinct_singular_values(n, n, dtype=dtype, device=device)
4214+
a = make_arg(n, n)
42044215

42054216
# test float and tensor variants
4206-
for tol_value in [0.51, torch.tensor(0.51, device=device)]:
4207-
# using rtol (relative tolerance) takes into account the largest singular value (0.9 in this case)
4217+
for tol_value in [0.81, torch.tensor(0.81, device=device)]:
4218+
# using rtol (relative tolerance) takes into account the largest singular value (1.5 in this case)
42084219
result = torch.linalg.matrix_rank(a, rtol=tol_value)
4209-
self.assertEqual(result, 5) # there are 5 singular values above 0.9*0.51=0.459
4220+
self.assertEqual(result, 2) # there are 2 singular values above 1.5*0.81 = 1.215
42104221

42114222
# atol is used directly to compare with singular values
42124223
result = torch.linalg.matrix_rank(a, atol=tol_value)
4213-
self.assertEqual(result, 4) # there are 4 singular values above 0.51
4224+
self.assertEqual(result, 7) # there are 7 singular values above 0.81
42144225

42154226
# when both are specified the maximum tolerance is used
42164227
result = torch.linalg.matrix_rank(a, atol=tol_value, rtol=tol_value)
4217-
self.assertEqual(result, 4) # there are 4 singular values above max(0.51, 0.9*0.51)
4218-
4228+
self.assertEqual(result, 2) # there are 2 singular values above max(0.81, 1.5*0.81)
42194229

42204230
@skipCUDAIfNoMagma
42214231
@skipCPUIfNoLapack
@@ -6832,7 +6842,8 @@ def test_solve_methods_arg_device(self, device):
68326842
@skipCPUIfNoLapack
68336843
@dtypes(*floating_and_complex_types())
68346844
def test_pinverse(self, device, dtype):
6835-
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value as fullrank
6845+
make_fullrank = make_fullrank_matrices_with_distinct_singular_values
6846+
make_arg = partial(make_fullrank, device=device, dtype=dtype)
68366847

68376848
def run_test(M):
68386849
# Testing against definition for pseudo-inverses
@@ -6857,7 +6868,7 @@ def run_test(M):
68576868
for sizes in [(5, 5), (3, 5, 5), (3, 7, 5, 5)]:
68586869
matsize = sizes[-1]
68596870
batchdims = sizes[:-2]
6860-
M = fullrank(matsize, *batchdims, dtype=dtype, device=device)
6871+
M = make_arg(*batchdims, matsize, matsize)
68616872
self.assertEqual(torch.eye(matsize, dtype=dtype, device=device).expand(sizes), M.pinverse().matmul(M),
68626873
atol=1e-7, rtol=0, msg='pseudo-inverse for invertible matrix')
68636874

@@ -6884,21 +6895,22 @@ def check(*size, noncontiguous=False):
68846895
@skipCUDAIfNoMagmaAndNoCusolver
68856896
@dtypes(torch.double, torch.cdouble)
68866897
def test_matrix_power_negative(self, device, dtype):
6887-
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value
6898+
make_fullrank = make_fullrank_matrices_with_distinct_singular_values
6899+
make_arg = partial(make_fullrank, device=device, dtype=dtype)
68886900

68896901
def check(*size):
6890-
t = random_fullrank_matrix_distinct_singular_value(*size, dtype=dtype, device=device)
6902+
t = make_arg(*size)
68916903
for n in range(-7, 0):
68926904
res = torch.linalg.matrix_power(t, n)
68936905
ref = np.linalg.matrix_power(t.cpu().numpy(), n)
68946906
self.assertEqual(res.cpu(), torch.from_numpy(ref))
68956907

6896-
check(0)
6897-
check(5)
6898-
check(0, 2)
6899-
check(3, 0)
6900-
check(3, 2)
6901-
check(5, 2, 3)
6908+
check(0, 0)
6909+
check(5, 5)
6910+
check(2, 0, 0)
6911+
check(0, 3, 3)
6912+
check(2, 3, 3)
6913+
check(2, 3, 5, 5)
69026914

69036915
@skipCUDAIfNoMagma
69046916
@skipCPUIfNoLapack
@@ -7761,9 +7773,10 @@ def maybe_squeeze_result(l, r, result):
77617773
@skipCPUIfNoLapack
77627774
@dtypes(*floating_and_complex_types())
77637775
def test_lu_solve_batched_non_contiguous(self, device, dtype):
7764-
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value
7776+
make_fullrank = make_fullrank_matrices_with_distinct_singular_values
7777+
make_A = partial(make_fullrank, device=device, dtype=dtype)
77657778

7766-
A = random_fullrank_matrix_distinct_singular_value(2, 2, dtype=dtype, device=device)
7779+
A = make_A(2, 2, 2)
77677780
b = torch.randn(2, 2, 2, dtype=dtype, device=device)
77687781
x_exp = np.linalg.solve(A.cpu().permute(0, 2, 1).numpy(), b.cpu().permute(2, 1, 0).numpy())
77697782
A = A.permute(0, 2, 1)
@@ -7774,10 +7787,11 @@ def test_lu_solve_batched_non_contiguous(self, device, dtype):
77747787
self.assertEqual(x, x_exp)
77757788

77767789
def lu_solve_test_helper(self, A_dims, b_dims, pivot, device, dtype):
7777-
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value
7790+
make_fullrank = make_fullrank_matrices_with_distinct_singular_values
7791+
make_A = partial(make_fullrank, device=device, dtype=dtype)
77787792

77797793
b = torch.randn(*b_dims, dtype=dtype, device=device)
7780-
A = random_fullrank_matrix_distinct_singular_value(*A_dims, dtype=dtype, device=device)
7794+
A = make_A(*A_dims)
77817795
LU_data, LU_pivots, info = torch.lu(A, get_infos=True, pivot=pivot)
77827796
self.assertEqual(info, torch.zeros_like(info))
77837797
return b, A, LU_data, LU_pivots
@@ -7790,7 +7804,7 @@ def lu_solve_test_helper(self, A_dims, b_dims, pivot, device, dtype):
77907804
def test_lu_solve(self, device, dtype):
77917805
def sub_test(pivot):
77927806
for k, n in zip([2, 3, 5], [3, 5, 7]):
7793-
b, A, LU_data, LU_pivots = self.lu_solve_test_helper((n,), (n, k), pivot, device, dtype)
7807+
b, A, LU_data, LU_pivots = self.lu_solve_test_helper((n, n), (n, k), pivot, device, dtype)
77947808
x = torch.lu_solve(b, LU_data, LU_pivots)
77957809
self.assertEqual(b, np.matmul(A.cpu(), x.cpu()))
77967810

@@ -7817,7 +7831,7 @@ def lu_solve_batch_test_helper(A_dims, b_dims, pivot):
78177831
self.assertEqual(b, Ax)
78187832

78197833
for batchsize in [1, 3, 4]:
7820-
lu_solve_batch_test_helper((5, batchsize), (batchsize, 5, 10), pivot)
7834+
lu_solve_batch_test_helper((batchsize, 5, 5), (batchsize, 5, 10), pivot)
78217835

78227836
# Tests tensors with 0 elements
78237837
b = torch.randn(3, 0, 3, dtype=dtype, device=device)
@@ -7840,19 +7854,20 @@ def run_test(A_dims, b_dims):
78407854
Ax = torch.matmul(A, x)
78417855
self.assertEqual(Ax, b.expand_as(Ax))
78427856

7843-
run_test((5, 65536), (65536, 5, 10))
7844-
run_test((5, 262144), (262144, 5, 10))
7857+
run_test((65536, 5, 5), (65536, 5, 10))
7858+
run_test((262144, 5, 5), (262144, 5, 10))
78457859

78467860
@skipCUDAIfNoMagma
78477861
@skipCPUIfNoLapack
78487862
@dtypes(*floating_and_complex_types())
78497863
def test_lu_solve_batched_broadcasting(self, device, dtype):
7850-
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value
7864+
make_fullrank = make_fullrank_matrices_with_distinct_singular_values
7865+
make_A = partial(make_fullrank, device=device, dtype=dtype)
78517866

78527867
def run_test(A_dims, b_dims, pivot=True):
78537868
A_matrix_size = A_dims[-1]
78547869
A_batch_dims = A_dims[:-2]
7855-
A = random_fullrank_matrix_distinct_singular_value(A_matrix_size, *A_batch_dims, dtype=dtype, device=device)
7870+
A = make_A(*A_batch_dims, A_matrix_size, A_matrix_size)
78567871
b = make_tensor(b_dims, dtype=dtype, device=device)
78577872
x_exp = np.linalg.solve(A.cpu(), b.cpu())
78587873
LU_data, LU_pivots = torch.lu(A, pivot=pivot)

0 commit comments

Comments
 (0)