Skip to content

Commit a35b4b4

Browse files
lezcanofacebook-github-bot
authored andcommitted
Add linalg.lu_factor (pytorch#66933)
Summary: Pull Request resolved: pytorch#66933 This PR exposes `torch.lu` as `torch.linalg.lu_factor` and `torch.linalg.lu_factor_ex`. This PR also adds support for matrices with zero elements both in the size of the matrix and the batch. Note that this function simply returns empty tensors of the correct size in this case. We add a test and an OpInfo for the new function. This PR also adds documentation for this new function in line of the documentation in the rest of `torch.linalg`. Fixes pytorch#56590 Fixes pytorch#64014 cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano Test Plan: Imported from OSS Reviewed By: gchanan Differential Revision: D32834069 Pulled By: mruberry fbshipit-source-id: 51ef12535fa91d292f419acf83b800b86ee9c7eb
1 parent 3f53365 commit a35b4b4

22 files changed

+597
-336
lines changed

aten/src/ATen/native/BatchLinearAlgebra.cpp

Lines changed: 99 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -928,7 +928,7 @@ static Tensor& linalg_solve_out_info(Tensor& result, Tensor& infos, const Tensor
928928
result = result.unsqueeze_(-1);
929929
}
930930

931-
// lu_stub+lu_solve_stub perform calculations in-place and 'result' must be a copy of 'other_broadcasted'
931+
// lu_factor_stub+lu_solve_stub perform calculations in-place and 'result' must be a copy of 'other_broadcasted'
932932
result.copy_(other_broadcasted);
933933

934934
auto input_working_copy = cloneBatchedColumnMajor(input_broadcasted);
@@ -945,7 +945,7 @@ static Tensor& linalg_solve_out_info(Tensor& result, Tensor& infos, const Tensor
945945
auto pivots_shape = IntArrayRef(input_broadcasted.sizes().data(), input_broadcasted.dim() - 2).vec(); // input_broadcasted.shape[:-2]
946946
pivots_shape.push_back(std::min(input.size(-2), input.size(-1)));
947947
Tensor pivots = at::empty(pivots_shape, input.options().dtype(kInt));
948-
lu_stub(input.device().type(), input_working_copy, pivots, infos, /*compute_pivots=*/true);
948+
lu_factor_stub(input.device().type(), input_working_copy, pivots, infos, /*compute_pivots=*/true);
949949

950950
// solve the linear system using the LU factorization
951951
lu_solve_stub(input.device().type(), result, input_working_copy, pivots);
@@ -1571,30 +1571,109 @@ Tensor cholesky_inverse(const Tensor &input, bool upper) {
15711571
return result;
15721572
}
15731573

1574-
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1574+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu_factor ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
15751575

1576-
DEFINE_DISPATCH(lu_stub);
1576+
DEFINE_DISPATCH(lu_factor_stub);
15771577

1578-
// TODO: remove check_errors argument
1579-
// https://github.com/pytorch/pytorch/issues/64014
1580-
std::tuple<Tensor, Tensor, Tensor> _lu_with_info(const Tensor& self, bool compute_pivots, bool check_errors) {
1581-
TORCH_CHECK(self.dim() >= 2,
1582-
"expected tensor with 2 or more dimensions, got size: ", self.sizes(),
1583-
" instead");
1584-
auto m = self.size(-2);
1585-
auto n = self.size(-1);
1586-
auto req_size = self.sizes().vec();
1578+
std::tuple<Tensor&, Tensor&, Tensor&> linalg_lu_factor_ex_out(const Tensor& A,
1579+
bool pivot,
1580+
bool check_errors,
1581+
Tensor& LU,
1582+
Tensor& pivots,
1583+
Tensor& info) {
1584+
TORCH_CHECK(A.dim() >= 2,
1585+
"expected tensor with 2 or more dimensions, got size: ", A.sizes(), " instead");
1586+
auto req_size = A.sizes().vec();
1587+
const auto m = req_size.cend()[-2];
1588+
const auto n = req_size.cend()[-1];
1589+
1590+
// TODO reimplementation of resize_output with format F-contiguous
1591+
// We should make this a standalone function
1592+
if (resize_output_check(LU, req_size)) {
1593+
// Transpose size
1594+
std::iter_swap(req_size.end() - 1, req_size.end() - 2);
1595+
LU.resize_(req_size, MemoryFormat::Contiguous);
1596+
LU.transpose_(-2, -1); // make 'LU' have Fortran contiguous memory
1597+
}
15871598
req_size.pop_back();
15881599
req_size.back() = std::min(m, n);
1589-
auto pivots_tensor = at::empty(req_size, self.options().dtype(kInt));
1600+
at::native::resize_output(pivots, req_size);
15901601
req_size.pop_back();
1591-
auto infos_tensor = at::zeros(req_size, self.options().dtype(kInt));
1602+
at::native::resize_output(info, req_size);
1603+
1604+
const auto LU_f_contig = LU.transpose(-2, -1).is_contiguous() ;
1605+
1606+
if (LU_f_contig && !LU.is_same(A)) {
1607+
LU.copy_(A);
1608+
}
1609+
const auto LU_ = borrow_else_clone(LU_f_contig, LU, A, /*C-contig*/false);
1610+
1611+
const auto pivots_contig = pivots.is_contiguous();
1612+
const auto pivots_ = borrow_else_clone(pivots_contig, pivots, pivots, /*C-contig*/true);
1613+
1614+
const auto info_contig = info.is_contiguous();
1615+
const auto info_ = borrow_else_clone(info_contig, info, info, /*C-contig*/true);
1616+
1617+
lu_factor_stub(A.device().type(), *LU_, *pivots_, *info_, pivot);
1618+
1619+
if (!LU_f_contig) {
1620+
LU.copy_(*LU_);
1621+
}
1622+
if (!pivots_contig) {
1623+
pivots.copy_(*pivots_);
1624+
}
1625+
if (!info_contig) {
1626+
info.copy_(*info_);
1627+
}
1628+
1629+
if (check_errors) {
1630+
if (A.dim() > 2) {
1631+
batchCheckErrors(info, "torch.linalg.lu_factor_ex");
1632+
} else {
1633+
singleCheckErrors(info.item<int64_t>(), "torch.linalg.lu_factor_ex");
1634+
}
1635+
}
1636+
1637+
return std::tie(LU, pivots, info);
1638+
}
1639+
1640+
std::tuple<Tensor, Tensor, Tensor> linalg_lu_factor_ex(const Tensor& A, bool pivot, bool check_errors) {
1641+
auto LU = at::empty({0}, A.options());
1642+
auto pivots = at::empty({0}, A.options().dtype(kInt));
1643+
auto info = at::empty({0}, A.options().dtype(kInt));
1644+
at::native::linalg_lu_factor_ex_out(A, pivot, check_errors, LU, pivots, info);
1645+
return std::make_tuple(std::move(LU), std::move(pivots), std::move(info));
1646+
}
1647+
1648+
std::tuple<Tensor&, Tensor&> linalg_lu_factor_out(const Tensor& A, bool pivot, Tensor & LU, Tensor & pivots) {
1649+
auto info = at::empty({0}, A.options().dtype(kInt));
1650+
// We pass check_errors as we want to use lu_factor rather than lu_factor_ex in the errors
1651+
at::linalg_lu_factor_ex_out(LU, pivots, info, A, pivot, /*chech_errors=*/false);
1652+
if (A.dim() > 2) {
1653+
batchCheckErrors(info, "torch.linalg.lu_factor");
1654+
} else {
1655+
singleCheckErrors(info.item<int64_t>(), "torch.linalg.lu_factor");
1656+
}
1657+
1658+
return std::tie(LU, pivots);
1659+
}
1660+
1661+
std::tuple<Tensor, Tensor> linalg_lu_factor(const Tensor& A, bool pivot) {
1662+
Tensor LU, pivots, info;
1663+
std::tie(LU, pivots, info) = at::linalg_lu_factor_ex(A, pivot, /*check_errors=*/false);
1664+
1665+
if (A.dim() > 2) {
1666+
batchCheckErrors(info, "torch.linalg.lu_factor");
1667+
} else {
1668+
singleCheckErrors(info.item<int64_t>(), "torch.linalg.lu_factor");
1669+
}
1670+
1671+
return std::make_tuple(std::move(LU), std::move(pivots));
1672+
}
15921673

1593-
// lu_stub (apply_lu) requires batched column major (Fortran-contiguous) tensors
1594-
// 'lu' tensor is modified in-place and must be a copy of 'self'
1595-
Tensor lu = cloneBatchedColumnMajor(self);
1596-
lu_stub(self.device().type(), lu, pivots_tensor, infos_tensor, compute_pivots);
1597-
return std::make_tuple(lu, pivots_tensor, infos_tensor);
1674+
// TODO Deprecate this function in favour of linalg_lu_factor_ex
1675+
std::tuple<Tensor, Tensor, Tensor> _lu_with_info(const Tensor& self, bool compute_pivots, bool) {
1676+
return at::linalg_lu_factor_ex(self, compute_pivots, false);
15981677
}
15991678

16001679
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangular_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

aten/src/ATen/native/BatchLinearAlgebra.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,12 +219,12 @@ using triangular_solve_fn = void (*)(
219219
bool /*unitriangular*/);
220220
DECLARE_DISPATCH(triangular_solve_fn, triangular_solve_stub);
221221

222-
using lu_fn = void (*)(
222+
using lu_factor_fn = void (*)(
223223
const Tensor& /*input*/,
224224
const Tensor& /*pivots*/,
225225
const Tensor& /*infos*/,
226226
bool /*compute_pivots*/);
227-
DECLARE_DISPATCH(lu_fn, lu_stub);
227+
DECLARE_DISPATCH(lu_factor_fn, lu_factor_stub);
228228

229229
using lu_solve_fn = void (*)(
230230
const Tensor& /*b*/,

aten/src/ATen/native/BatchLinearAlgebraKernel.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -847,14 +847,14 @@ void triangular_solve_kernel(const Tensor& A, const Tensor& B, bool left, bool u
847847
For further details, please see the LAPACK documentation for GETRF.
848848
*/
849849
template <typename scalar_t>
850-
void apply_lu(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) {
850+
void apply_lu_factor(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) {
851851
#if !AT_BUILD_WITH_LAPACK()
852852
TORCH_CHECK(
853853
false,
854854
"Calling torch.lu on a CPU tensor requires compiling ",
855855
"PyTorch with LAPACK. Please use PyTorch built with LAPACK support.");
856856
#else
857-
TORCH_CHECK(compute_pivots, "lu without pivoting is not implemented on the CPU");
857+
TORCH_CHECK(compute_pivots, "linalg.lu_factor: LU without pivoting is not implemented on the CPU");
858858

859859
auto input_data = input.data_ptr<scalar_t>();
860860
auto pivots_data = pivots.data_ptr<int>();
@@ -876,9 +876,9 @@ void apply_lu(const Tensor& input, const Tensor& pivots, const Tensor& infos, bo
876876
}
877877

878878
// This is a type dispatching helper function for 'apply_lu'
879-
void lu_kernel(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) {
879+
void lu_factor_kernel(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) {
880880
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "lu_cpu", [&]{
881-
apply_lu<scalar_t>(input, pivots, infos, compute_pivots);
881+
apply_lu_factor<scalar_t>(input, pivots, infos, compute_pivots);
882882
});
883883
}
884884

@@ -890,8 +890,8 @@ void lu_kernel(const Tensor& input, const Tensor& pivots, const Tensor& infos, b
890890
Args:
891891
* `b` - [in] the right hand side matrix B
892892
[out] the solution matrix X
893-
* `lu` - [in] the LU factorization of matrix A (see at::_lu_with_info)
894-
* `pivots` - [in] the pivot indices (see at::_lu_with_info)
893+
* `lu` - [in] the LU factorization of matrix A (see at::linalg_lu_factor)
894+
* `pivots` - [in] the pivot indices (see at::linalg_lu_factor)
895895
896896
For further details, please see the LAPACK documentation for GETRS.
897897
*/
@@ -1005,11 +1005,11 @@ REGISTER_AVX2_DISPATCH(triangular_solve_stub, &triangular_solve_kernel);
10051005
REGISTER_VSX_DISPATCH(triangular_solve_stub, &triangular_solve_kernel);
10061006
REGISTER_ZVECTOR_DISPATCH(triangular_solve_stub, &triangular_solve_kernel);
10071007

1008-
REGISTER_ARCH_DISPATCH(lu_stub, DEFAULT, &lu_kernel);
1009-
REGISTER_AVX512_DISPATCH(lu_stub, &lu_kernel);
1010-
REGISTER_AVX2_DISPATCH(lu_stub, &lu_kernel);
1011-
REGISTER_VSX_DISPATCH(lu_stub, &lu_kernel);
1012-
REGISTER_ZVECTOR_DISPATCH(lu_stub, &lu_kernel);
1008+
REGISTER_ARCH_DISPATCH(lu_factor_stub, DEFAULT, &lu_factor_kernel);
1009+
REGISTER_AVX512_DISPATCH(lu_factor_stub, &lu_factor_kernel);
1010+
REGISTER_AVX2_DISPATCH(lu_factor_stub, &lu_factor_kernel);
1011+
REGISTER_VSX_DISPATCH(lu_factor_stub, &lu_factor_kernel);
1012+
REGISTER_ZVECTOR_DISPATCH(lu_factor_stub, &lu_factor_kernel);
10131013

10141014
REGISTER_ARCH_DISPATCH(lu_solve_trans_stub, DEFAULT, &lu_solve_trans_kernel);
10151015
REGISTER_AVX512_DISPATCH(lu_solve_trans_stub, &lu_solve_trans_kernel);

aten/src/ATen/native/LinearAlgebra.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ DEFINE_DISPATCH(linalg_vector_norm_stub);
119119
// where info helps us identify singular matrices.
120120
static inline std::tuple<c10::ExclusivelyOwned<Tensor>, c10::ExclusivelyOwned<Tensor>> _lu_det_P_diag_U(const Tensor& self) {
121121
Tensor pivs, lu, infos;
122-
std::tie(lu, pivs, infos) = at::_lu_with_info(self, /*pivot=*/true, /*check_errors=*/false);
122+
std::tie(lu, pivs, infos) = at::linalg_lu_factor_ex(self);
123123
TORCH_CHECK(infos.ge(0).all().item<uint8_t>(), "Invalid argument passed to lu");
124124
auto n = self.size(-1);
125125
auto num_exchanges = (at::arange(1, n + 1, pivs.options()) != pivs)
@@ -135,7 +135,7 @@ static inline std::tuple<c10::ExclusivelyOwned<Tensor>, c10::ExclusivelyOwned<Te
135135
// det(A) = ([is P odd] * -2 + 1) * prod(diag(U))
136136
std::tuple<Tensor, Tensor, Tensor> _det_lu_based_helper(const Tensor& self) {
137137
Tensor lu, pivs, infos;
138-
std::tie(lu, pivs, infos) = at::_lu_with_info(self, /*pivot=*/true, /*check_errors*/false);
138+
std::tie(lu, pivs, infos) = at::linalg_lu_factor_ex(self);
139139
TORCH_CHECK(infos.ge(0).all().item<uint8_t>(), "at::_det_lu_based_helper(): Invalid argument passed to LU");
140140

141141
// find det(P)

aten/src/ATen/native/LinearAlgebraUtils.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,15 @@ static inline Tensor cloneBatchedColumnMajor(const Tensor& src) {
5959
return result;
6060
}
6161

62+
/*
63+
* contig chooses between C-contig (true) and F-contig (false)
64+
*/
65+
static inline c10::MaybeOwned<Tensor> borrow_else_clone(const bool cond, const Tensor& borrow, const Tensor& clone, const bool contig) {
66+
return cond ? c10::MaybeOwned<Tensor>::borrowed(borrow)
67+
: c10::MaybeOwned<Tensor>::owned(contig ? clone.clone(MemoryFormat::Contiguous)
68+
: cloneBatchedColumnMajor(clone));
69+
}
70+
6271
/*
6372
* This method is designed to be a faster alternative to
6473
* `cloneBatchedColumnMajor` with some additional features,
@@ -280,6 +289,11 @@ static inline void singleCheckErrors(int64_t info, const char* name, int64_t bat
280289
} else if (strstr(name, "lstsq")) {
281290
TORCH_CHECK_LINALG(false, name, batch_string,
282291
": The least squares solution could not be computed because the input matrix does not have full rank (error code: ", info, ").");
292+
} else if (strstr(name, "lu_factor")) {
293+
TORCH_CHECK(false, name, batch_string,
294+
": U[", info, ",", info, "] is zero and using it on lu_solve would result in a division by zero. "
295+
"If you still want to perform the factorization, consider calling linalg.lu(A, pivot) or "
296+
"linalg.lu_factor_ex(A, pivot)");
283297
} else {
284298
TORCH_INTERNAL_ASSERT(false, name, ": Unknown error code: ", info, ".");
285299
}

0 commit comments

Comments
 (0)