diff --git a/src/integrals/ao_integrals/ao_integrals.cpp b/src/integrals/ao_integrals/ao_integrals.cpp index 07277c4c..833804d2 100644 --- a/src/integrals/ao_integrals/ao_integrals.cpp +++ b/src/integrals/ao_integrals/ao_integrals.cpp @@ -24,6 +24,37 @@ namespace integrals::ao_integrals { +class LibIntVisitor : public chemist::qm_operator::OperatorVisitor { +public: + using t_e_type = simde::type::t_e_type; + using v_ee_type = simde::type::v_ee_type; + using v_en_type = simde::type::v_en_type; + + LibIntVisitor(const std::vector& bases, double thresh, + std::size_t deriv = 0) : + m_bases(bases), m_thresh(thresh), m_deriv(deriv){}; + + void run(const t_e_type& T_e) { + m_engine = detail_::make_engine(m_bases, T_e, m_thresh, m_deriv); + } + + void run(const v_en_type& V_en) { + m_engine = detail_::make_engine(m_bases, V_en, m_thresh, m_deriv); + } + + void run(const v_ee_type& V_ee) { + m_engine = detail_::make_engine(m_bases, V_ee, m_thresh, m_deriv); + } + + libint2::Engine& engine() { return m_engine; } + +private: + const std::vector& m_bases; + double m_thresh; + std::size_t m_deriv; + libint2::Engine m_engine; +}; + template TEMPLATED_MODULE_CTOR(AOIntegral, BraKetType) { using my_pt = simde::EvaluateBraKet; @@ -44,7 +75,7 @@ TEMPLATED_MODULE_RUN(AOIntegral, BraKetType) { auto thresh = inputs.at("Threshold").value(); auto bra = braket.bra(); auto ket = braket.ket(); - auto op = braket.op(); + auto& op = braket.op(); // Gather information from Bra, Ket, and Op auto basis_sets = detail_::get_basis_sets(bra, ket); @@ -72,7 +103,9 @@ TEMPLATED_MODULE_RUN(AOIntegral, BraKetType) { b.value().setZero(); // Make libint engine - auto engine = detail_::make_engine(basis_sets, op, thresh); + LibIntVisitor visitor(basis_sets, thresh); + op.visit(visitor); + auto engine = visitor.engine(); const auto& buf = engine.results(); // Fill in values @@ -111,6 +144,10 @@ TEMPLATED_MODULE_RUN(AOIntegral, BraKetType) { #define EXTERN_AOI(bra, op, ket) template struct AOI(bra, op, ket) #define LOAD_AOI(bra, op, ket, key) mm.add_module(key) +EXTERN_AOI(aos, op_base_type, aos); +EXTERN_AOI(aos, op_base_type, aos_squared); +EXTERN_AOI(aos_squared, op_base_type, aos_squared); + EXTERN_AOI(aos, t_e_type, aos); EXTERN_AOI(aos, v_en_type, aos); EXTERN_AOI(aos, v_ee_type, aos); @@ -122,6 +159,10 @@ void ao_integrals_set_defaults(pluginplay::ModuleManager& mm) { } void load_ao_integrals(pluginplay::ModuleManager& mm) { + LOAD_AOI(aos, op_base_type, aos, "Evaluate 2-Index BraKet"); + LOAD_AOI(aos, op_base_type, aos_squared, "Evaluate 3-Index BraKet"); + LOAD_AOI(aos_squared, op_base_type, aos_squared, "Evaluate 4-Index BraKet"); + LOAD_AOI(aos, t_e_type, aos, "Kinetic"); LOAD_AOI(aos, v_en_type, aos, "Nuclear"); LOAD_AOI(aos, v_ee_type, aos, "ERI2"); diff --git a/src/integrals/ao_integrals/ao_integrals.hpp b/src/integrals/ao_integrals/ao_integrals.hpp index b040b1d3..7d30cb6d 100644 --- a/src/integrals/ao_integrals/ao_integrals.hpp +++ b/src/integrals/ao_integrals/ao_integrals.hpp @@ -29,6 +29,7 @@ using simde::type::braket; using simde::type::aos; using simde::type::aos_squared; +using simde::type::op_base_type; using simde::type::t_e_type; using simde::type::v_ee_type; using simde::type::v_en_type; @@ -59,6 +60,10 @@ void ao_integrals_set_defaults(pluginplay::ModuleManager& mm); // Forward External Template Declarations #define EXTERN_AOI extern template struct AOIntegral +EXTERN_AOI>; +EXTERN_AOI>; +EXTERN_AOI>; + EXTERN_AOI>; EXTERN_AOI>; EXTERN_AOI>; diff --git a/tests/cxx/unit/integrals/ao_integrals/test_arbitrary_operator.cpp b/tests/cxx/unit/integrals/ao_integrals/test_arbitrary_operator.cpp new file mode 100644 index 00000000..1b63467f --- /dev/null +++ b/tests/cxx/unit/integrals/ao_integrals/test_arbitrary_operator.cpp @@ -0,0 +1,97 @@ +/* + * Copyright 2022 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "test_ao_integrals.hpp" + +TEST_CASE("OperatorBase") { + using aos_t = simde::type::aos; + using aos_squared_t = simde::type::aos_squared; + using op_t = simde::type::v_ee_type; + using op_base_t = simde::type::op_base_type; + + pluginplay::ModuleManager mm; + integrals::load_modules(mm); + REQUIRE(mm.count("Evaluate 2-Index BraKet")); + REQUIRE(mm.count("Evaluate 3-Index BraKet")); + REQUIRE(mm.count("Evaluate 4-Index BraKet")); + + // Get basis set + auto mol = test::water_molecule(); + auto aobs = test::water_sto3g_basis_set(); + + // Make AOS object + aos_t aos(aobs); + aos_squared_t aos_squared(aos, aos); + + // Make Operator + op_t op{}; + op_base_t& op_base = op; + + SECTION("2-Index") { + using braket_t = simde::type::braket; + using test_pt = simde::EvaluateBraKet; + + // Make BraKet Input + braket_t braket(aos, op_base, aos); + + // Call module + auto T = mm.at("Evaluate 2-Index BraKet").run_as(braket); + + // Check output + auto t = test::eigen_buffer<2>(T.buffer()); + REQUIRE(test::trace(t) == + Catch::Approx(124.7011973877891364).margin(1.0e-16)); + REQUIRE(test::norm(t) == + Catch::Approx(90.2562579028763707).margin(1.0e-16)); + } + + SECTION("3-Index") { + using braket_t = simde::type::braket; + using test_pt = simde::EvaluateBraKet; + + // Make BraKet Input + braket_t braket(aos, op_base, aos_squared); + + // Call module + auto T = mm.at("Evaluate 3-Index BraKet").run_as(braket); + + // Check output + auto t = test::eigen_buffer<3>(T.buffer()); + REQUIRE(test::trace(t) == + Catch::Approx(16.8245948391706577).margin(1.0e-16)); + REQUIRE(test::norm(t) == + Catch::Approx(20.6560572032543597).margin(1.0e-16)); + } + + SECTION("4-Index") { + using braket_t = + simde::type::braket; + using test_pt = simde::EvaluateBraKet; + + // Make BraKet Input + braket_t braket(aos_squared, op_base, aos_squared); + + // Call module + auto T = mm.at("Evaluate 4-Index BraKet").run_as(braket); + + // Check output + auto t = test::eigen_buffer<4>(T.buffer()); + REQUIRE(test::trace(t) == + Catch::Approx(9.7919608941952063).margin(1.0e-16)); + REQUIRE(test::norm(t) == + Catch::Approx(7.7796143419802553).margin(1.0e-16)); + } +}