diff --git a/src/tensorwrapper/buffer/einsum_planner.hpp b/src/tensorwrapper/buffer/einsum_planner.hpp new file mode 100644 index 00000000..6204ec95 --- /dev/null +++ b/src/tensorwrapper/buffer/einsum_planner.hpp @@ -0,0 +1,100 @@ +/* + * Copyright 2025 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +namespace tensorwrapper::buffer { + +/** @brief Works out the details pertaining to an arbitrary binary einsum op. + * + * For a general einsum operation the indices in a label fall into one of four + * categories: + * + * - trace indices: appear in only one of the input tensors, but not the output + * - dummy indices: those that appear in both input tensors, but not the output + * - free indices: appear in result and ONE of the input tensors + * - batch indices: appear in all three tensors + * + * N.b., though the set of indices in say lhs_batch and rhs_batch must be the + * same, the order can be different. This applies to dummy indices too. + */ +class EinsumPlanner { +public: + using string_type = std::string; + + using label_type = dsl::DummyIndices; + + EinsumPlanner(std::string result, std::string lhs, std::string rhs) : + EinsumPlanner(label_type(result), label_type(lhs), label_type(rhs)) {} + + EinsumPlanner(label_type result, label_type lhs, label_type rhs) : + m_result_(std::move(result)), + m_lhs_(std::move(lhs)), + m_rhs_(std::move(rhs)) {} + + // Labels that ONLY appear in LHS + label_type lhs_trace() const { + return m_lhs_.difference(m_rhs_).difference(m_result_); + } + + /// Labels that ONLY appear in RHS + label_type rhs_trace() const { + return m_rhs_.difference(m_lhs_).difference(m_result_); + } + + /// Labels that appear in both LHS and RHS, but NOT in result + label_type lhs_dummy() const { + return m_lhs_.intersection(m_rhs_).difference(m_result_); + } + + /// Labels that appear in both LHS and RHS, but NOT in result + label_type rhs_dummy() const { + return m_rhs_.intersection(m_lhs_).difference(m_result_); + } + + /// Labels that appear in result and LHS, but NOT in RHS + label_type lhs_free() const { + return m_lhs_.intersection(m_result_).difference(m_rhs_); + } + + /// Labels that appear in result and RHS, but NOT in LHS + label_type rhs_free() const { + return m_rhs_.intersection(m_result_).difference(m_lhs_); + } + + /// Labels that appear in all three tensors + label_type result_batch() const { + return m_result_.intersection(m_lhs_).intersection(m_rhs_); + } + + /// Labels that appear in all three tensors + label_type lhs_batch() const { + return m_lhs_.intersection(m_result_).intersection(m_rhs_); + } + + /// Labels that appear in all three tensors + label_type rhs_batch() const { + return m_rhs_.intersection(m_result_).intersection(m_lhs_); + } + +private: + label_type m_result_; + label_type m_lhs_; + label_type m_rhs_; +}; + +} // namespace tensorwrapper::buffer diff --git a/src/tensorwrapper/shape/shape_from_labels.hpp b/src/tensorwrapper/shape/shape_from_labels.hpp new file mode 100644 index 00000000..f4e2dcd0 --- /dev/null +++ b/src/tensorwrapper/shape/shape_from_labels.hpp @@ -0,0 +1,105 @@ +/* + * Copyright 2025 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include + +namespace tensorwrapper::shape { +namespace { + +/** @brief Recursively searches @p args for @p label. + * + * @tparam Args the types of the labeled shapes to search. + * + * To deal with an unknown number of labeled shapes we use recursion to loop + * over the list. Each invocation of `recurse_for_extent_` checks if @p label + * is found in @p shape. If it is, the extent is returned. If not, the + * parameter pack is unpacked into a new invocation of `recurse_for_extent_` + * and the process repeats. + * + * @note This function short-circuits as soon as @p label is found and does not + * ensure that all shapes agree on the extend for @p label. + * + * @param[in] label The label whose extent we are searching for. + * @param[in] shape The labeled shape to search at this recursion depth. + * @param[in] args The remaining labeled shapes to search if @p label is not + * found in @p shape. + * + * @return The extent associated with @p label. + * + * @throws std::runtime_error if @p label is not found in @p shape or any of + * the objects in @p args. Strong throw guarantee. + */ +template +auto recurse_for_extent_(const std::string& label, + dsl::Labeled shape, Args&&... args) { + auto idx = shape.labels().find(label); + if(idx.empty()) { + if constexpr(sizeof...(args) > 0) { + return recurse_for_extent_(label, std::forward(args)...); + } else { + throw std::runtime_error("Label " + label + + " not found in any provided shapes"); + } + } else { + return shape.object().as_smooth().extent(idx[0]); + } +} + +} // namespace + +/** @brief Given a series of dummy indices and labeled shapes, works out the + * shape of the tensor described by the dummy indices. + * + * @tparam StringType The string type used to represent the labels. Assumed to + * be a type like std::string. + * @tparam Args The types of the labeled shapes provided. + * + * This function wraps the process of working out the shape associated with a + * list of dummy indices. To do this, the function loops over each dummy + * index in @p labels and searches the labeled shapes in @p labeled_shapes for + * the dummy index. When the dummy index is found, the extent associated with + * the dummy index is recorded. If the dummy index is not found in any of the + * labeled shapes, an exception is thrown. + * + * @param[in] labels The dummy indices describing the tensor whose shape is to + * be determined. + * @param[in] labeled_shapes The labeled shapes to search for the dummy indices + * in. + * + * @return A Smooth shape describing the shape of the tensor with dummy indices + * @p labels. + * + * @throw std::runtime_error if any of the labels in @p labels are not found + * in @p labeled_shapes. Strong throw guarantee. + */ +template +shape::Smooth shape_from_labels(const dsl::DummyIndices& labels, + Args&&... labeled_shapes) { + static_assert(sizeof...(Args) > 0, + "Must provide at least one labeled shape"); + + std::vector extents; + for(const auto& label : labels) { + extents.push_back(recurse_for_extent_(label, labeled_shapes...)); + } + + return shape::Smooth(extents.begin(), extents.end()); +} + +} // namespace tensorwrapper::shape diff --git a/tests/cxx/unit_tests/tensorwrapper/buffer/einsum_planner.cpp b/tests/cxx/unit_tests/tensorwrapper/buffer/einsum_planner.cpp new file mode 100644 index 00000000..af7beb9c --- /dev/null +++ b/tests/cxx/unit_tests/tensorwrapper/buffer/einsum_planner.cpp @@ -0,0 +1,553 @@ +/* + * Copyright 2025 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../testing/testing.hpp" +#include + +using namespace tensorwrapper; +using namespace buffer; + +/* + * Let "t" stand for a set of trace indices, "f" for a set of free indices, + * "d" for a set of dummy indices, and "b" for a set of batch indices. Then any + * given label can be described as a combination of these four categories. In + * the event that a label is empty we label it "s" for scalar. + * + * For the tensor operation A = B * C the possible categorization of the labels + * for A, B, and C can respectively be: + * - s s s + * - s s t + * - s t s + * - s t t + * - s d d + * - s d dt + * - s dt d + * - s dt dt + * - f t f + * - f f t + * - f f f + * - f t ft + * - f ft t + * - f ft ft + * - f d df + * - f df d + * - f df df + * - f dt df + * - f df dt + * - f df df + * - f dt dft + * - f dft dt + * - f dft dft + * - bf bt bf + * - bf bf bt + * - bf bf bf + * - bf bd bdf + * - bf bdf bd + * - bf bdf bdf + * - bf bt bft + * - bf bft bt + * - bf bft bft + * - bf bdt bdf + * - bf bdf bdt + * - bf bdf bdf + * - bf bdt bdft + * - bf bdft bdt + * - bf bdft bdft + * + * (these enumerations ignore permuting the categories within a label) + * + * The following are NOT possible: + * + * - labels that are scalar and something else (e.g., trace). Scalar is by + * definition the lack of the four index categories. + * - trace in the result (would require result to have a mode that is + * independent of the inputs) + * - dummy in the result (dummy can only appear in the inputs) + * - dummy in only one of the inputs + * - free indices when the result is a scalar + * - batch indices when the result is a scalar + * + */ + +TEST_CASE("EinsumPlanner") { + SECTION("Result in scalars") { + SECTION("s s s") { + EinsumPlanner ep___("", "", ""); + REQUIRE(ep___.lhs_trace() == ""); + REQUIRE(ep___.rhs_trace() == ""); + REQUIRE(ep___.lhs_dummy() == ""); + REQUIRE(ep___.rhs_dummy() == ""); + REQUIRE(ep___.lhs_free() == ""); + REQUIRE(ep___.rhs_free() == ""); + REQUIRE(ep___.result_batch() == ""); + REQUIRE(ep___.lhs_batch() == ""); + REQUIRE(ep___.rhs_batch() == ""); + } + SECTION("s s t") { + EinsumPlanner ep___kl("", "", "k,l"); + REQUIRE(ep___kl.lhs_trace() == ""); + REQUIRE(ep___kl.rhs_trace() == "k,l"); + REQUIRE(ep___kl.lhs_dummy() == ""); + REQUIRE(ep___kl.rhs_dummy() == ""); + REQUIRE(ep___kl.lhs_free() == ""); + REQUIRE(ep___kl.rhs_free() == ""); + REQUIRE(ep___kl.result_batch() == ""); + REQUIRE(ep___kl.lhs_batch() == ""); + REQUIRE(ep___kl.rhs_batch() == ""); + } + SECTION("s t s") { + EinsumPlanner ep__ij_("", "i,j", ""); + REQUIRE(ep__ij_.lhs_trace() == "i,j"); + REQUIRE(ep__ij_.rhs_trace() == ""); + REQUIRE(ep__ij_.lhs_dummy() == ""); + REQUIRE(ep__ij_.rhs_dummy() == ""); + REQUIRE(ep__ij_.lhs_free() == ""); + REQUIRE(ep__ij_.rhs_free() == ""); + REQUIRE(ep__ij_.result_batch() == ""); + REQUIRE(ep__ij_.lhs_batch() == ""); + REQUIRE(ep__ij_.rhs_batch() == ""); + } + SECTION("s t t") { + EinsumPlanner ep__ij_klm("", "i,j", "k,l,m"); + REQUIRE(ep__ij_klm.lhs_trace() == "i,j"); + REQUIRE(ep__ij_klm.rhs_trace() == "k,l,m"); + REQUIRE(ep__ij_klm.lhs_dummy() == ""); + REQUIRE(ep__ij_klm.rhs_dummy() == ""); + REQUIRE(ep__ij_klm.lhs_free() == ""); + REQUIRE(ep__ij_klm.rhs_free() == ""); + REQUIRE(ep__ij_klm.result_batch() == ""); + REQUIRE(ep__ij_klm.lhs_batch() == ""); + REQUIRE(ep__ij_klm.rhs_batch() == ""); + } + SECTION("s d d") { + EinsumPlanner ep__ij_ji("", "i,j", "j,i"); + REQUIRE(ep__ij_ji.lhs_trace() == ""); + REQUIRE(ep__ij_ji.rhs_trace() == ""); + REQUIRE(ep__ij_ji.lhs_dummy() == "i,j"); + REQUIRE(ep__ij_ji.rhs_dummy() == "j,i"); + REQUIRE(ep__ij_ji.lhs_free() == ""); + REQUIRE(ep__ij_ji.rhs_free() == ""); + REQUIRE(ep__ij_ji.result_batch() == ""); + REQUIRE(ep__ij_ji.lhs_batch() == ""); + REQUIRE(ep__ij_ji.rhs_batch() == ""); + } + SECTION("s d dt") { + EinsumPlanner ep__ij_jik("", "i,j", "j,i,k"); + REQUIRE(ep__ij_jik.lhs_trace() == ""); + REQUIRE(ep__ij_jik.rhs_trace() == "k"); + REQUIRE(ep__ij_jik.lhs_dummy() == "i,j"); + REQUIRE(ep__ij_jik.rhs_dummy() == "j,i"); + REQUIRE(ep__ij_jik.lhs_free() == ""); + REQUIRE(ep__ij_jik.rhs_free() == ""); + REQUIRE(ep__ij_jik.result_batch() == ""); + REQUIRE(ep__ij_jik.lhs_batch() == ""); + REQUIRE(ep__ij_jik.rhs_batch() == ""); + } + SECTION("s dt d") { + EinsumPlanner ep__jik_ik("", "j,i,k", "i,k"); + REQUIRE(ep__jik_ik.lhs_trace() == "j"); + REQUIRE(ep__jik_ik.rhs_trace() == ""); + REQUIRE(ep__jik_ik.lhs_dummy() == "i,k"); + REQUIRE(ep__jik_ik.rhs_dummy() == "i,k"); + REQUIRE(ep__jik_ik.lhs_free() == ""); + REQUIRE(ep__jik_ik.rhs_free() == ""); + REQUIRE(ep__jik_ik.result_batch() == ""); + REQUIRE(ep__jik_ik.lhs_batch() == ""); + REQUIRE(ep__jik_ik.rhs_batch() == ""); + } + SECTION("s dt dt") { + EinsumPlanner ep__jik_ikm("", "j,i,k", "i,k,m"); + REQUIRE(ep__jik_ikm.lhs_trace() == "j"); + REQUIRE(ep__jik_ikm.rhs_trace() == "m"); + REQUIRE(ep__jik_ikm.lhs_dummy() == "i,k"); + REQUIRE(ep__jik_ikm.rhs_dummy() == "i,k"); + REQUIRE(ep__jik_ikm.lhs_free() == ""); + REQUIRE(ep__jik_ikm.rhs_free() == ""); + REQUIRE(ep__jik_ikm.result_batch() == ""); + REQUIRE(ep__jik_ikm.lhs_batch() == ""); + REQUIRE(ep__jik_ikm.rhs_batch() == ""); + } + } + + SECTION("Result in free indices") { + SECTION("f t f") { + EinsumPlanner ep_ik_jl_ik("i,k", "j,l", "i,k"); + REQUIRE(ep_ik_jl_ik.lhs_trace() == "j,l"); + REQUIRE(ep_ik_jl_ik.rhs_trace() == ""); + REQUIRE(ep_ik_jl_ik.lhs_dummy() == ""); + REQUIRE(ep_ik_jl_ik.rhs_dummy() == ""); + REQUIRE(ep_ik_jl_ik.lhs_free() == ""); + REQUIRE(ep_ik_jl_ik.rhs_free() == "i,k"); + REQUIRE(ep_ik_jl_ik.result_batch() == ""); + REQUIRE(ep_ik_jl_ik.lhs_batch() == ""); + REQUIRE(ep_ik_jl_ik.rhs_batch() == ""); + } + SECTION("f f t") { + EinsumPlanner ep_ij_ji_kl("i,j", "j,i", "k,l"); + REQUIRE(ep_ij_ji_kl.lhs_trace() == ""); + REQUIRE(ep_ij_ji_kl.rhs_trace() == "k,l"); + REQUIRE(ep_ij_ji_kl.lhs_dummy() == ""); + REQUIRE(ep_ij_ji_kl.rhs_dummy() == ""); + REQUIRE(ep_ij_ji_kl.lhs_free() == "j,i"); + REQUIRE(ep_ij_ji_kl.rhs_free() == ""); + REQUIRE(ep_ij_ji_kl.result_batch() == ""); + REQUIRE(ep_ij_ji_kl.lhs_batch() == ""); + REQUIRE(ep_ij_ji_kl.rhs_batch() == ""); + } + SECTION("f f f") { + EinsumPlanner ep_ijkl_kl_ji("i,j,k,l", "k,l", "j,i"); + REQUIRE(ep_ijkl_kl_ji.lhs_trace() == ""); + REQUIRE(ep_ijkl_kl_ji.rhs_trace() == ""); + REQUIRE(ep_ijkl_kl_ji.lhs_dummy() == ""); + REQUIRE(ep_ijkl_kl_ji.rhs_dummy() == ""); + REQUIRE(ep_ijkl_kl_ji.lhs_free() == "k,l"); + REQUIRE(ep_ijkl_kl_ji.rhs_free() == "j,i"); + REQUIRE(ep_ijkl_kl_ji.result_batch() == ""); + REQUIRE(ep_ijkl_kl_ji.lhs_batch() == ""); + REQUIRE(ep_ijkl_kl_ji.rhs_batch() == ""); + } + SECTION("f t ft") { + EinsumPlanner ep_ik_jl_kmi("i,k", "j,l", "k,m,i"); + REQUIRE(ep_ik_jl_kmi.lhs_trace() == "j,l"); + REQUIRE(ep_ik_jl_kmi.rhs_trace() == "m"); + REQUIRE(ep_ik_jl_kmi.lhs_dummy() == ""); + REQUIRE(ep_ik_jl_kmi.rhs_dummy() == ""); + REQUIRE(ep_ik_jl_kmi.lhs_free() == ""); + REQUIRE(ep_ik_jl_kmi.rhs_free() == "k,i"); + REQUIRE(ep_ik_jl_kmi.result_batch() == ""); + REQUIRE(ep_ik_jl_kmi.lhs_batch() == ""); + REQUIRE(ep_ik_jl_kmi.rhs_batch() == ""); + } + SECTION("f ft t") { + EinsumPlanner ep_jl_ljm_i("j,l", "l,j,m", "i"); + REQUIRE(ep_jl_ljm_i.lhs_trace() == "m"); + REQUIRE(ep_jl_ljm_i.rhs_trace() == "i"); + REQUIRE(ep_jl_ljm_i.lhs_dummy() == ""); + REQUIRE(ep_jl_ljm_i.rhs_dummy() == ""); + REQUIRE(ep_jl_ljm_i.lhs_free() == "l,j"); + REQUIRE(ep_jl_ljm_i.rhs_free() == ""); + REQUIRE(ep_jl_ljm_i.result_batch() == ""); + REQUIRE(ep_jl_ljm_i.lhs_batch() == ""); + REQUIRE(ep_jl_ljm_i.rhs_batch() == ""); + } + SECTION("f ft ft") { + EinsumPlanner ep_ik_kl_im("i,k", "k,l", "i,m"); + REQUIRE(ep_ik_kl_im.lhs_trace() == "l"); + REQUIRE(ep_ik_kl_im.rhs_trace() == "m"); + REQUIRE(ep_ik_kl_im.lhs_dummy() == ""); + REQUIRE(ep_ik_kl_im.rhs_dummy() == ""); + REQUIRE(ep_ik_kl_im.lhs_free() == "k"); + REQUIRE(ep_ik_kl_im.rhs_free() == "i"); + REQUIRE(ep_ik_kl_im.result_batch() == ""); + REQUIRE(ep_ik_kl_im.lhs_batch() == ""); + REQUIRE(ep_ik_kl_im.rhs_batch() == ""); + } + SECTION("f d df") { + EinsumPlanner ep_i_kj_jki("i", "k,j", "j,k,i"); + REQUIRE(ep_i_kj_jki.lhs_trace() == ""); + REQUIRE(ep_i_kj_jki.rhs_trace() == ""); + REQUIRE(ep_i_kj_jki.lhs_dummy() == "k,j"); + REQUIRE(ep_i_kj_jki.rhs_dummy() == "j,k"); + REQUIRE(ep_i_kj_jki.lhs_free() == ""); + REQUIRE(ep_i_kj_jki.rhs_free() == "i"); + REQUIRE(ep_i_kj_jki.result_batch() == ""); + REQUIRE(ep_i_kj_jki.lhs_batch() == ""); + REQUIRE(ep_i_kj_jki.rhs_batch() == ""); + } + SECTION("f df d") { + EinsumPlanner ep_ij_jikl_kl("i,j", "j,i,k,l", "k,l"); + REQUIRE(ep_ij_jikl_kl.lhs_trace() == ""); + REQUIRE(ep_ij_jikl_kl.rhs_trace() == ""); + REQUIRE(ep_ij_jikl_kl.lhs_dummy() == "k,l"); + REQUIRE(ep_ij_jikl_kl.rhs_dummy() == "k,l"); + REQUIRE(ep_ij_jikl_kl.lhs_free() == "j,i"); + REQUIRE(ep_ij_jikl_kl.rhs_free() == ""); + REQUIRE(ep_ij_jikl_kl.result_batch() == ""); + REQUIRE(ep_ij_jikl_kl.lhs_batch() == ""); + REQUIRE(ep_ij_jikl_kl.rhs_batch() == ""); + } + SECTION("f df df") { + EinsumPlanner ep_jm_im_ij("j,m", "i,m", "i,j"); + REQUIRE(ep_jm_im_ij.lhs_trace() == ""); + REQUIRE(ep_jm_im_ij.rhs_trace() == ""); + REQUIRE(ep_jm_im_ij.lhs_dummy() == "i"); + REQUIRE(ep_jm_im_ij.rhs_dummy() == "i"); + REQUIRE(ep_jm_im_ij.lhs_free() == "m"); + REQUIRE(ep_jm_im_ij.rhs_free() == "j"); + REQUIRE(ep_jm_im_ij.result_batch() == ""); + REQUIRE(ep_jm_im_ij.lhs_batch() == ""); + REQUIRE(ep_jm_im_ij.rhs_batch() == ""); + } + SECTION("f dt df") { + EinsumPlanner ep_lm_ij_iml("l,m", "i,j", "i,m,l"); + REQUIRE(ep_lm_ij_iml.lhs_trace() == "j"); + REQUIRE(ep_lm_ij_iml.rhs_trace() == ""); + REQUIRE(ep_lm_ij_iml.lhs_dummy() == "i"); + REQUIRE(ep_lm_ij_iml.rhs_dummy() == "i"); + REQUIRE(ep_lm_ij_iml.lhs_free() == ""); + REQUIRE(ep_lm_ij_iml.rhs_free() == "m,l"); + REQUIRE(ep_lm_ij_iml.result_batch() == ""); + REQUIRE(ep_lm_ij_iml.lhs_batch() == ""); + REQUIRE(ep_lm_ij_iml.rhs_batch() == ""); + } + SECTION("f df dt") { + EinsumPlanner ep_i_ij_jk("i", "i,j", "j,k"); + REQUIRE(ep_i_ij_jk.lhs_trace() == ""); + REQUIRE(ep_i_ij_jk.rhs_trace() == "k"); + REQUIRE(ep_i_ij_jk.lhs_dummy() == "j"); + REQUIRE(ep_i_ij_jk.rhs_dummy() == "j"); + REQUIRE(ep_i_ij_jk.lhs_free() == "i"); + REQUIRE(ep_i_ij_jk.rhs_free() == ""); + REQUIRE(ep_i_ij_jk.result_batch() == ""); + REQUIRE(ep_i_ij_jk.lhs_batch() == ""); + REQUIRE(ep_i_ij_jk.rhs_batch() == ""); + } + SECTION("f df df") { + EinsumPlanner ep_ijk_klm_jlmi("i,j,k", "k,l,m", "j,l,m,i"); + REQUIRE(ep_ijk_klm_jlmi.lhs_trace() == ""); + REQUIRE(ep_ijk_klm_jlmi.rhs_trace() == ""); + REQUIRE(ep_ijk_klm_jlmi.lhs_dummy() == "l,m"); + REQUIRE(ep_ijk_klm_jlmi.rhs_dummy() == "l,m"); + REQUIRE(ep_ijk_klm_jlmi.lhs_free() == "k"); + REQUIRE(ep_ijk_klm_jlmi.rhs_free() == "j,i"); + REQUIRE(ep_ijk_klm_jlmi.result_batch() == ""); + REQUIRE(ep_ijk_klm_jlmi.lhs_batch() == ""); + REQUIRE(ep_ijk_klm_jlmi.rhs_batch() == ""); + } + SECTION("f dt dft") { + EinsumPlanner ep_il_jm_jlis("i,l", "j,m", "j,l,i,s"); + REQUIRE(ep_il_jm_jlis.lhs_trace() == "m"); + REQUIRE(ep_il_jm_jlis.rhs_trace() == "s"); + REQUIRE(ep_il_jm_jlis.lhs_dummy() == "j"); + REQUIRE(ep_il_jm_jlis.rhs_dummy() == "j"); + REQUIRE(ep_il_jm_jlis.lhs_free() == ""); + REQUIRE(ep_il_jm_jlis.rhs_free() == "l,i"); + REQUIRE(ep_il_jm_jlis.result_batch() == ""); + REQUIRE(ep_il_jm_jlis.lhs_batch() == ""); + REQUIRE(ep_il_jm_jlis.rhs_batch() == ""); + } + SECTION("f dft dt") { + EinsumPlanner ep_i_jikm_kjn("i", "j,i,k,m", "k,j,n"); + REQUIRE(ep_i_jikm_kjn.lhs_trace() == "m"); + REQUIRE(ep_i_jikm_kjn.rhs_trace() == "n"); + REQUIRE(ep_i_jikm_kjn.lhs_dummy() == "j,k"); + REQUIRE(ep_i_jikm_kjn.rhs_dummy() == "k,j"); + REQUIRE(ep_i_jikm_kjn.lhs_free() == "i"); + REQUIRE(ep_i_jikm_kjn.rhs_free() == ""); + REQUIRE(ep_i_jikm_kjn.result_batch() == ""); + REQUIRE(ep_i_jikm_kjn.lhs_batch() == ""); + REQUIRE(ep_i_jikm_kjn.rhs_batch() == ""); + } + SECTION("f dft dft") { + EinsumPlanner ep_ijk_nilsk_sammjl("i,j,k", "n,i,l,s,k", + "s,a,m,m,j,l"); + REQUIRE(ep_ijk_nilsk_sammjl.lhs_trace() == "n"); + REQUIRE(ep_ijk_nilsk_sammjl.rhs_trace() == "a,m"); + REQUIRE(ep_ijk_nilsk_sammjl.lhs_dummy() == "l,s"); + REQUIRE(ep_ijk_nilsk_sammjl.rhs_dummy() == "s,l"); + REQUIRE(ep_ijk_nilsk_sammjl.lhs_free() == "i,k"); + REQUIRE(ep_ijk_nilsk_sammjl.rhs_free() == "j"); + REQUIRE(ep_ijk_nilsk_sammjl.result_batch() == ""); + REQUIRE(ep_ijk_nilsk_sammjl.lhs_batch() == ""); + REQUIRE(ep_ijk_nilsk_sammjl.rhs_batch() == ""); + } + } + + SECTION("Result in batched free indices") { + SECTION("bf bt bf") { + EinsumPlanner ep_ibk_bjl_bik("i,b,k", "b,j,l", "b,i,k"); + REQUIRE(ep_ibk_bjl_bik.lhs_trace() == "j,l"); + REQUIRE(ep_ibk_bjl_bik.rhs_trace() == ""); + REQUIRE(ep_ibk_bjl_bik.lhs_dummy() == ""); + REQUIRE(ep_ibk_bjl_bik.rhs_dummy() == ""); + REQUIRE(ep_ibk_bjl_bik.lhs_free() == ""); + REQUIRE(ep_ibk_bjl_bik.rhs_free() == "i,k"); + REQUIRE(ep_ibk_bjl_bik.result_batch() == "b"); + REQUIRE(ep_ibk_bjl_bik.lhs_batch() == "b"); + REQUIRE(ep_ibk_bjl_bik.rhs_batch() == "b"); + } + SECTION("bf bf bt") { + EinsumPlanner ep_bij_jib_kbl("b,i,j", "j,i,b", "k,b,l"); + REQUIRE(ep_bij_jib_kbl.lhs_trace() == ""); + REQUIRE(ep_bij_jib_kbl.rhs_trace() == "k,l"); + REQUIRE(ep_bij_jib_kbl.lhs_dummy() == ""); + REQUIRE(ep_bij_jib_kbl.rhs_dummy() == ""); + REQUIRE(ep_bij_jib_kbl.lhs_free() == "j,i"); + REQUIRE(ep_bij_jib_kbl.rhs_free() == ""); + REQUIRE(ep_bij_jib_kbl.result_batch() == "b"); + REQUIRE(ep_bij_jib_kbl.lhs_batch() == "b"); + REQUIRE(ep_bij_jib_kbl.rhs_batch() == "b"); + } + SECTION("bf bf bf") { + EinsumPlanner ep_iajkbl_kbla_ajbi("i,a,j,k,b,l", "k,b,l,a", + "a,j,b,i"); + REQUIRE(ep_iajkbl_kbla_ajbi.lhs_trace() == ""); + REQUIRE(ep_iajkbl_kbla_ajbi.rhs_trace() == ""); + REQUIRE(ep_iajkbl_kbla_ajbi.lhs_dummy() == ""); + REQUIRE(ep_iajkbl_kbla_ajbi.rhs_dummy() == ""); + REQUIRE(ep_iajkbl_kbla_ajbi.lhs_free() == "k,l"); + REQUIRE(ep_iajkbl_kbla_ajbi.rhs_free() == "j,i"); + REQUIRE(ep_iajkbl_kbla_ajbi.result_batch() == "a,b"); + REQUIRE(ep_iajkbl_kbla_ajbi.lhs_batch() == "b,a"); + REQUIRE(ep_iajkbl_kbla_ajbi.rhs_batch() == "a,b"); + } + SECTION("bf bt bft") { + EinsumPlanner ep_ibk_jbl_kbmi("i,b,k", "j,b,l", "k,b,m,i"); + REQUIRE(ep_ibk_jbl_kbmi.lhs_trace() == "j,l"); + REQUIRE(ep_ibk_jbl_kbmi.rhs_trace() == "m"); + REQUIRE(ep_ibk_jbl_kbmi.lhs_dummy() == ""); + REQUIRE(ep_ibk_jbl_kbmi.rhs_dummy() == ""); + REQUIRE(ep_ibk_jbl_kbmi.lhs_free() == ""); + REQUIRE(ep_ibk_jbl_kbmi.rhs_free() == "k,i"); + REQUIRE(ep_ibk_jbl_kbmi.result_batch() == "b"); + REQUIRE(ep_ibk_jbl_kbmi.lhs_batch() == "b"); + REQUIRE(ep_ibk_jbl_kbmi.rhs_batch() == "b"); + } + SECTION("bf bft bt") { + EinsumPlanner ep_jlb_ljmb_ib("j,l,b", "l,j,m,b", "i,b"); + REQUIRE(ep_jlb_ljmb_ib.lhs_trace() == "m"); + REQUIRE(ep_jlb_ljmb_ib.rhs_trace() == "i"); + REQUIRE(ep_jlb_ljmb_ib.lhs_dummy() == ""); + REQUIRE(ep_jlb_ljmb_ib.rhs_dummy() == ""); + REQUIRE(ep_jlb_ljmb_ib.lhs_free() == "l,j"); + REQUIRE(ep_jlb_ljmb_ib.rhs_free() == ""); + REQUIRE(ep_jlb_ljmb_ib.result_batch() == "b"); + REQUIRE(ep_jlb_ljmb_ib.lhs_batch() == "b"); + REQUIRE(ep_jlb_ljmb_ib.rhs_batch() == "b"); + } + SECTION("bf bft bft") { + EinsumPlanner ep_ibk_bkl_bim("i,b,k", "b,k,l", "b,i,m"); + REQUIRE(ep_ibk_bkl_bim.lhs_trace() == "l"); + REQUIRE(ep_ibk_bkl_bim.rhs_trace() == "m"); + REQUIRE(ep_ibk_bkl_bim.lhs_dummy() == ""); + REQUIRE(ep_ibk_bkl_bim.rhs_dummy() == ""); + REQUIRE(ep_ibk_bkl_bim.lhs_free() == "k"); + REQUIRE(ep_ibk_bkl_bim.rhs_free() == "i"); + REQUIRE(ep_ibk_bkl_bim.result_batch() == "b"); + REQUIRE(ep_ibk_bkl_bim.lhs_batch() == "b"); + REQUIRE(ep_ibk_bkl_bim.rhs_batch() == "b"); + } + SECTION("bf bd bdf") { + EinsumPlanner ep_ib_bkj_bjki("i,b", "b,k,j", "b,j,k,i"); + REQUIRE(ep_ib_bkj_bjki.lhs_trace() == ""); + REQUIRE(ep_ib_bkj_bjki.rhs_trace() == ""); + REQUIRE(ep_ib_bkj_bjki.lhs_dummy() == "k,j"); + REQUIRE(ep_ib_bkj_bjki.rhs_dummy() == "j,k"); + REQUIRE(ep_ib_bkj_bjki.lhs_free() == ""); + REQUIRE(ep_ib_bkj_bjki.rhs_free() == "i"); + REQUIRE(ep_ib_bkj_bjki.result_batch() == "b"); + REQUIRE(ep_ib_bkj_bjki.lhs_batch() == "b"); + REQUIRE(ep_ib_bkj_bjki.rhs_batch() == "b"); + } + SECTION("bf bdf bd") { + EinsumPlanner ep_ibj_jikbl_klb("i,b,j", "j,i,k,b,l", "k,l,b"); + REQUIRE(ep_ibj_jikbl_klb.lhs_trace() == ""); + REQUIRE(ep_ibj_jikbl_klb.rhs_trace() == ""); + REQUIRE(ep_ibj_jikbl_klb.lhs_dummy() == "k,l"); + REQUIRE(ep_ibj_jikbl_klb.rhs_dummy() == "k,l"); + REQUIRE(ep_ibj_jikbl_klb.lhs_free() == "j,i"); + REQUIRE(ep_ibj_jikbl_klb.rhs_free() == ""); + REQUIRE(ep_ibj_jikbl_klb.result_batch() == "b"); + REQUIRE(ep_ibj_jikbl_klb.lhs_batch() == "b"); + REQUIRE(ep_ibj_jikbl_klb.rhs_batch() == "b"); + } + SECTION("bf bdf bdf") { + EinsumPlanner ep_jmb_imb_ijb("j,m,b", "i,m,b", "i,j,b"); + REQUIRE(ep_jmb_imb_ijb.lhs_trace() == ""); + REQUIRE(ep_jmb_imb_ijb.rhs_trace() == ""); + REQUIRE(ep_jmb_imb_ijb.lhs_dummy() == "i"); + REQUIRE(ep_jmb_imb_ijb.rhs_dummy() == "i"); + REQUIRE(ep_jmb_imb_ijb.lhs_free() == "m"); + REQUIRE(ep_jmb_imb_ijb.rhs_free() == "j"); + REQUIRE(ep_jmb_imb_ijb.result_batch() == "b"); + REQUIRE(ep_jmb_imb_ijb.lhs_batch() == "b"); + REQUIRE(ep_jmb_imb_ijb.rhs_batch() == "b"); + } + SECTION("bf bdt bdf") { + EinsumPlanner ep_lbqm_iqbj_iqbml("l,b,q,m", "i,q,b,j", "i,q,b,m,l"); + REQUIRE(ep_lbqm_iqbj_iqbml.lhs_trace() == "j"); + REQUIRE(ep_lbqm_iqbj_iqbml.rhs_trace() == ""); + REQUIRE(ep_lbqm_iqbj_iqbml.lhs_dummy() == "i"); + REQUIRE(ep_lbqm_iqbj_iqbml.rhs_dummy() == "i"); + REQUIRE(ep_lbqm_iqbj_iqbml.lhs_free() == ""); + REQUIRE(ep_lbqm_iqbj_iqbml.rhs_free() == "m,l"); + REQUIRE(ep_lbqm_iqbj_iqbml.result_batch() == "b,q"); + REQUIRE(ep_lbqm_iqbj_iqbml.lhs_batch() == "q,b"); + REQUIRE(ep_lbqm_iqbj_iqbml.rhs_batch() == "q,b"); + } + SECTION("bf bdf bdt") { + EinsumPlanner ep_bi_bij_bjk("b,i", "b,i,j", "b,j,k"); + REQUIRE(ep_bi_bij_bjk.lhs_trace() == ""); + REQUIRE(ep_bi_bij_bjk.rhs_trace() == "k"); + REQUIRE(ep_bi_bij_bjk.lhs_dummy() == "j"); + REQUIRE(ep_bi_bij_bjk.rhs_dummy() == "j"); + REQUIRE(ep_bi_bij_bjk.lhs_free() == "i"); + REQUIRE(ep_bi_bij_bjk.rhs_free() == ""); + REQUIRE(ep_bi_bij_bjk.result_batch() == "b"); + REQUIRE(ep_bi_bij_bjk.lhs_batch() == "b"); + REQUIRE(ep_bi_bij_bjk.rhs_batch() == "b"); + } + SECTION("bf bdf bdf") { + EinsumPlanner ep_ibjk_kblm_jblmi("i,b,j,k", "k,b,l,m", "j,b,l,m,i"); + REQUIRE(ep_ibjk_kblm_jblmi.lhs_trace() == ""); + REQUIRE(ep_ibjk_kblm_jblmi.rhs_trace() == ""); + REQUIRE(ep_ibjk_kblm_jblmi.lhs_dummy() == "l,m"); + REQUIRE(ep_ibjk_kblm_jblmi.rhs_dummy() == "l,m"); + REQUIRE(ep_ibjk_kblm_jblmi.lhs_free() == "k"); + REQUIRE(ep_ibjk_kblm_jblmi.rhs_free() == "j,i"); + REQUIRE(ep_ibjk_kblm_jblmi.result_batch() == "b"); + REQUIRE(ep_ibjk_kblm_jblmi.lhs_batch() == "b"); + REQUIRE(ep_ibjk_kblm_jblmi.rhs_batch() == "b"); + } + SECTION("bf bdt bdft") { + EinsumPlanner ep_ilb_bjm_jlibs("i,l,b", "b,j,m", "j,l,i,b,s"); + REQUIRE(ep_ilb_bjm_jlibs.lhs_trace() == "m"); + REQUIRE(ep_ilb_bjm_jlibs.rhs_trace() == "s"); + REQUIRE(ep_ilb_bjm_jlibs.lhs_dummy() == "j"); + REQUIRE(ep_ilb_bjm_jlibs.rhs_dummy() == "j"); + REQUIRE(ep_ilb_bjm_jlibs.lhs_free() == ""); + REQUIRE(ep_ilb_bjm_jlibs.rhs_free() == "l,i"); + REQUIRE(ep_ilb_bjm_jlibs.result_batch() == "b"); + REQUIRE(ep_ilb_bjm_jlibs.lhs_batch() == "b"); + REQUIRE(ep_ilb_bjm_jlibs.rhs_batch() == "b"); + } + SECTION("bf bdft bdt") { + EinsumPlanner ep_bi_jikbm_bkjn("b,i", "j,i,k,b,m", "b,k,j,n"); + REQUIRE(ep_bi_jikbm_bkjn.lhs_trace() == "m"); + REQUIRE(ep_bi_jikbm_bkjn.rhs_trace() == "n"); + REQUIRE(ep_bi_jikbm_bkjn.lhs_dummy() == "j,k"); + REQUIRE(ep_bi_jikbm_bkjn.rhs_dummy() == "k,j"); + REQUIRE(ep_bi_jikbm_bkjn.lhs_free() == "i"); + REQUIRE(ep_bi_jikbm_bkjn.rhs_free() == ""); + REQUIRE(ep_bi_jikbm_bkjn.result_batch() == "b"); + REQUIRE(ep_bi_jikbm_bkjn.lhs_batch() == "b"); + REQUIRE(ep_bi_jikbm_bkjn.rhs_batch() == "b"); + } + SECTION("bf bdft bdft") { + EinsumPlanner ep_ijbk_bnilsk_bsammjl("i,j,b,k", "b,n,i,l,s,k", + "b,s,a,m,m,j,l"); + REQUIRE(ep_ijbk_bnilsk_bsammjl.lhs_trace() == "n"); + REQUIRE(ep_ijbk_bnilsk_bsammjl.rhs_trace() == "a,m"); + REQUIRE(ep_ijbk_bnilsk_bsammjl.lhs_dummy() == "l,s"); + REQUIRE(ep_ijbk_bnilsk_bsammjl.rhs_dummy() == "s,l"); + REQUIRE(ep_ijbk_bnilsk_bsammjl.lhs_free() == "i,k"); + REQUIRE(ep_ijbk_bnilsk_bsammjl.rhs_free() == "j"); + REQUIRE(ep_ijbk_bnilsk_bsammjl.result_batch() == "b"); + REQUIRE(ep_ijbk_bnilsk_bsammjl.lhs_batch() == "b"); + REQUIRE(ep_ijbk_bnilsk_bsammjl.rhs_batch() == "b"); + } + } +} diff --git a/tests/cxx/unit_tests/tensorwrapper/shape/shape_from_labels.cpp b/tests/cxx/unit_tests/tensorwrapper/shape/shape_from_labels.cpp new file mode 100644 index 00000000..b31f81a7 --- /dev/null +++ b/tests/cxx/unit_tests/tensorwrapper/shape/shape_from_labels.cpp @@ -0,0 +1,68 @@ +/* + * Copyright 2025 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../testing/testing.hpp" +#include + +using namespace tensorwrapper::shape; + +TEST_CASE("shape_form_labels") { + using shape_type = tensorwrapper::shape::Smooth; + using label_type = tensorwrapper::shape::ShapeBase::label_type; + shape_type s0{}; + shape_type s1{4}; + shape_type s2{5, 6}; + shape_type s3{7, 5, 4}; + + SECTION("Throws if label is not found") { + using except_t = std::runtime_error; + label_type i("i"); + REQUIRE_THROWS_AS(shape_from_labels(i, s0("")), except_t); + REQUIRE_THROWS_AS(shape_from_labels(i, s1("j")), except_t); + REQUIRE_THROWS_AS(shape_from_labels(i, s1("j"), s2("k,l")), except_t); + } + + SECTION("Scalar labels") { + label_type empty(""); + REQUIRE(shape_from_labels(empty, s0("")) == s0); + REQUIRE(shape_from_labels(empty, s1("i")) == s0); + REQUIRE(shape_from_labels(empty, s1("i"), s2("j,k")) == s0); + REQUIRE(shape_from_labels(empty, s3("i,j,k")) == s0); + } + + SECTION("Vector labels") { + label_type i("i"), j("j"), k("k"); + REQUIRE(shape_from_labels(i, s1("i")) == s1); + REQUIRE(shape_from_labels(j, s2("i,j")) == shape_type({6})); + REQUIRE(shape_from_labels(k, s2("i,j"), s3("j,k,l")) == + shape_type({5})); + } + + SECTION("Matrix labels") { + label_type ij("i,j"), jk("j,k"), ik("i,k"); + REQUIRE(shape_from_labels(ij, s2("i,j")) == s2); + REQUIRE(shape_from_labels(jk, s3("i,j,k")) == shape_type({5, 4})); + REQUIRE(shape_from_labels(ik, s2("i,j"), s3("j,k,l")) == + shape_type({5, 5})); + } + + SECTION("Tensor labels") { + label_type ijk("i,j,k"), ijl("i,j,l"); + REQUIRE(shape_from_labels(ijk, s3("i,j,k")) == s3); + REQUIRE(shape_from_labels(ijl, s2("i,j"), s3("j,k,l")) == + shape_type({5, 6, 4})); + } +}