Skip to content

Commit e3b3619

Browse files
authored
Merge pull request #165 from mydatamodels/AddFunctionsForElasticNet
feat: add solve_cholesky and solve_triangular functions
2 parents e9f557d + 33ba8da commit e3b3619

File tree

3 files changed

+132
-0
lines changed

3 files changed

+132
-0
lines changed

include/xtensor-blas/xlapack.hpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,54 @@ namespace lapack
427427
return info;
428428
}
429429

430+
template <class E1, class E2>
431+
int potrs(E1& A, E2& b, char uplo = 'L')
432+
{
433+
XTENSOR_ASSERT(A.dimension() == 2);
434+
XTENSOR_ASSERT(A.layout() == layout_type::column_major);
435+
436+
XTENSOR_ASSERT(b.dimension() == 1);
437+
438+
XTENSOR_ASSERT(A.shape()[0] == A.shape()[1]);
439+
440+
int info = cxxlapack::potrs<blas_index_t>(
441+
uplo,
442+
static_cast<blas_index_t>(A.shape()[0]),
443+
1,
444+
A.data(),
445+
static_cast<blas_index_t>(A.shape()[0]),
446+
b.data(),
447+
static_cast<blas_index_t>(b.shape()[0])
448+
);
449+
450+
return info;
451+
}
452+
453+
template <class E1, class E2>
454+
int trtrs(E1& A, E2& b, char uplo = 'L', char trans = 'N', char diag = 'N')
455+
{
456+
XTENSOR_ASSERT(A.dimension() == 2);
457+
XTENSOR_ASSERT(A.layout() == layout_type::column_major);
458+
459+
XTENSOR_ASSERT(b.dimension() == 1);
460+
461+
XTENSOR_ASSERT(A.shape()[0] == A.shape()[1]);
462+
463+
int info = cxxlapack::trtrs<blas_index_t>(
464+
uplo,
465+
trans,
466+
diag,
467+
static_cast<blas_index_t>(A.shape()[0]),
468+
1,
469+
A.data(),
470+
static_cast<blas_index_t>(A.shape()[0]),
471+
b.data(),
472+
static_cast<blas_index_t>(b.shape()[0])
473+
);
474+
475+
return info;
476+
}
477+
430478
/**
431479
* Interface to LAPACK getri.
432480
*

include/xtensor-blas/xlinalg.hpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1253,6 +1253,50 @@ namespace linalg
12531253
return M;
12541254
}
12551255

1256+
/**
1257+
* Solves a system of linear equations M*X = B with a symmetric
1258+
* where M = A*A**T if uplo is L.
1259+
* Factorization of M can be computed with cholesky.
1260+
* @return solution X
1261+
*/
1262+
template <class T, class D>
1263+
auto solve_cholesky(const xexpression<T>& A, const xexpression<D>& b)
1264+
{
1265+
assert_nd_square(A);
1266+
auto M = copy_to_layout<layout_type::column_major>(A.derived_cast());
1267+
auto p = copy_to_layout<layout_type::column_major>(b.derived_cast());
1268+
1269+
int info = lapack::potrs(M, p, 'L');
1270+
1271+
if (info > 0)
1272+
{
1273+
throw std::runtime_error("Cholesky decomposition failed.");
1274+
}
1275+
1276+
return p;
1277+
}
1278+
1279+
/**
1280+
* Solves Ax = b, where A is a lower triangular matrix
1281+
* @return solution x
1282+
*/
1283+
template <class T, class D>
1284+
auto solve_triangular(const xexpression<T>& A, const xexpression<D>& b)
1285+
{
1286+
assert_nd_square(A);
1287+
auto M = copy_to_layout<layout_type::column_major>(A.derived_cast());
1288+
auto p = copy_to_layout<layout_type::column_major>(b.derived_cast());
1289+
1290+
int info = lapack::trtrs(M, p, 'L', 'N');
1291+
1292+
if (info > 0)
1293+
{
1294+
throw std::runtime_error("Cholesky decomposition failed.");
1295+
}
1296+
1297+
return p;
1298+
}
1299+
12561300
/**
12571301
* Compute the SVD decomposition of \em A.
12581302
* @return tuple containing S, V, and D

test/test_lapack.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,4 +147,44 @@ namespace xt
147147
EXPECT_EQ(expected3, res3);
148148
}
149149

150+
TEST(xlapack, solveCholesky) {
151+
152+
xarray<double> A =
153+
{{ 1. , 0. , 0. , 0. , 0. },
154+
{ 0.44615865, 0.89495389, 0. , 0. , 0. },
155+
{ 0.39541532, 0.24253783, 0.88590187, 0. , 0. },
156+
{-0.36681098, -0.26249522, 0.0338034 , 0.89185386, 0. },
157+
{ 0.0881614 , 0.12356345, 0.19887529, -0.35996807, 0.89879433}};
158+
159+
xarray<double> b = {1, 1, 1, -1, -1};
160+
auto x = linalg::solve_cholesky(A, b);
161+
162+
const xarray<double> x_expected = { 0.13757507429403265, 0.26609253571318064, 1.03715526610177222,
163+
-1.3449222878385465 , -1.81183493755905478};
164+
165+
for (int i = 0; i < x_expected.shape()[0]; ++i) {
166+
EXPECT_DOUBLE_EQ(x_expected[i], x[i]);
167+
}
168+
}
169+
170+
TEST(xlapack, solveTriangular) {
171+
172+
const xt::xtensor<double, 2> A =
173+
{{ 1. , 0. , 0. , 0. , 0. },
174+
{ 0.44615865, 0.89495389, 0. , 0. , 0. },
175+
{ 0.39541532, 0.24253783, 0.88590187, 0. , 0. },
176+
{-0.36681098, -0.26249522, 0.0338034 , 0.89185386, 0. },
177+
{ 0.0881614 , 0.12356345, 0.19887529, -0.35996807, 0.89879433}};
178+
179+
const xt::xtensor<double, 1> b = {0.38867999, 0.46467046, 0.39042938, -0.2736973, 0.20813322};
180+
auto x = linalg::solve_triangular(A, b);
181+
182+
const xarray<double> x_expected = { 0.38867998999999998, 0.32544416381003327, 0.17813128230545805,
183+
-0.05799057434472885, 0.08606304705465571};
184+
185+
for (int i = 0; i < x_expected.shape()[0]; ++i) {
186+
EXPECT_DOUBLE_EQ(x_expected[i], x[i]);
187+
}
188+
}
189+
150190
}

0 commit comments

Comments
 (0)