Skip to content

Commit d714ecf

Browse files
vishwakftwfacebook-github-bot
authored andcommitted
Rename potrf to cholesky (pytorch#12699)
Summary: This PR performs a renaming of the function `potrf` responsible for the Cholesky decomposition on positive definite matrices to `cholesky` as NumPy and TF do. Billing of changes - make potrf cname for cholesky in Declarations.cwrap - modify the function names in ATen/core - modify the function names in Python frontend - issue warnings when potrf is called to notify users of the change Reviewed By: soumith Differential Revision: D10528361 Pulled By: zou3519 fbshipit-source-id: 19d9bcf8ffb38def698ae5acf30743884dda0d88
1 parent 26a8bb6 commit d714ecf

18 files changed

+118
-90
lines changed

aten/src/ATen/core/Tensor.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -747,7 +747,7 @@ class CAFFE2_API Tensor {
747747
std::tuple<Tensor,Tensor> symeig(bool eigenvectors=false, bool upper=true) const;
748748
std::tuple<Tensor,Tensor> eig(bool eigenvectors=false) const;
749749
std::tuple<Tensor,Tensor,Tensor> svd(bool some=true, bool compute_uv=true) const;
750-
Tensor potrf(bool upper=true) const;
750+
Tensor cholesky(bool upper=false) const;
751751
Tensor potrs(const Tensor & input2, bool upper=true) const;
752752
Tensor potri(bool upper=true) const;
753753
std::tuple<Tensor,Tensor> pstrf(bool upper=true, Scalar tol=-1) const;

aten/src/ATen/core/TensorMethods.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -1508,8 +1508,8 @@ inline std::tuple<Tensor,Tensor> Tensor::eig(bool eigenvectors) const {
15081508
inline std::tuple<Tensor,Tensor,Tensor> Tensor::svd(bool some, bool compute_uv) const {
15091509
return type().svd(*this, some, compute_uv);
15101510
}
1511-
inline Tensor Tensor::potrf(bool upper) const {
1512-
return type().potrf(*this, upper);
1511+
inline Tensor Tensor::cholesky(bool upper) const {
1512+
return type().cholesky(*this, upper);
15131513
}
15141514
inline Tensor Tensor::potrs(const Tensor & input2, bool upper) const {
15151515
return type().potrs(*this, input2, upper);

aten/src/ATen/core/Type.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -704,7 +704,7 @@ struct CAFFE2_API Type {
704704
virtual std::tuple<Tensor,Tensor> symeig(const Tensor & self, bool eigenvectors, bool upper) const = 0;
705705
virtual std::tuple<Tensor,Tensor> eig(const Tensor & self, bool eigenvectors) const = 0;
706706
virtual std::tuple<Tensor,Tensor,Tensor> svd(const Tensor & self, bool some, bool compute_uv) const = 0;
707-
virtual Tensor potrf(const Tensor & self, bool upper) const = 0;
707+
virtual Tensor cholesky(const Tensor & self, bool upper) const = 0;
708708
virtual Tensor potrs(const Tensor & self, const Tensor & input2, bool upper) const = 0;
709709
virtual Tensor potri(const Tensor & self, bool upper) const = 0;
710710
virtual std::tuple<Tensor,Tensor> pstrf(const Tensor & self, bool upper, Scalar tol) const = 0;

aten/src/ATen/core/aten_interned_strings.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ _(aten, cauchy) \
236236
_(aten, ceil) \
237237
_(aten, celu) \
238238
_(aten, chain_matmul) \
239+
_(aten, cholesky) \
239240
_(aten, chunk) \
240241
_(aten, clamp) \
241242
_(aten, clamp_max) \
@@ -510,7 +511,6 @@ _(aten, pinverse) \
510511
_(aten, pixel_shuffle) \
511512
_(aten, poisson) \
512513
_(aten, polygamma) \
513-
_(aten, potrf) \
514514
_(aten, potri) \
515515
_(aten, potrs) \
516516
_(aten, pow) \

aten/src/ATen/native/LegacyDefinitions.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -491,11 +491,11 @@ std::tuple<Tensor,Tensor,Tensor> svd(const Tensor & self, bool some, bool comput
491491
return at::_th_svd(self, some, compute_uv);
492492
}
493493

494-
Tensor & potrf_out(Tensor & result, const Tensor & self, bool upper) {
494+
Tensor & cholesky_out(Tensor & result, const Tensor & self, bool upper) {
495495
return at::_th_potrf_out(result, self, upper);
496496
}
497497

498-
Tensor potrf(const Tensor & self, bool upper) {
498+
Tensor cholesky(const Tensor & self, bool upper) {
499499
return at::_th_potrf(self, upper);
500500
}
501501

aten/src/ATen/native/native_functions.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -2785,10 +2785,10 @@
27852785
variants: method, function
27862786
device_guard: false
27872787

2788-
- func: potrf_out(Tensor result, Tensor self, bool upper=true) -> Tensor
2788+
- func: cholesky_out(Tensor result, Tensor self, bool upper=false) -> Tensor
27892789
device_guard: false
27902790

2791-
- func: potrf(Tensor self, bool upper=true) -> Tensor
2791+
- func: cholesky(Tensor self, bool upper=false) -> Tensor
27922792
variants: method, function
27932793
device_guard: false
27942794

docs/source/tensors.rst

+1
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ view of a storage and defines numeric operations on it.
180180
.. automethod:: ceil
181181
.. automethod:: ceil_
182182
.. automethod:: char
183+
.. automethod:: cholesky
183184
.. automethod:: chunk
184185
.. automethod:: clamp
185186
.. automethod:: clamp_

docs/source/torch.rst

+1
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ BLAS and LAPACK Operations
288288
.. autofunction:: btrisolve
289289
.. autofunction:: btriunpack
290290
.. autofunction:: chain_matmul
291+
.. autofunction:: cholesky
291292
.. autofunction:: dot
292293
.. autofunction:: eig
293294
.. autofunction:: gels

test/test_autograd.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -2025,13 +2025,13 @@ def test_cat_empty(self):
20252025
True, f_args_variable, f_args_tensor)
20262026

20272027
@skipIfNoLapack
2028-
def test_potrf(self):
2029-
root = Variable(torch.tril(torch.rand(S, S)), requires_grad=True)
2028+
def test_cholesky(self):
2029+
root = torch.tril(torch.rand(S, S)).requires_grad_()
20302030

20312031
def run_test(upper):
20322032
def func(root):
20332033
x = torch.mm(root, root.t())
2034-
return torch.potrf(x, upper)
2034+
return torch.cholesky(x, upper)
20352035

20362036
gradcheck(func, [root])
20372037
gradgradcheck(func, [root])

test/test_distributions.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1721,14 +1721,14 @@ def test_multivariate_normal_shape(self):
17211721
tmp = torch.randn(3, 10)
17221722
cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_()
17231723
prec = cov.inverse().requires_grad_()
1724-
scale_tril = torch.potrf(cov, upper=False).requires_grad_()
1724+
scale_tril = torch.cholesky(cov, upper=False).requires_grad_()
17251725

17261726
# construct batch of PSD covariances
17271727
tmp = torch.randn(6, 5, 3, 10)
17281728
cov_batched = (tmp.unsqueeze(-2) * tmp.unsqueeze(-3)).mean(-1).requires_grad_()
17291729
prec_batched = [C.inverse() for C in cov_batched.view((-1, 3, 3))]
17301730
prec_batched = torch.stack(prec_batched).view(cov_batched.shape)
1731-
scale_tril_batched = [torch.potrf(C, upper=False) for C in cov_batched.view((-1, 3, 3))]
1731+
scale_tril_batched = [torch.cholesky(C, upper=False) for C in cov_batched.view((-1, 3, 3))]
17321732
scale_tril_batched = torch.stack(scale_tril_batched).view(cov_batched.shape)
17331733

17341734
# ensure that sample, batch, event shapes all handled correctly
@@ -1764,7 +1764,7 @@ def test_multivariate_normal_log_prob(self):
17641764
tmp = torch.randn(3, 10)
17651765
cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_()
17661766
prec = cov.inverse().requires_grad_()
1767-
scale_tril = torch.potrf(cov, upper=False).requires_grad_()
1767+
scale_tril = torch.cholesky(cov, upper=False).requires_grad_()
17681768

17691769
# check that logprob values match scipy logpdf,
17701770
# and that covariance and scale_tril parameters are equivalent
@@ -1802,7 +1802,7 @@ def test_multivariate_normal_sample(self):
18021802
tmp = torch.randn(3, 10)
18031803
cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_()
18041804
prec = cov.inverse().requires_grad_()
1805-
scale_tril = torch.potrf(cov, upper=False).requires_grad_()
1805+
scale_tril = torch.cholesky(cov, upper=False).requires_grad_()
18061806

18071807
self._check_sampler_sampler(MultivariateNormal(mean, cov),
18081808
scipy.stats.multivariate_normal(mean.detach().numpy(), cov.detach().numpy()),
@@ -1823,7 +1823,7 @@ def test_multivariate_normal_properties(self):
18231823
m = MultivariateNormal(loc=loc, scale_tril=scale_tril)
18241824
self.assertEqual(m.covariance_matrix, m.scale_tril.mm(m.scale_tril.t()))
18251825
self.assertEqual(m.covariance_matrix.mm(m.precision_matrix), torch.eye(m.event_shape[0]))
1826-
self.assertEqual(m.scale_tril, torch.potrf(m.covariance_matrix, upper=False))
1826+
self.assertEqual(m.scale_tril, torch.cholesky(m.covariance_matrix, upper=False))
18271827

18281828
def test_multivariate_normal_moments(self):
18291829
set_rng_seed(0) # see Note [Randomized statistical tests]

test/test_torch.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -5303,19 +5303,19 @@ def test_cholesky(self):
53035303
A = torch.mm(x, x.t())
53045304

53055305
# default Case
5306-
C = torch.potrf(A)
5307-
B = torch.mm(C.t(), C)
5306+
C = torch.cholesky(A)
5307+
B = torch.mm(C, C.t())
53085308
self.assertEqual(A, B, 1e-14)
53095309

53105310
# test Upper Triangular
5311-
U = torch.potrf(A, True)
5311+
U = torch.cholesky(A, True)
53125312
B = torch.mm(U.t(), U)
5313-
self.assertEqual(A, B, 1e-14, 'potrf (upper) did not allow rebuilding the original matrix')
5313+
self.assertEqual(A, B, 1e-14, 'cholesky (upper) did not allow rebuilding the original matrix')
53145314

53155315
# test Lower Triangular
5316-
L = torch.potrf(A, False)
5316+
L = torch.cholesky(A, False)
53175317
B = torch.mm(L, L.t())
5318-
self.assertEqual(A, B, 1e-14, 'potrf (lower) did not allow rebuilding the original matrix')
5318+
self.assertEqual(A, B, 1e-14, 'cholesky (lower) did not allow rebuilding the original matrix')
53195319

53205320
@skipIfNoLapack
53215321
def test_potrs(self):
@@ -5332,12 +5332,12 @@ def test_potrs(self):
53325332
a = torch.mm(a, a.t())
53335333

53345334
# upper Triangular Test
5335-
U = torch.potrf(a)
5336-
x = torch.potrs(b, U)
5335+
U = torch.cholesky(a, True)
5336+
x = torch.potrs(b, U, True)
53375337
self.assertLessEqual(b.dist(torch.mm(a, x)), 1e-12)
53385338

53395339
# lower Triangular Test
5340-
L = torch.potrf(a, False)
5340+
L = torch.cholesky(a, False)
53415341
x = torch.potrs(b, L, False)
53425342
self.assertLessEqual(b.dist(torch.mm(a, x)), 1e-12)
53435343

@@ -5356,17 +5356,17 @@ def test_potri(self):
53565356
inv0 = torch.inverse(a)
53575357

53585358
# default case
5359-
chol = torch.potrf(a)
5360-
inv1 = torch.potri(chol)
5359+
chol = torch.cholesky(a)
5360+
inv1 = torch.potri(chol, False)
53615361
self.assertLessEqual(inv0.dist(inv1), 1e-12)
53625362

53635363
# upper Triangular Test
5364-
chol = torch.potrf(a, True)
5364+
chol = torch.cholesky(a, True)
53655365
inv1 = torch.potri(chol, True)
53665366
self.assertLessEqual(inv0.dist(inv1), 1e-12)
53675367

53685368
# lower Triangular Test
5369-
chol = torch.potrf(a, False)
5369+
chol = torch.cholesky(a, False)
53705370
inv1 = torch.potri(chol, False)
53715371
self.assertLessEqual(inv0.dist(inv1), 1e-12)
53725372

tools/autograd/derivatives.yaml

+3-3
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,9 @@
184184
- name: ceil(Tensor self)
185185
self: zeros_like(grad)
186186

187+
- name: cholesky(Tensor self, bool upper)
188+
self: cholesky_backward(grad, upper, result)
189+
187190
# For clamp, gradient is not defined at the boundaries. But empirically it's helpful
188191
# to be able to get gradient on min and max, so we return the subgradient 1 for these cases.
189192
- name: clamp(Tensor self, Scalar? min, Scalar? max)
@@ -563,9 +566,6 @@
563566
- name: poisson(Tensor self, Generator generator)
564567
self: zeros_like(self)
565568

566-
- name: potrf(Tensor self, bool upper)
567-
self: potrf_backward(grad, upper, result)
568-
569569
- name: potri(Tensor self, bool upper)
570570
self: not_implemented("potri")
571571

tools/autograd/templates/Functions.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,7 @@ Tensor masked_scatter_backward(const Tensor & grad, const Tensor & mask, IntList
626626
return mask_selected.view(sizes);
627627
}
628628

629-
Tensor potrf_backward(Tensor grad, bool upper, Tensor L) {
629+
Tensor cholesky_backward(Tensor grad, bool upper, Tensor L) {
630630
// cf. Iain Murray (2016); arXiv 1602.07527
631631
if (upper) {
632632
L = L.t();

torch/_tensor_docs.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,13 @@ def add_docstr_all(method, docstr):
525525
In-place version of :meth:`~Tensor.ceil`
526526
""")
527527

528+
add_docstr_all('cholesky',
529+
r"""
530+
cholesky(upper=False) -> Tensor
531+
532+
See :func:`torch.cholesky`
533+
""")
534+
528535
add_docstr_all('clamp',
529536
r"""
530537
clamp(min, max) -> Tensor
@@ -1619,13 +1626,6 @@ def callable(a, b) -> number
16191626
torch.Size([5, 2, 3])
16201627
""")
16211628

1622-
add_docstr_all('potrf',
1623-
r"""
1624-
potrf(upper=True) -> Tensor
1625-
1626-
See :func:`torch.potrf`
1627-
""")
1628-
16291629
add_docstr_all('potri',
16301630
r"""
16311631
potri(upper=True) -> Tensor

torch/_torch_docs.py

+47-47
Original file line numberDiff line numberDiff line change
@@ -825,6 +825,51 @@ def parse_kwargs(desc):
825825
tensor([-2.1763, -0.4713, -0.6986, 1.3702])
826826
""")
827827

828+
add_docstr(torch.cholesky, r"""
829+
cholesky(a, upper=False, out=None) -> Tensor
830+
831+
Computes the Cholesky decomposition of a symmetric positive-definite
832+
matrix :math:`A`.
833+
834+
If :attr:`upper` is ``True``, the returned matrix `U` is upper-triangular, and
835+
the decomposition has the form:
836+
837+
.. math::
838+
839+
A = U^TU
840+
841+
If :attr:`upper` is ``False``, the returned matrix `L` is lower-triangular, and
842+
the decomposition has the form:
843+
844+
.. math::
845+
846+
A = LL^T
847+
848+
Args:
849+
a (Tensor): the input 2-D tensor, a symmetric positive-definite matrix
850+
upper (bool, optional): flag that indicates whether to return the
851+
upper or lower triangular matrix. Default: ``False``
852+
out (Tensor, optional): the output matrix
853+
854+
Example::
855+
856+
>>> a = torch.randn(3, 3)
857+
>>> a = torch.mm(a, a.t()) # make symmetric positive definite
858+
>>> l = torch.cholesky(a)
859+
>>> a
860+
tensor([[ 2.4112, -0.7486, 1.4551],
861+
[-0.7486, 1.3544, 0.1294],
862+
[ 1.4551, 0.1294, 1.6724]])
863+
>>> l
864+
tensor([[ 1.5528, 0.0000, 0.0000],
865+
[-0.4821, 1.0592, 0.0000],
866+
[ 0.9371, 0.5487, 0.7023]])
867+
>>> torch.mm(l, l.t())
868+
tensor([[ 2.4112, -0.7486, 1.4551],
869+
[-0.7486, 1.3544, 0.1294],
870+
[ 1.4551, 0.1294, 1.6724]])
871+
""")
872+
828873
add_docstr(torch.clamp,
829874
r"""
830875
clamp(input, min, max, out=None) -> Tensor
@@ -3249,51 +3294,6 @@ def parse_kwargs(desc):
32493294
32503295
""")
32513296

3252-
add_docstr(torch.potrf, r"""
3253-
potrf(a, upper=True, out=None) -> Tensor
3254-
3255-
Computes the Cholesky decomposition of a symmetric positive-definite
3256-
matrix :math:`A`.
3257-
3258-
If :attr:`upper` is ``True``, the returned matrix `U` is upper-triangular, and
3259-
the decomposition has the form:
3260-
3261-
.. math::
3262-
3263-
A = U^TU
3264-
3265-
If :attr:`upper` is ``False``, the returned matrix `L` is lower-triangular, and
3266-
the decomposition has the form:
3267-
3268-
.. math::
3269-
3270-
A = LL^T
3271-
3272-
Args:
3273-
a (Tensor): the input 2-D tensor, a symmetric positive-definite matrix
3274-
upper (bool, optional): flag that indicates whether to return the
3275-
upper or lower triangular matrix
3276-
out (Tensor, optional): the output matrix
3277-
3278-
Example::
3279-
3280-
>>> a = torch.randn(3, 3)
3281-
>>> a = torch.mm(a, a.t()) # make symmetric positive definite
3282-
>>> u = torch.potrf(a)
3283-
>>> a
3284-
tensor([[ 2.4112, -0.7486, 1.4551],
3285-
[-0.7486, 1.3544, 0.1294],
3286-
[ 1.4551, 0.1294, 1.6724]])
3287-
>>> u
3288-
tensor([[ 1.5528, -0.4821, 0.9371],
3289-
[ 0.0000, 1.0592, 0.5486],
3290-
[ 0.0000, 0.0000, 0.7023]])
3291-
>>> torch.mm(u.t(), u)
3292-
tensor([[ 2.4112, -0.7486, 1.4551],
3293-
[-0.7486, 1.3544, 0.1294],
3294-
[ 1.4551, 0.1294, 1.6724]])
3295-
""")
3296-
32973297
add_docstr(torch.potri, r"""
32983298
potri(u, upper=True, out=None) -> Tensor
32993299
@@ -3322,7 +3322,7 @@ def parse_kwargs(desc):
33223322
33233323
>>> a = torch.randn(3, 3)
33243324
>>> a = torch.mm(a, a.t()) # make symmetric positive definite
3325-
>>> u = torch.potrf(a)
3325+
>>> u = torch.cholesky(a)
33263326
>>> a
33273327
tensor([[ 0.9935, -0.6353, 1.5806],
33283328
[ -0.6353, 0.8769, -1.7183],
@@ -3367,7 +3367,7 @@ def parse_kwargs(desc):
33673367
33683368
>>> a = torch.randn(3, 3)
33693369
>>> a = torch.mm(a, a.t()) # make symmetric positive definite
3370-
>>> u = torch.potrf(a)
3370+
>>> u = torch.cholesky(a)
33713371
>>> a
33723372
tensor([[ 0.7747, -1.9549, 1.3086],
33733373
[-1.9549, 6.7546, -5.4114],

0 commit comments

Comments
 (0)