@@ -928,7 +928,7 @@ static Tensor& linalg_solve_out_info(Tensor& result, Tensor& infos, const Tensor
928
928
result = result.unsqueeze_ (-1 );
929
929
}
930
930
931
- // lu_stub +lu_solve_stub perform calculations in-place and 'result' must be a copy of 'other_broadcasted'
931
+ // lu_factor_stub +lu_solve_stub perform calculations in-place and 'result' must be a copy of 'other_broadcasted'
932
932
result.copy_ (other_broadcasted);
933
933
934
934
auto input_working_copy = cloneBatchedColumnMajor (input_broadcasted);
@@ -945,7 +945,7 @@ static Tensor& linalg_solve_out_info(Tensor& result, Tensor& infos, const Tensor
945
945
auto pivots_shape = IntArrayRef (input_broadcasted.sizes ().data (), input_broadcasted.dim () - 2 ).vec (); // input_broadcasted.shape[:-2]
946
946
pivots_shape.push_back (std::min (input.size (-2 ), input.size (-1 )));
947
947
Tensor pivots = at::empty (pivots_shape, input.options ().dtype (kInt ));
948
- lu_stub (input.device ().type (), input_working_copy, pivots, infos, /* compute_pivots=*/ true );
948
+ lu_factor_stub (input.device ().type (), input_working_copy, pivots, infos, /* compute_pivots=*/ true );
949
949
950
950
// solve the linear system using the LU factorization
951
951
lu_solve_stub (input.device ().type (), result, input_working_copy, pivots);
@@ -1571,30 +1571,109 @@ Tensor cholesky_inverse(const Tensor &input, bool upper) {
1571
1571
return result;
1572
1572
}
1573
1573
1574
- // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1574
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu_factor ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1575
1575
1576
- DEFINE_DISPATCH (lu_stub );
1576
+ DEFINE_DISPATCH (lu_factor_stub );
1577
1577
1578
- // TODO: remove check_errors argument
1579
- // https://github.com/pytorch/pytorch/issues/64014
1580
- std::tuple<Tensor, Tensor, Tensor> _lu_with_info (const Tensor& self, bool compute_pivots, bool check_errors) {
1581
- TORCH_CHECK (self.dim () >= 2 ,
1582
- " expected tensor with 2 or more dimensions, got size: " , self.sizes (),
1583
- " instead" );
1584
- auto m = self.size (-2 );
1585
- auto n = self.size (-1 );
1586
- auto req_size = self.sizes ().vec ();
1578
+ std::tuple<Tensor&, Tensor&, Tensor&> linalg_lu_factor_ex_out (const Tensor& A,
1579
+ bool pivot,
1580
+ bool check_errors,
1581
+ Tensor& LU,
1582
+ Tensor& pivots,
1583
+ Tensor& info) {
1584
+ TORCH_CHECK (A.dim () >= 2 ,
1585
+ " expected tensor with 2 or more dimensions, got size: " , A.sizes (), " instead" );
1586
+ auto req_size = A.sizes ().vec ();
1587
+ const auto m = req_size.cend ()[-2 ];
1588
+ const auto n = req_size.cend ()[-1 ];
1589
+
1590
+ // TODO reimplementation of resize_output with format F-contiguous
1591
+ // We should make this a standalone function
1592
+ if (resize_output_check (LU, req_size)) {
1593
+ // Transpose size
1594
+ std::iter_swap (req_size.end () - 1 , req_size.end () - 2 );
1595
+ LU.resize_ (req_size, MemoryFormat::Contiguous);
1596
+ LU.transpose_ (-2 , -1 ); // make 'LU' have Fortran contiguous memory
1597
+ }
1587
1598
req_size.pop_back ();
1588
1599
req_size.back () = std::min (m, n);
1589
- auto pivots_tensor = at::empty (req_size, self. options (). dtype ( kInt ) );
1600
+ at::native::resize_output (pivots, req_size );
1590
1601
req_size.pop_back ();
1591
- auto infos_tensor = at::zeros (req_size, self.options ().dtype (kInt ));
1602
+ at::native::resize_output (info, req_size);
1603
+
1604
+ const auto LU_f_contig = LU.transpose (-2 , -1 ).is_contiguous () ;
1605
+
1606
+ if (LU_f_contig && !LU.is_same (A)) {
1607
+ LU.copy_ (A);
1608
+ }
1609
+ const auto LU_ = borrow_else_clone (LU_f_contig, LU, A, /* C-contig*/ false );
1610
+
1611
+ const auto pivots_contig = pivots.is_contiguous ();
1612
+ const auto pivots_ = borrow_else_clone (pivots_contig, pivots, pivots, /* C-contig*/ true );
1613
+
1614
+ const auto info_contig = info.is_contiguous ();
1615
+ const auto info_ = borrow_else_clone (info_contig, info, info, /* C-contig*/ true );
1616
+
1617
+ lu_factor_stub (A.device ().type (), *LU_, *pivots_, *info_, pivot);
1618
+
1619
+ if (!LU_f_contig) {
1620
+ LU.copy_ (*LU_);
1621
+ }
1622
+ if (!pivots_contig) {
1623
+ pivots.copy_ (*pivots_);
1624
+ }
1625
+ if (!info_contig) {
1626
+ info.copy_ (*info_);
1627
+ }
1628
+
1629
+ if (check_errors) {
1630
+ if (A.dim () > 2 ) {
1631
+ batchCheckErrors (info, " torch.linalg.lu_factor_ex" );
1632
+ } else {
1633
+ singleCheckErrors (info.item <int64_t >(), " torch.linalg.lu_factor_ex" );
1634
+ }
1635
+ }
1636
+
1637
+ return std::tie (LU, pivots, info);
1638
+ }
1639
+
1640
+ std::tuple<Tensor, Tensor, Tensor> linalg_lu_factor_ex (const Tensor& A, bool pivot, bool check_errors) {
1641
+ auto LU = at::empty ({0 }, A.options ());
1642
+ auto pivots = at::empty ({0 }, A.options ().dtype (kInt ));
1643
+ auto info = at::empty ({0 }, A.options ().dtype (kInt ));
1644
+ at::native::linalg_lu_factor_ex_out (A, pivot, check_errors, LU, pivots, info);
1645
+ return std::make_tuple (std::move (LU), std::move (pivots), std::move (info));
1646
+ }
1647
+
1648
+ std::tuple<Tensor&, Tensor&> linalg_lu_factor_out (const Tensor& A, bool pivot, Tensor & LU, Tensor & pivots) {
1649
+ auto info = at::empty ({0 }, A.options ().dtype (kInt ));
1650
+ // We pass check_errors as we want to use lu_factor rather than lu_factor_ex in the errors
1651
+ at::linalg_lu_factor_ex_out (LU, pivots, info, A, pivot, /* chech_errors=*/ false );
1652
+ if (A.dim () > 2 ) {
1653
+ batchCheckErrors (info, " torch.linalg.lu_factor" );
1654
+ } else {
1655
+ singleCheckErrors (info.item <int64_t >(), " torch.linalg.lu_factor" );
1656
+ }
1657
+
1658
+ return std::tie (LU, pivots);
1659
+ }
1660
+
1661
+ std::tuple<Tensor, Tensor> linalg_lu_factor (const Tensor& A, bool pivot) {
1662
+ Tensor LU, pivots, info;
1663
+ std::tie (LU, pivots, info) = at::linalg_lu_factor_ex (A, pivot, /* check_errors=*/ false );
1664
+
1665
+ if (A.dim () > 2 ) {
1666
+ batchCheckErrors (info, " torch.linalg.lu_factor" );
1667
+ } else {
1668
+ singleCheckErrors (info.item <int64_t >(), " torch.linalg.lu_factor" );
1669
+ }
1670
+
1671
+ return std::make_tuple (std::move (LU), std::move (pivots));
1672
+ }
1592
1673
1593
- // lu_stub (apply_lu) requires batched column major (Fortran-contiguous) tensors
1594
- // 'lu' tensor is modified in-place and must be a copy of 'self'
1595
- Tensor lu = cloneBatchedColumnMajor (self);
1596
- lu_stub (self.device ().type (), lu, pivots_tensor, infos_tensor, compute_pivots);
1597
- return std::make_tuple (lu, pivots_tensor, infos_tensor);
1674
+ // TODO Deprecate this function in favour of linalg_lu_factor_ex
1675
+ std::tuple<Tensor, Tensor, Tensor> _lu_with_info (const Tensor& self, bool compute_pivots, bool ) {
1676
+ return at::linalg_lu_factor_ex (self, compute_pivots, false );
1598
1677
}
1599
1678
1600
1679
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangular_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
0 commit comments