diff --git a/src/integrals/libint/libint.cpp b/src/integrals/libint/libint.cpp index 2f3639c4..583ed776 100644 --- a/src/integrals/libint/libint.cpp +++ b/src/integrals/libint/libint.cpp @@ -26,39 +26,37 @@ namespace integrals::libint { namespace { -template +template auto build_eigen_buffer(const std::vector& basis_sets, - double thresh) { + parallelzone::runtime::RuntimeView& rv, double thresh) { FloatType initial_value; if constexpr(std::is_same_v) { initial_value = 0.0; } else { // Presumably sigma::UDouble initial_value = FloatType(0.0, thresh); } - Eigen::array dims_bfs; - for(decltype(N) i = 0; i < N; ++i) dims_bfs[i] = basis_sets[i].nbf(); + auto N = basis_sets.size(); + std::vector dims(N); + for(decltype(N) i = 0; i < N; ++i) dims[i] = basis_sets[i].nbf(); using shape_t = tensorwrapper::shape::Smooth; using layout_t = tensorwrapper::layout::Physical; - using buffer_t = tensorwrapper::buffer::Eigen; - using data_t = typename buffer_t::data_type; - shape_t s{dims_bfs.begin(), dims_bfs.end()}; + shape_t s{dims.begin(), dims.end()}; layout_t l(s); - data_t d(dims_bfs); - buffer_t b{d, l}; - b.value().setConstant(initial_value); - return b; + tensorwrapper::allocator::Eigen alloc(rv); + return alloc.construct(l, initial_value); } -template +template auto fill_tensor(const std::vector& basis_sets, - const chemist::qm_operator::OperatorBase& op, double thresh) { + const chemist::qm_operator::OperatorBase& op, + parallelzone::runtime::RuntimeView& rv, double thresh) { // Dimensional information std::vector dims_shells(N); for(decltype(N) i = 0; i < N; ++i) dims_shells[i] = basis_sets[i].size(); - auto b = build_eigen_buffer(basis_sets, thresh); + auto pbuffer = build_eigen_buffer(basis_sets, rv, thresh); // Make libint engine LibintVisitor visitor(basis_sets, thresh); @@ -77,7 +75,7 @@ auto fill_tensor(const std::vector& basis_sets, auto ord = detail_::shells2ord(basis_sets, shells); auto n_ord = ord.size(); for(decltype(n_ord) i_ord = 0; i_ord < n_ord; ++i_ord) { - b.value().data()[ord[i_ord]] += vals[i_ord]; + pbuffer->data()[ord[i_ord]] += vals[i_ord]; } } @@ -93,7 +91,8 @@ auto fill_tensor(const std::vector& basis_sets, } } - return simde::type::tensor(b.layout().shape().clone(), b); + auto pshape = pbuffer->layout().shape().clone(); + return simde::type::tensor(std::move(pshape), std::move(pbuffer)); } } // namespace @@ -122,6 +121,7 @@ TEMPLATED_MODULE_RUN(Libint, BraKetType) { auto bra = braket.bra(); auto ket = braket.ket(); auto& op = braket.op(); + auto& rv = this->get_runtime(); // Gather information from Bra, Ket, and Op auto basis_sets = detail_::get_basis_sets(bra, ket); @@ -130,16 +130,17 @@ TEMPLATED_MODULE_RUN(Libint, BraKetType) { simde::type::tensor t; if(with_uq) { if constexpr(integrals::type::has_sigma()) { - t = fill_tensor(basis_sets, op, thresh); + t = fill_tensor(basis_sets, op, rv, + thresh); } else { throw std::runtime_error("Sigma support not enabled!"); } } else { - t = fill_tensor(basis_sets, op, thresh); + t = fill_tensor(basis_sets, op, rv, thresh); } - auto rv = results(); - return my_pt::wrap_results(rv, t); + auto result = results(); + return my_pt::wrap_results(result, t); } #define LIBINT(bra, op, ket) Libint> diff --git a/tests/cxx/unit/integrals/ao_integrals/ao_integrals_driver.cpp b/tests/cxx/unit/integrals/ao_integrals/ao_integrals_driver.cpp index 6d6a1247..35a394e0 100644 --- a/tests/cxx/unit/integrals/ao_integrals/ao_integrals_driver.cpp +++ b/tests/cxx/unit/integrals/ao_integrals/ao_integrals_driver.cpp @@ -21,17 +21,20 @@ using simde::type::tensor; namespace { void compare_matrices(const tensor& A, const tensor& A_corr) { - using alloc_type = tensorwrapper::allocator::Eigen; + using alloc_type = tensorwrapper::allocator::Eigen; const auto& A_buffer = alloc_type::rebind(A.buffer()); const auto& A_corr_buffer = alloc_type::rebind(A_corr.buffer()); - const auto& A_eigen = A_buffer.value(); - const auto& A_corr_eigen = A_corr_buffer.value(); const auto tol = 1E-6; - REQUIRE(A_eigen(0, 0) == Catch::Approx(A_corr_eigen(0, 0)).margin(tol)); - REQUIRE(A_eigen(0, 1) == Catch::Approx(A_corr_eigen(0, 1)).margin(tol)); - REQUIRE(A_eigen(1, 0) == Catch::Approx(A_corr_eigen(1, 0)).margin(tol)); - REQUIRE(A_eigen(1, 1) == Catch::Approx(A_corr_eigen(1, 1)).margin(tol)); + auto A00 = A_buffer.at(0, 0); + auto A01 = A_buffer.at(0, 1); + auto A10 = A_buffer.at(1, 0); + auto A11 = A_buffer.at(1, 1); + + REQUIRE(A00 == Catch::Approx(A_corr_buffer.at(0, 0)).margin(tol)); + REQUIRE(A01 == Catch::Approx(A_corr_buffer.at(0, 1)).margin(tol)); + REQUIRE(A10 == Catch::Approx(A_corr_buffer.at(1, 0)).margin(tol)); + REQUIRE(A11 == Catch::Approx(A_corr_buffer.at(1, 1)).margin(tol)); } } // namespace diff --git a/tests/cxx/unit/integrals/ao_integrals/j_four_center.cpp b/tests/cxx/unit/integrals/ao_integrals/j_four_center.cpp index da0d66b2..011b4452 100644 --- a/tests/cxx/unit/integrals/ao_integrals/j_four_center.cpp +++ b/tests/cxx/unit/integrals/ao_integrals/j_four_center.cpp @@ -39,9 +39,9 @@ TEST_CASE("Four center J builder") { // Call module const auto& T = mm.at("Four center J builder").run_as(braket); - auto t = test::eigen_buffer<2>(T.buffer()); - REQUIRE(t.value()(0, 0) == Catch::Approx(0.56044143).margin(1E-6)); - REQUIRE(t.value()(0, 1) == Catch::Approx(0.24704427).margin(1E-6)); - REQUIRE(t.value()(1, 0) == Catch::Approx(0.24704427).margin(1E-6)); - REQUIRE(t.value()(1, 1) == Catch::Approx(0.56044143).margin(1E-6)); + auto t = test::eigen_tensor<2>(T.buffer()); + REQUIRE(t(0, 0) == Catch::Approx(0.56044143).margin(1E-6)); + REQUIRE(t(0, 1) == Catch::Approx(0.24704427).margin(1E-6)); + REQUIRE(t(1, 0) == Catch::Approx(0.24704427).margin(1E-6)); + REQUIRE(t(1, 1) == Catch::Approx(0.56044143).margin(1E-6)); } \ No newline at end of file diff --git a/tests/cxx/unit/integrals/ao_integrals/k_four_center.cpp b/tests/cxx/unit/integrals/ao_integrals/k_four_center.cpp index f001ea16..2f0c80da 100644 --- a/tests/cxx/unit/integrals/ao_integrals/k_four_center.cpp +++ b/tests/cxx/unit/integrals/ao_integrals/k_four_center.cpp @@ -39,9 +39,9 @@ TEST_CASE("Four center K builder") { // Call module const auto& T = mm.at("Four center K builder").run_as(braket); - auto t = test::eigen_buffer<2>(T.buffer()); - REQUIRE(t.value()(0, 0) == Catch::Approx(0.45617623).margin(1E-6)); - REQUIRE(t.value()(0, 1) == Catch::Approx(0.35130947).margin(1E-6)); - REQUIRE(t.value()(1, 0) == Catch::Approx(0.35130947).margin(1E-6)); - REQUIRE(t.value()(1, 1) == Catch::Approx(0.45617623).margin(1E-6)); + auto t = test::eigen_tensor<2>(T.buffer()); + REQUIRE(t(0, 0) == Catch::Approx(0.45617623).margin(1E-6)); + REQUIRE(t(0, 1) == Catch::Approx(0.35130947).margin(1E-6)); + REQUIRE(t(1, 0) == Catch::Approx(0.35130947).margin(1E-6)); + REQUIRE(t(1, 1) == Catch::Approx(0.45617623).margin(1E-6)); } \ No newline at end of file diff --git a/tests/cxx/unit/integrals/libint/test_arbitrary_operator.cpp b/tests/cxx/unit/integrals/libint/test_arbitrary_operator.cpp index 1bf408eb..9c1e5680 100644 --- a/tests/cxx/unit/integrals/libint/test_arbitrary_operator.cpp +++ b/tests/cxx/unit/integrals/libint/test_arbitrary_operator.cpp @@ -76,10 +76,9 @@ TEST_CASE("OperatorBase") { auto T = mod.run_as(braket); // Check output - auto t = test::eigen_buffer<2>(T.buffer()); - REQUIRE(test::trace(t) == + REQUIRE(test::trace<2>(T.buffer()) == Catch::Approx(124.7011973877891364).margin(1.0e-16)); - REQUIRE(test::norm(t) == + REQUIRE(test::norm<2>(T.buffer()) == Catch::Approx(90.2562579028763707).margin(1.0e-16)); } @@ -89,14 +88,13 @@ TEST_CASE("OperatorBase") { auto T = mod.run_as(braket); // Check output - auto t = test::eigen_buffer<2, udouble>(T.buffer()); - REQUIRE(unwrap_mean(test::trace(t)) == + REQUIRE(unwrap_mean(test::trace<2, udouble>(T.buffer())) == Catch::Approx(124.7011973877891364).margin(1.0e-16)); - REQUIRE(unwrap_sd(test::trace(t)) == + REQUIRE(unwrap_sd(test::trace<2, udouble>(T.buffer())) == Catch::Approx(7e-16).margin(1.0e-16)); - REQUIRE(unwrap_mean(test::norm(t)) == + REQUIRE(unwrap_mean(test::norm<2, udouble>(T.buffer())) == Catch::Approx(90.2562579028763707).margin(1.0e-16)); - REQUIRE(unwrap_sd(test::norm(t)) == + REQUIRE(unwrap_sd(test::norm<2, udouble>(T.buffer())) == Catch::Approx(3e-16).margin(1.0e-16)); } } @@ -115,10 +113,9 @@ TEST_CASE("OperatorBase") { auto T = mod.run_as(braket); // Check output - auto t = test::eigen_buffer<3>(T.buffer()); - REQUIRE(test::trace(t) == + REQUIRE(test::trace<3>(T.buffer()) == Catch::Approx(16.8245948391706577).margin(1.0e-16)); - REQUIRE(test::norm(t) == + REQUIRE(test::norm<3>(T.buffer()) == Catch::Approx(20.6560572032543597).margin(1.0e-16)); } @@ -129,14 +126,14 @@ TEST_CASE("OperatorBase") { auto T = mod.run_as(braket); // Check output - auto t = test::eigen_buffer<3, udouble>(T.buffer()); - REQUIRE(unwrap_mean(test::trace(t)) == + auto& t = T.buffer(); + REQUIRE(unwrap_mean(test::trace<3, udouble>(t)) == Catch::Approx(16.8245948391706577).margin(1.0e-16)); - REQUIRE(unwrap_sd(test::trace(t)) == + REQUIRE(unwrap_sd(test::trace<3, udouble>(t)) == Catch::Approx(7e-16).margin(1.0e-16)); - REQUIRE(unwrap_mean(test::norm(t)) == + REQUIRE(unwrap_mean(test::norm<3, udouble>(t)) == Catch::Approx(20.6560572032543597).margin(1.0e-16)); - REQUIRE(unwrap_sd(test::norm(t)) == + REQUIRE(unwrap_sd(test::norm<3, udouble>(t)) == Catch::Approx(7e-16).margin(1.0e-16)); } } @@ -157,10 +154,10 @@ TEST_CASE("OperatorBase") { auto T = mod.run_as(braket); // Check output - auto t = test::eigen_buffer<4>(T.buffer()); - REQUIRE(test::trace(t) == + auto& t = T.buffer(); + REQUIRE(test::trace<4>(t) == Catch::Approx(9.7919608941952063).margin(1.0e-16)); - REQUIRE(test::norm(t) == + REQUIRE(test::norm<4>(t) == Catch::Approx(7.7796143419802553).margin(1.0e-16)); } @@ -171,14 +168,14 @@ TEST_CASE("OperatorBase") { auto T = mod.run_as(braket); // Check output - auto t = test::eigen_buffer<4, udouble>(T.buffer()); - REQUIRE(unwrap_mean(test::trace(t)) == + auto& t = T.buffer(); + REQUIRE(unwrap_mean(test::trace<4, udouble>(t)) == Catch::Approx(9.7919608941952063).margin(1.0e-16)); - REQUIRE(unwrap_sd(test::trace(t)) == + REQUIRE(unwrap_sd(test::trace<4, udouble>(t)) == Catch::Approx(7e-16).margin(1.0e-16)); - REQUIRE(unwrap_mean(test::norm(t)) == + REQUIRE(unwrap_mean(test::norm<4, udouble>(t)) == Catch::Approx(7.7796143419802553).margin(1.0e-16)); - REQUIRE(unwrap_sd(test::norm(t)) == + REQUIRE(unwrap_sd(test::norm<4, udouble>(t)) == Catch::Approx(11e-16).margin(1.0e-16)); } } diff --git a/tests/cxx/unit/integrals/libint/test_eri2.cpp b/tests/cxx/unit/integrals/libint/test_eri2.cpp index 4b1e8c09..b98afd13 100644 --- a/tests/cxx/unit/integrals/libint/test_eri2.cpp +++ b/tests/cxx/unit/integrals/libint/test_eri2.cpp @@ -40,9 +40,9 @@ TEST_CASE("ERI2") { auto T = mm.at("ERI2").run_as(braket); // Check output - auto t = test::eigen_buffer<2>(T.buffer()); - REQUIRE(test::trace(t) == + auto& t = T.buffer(); + REQUIRE(test::trace<2>(t) == Catch::Approx(124.7011973877891364).margin(1.0e-16)); - REQUIRE(test::norm(t) == + REQUIRE(test::norm<2>(t) == Catch::Approx(90.2562579028763707).margin(1.0e-16)); } diff --git a/tests/cxx/unit/integrals/libint/test_eri3.cpp b/tests/cxx/unit/integrals/libint/test_eri3.cpp index ebca158e..43d4287c 100644 --- a/tests/cxx/unit/integrals/libint/test_eri3.cpp +++ b/tests/cxx/unit/integrals/libint/test_eri3.cpp @@ -41,9 +41,9 @@ TEST_CASE("ERI3") { auto T = mm.at("ERI3").run_as(braket); // Check output - auto t = test::eigen_buffer<3>(T.buffer()); - REQUIRE(test::trace(t) == + auto& t = T.buffer(); + REQUIRE(test::trace<3>(t) == Catch::Approx(16.8245948391706577).margin(1.0e-16)); - REQUIRE(test::norm(t) == + REQUIRE(test::norm<3>(t) == Catch::Approx(20.6560572032543597).margin(1.0e-16)); } diff --git a/tests/cxx/unit/integrals/libint/test_eri4.cpp b/tests/cxx/unit/integrals/libint/test_eri4.cpp index 58268edd..43881954 100644 --- a/tests/cxx/unit/integrals/libint/test_eri4.cpp +++ b/tests/cxx/unit/integrals/libint/test_eri4.cpp @@ -41,8 +41,9 @@ TEST_CASE("ERI4") { auto T = mm.at("ERI4").run_as(braket); // Check output - auto t = test::eigen_buffer<4>(T.buffer()); - REQUIRE(test::trace(t) == + auto& t = T.buffer(); + REQUIRE(test::trace<4>(t) == Catch::Approx(9.7919608941952063).margin(1.0e-16)); - REQUIRE(test::norm(t) == Catch::Approx(7.7796143419802553).margin(1.0e-16)); + REQUIRE(test::norm<4>(t) == + Catch::Approx(7.7796143419802553).margin(1.0e-16)); } diff --git a/tests/cxx/unit/integrals/libint/test_kinetic.cpp b/tests/cxx/unit/integrals/libint/test_kinetic.cpp index 9d9d93f2..9f000c25 100644 --- a/tests/cxx/unit/integrals/libint/test_kinetic.cpp +++ b/tests/cxx/unit/integrals/libint/test_kinetic.cpp @@ -40,9 +40,9 @@ TEST_CASE("Kinetic") { auto T = mm.at("Kinetic").run_as(braket); // Check output - auto t = test::eigen_buffer<2>(T.buffer()); - REQUIRE(test::trace(t) == + auto& t = T.buffer(); + REQUIRE(test::trace<2>(t) == Catch::Approx(38.9175852621874441).margin(1.0e-16)); - REQUIRE(test::norm(t) == + REQUIRE(test::norm<2>(t) == Catch::Approx(29.3665362218072552).margin(1.0e-16)); } diff --git a/tests/cxx/unit/integrals/libint/test_nuclear.cpp b/tests/cxx/unit/integrals/libint/test_nuclear.cpp index 26c34a93..9f3d8360 100644 --- a/tests/cxx/unit/integrals/libint/test_nuclear.cpp +++ b/tests/cxx/unit/integrals/libint/test_nuclear.cpp @@ -40,9 +40,9 @@ TEST_CASE("Nuclear") { auto T = mm.at("Nuclear").run_as(braket); // Check output - auto t = test::eigen_buffer<2>(T.buffer()); - REQUIRE(test::trace(t) == + auto& t = T.buffer(); + REQUIRE(test::trace<2>(t) == Catch::Approx(-111.9975421879705664).margin(1.0e-16)); - REQUIRE(test::norm(t) == + REQUIRE(test::norm<2>(t) == Catch::Approx(66.4857539908047528).margin(1.0e-16)); } diff --git a/tests/cxx/unit/integrals/libint/test_overlap.cpp b/tests/cxx/unit/integrals/libint/test_overlap.cpp index 535c9afc..bece7ae3 100644 --- a/tests/cxx/unit/integrals/libint/test_overlap.cpp +++ b/tests/cxx/unit/integrals/libint/test_overlap.cpp @@ -40,9 +40,9 @@ TEST_CASE("Overlap") { auto S = mm.at("Overlap").run_as(braket); // Check output - auto t = test::eigen_buffer<2>(S.buffer()); - REQUIRE(test::trace(t) == + auto& t = S.buffer(); + REQUIRE(test::trace<2>(t) == Catch::Approx(7.00000000000000266).margin(1.0e-16)); - REQUIRE(test::norm(t) == + REQUIRE(test::norm<2>(t) == Catch::Approx(2.87134497074907324).margin(1.0e-16)); } diff --git a/tests/cxx/unit/integrals/testing.hpp b/tests/cxx/unit/integrals/testing.hpp index 402dbc50..6f7258dd 100644 --- a/tests/cxx/unit/integrals/testing.hpp +++ b/tests/cxx/unit/integrals/testing.hpp @@ -24,25 +24,39 @@ #include #include +#include + namespace test { +template +auto eigen_tensor_(const tensorwrapper::buffer::BufferBase& buffer, + std::array extents, std::index_sequence) { + const auto& b = tensorwrapper::allocator::Eigen::rebind(buffer); + using eigen_type = Eigen::Tensor; + return Eigen::TensorMap(b.data(), extents[Is]...); +} + // Checking eigen outputs template -auto eigen_buffer(const tensorwrapper::buffer::BufferBase& buffer) { - return static_cast&>( - buffer); +auto eigen_tensor(const tensorwrapper::buffer::BufferBase& buffer) { + std::array extents; + auto shape = buffer.layout().shape().as_smooth(); + for(std::size_t i = 0; i < N; ++i) extents[i] = shape.extent(i); + return eigen_tensor_(buffer, extents, + std::make_index_sequence()); } -template -auto trace(const tensorwrapper::buffer::Eigen& t) { - Eigen::Tensor trace = t.value().trace(); +template +auto trace(const tensorwrapper::buffer::BufferBase& buffer) { + auto t = eigen_tensor(buffer); + Eigen::Tensor trace = t.trace(); return trace.coeff(); } -template -auto norm(const tensorwrapper::buffer::Eigen& t) { - Eigen::Tensor norm = - t.value().square().sum().sqrt(); +template +auto norm(const tensorwrapper::buffer::BufferBase& buffer) { + auto t = eigen_tensor(buffer); + Eigen::Tensor norm = t.square().sum().sqrt(); return norm.coeff(); }