Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions src/integrals/ao_integrals/ao_integrals.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,37 @@

namespace integrals::ao_integrals {

class LibIntVisitor : public chemist::qm_operator::OperatorVisitor {
public:
using t_e_type = simde::type::t_e_type;
using v_ee_type = simde::type::v_ee_type;
using v_en_type = simde::type::v_en_type;

LibIntVisitor(const std::vector<libint2::BasisSet>& bases, double thresh,
std::size_t deriv = 0) :
m_bases(bases), m_thresh(thresh), m_deriv(deriv){};

void run(const t_e_type& T_e) {
m_engine = detail_::make_engine(m_bases, T_e, m_thresh, m_deriv);
}

void run(const v_en_type& V_en) {
m_engine = detail_::make_engine(m_bases, V_en, m_thresh, m_deriv);
}

void run(const v_ee_type& V_ee) {
m_engine = detail_::make_engine(m_bases, V_ee, m_thresh, m_deriv);
}

libint2::Engine& engine() { return m_engine; }

private:
const std::vector<libint2::BasisSet>& m_bases;
double m_thresh;
std::size_t m_deriv;
libint2::Engine m_engine;
};

template<typename BraKetType>
TEMPLATED_MODULE_CTOR(AOIntegral, BraKetType) {
using my_pt = simde::EvaluateBraKet<BraKetType>;
Expand All @@ -44,7 +75,7 @@ TEMPLATED_MODULE_RUN(AOIntegral, BraKetType) {
auto thresh = inputs.at("Threshold").value<double>();
auto bra = braket.bra();
auto ket = braket.ket();
auto op = braket.op();
auto& op = braket.op();

// Gather information from Bra, Ket, and Op
auto basis_sets = detail_::get_basis_sets(bra, ket);
Expand Down Expand Up @@ -72,7 +103,9 @@ TEMPLATED_MODULE_RUN(AOIntegral, BraKetType) {
b.value().setZero();

// Make libint engine
auto engine = detail_::make_engine(basis_sets, op, thresh);
LibIntVisitor visitor(basis_sets, thresh);
op.visit(visitor);
auto engine = visitor.engine();
const auto& buf = engine.results();

// Fill in values
Expand Down Expand Up @@ -111,6 +144,10 @@ TEMPLATED_MODULE_RUN(AOIntegral, BraKetType) {
#define EXTERN_AOI(bra, op, ket) template struct AOI(bra, op, ket)
#define LOAD_AOI(bra, op, ket, key) mm.add_module<AOI(bra, op, ket)>(key)

EXTERN_AOI(aos, op_base_type, aos);
EXTERN_AOI(aos, op_base_type, aos_squared);
EXTERN_AOI(aos_squared, op_base_type, aos_squared);

EXTERN_AOI(aos, t_e_type, aos);
EXTERN_AOI(aos, v_en_type, aos);
EXTERN_AOI(aos, v_ee_type, aos);
Expand All @@ -122,6 +159,10 @@ void ao_integrals_set_defaults(pluginplay::ModuleManager& mm) {
}

void load_ao_integrals(pluginplay::ModuleManager& mm) {
LOAD_AOI(aos, op_base_type, aos, "Evaluate 2-Index BraKet");
LOAD_AOI(aos, op_base_type, aos_squared, "Evaluate 3-Index BraKet");
LOAD_AOI(aos_squared, op_base_type, aos_squared, "Evaluate 4-Index BraKet");

LOAD_AOI(aos, t_e_type, aos, "Kinetic");
LOAD_AOI(aos, v_en_type, aos, "Nuclear");
LOAD_AOI(aos, v_ee_type, aos, "ERI2");
Expand Down
5 changes: 5 additions & 0 deletions src/integrals/ao_integrals/ao_integrals.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ using simde::type::braket;
using simde::type::aos;
using simde::type::aos_squared;

using simde::type::op_base_type;
using simde::type::t_e_type;
using simde::type::v_ee_type;
using simde::type::v_en_type;
Expand Down Expand Up @@ -59,6 +60,10 @@ void ao_integrals_set_defaults(pluginplay::ModuleManager& mm);
// Forward External Template Declarations
#define EXTERN_AOI extern template struct AOIntegral

EXTERN_AOI<braket<aos, op_base_type, aos>>;
EXTERN_AOI<braket<aos, op_base_type, aos_squared>>;
EXTERN_AOI<braket<aos_squared, op_base_type, aos_squared>>;

EXTERN_AOI<braket<aos, t_e_type, aos>>;
EXTERN_AOI<braket<aos, v_en_type, aos>>;
EXTERN_AOI<braket<aos, v_ee_type, aos>>;
Expand Down
97 changes: 97 additions & 0 deletions tests/cxx/unit/integrals/ao_integrals/test_arbitrary_operator.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Copyright 2022 NWChemEx-Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "test_ao_integrals.hpp"

TEST_CASE("OperatorBase") {
using aos_t = simde::type::aos;
using aos_squared_t = simde::type::aos_squared;
using op_t = simde::type::v_ee_type;
using op_base_t = simde::type::op_base_type;

pluginplay::ModuleManager mm;
integrals::load_modules(mm);
REQUIRE(mm.count("Evaluate 2-Index BraKet"));
REQUIRE(mm.count("Evaluate 3-Index BraKet"));
REQUIRE(mm.count("Evaluate 4-Index BraKet"));

// Get basis set
auto mol = test::water_molecule();
auto aobs = test::water_sto3g_basis_set();

// Make AOS object
aos_t aos(aobs);
aos_squared_t aos_squared(aos, aos);

// Make Operator
op_t op{};
op_base_t& op_base = op;

SECTION("2-Index") {
using braket_t = simde::type::braket<aos_t, op_base_t, aos_t>;
using test_pt = simde::EvaluateBraKet<braket_t>;

// Make BraKet Input
braket_t braket(aos, op_base, aos);

// Call module
auto T = mm.at("Evaluate 2-Index BraKet").run_as<test_pt>(braket);

// Check output
auto t = test::eigen_buffer<2>(T.buffer());
REQUIRE(test::trace(t) ==
Catch::Approx(124.7011973877891364).margin(1.0e-16));
REQUIRE(test::norm(t) ==
Catch::Approx(90.2562579028763707).margin(1.0e-16));
}

SECTION("3-Index") {
using braket_t = simde::type::braket<aos_t, op_base_t, aos_squared_t>;
using test_pt = simde::EvaluateBraKet<braket_t>;

// Make BraKet Input
braket_t braket(aos, op_base, aos_squared);

// Call module
auto T = mm.at("Evaluate 3-Index BraKet").run_as<test_pt>(braket);

// Check output
auto t = test::eigen_buffer<3>(T.buffer());
REQUIRE(test::trace(t) ==
Catch::Approx(16.8245948391706577).margin(1.0e-16));
REQUIRE(test::norm(t) ==
Catch::Approx(20.6560572032543597).margin(1.0e-16));
}

SECTION("4-Index") {
using braket_t =
simde::type::braket<aos_squared_t, op_base_t, aos_squared_t>;
using test_pt = simde::EvaluateBraKet<braket_t>;

// Make BraKet Input
braket_t braket(aos_squared, op_base, aos_squared);

// Call module
auto T = mm.at("Evaluate 4-Index BraKet").run_as<test_pt>(braket);

// Check output
auto t = test::eigen_buffer<4>(T.buffer());
REQUIRE(test::trace(t) ==
Catch::Approx(9.7919608941952063).margin(1.0e-16));
REQUIRE(test::norm(t) ==
Catch::Approx(7.7796143419802553).margin(1.0e-16));
}
}
Loading