Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/TiledArray/math/linalg/non-distributed/qr.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,22 @@ auto householder_qr(const ArrayV& V, TiledRange q_trange = TiledRange(),
}
}

template <typename ArrayA, typename ArrayB, typename T = ArrayB::numeric_type>
auto qr_solve(const ArrayA& A, const ArrayB& B,
const TiledArray::detail::real_t<T> cond = 1e8,
TiledRange x_trange = TiledRange()) {
(void)detail::array_traits<ArrayB>{};
auto& world = B.world();
auto A_eig = detail::make_matrix(A);
auto B_eig = detail::make_matrix(B);
TA_LAPACK_ON_RANK_ZERO(qr_solve, world, A_eig, B_eig, cond);
world.gop.broadcast_serializable(A_eig, 0);
world.gop.broadcast_serializable(B_eig, 0);
if (x_trange.rank() == 0) x_trange = B.trange();
auto X = eigen_to_array<ArrayB>(world, x_trange, B_eig);
return X;
}

} // namespace TiledArray::math::linalg::non_distributed

#endif
29 changes: 23 additions & 6 deletions src/TiledArray/math/linalg/rank-local.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,22 @@ void cholesky_lsolve(Op transpose, Matrix<T>& A, Matrix<T>& X) {
TA_LAPACK(trtrs, uplo, transpose, diag, n, nrhs, a, lda, b, ldb);
}

template <typename T>
void qr_solve(Matrix<T>& A, Matrix<T>& B,
const TiledArray::detail::real_t<T> cond) {
integer m = A.rows();
integer n = A.cols();
integer nrhs = B.cols();
T* a = A.data();
integer lda = A.rows();
T* b = B.data();
integer ldb = B.rows();
std::vector<integer> jpiv(n);
const TiledArray::detail::real_t<T> rcond = 1 / cond;
integer rank = -1;
TA_LAPACK(gelsy, m, n, nrhs, a, lda, b, ldb, jpiv.data(), rcond, &rank);
}

template <typename T>
void heig(Matrix<T>& A, std::vector<TiledArray::detail::real_t<T>>& W) {
auto jobz = lapack::Job::Vec;
Expand Down Expand Up @@ -250,7 +266,7 @@ void householder_qr(Matrix<T>& V, Matrix<T>& R) {
lapack::orgqr(m, n, k, v, ldv, tau.data());
}

#define TA_LAPACK_EXPLICIT(MATRIX, VECTOR) \
#define TA_LAPACK_EXPLICIT(MATRIX, VECTOR, DOUBLE) \
template void cholesky(MATRIX&); \
template void cholesky_linv(MATRIX&); \
template void cholesky_solve(MATRIX&, MATRIX&); \
Expand All @@ -261,11 +277,12 @@ void householder_qr(Matrix<T>& V, Matrix<T>& R) {
template void lu_solve(MATRIX&, MATRIX&); \
template void lu_inv(MATRIX&); \
template void householder_qr<true>(MATRIX&, MATRIX&); \
template void householder_qr<false>(MATRIX&, MATRIX&);
template void householder_qr<false>(MATRIX&, MATRIX&); \
template void qr_solve(MATRIX&, MATRIX&, DOUBLE)

TA_LAPACK_EXPLICIT(Matrix<double>, std::vector<double>);
TA_LAPACK_EXPLICIT(Matrix<float>, std::vector<float>);
TA_LAPACK_EXPLICIT(Matrix<std::complex<double>>, std::vector<double>);
TA_LAPACK_EXPLICIT(Matrix<std::complex<float>>, std::vector<float>);
TA_LAPACK_EXPLICIT(Matrix<double>, std::vector<double>, double );
TA_LAPACK_EXPLICIT(Matrix<float>, std::vector<float>, float);
TA_LAPACK_EXPLICIT(Matrix<std::complex<double>>, std::vector<double>, double);
TA_LAPACK_EXPLICIT(Matrix<std::complex<float>>, std::vector<float>, float);

} // namespace TiledArray::math::linalg::rank_local
4 changes: 4 additions & 0 deletions src/TiledArray/math/linalg/rank-local.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ void cholesky_solve(Matrix<T> &A, Matrix<T> &X);
template <typename T>
void cholesky_lsolve(Op transpose, Matrix<T> &A, Matrix<T> &X);

template <typename T>
void qr_solve(Matrix<T> &A, Matrix<T> &B,
const TiledArray::detail::real_t<T> cond = 1e8);

template <typename T>
void heig(Matrix<T> &A, std::vector<TiledArray::detail::real_t<T>> &W);

Expand Down
33 changes: 33 additions & 0 deletions tests/linalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,39 @@ BOOST_AUTO_TEST_CASE(cholesky_lsolve) {
GlobalFixture::world->gop.fence();
}

BOOST_AUTO_TEST_CASE(qr_solve) {
GlobalFixture::world->gop.fence();

auto trange = gen_trange(N, {128ul});

auto ref_ta = TA::make_array<TA::TArray<double>>(
*GlobalFixture::world, trange,
[this](TA::Tensor<double>& t, TA::Range const& range) -> double {
return this->make_ta_reference(t, range);
});

auto iden = non_dist::qr_solve(ref_ta, ref_ta);

BOOST_CHECK(iden.trange() == ref_ta.trange());

TA::foreach_inplace(iden, [](TA::Tensor<double>& tile) {
auto range = tile.range();
auto lo = range.lobound_data();
auto up = range.upbound_data();
for (auto m = lo[0]; m < up[0]; ++m)
for (auto n = lo[1]; n < up[1]; ++n)
if (m == n) {
tile(m, n) -= 1.;
}
});

double epsilon = N * N * std::numeric_limits<double>::epsilon();
double norm = iden("i,j").norm(*GlobalFixture::world).get();

BOOST_CHECK_SMALL(norm, epsilon);
GlobalFixture::world->gop.fence();
}

BOOST_AUTO_TEST_CASE(lu_solve) {
GlobalFixture::world->gop.fence();

Expand Down
Loading