diff --git a/include/tensorwrapper/buffer/buffer_base.hpp b/include/tensorwrapper/buffer/buffer_base.hpp index 52f812ab..c0801afc 100644 --- a/include/tensorwrapper/buffer/buffer_base.hpp +++ b/include/tensorwrapper/buffer/buffer_base.hpp @@ -59,18 +59,6 @@ class BufferBase : public detail_::PolymorphicBase { /// Type of a pointer to the layout using layout_pointer = typename layout_type::layout_pointer; - /// Type of labels for making a labeled buffer - using label_type = std::string; - - /// Type of a labeled buffer - using labeled_buffer_type = dsl::Labeled; - - /// Type of a labeled read-only buffer (n.b. labels are mutable) - using labeled_const_buffer_type = dsl::Labeled; - - /// Type of a read-only reference to a labeled_buffer_type object - using const_labeled_buffer_reference = const labeled_const_buffer_type&; - // ------------------------------------------------------------------------- // -- Accessors // ------------------------------------------------------------------------- @@ -102,128 +90,10 @@ class BufferBase : public detail_::PolymorphicBase { return *m_layout_; } - // ------------------------------------------------------------------------- - // -- BLAS Operations - // ------------------------------------------------------------------------- - - /** @brief Set this to the result of *this + rhs. - * - * This method will overwrite the state of *this with the result of - * adding the original state of *this to that of @p rhs. Depending on the - * value @p this_labels compared to the labels associated with @p rhs, - * it may be a permutation of @p rhs that is added to *this. - * - * @param[in] this_labels The labels to associate with the modes of *this. - * @param[in] rhs The buffer to add into *this. - * - * @throws ??? Throws if the derived class's implementation throws. Same - * throw guarantee. - */ - buffer_base_reference addition_assignment( - label_type this_labels, const_labeled_buffer_reference rhs) { - return addition_assignment_(std::move(this_labels), rhs); - } - - /** @brief Returns the result of *this + rhs. - * - * This method is the same as addition_assignment except that the result - * is returned in a newly allocated buffer instead of overwriting *this. - * - * @param[in] this_labels the labels for the modes of *this. - * @param[in] rhs The buffer to add to *this. - * - * @return The buffer resulting from adding *this to @p rhs. - * - * @throw std::bad_alloc if there is a problem copying *this. Strong throw - * guarantee. - * @throw ??? If addition_assignment throws when adding @p rhs to the - * copy of *this. Same throw guarantee. - */ - buffer_base_pointer addition(label_type this_labels, - const_labeled_buffer_reference rhs) const { - auto pthis = clone(); - pthis->addition_assignment(std::move(this_labels), rhs); - return pthis; - } - - /** @brief Sets *this to a permutation of @p rhs. - * - * `rhs.rhs()` are the dummy indices associated with the modes of the - * buffer in @p rhs and @p this_labels are the dummy indices associated - * with the buffer in *this. This method will permute @p rhs so that the - * resulting buffer's modes are ordered consistently with @p this_labels, - * i.e. the permutation is FROM the `rhs.rhs()` order TO the - * @p this_labels order. This is seemingly backwards when described out, - * but consistent with the intent of a DSL expression like - * `t("i,j") = x("j,i");` where the intent is to set `t` equal to the - * transpose of `x`. - * - * @param[in] this_labels the dummy indices for the modes of *this. - * @param[in] rhs The tensor to permute. - * - * @return *this after setting it equal to a permutation of @p rhs. - * - * @throw ??? If the derived class's implementation of permute_assignment_ - * throws. Same throw guarantee. - */ - buffer_base_reference permute_assignment( - label_type this_labels, const_labeled_buffer_reference rhs) { - return permute_assignment_(std::move(this_labels), rhs); - } - - /** @brief Returns a copy of *this obtained by permuting *this. - * - * This method simply calls permute_assignment on a copy of *this. See the - * description of permute_assignment for more details. - * - * @param[in] this_labels dummy indices representing the modes of *this in - * its current state. - * @param[in] out_labels how the user wants the modes of *this to be - * ordered. - * - * @throw std::bad_alloc if there is a problem allocating the copy. Strong - * throw guarantee. - * @throw ??? If the derived class's implementation of permute_assignment_ - * throws. Same throw guarantee. - */ - buffer_base_pointer permute(label_type this_labels, - label_type out_labels) const { - auto pthis = clone(); - pthis->permute_assignment(std::move(out_labels), (*this)(this_labels)); - return pthis; - } - // ------------------------------------------------------------------------- // -- Utility methods // ------------------------------------------------------------------------- - /** @brief Associates labels with the modes of *this. - * - * This method is used to create a labeled buffer object by pairing *this - * with the provided labels. The resulting object is capable of being - * composed via the DSL. - * - * @param[in] labels The indices to associate with the modes of *this. - * - * @return A DSL term pairing *this with @p labels. - * - * @throw None No throw guarantee. - */ - labeled_buffer_type operator()(label_type labels); - - /** @brief Associates labels with the modes of *this. - * - * This method is the same as the non-const version except that the result - * contains a read-only reference to *this. - * - * @param[in] labels The labels to associate with *this. - * - * @return A DSL term pairing *this with @p labels. - * - * @throw None No throw guarantee. - */ - labeled_const_buffer_type operator()(label_type labels) const; - /** @brief Is *this value equal to @p rhs? * * Two BufferBase objects are value equal if the layouts they contain are @@ -321,18 +191,6 @@ class BufferBase : public detail_::PolymorphicBase { return *this; } - /// Derived class should overwrite to implement addition_assignment - virtual buffer_base_reference addition_assignment_( - label_type this_labels, const_labeled_buffer_reference rhs) { - throw std::runtime_error("Addition assignment NYI"); - } - - /// Derived class should overwrite to implement permute_assignment - virtual buffer_base_reference permute_assignment_( - label_type this_labels, const_labeled_buffer_reference rhs) { - throw std::runtime_error("Permute assignment NYI"); - } - private: /// Throws std::runtime_error when there is no layout void assert_layout_() const { diff --git a/include/tensorwrapper/buffer/eigen.hpp b/include/tensorwrapper/buffer/eigen.hpp index 4f80a3e0..b14ee00f 100644 --- a/include/tensorwrapper/buffer/eigen.hpp +++ b/include/tensorwrapper/buffer/eigen.hpp @@ -39,9 +39,7 @@ class Eigen : public Replicated { /// Pull in base class's types using typename my_base_type::buffer_base_pointer; using typename my_base_type::const_buffer_base_reference; - using typename my_base_type::const_labeled_buffer_reference; using typename my_base_type::const_layout_reference; - using typename my_base_type::label_type; /// Type of a rank @p Rank tensor using floats of type @p FloatType using data_type = eigen::data_type; @@ -182,14 +180,6 @@ class Eigen : public Replicated { return my_base_type::are_equal_impl_(rhs); } - /// Implements addition_assignment by rebinding rhs to an Eigen buffer - buffer_base_reference addition_assignment_( - label_type this_labels, const_labeled_buffer_reference rhs) override; - - /// Implements permute assignment by deferring to Eigen's shuffle command. - buffer_base_reference permute_assignment_( - label_type this_labels, const_labeled_buffer_reference rhs) override; - /// Implements to_string typename my_base_type::string_type to_string_() const override; diff --git a/include/tensorwrapper/detail_/dsl_base.hpp b/include/tensorwrapper/detail_/dsl_base.hpp new file mode 100644 index 00000000..4e3bccff --- /dev/null +++ b/include/tensorwrapper/detail_/dsl_base.hpp @@ -0,0 +1,294 @@ +/* + * Copyright 2024 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include + +namespace tensorwrapper::detail_ { + +/** @brief Code factorization for objects that are composable via the DSL. + * + * @tparam DerivedType the type of the object which wants to interact with the + * DSL. @p DerivedType is assumed to have a clone method. + * + * This class defines the API parsers of the abstract syntax tree can interact + * with to interact with labeled objects generically. Most operations defined + * by *this have defaults (which just throw with a "not yet implemented" + * error) so that derived classes do not have to override all methods all at + * once. + */ +template +class DSLBase { +public: + /// Type of the derived class + using dsl_value_type = DerivedType; + + /// Type of a read-only object of type dsl_value_type + using dsl_const_value_type = const dsl_value_type; + + /// Type of a reference to an object of type dsl_value_type + using dsl_reference = dsl_value_type&; + + /// Type of a read-only reference to an object of type dsl_value_type + using dsl_const_reference = const dsl_value_type&; + + /// Type of a pointer to an object of type dsl_value_type + using dsl_pointer = std::unique_ptr; + + /// Type used for representing the dummy indices as a string + using string_type = StringType; + + /// Type of a labeled object + using labeled_type = dsl::Labeled; + + /// Type of a labeled read-only object (n.b. labels are mutable) + using labeled_const_type = dsl::Labeled; + + /// Type of parsed labels + using label_type = typename labeled_type::label_type; + + /// Type of a read-only reference to a labeled_type object + using const_labeled_reference = const labeled_const_type&; + + /// Polymorphic no-throw defaulted dtor + virtual ~DSLBase() noexcept = default; + + /** @brief Associates labels with the modes of *this. + * + * @tparam LabelType The type of @p labels. Assumed to be explicitly + * convertible to label_type. + * + * This method is used to create a labeled object by pairing *this + * with the provided labels. The resulting object is capable of being + * composed via the DSL. + * + * N.b., the resulting term aliases *this and the user is responsible for + * ensuring that *this is not deallocated. + * + * @param[in] labels The indices to associate with the modes of *this. + * + * @return A DSL term pairing *this with @p labels. + * + * @throw None No throw guarantee. + */ + template + labeled_type operator()(LabelType&& labels) { + label_type this_labels(std::forward(labels)); + return labeled_type(downcast_(), std::move(this_labels)); + } + + /** @brief Associates labels with the modes of *this. + * + * @tparam LabelType The type of @p labels. Assumed to be explicitly + * convertible to label_type. + * + * This method is the same as the non-const version except that the result + * contains a read-only reference to *this. + * + * @param[in] labels The labels to associate with *this. + * + * @return A DSL term pairing *this with @p labels. + * + * @throw None No throw guarantee. + */ + template + labeled_const_type operator()(LabelType&& labels) const { + label_type this_labels(std::forward(labels)); + return labeled_const_type(downcast_(), std::move(this_labels)); + } + + // ------------------------------------------------------------------------- + // -- BLAS-Like Operations + // ------------------------------------------------------------------------- + + /** @brief Set this to the result of @p lhs + @p rhs. + * + * This method will overwrite the state of *this with the result of + * adding @p lhs to @p rhs. + * + * @param[in] this_labels The labels to associate with the modes of *this. + * @param[in] lhs The object to add to @p rhs + * @param[in] rhs The object to add to @p lhs. + * + * @return *this after assigning the sum of @p lhs plus @p rhs to *this. + * + * @throws ??? Throws if the derived class's implementation throws. Same + * throw guarantee. + */ + template + dsl_reference addition_assignment(LabelType&& this_labels, + const_labeled_reference lhs, + const_labeled_reference rhs); + + /** @brief Set this to the result of @p lhs - @p rhs. + * + * This method will overwrite the state of *this with the result of + * subtracting @p rhs from @p lhs. + * + * @param[in] this_labels The labels to associate with the modes of *this. + * @param[in] lhs The object to subtract from. + * @param[in] rhs The object to be subtracted. + * + * @return *this after assigning the difference of @p lhs and @p rhs to + * *this. + * + * @throws ??? Throws if the derived class's implementation throws. Same + * throw guarantee. + */ + template + dsl_reference subtraction_assignment(LabelType&& this_labels, + const_labeled_reference lhs, + const_labeled_reference rhs); + + /** @brief Set this to the result of @p lhs * @p rhs. + * + * This method will overwrite the state of *this with the result of + * multiplying @p lhs with @p rhs. This method is responsible for + * element-wise multiplication, contraction, and mixed operations. + * + * @param[in] this_labels The labels to associate with the modes of *this. + * @param[in] lhs The object to subtract from. + * @param[in] rhs The object to be subtracted. + * + * @return *this after assigning the product of @p lhs and @p rhs to + * *this. + * + * @throws ??? Throws if the derived class's implementation throws. Same + * throw guarantee. + */ + template + dsl_reference multiplication_assignment(LabelType&& this_labels, + const_labeled_reference lhs, + const_labeled_reference rhs); + + /** @brief Sets *this to a permutation of @p rhs. + * + * `rhs.labels()` are the dummy indices associated with the modes of the + * object in @p rhs and @p this_labels are the dummy indices associated + * with the object in *this. This method will permute @p rhs so that the + * resulting object's modes are ordered consistently with @p this_labels, + * i.e. the permutation is FROM the `rhs.labels()` order TO the + * @p this_labels order. This is seemingly backwards when described out, + * but consistent with the intent of a DSL expression like + * `t("i,j") = x("j,i");` where the intent is to set `t` equal to the + * transpose of `x`. + * + * @param[in] this_labels the dummy indices for the modes of *this. + * @param[in] rhs The object to permute. + * + * @return *this after setting it equal to a permutation of @p rhs. + * + * @throw std::runtime_error if @p this_labels does not contain the same + * number of indices as *this does modes. Strong + * throw guarantee. + * @throw std::runtime_error if @p this_labels contains more dummy indices + * than @p rhs. Strong throw guarantee. + * @throw ??? If the derived class's implementation of permute_assignment_ + * throws. Same throw guarantee. + */ + template + dsl_reference permute_assignment(LabelType&& this_labels, + const_labeled_reference rhs); + + /** @brief Scales *this by @p scalar. + * + * @tparam ScalarType The type of @p scalar. Assumed to be a floating- + * point type. + * + * This method is responsible for scaling @p *this by @p scalar. + * + * @note This method is templated on the scalar type to avoid limiting the + * API. That said, at present the backend converts @p scalar to + * double precision. + */ + template + dsl_reference scalar_multiplication(ScalarType&& scalar) { + return scalar_multiplication_(std::forward(scalar)); + } + +protected: + /// Derived class should overwrite to implement addition_assignment + virtual dsl_reference addition_assignment_(label_type this_labels, + const_labeled_reference lhs, + const_labeled_reference rhs) { + throw std::runtime_error("Addition assignment NYI"); + } + + /// Derived class should overwrite to implement subtraction_assignment + virtual dsl_reference subtraction_assignment_(label_type this_labels, + const_labeled_reference lhs, + const_labeled_reference rhs) { + throw std::runtime_error("Subtraction assignment NYI"); + } + + /// Derived class should overwrite to implement multiplication_assignment + virtual dsl_reference multiplication_assignment_( + label_type this_labels, const_labeled_reference lhs, + const_labeled_reference rhs) { + throw std::runtime_error("Multiplication assignment NYI"); + } + + /// Derived class should overwrite to implement permute_assignment + virtual dsl_reference permute_assignment_(label_type this_labels, + const_labeled_reference rhs) { + throw std::runtime_error("Permute assignment NYI"); + } + + /// Derived class should overwrite to implement scalar_multiplication + dsl_reference scalar_multiplication_(double scalar) { + throw std::runtime_error("Scalar multiplication NYI"); + } + +private: + /// Checks that the dummy indices on an object are consistent with its rank + void assert_indices_match_rank_(const_labeled_reference other) const { + const auto rank = other.object().rank(); + const auto n = other.labels().size(); + if(rank == n) return; + throw std::runtime_error( + std::to_string(n) + " dummy indices is incompatible with an object" + " with rank " = std::to_string(rank)); + } + + /// Checks that @p output is a subset of @p input + void assert_is_subset_(const label_type& output, + const label_type& input) const { + if(output.intersection(input).size() < output.unique_index_size()) + throw std::runtime_error( + "Output indices must be a subset of input indices"); + } + + /// Asserts that @p lhs is a permutation of @p rhs + void assert_is_permutation_(const label_type& lhs, + const label_type& rhs) const { + if(lhs.is_permutation(rhs)) return; + throw std::runtime_error( + "Dummy indices are not related via permutation."); + } + + /// Wraps getting a mutable reference to the derived class + decltype(auto) downcast_() { return static_cast(*this); } + + /// Wraps getting a read-only reference to the derived class + decltype(auto) downcast_() const { + return static_cast(*this); + } +}; + +} // namespace tensorwrapper::detail_ + +#include "dsl_base.ipp" \ No newline at end of file diff --git a/include/tensorwrapper/detail_/dsl_base.ipp b/include/tensorwrapper/detail_/dsl_base.ipp new file mode 100644 index 00000000..6e2db6a1 --- /dev/null +++ b/include/tensorwrapper/detail_/dsl_base.ipp @@ -0,0 +1,90 @@ +/* + * Copyright 2025 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** @file dsl_base.ipp + * + * Contains inline implementations for the DSLBase class. This file is meant + * only for inclusion by dsl_base.hpp. + */ + +namespace tensorwrapper::detail_ { + +#define TPARAMS template +#define DSL_BASE DSLBase + +TPARAMS +template +typename DSL_BASE::dsl_reference DSL_BASE::addition_assignment( + LabelType&& this_labels, const_labeled_reference lhs, + const_labeled_reference rhs) { + assert_indices_match_rank_(lhs); + assert_indices_match_rank_(rhs); + assert_is_permutation_(lhs.labels(), rhs.labels()); + + label_type result_labels(std::forward(this_labels)); + auto lr_labels = lhs.labels().concatenation(rhs.labels()); + assert_is_subset_(result_labels, lr_labels); + + return addition_assignment_(std::move(result_labels), lhs, rhs); +} + +TPARAMS +template +typename DSL_BASE::dsl_reference DSL_BASE::subtraction_assignment( + LabelType&& this_labels, const_labeled_reference lhs, + const_labeled_reference rhs) { + assert_indices_match_rank_(lhs); + assert_indices_match_rank_(rhs); + assert_is_permutation_(lhs.labels(), rhs.labels()); + + label_type result_labels(std::forward(this_labels)); + auto lr_labels = lhs.labels().concatenation(rhs.labels()); + assert_is_subset_(result_labels, lr_labels); + + return subtraction_assignment_(std::move(result_labels), lhs, rhs); +} + +TPARAMS +template +typename DSL_BASE::dsl_reference DSL_BASE::multiplication_assignment( + LabelType&& this_labels, const_labeled_reference lhs, + const_labeled_reference rhs) { + assert_indices_match_rank_(lhs); + assert_indices_match_rank_(rhs); + + label_type result_labels(std::forward(this_labels)); + auto lr_labels = lhs.labels().concatenation(rhs.labels()); + assert_is_subset_(result_labels, lr_labels); + + return multiplication_assignment_(std::move(result_labels), lhs, rhs); +} + +TPARAMS +template +typename DSL_BASE::dsl_reference DSL_BASE::permute_assignment( + LabelType&& this_labels, const_labeled_reference rhs) { + assert_indices_match_rank_(rhs); + + label_type lhs_labels(std::forward(this_labels)); + assert_is_subset_(lhs_labels, rhs.labels()); + + return permute_assignment_(std::move(lhs_labels), rhs); +} + +#undef DSL_BASE +#undef TPARAMS + +} // namespace tensorwrapper::detail_ \ No newline at end of file diff --git a/include/tensorwrapper/detail_/polymorphic_base.hpp b/include/tensorwrapper/detail_/polymorphic_base.hpp index f6a7c152..ccbd094b 100644 --- a/include/tensorwrapper/detail_/polymorphic_base.hpp +++ b/include/tensorwrapper/detail_/polymorphic_base.hpp @@ -41,6 +41,9 @@ class PolymorphicBase { /// Read-only reference to an object of type base_type using const_base_reference = const base_type&; + /// Mutable rvalue reference to an object of type base_type + using base_rvalue = base_type&&; + /// Pointer to an object of type base_type using base_pointer = std::unique_ptr; @@ -93,13 +96,13 @@ class PolymorphicBase { /** @brief Determines if *this and @p rhs are polymorphically equal. * * Calling operator== on an object of type T is supposed to compare the - * state defined in class T as well as all state defined in parent classes. - * If there is other classes which derived from T that possess state, use - * of T::operator== will not consider such state in the comparison. This - * method casts both *this and @p rhs to their most derived class and then - * performs the value comparison to ensure that all state is considered. If - * *this and @p rhs have different most derived classes this comparison - * returns false. + * state defined in class T as well as all state defined in parent + * classes. If there is other classes which derived from T that possess + * state, use of T::operator== will not consider such state in the + * comparison. This method casts both *this and @p rhs to their most + * derived class and then performs the value comparison to ensure that all + * state is considered. If *this and @p rhs have different most derived + * classes this comparison returns false. * * Derived classes should override are_equal_ to implement this method. * @@ -222,6 +225,7 @@ class PolymorphicBase { */ virtual bool are_equal_(const_base_reference rhs) const noexcept = 0; + /// Should be overridden by the derived class to provide logging details. virtual string_type to_string_() const { return "{?}"; } }; diff --git a/include/tensorwrapper/dsl/dsl_forward.hpp b/include/tensorwrapper/dsl/dsl_forward.hpp index c756ebd7..340b12bc 100644 --- a/include/tensorwrapper/dsl/dsl_forward.hpp +++ b/include/tensorwrapper/dsl/dsl_forward.hpp @@ -21,7 +21,6 @@ namespace tensorwrapper::dsl { template class Labeled; -template class PairwiseParser; } // namespace tensorwrapper::dsl \ No newline at end of file diff --git a/include/tensorwrapper/dsl/dummy_indices.hpp b/include/tensorwrapper/dsl/dummy_indices.hpp index 73b50808..81af1a31 100644 --- a/include/tensorwrapper/dsl/dummy_indices.hpp +++ b/include/tensorwrapper/dsl/dummy_indices.hpp @@ -15,6 +15,7 @@ */ #pragma once +#include #include #include #include @@ -139,6 +140,29 @@ class DummyIndices return unique_index_size() != this->size(); } + /** @brief Determines if *this is a permutation of @p other. + * + * *this is a permutation of @p other if both *this and @p other contain + * the same number of dummy indices and if each unique index in *this + * appears the same number of times in *this and @p other. + * + * @param[in] other The set of dummy indices to compare against. + * + * @return True if *this is a permutation of @p other and false otherwise. + * + * @throw None No throw guarantee. + */ + bool is_permutation(const DummyIndices& other) const noexcept { + // Must be same size + if(this->size() != other.size()) return false; + + // Each index in *this must show up the same number of times in other + for(const auto& index : *this) { + if(count(index) != other.count(index)) return false; + } + return true; + } + /** @brief Computes the permutation needed to convert *this into @p other. * * Each DummyIndices object is viewed as an ordered set of objects. If @@ -202,6 +226,146 @@ class DummyIndices return rv; } + /** @brief Determines how many times @p index_to_find occurs in *this. + * + * @param[in] index_to_find The dummy index to find. + * + * @return The number of times dummy index occurs in *this. + * + * @throw None No throw guarantee. + */ + size_type count(const_reference index_to_find) const noexcept { + size_type rv = 0; + for(const auto& x : *this) + if(x == index_to_find) ++rv; + return rv; + } + + /** @brief Determines if *this is value equal to @p rhs. + * + * *this and @p rhs are value equal if they contain the same number of + * indices and if the i-th dummy index of *this is value equal + * to the i-th dummy index of @p rhs for all i in the range [0, size()). + * + * @param[in] rhs The object to compare to. + * + * @return True if *this is value equal to @p rhs and false otherwise. + * + * @throw None No throw guarantee. + */ + bool operator==(const DummyIndices& rhs) const noexcept { + return m_dummy_indices_ == rhs.m_dummy_indices_; + } + + /** @brief Determines if *this is value equal to @p s. + * + * This method is useful for comparing an already parsed DummyIndices + * object to a string literal. This method works by converting @p s into a + * DummyIndex object and then comparing the resulting DummyIndices object + * to *this. See operator==(const DummyIndices&) for more details on the + * comparison. + * + * @param[in] s The string to compare to *this. + * + * @return True if *this is value equal to @p s as a DummyIndices object + * and false otherwise. + * + * @throw std::bad_alloc if parsing @p s encounters an allocation problem. + * Strong throw guara + * @throw std::runtime_error if parsing @p s throws. Strong throw + * guarantee. + */ + bool operator==(const_reference s) const { + return operator==(DummyIndices(s)); + } + + /** @brief Determines if *this is different from @p rhs. + * + * Two DummyIndices objects are different if they are not value equal. See + * the description of operator==(const DummyIndices&) for the definition + * of value equal. + * + * @param[in] rhs The object to compare to. + * + * @return false if *this and @p rhs are value equal and true otherwise. + * + * @throw None No throw guarantee. + */ + bool operator!=(const DummyIndices& rhs) const noexcept { + return !((*this) == rhs); + } + + /** @brief Determines if *this is different than @p s. + * + * This method is useful for comparing an already parsed DummyIndices + * object to a string literal. It works by converting @p s into a + * DummyIndices object and then comparing the resulting DummyIndices object + * to *this. See operator!=(const DummyIndices&) for more details on the + * comparison. + * + * @param[in] s The string to compare to *this. + * + * @return False if *this is value equal to @p s as a DummyIndices object + * and true otherwise. + * + * @throw std::bad_alloc if parsing @p s encounters an allocation problem. + * Strong throw guara + * @throw std::runtime_error if parsing @p s throws. Strong throw + * guarantee. + */ + bool operator!=(const_reference s) const { return !((*this) == s); } + + /** @brief Computes the DummyIndices object formed by concatenating *this + * with @p other. + * + * This method will create a new DummyIndices object which contains + * `this->size()` plus `other.size()` indices. The first `this->size()` + * indices will be the indices of *this and the next `other.size()` indices + * will be the indices of @p other. + * + * @note This is in general NOT the union of *this with @p other, in + * particular repeat indices may appear. + * + * @param[in] other The indices to concatenate onto *this. + * + * @return A new DummyIndices object formed by concatenating *this with + * @p other. + * + * @throw std::bad_alloc if allocating the return fails. Strong throw + * guarantee. + */ + DummyIndices concatenation(const DummyIndices& other) const { + DummyIndices rv(*this); + for(const auto& x : other) rv.m_dummy_indices_.push_back(x); + return rv; + } + + /** @brief Returns the unique indices of *this which appear in @p other. + * + * This method retruns a new DummyIndices object containing the set of + * indices in *this which also appear in @p other. The indices in the + * result are unique (i.e., if an index is repeated in *this it is only + * added to result once). + * + * @param[in] other The object to compare *this to. + * + * @return A DummyIndices object which contains the intersection of *this + * with @p other. + * + * @throw std::bad_alloc if allocation of the return fails. Strong throw + * guarantee. + */ + DummyIndices intersection(const DummyIndices& other) const { + DummyIndices rv; + std::set seen; + for(const auto& x : *this) { + if(seen.count(x)) continue; + seen.insert(x); + if(other.count(x)) rv.m_dummy_indices_.push_back(x); + } + return rv; + } + protected: /// Main ctor for setting the value, throws if any index is empty explicit DummyIndices(split_string_type split_dummy_indices) : diff --git a/include/tensorwrapper/dsl/labeled.hpp b/include/tensorwrapper/dsl/labeled.hpp index 7a4d4f89..ac83532b 100644 --- a/include/tensorwrapper/dsl/labeled.hpp +++ b/include/tensorwrapper/dsl/labeled.hpp @@ -17,22 +17,37 @@ #pragma once #include +#include #include #include #include + namespace tensorwrapper::dsl { /** @brief Represents an object whose modes are assigned dummy indices. + * + * @tparam ObjectType the type of the object. Assumed to be a class from the + * shape, symmetry, sparsity, layout, buffer, or tensor + * class hierarchies. + * @tparam StringType the type used for string literals. Default is + * std::string. + * + * This class is used to promote TensorWrapper objects into the DSL layer. + * Users will interact with this class somewhat transparently (usually via + * unnamed temporary objects). + * */ -template -class Labeled : public utilities::dsl::BinaryOp, - ObjectType, LabelType> { +template +class Labeled : public utilities::dsl::Term> { private: /// Type of *this - using my_type = Labeled; + using my_type = Labeled; + + /// Type of *this if ObjectType is const + using const_my_type = Labeled, StringType>; /// Type *this inherits from - using op_type = utilities::dsl::BinaryOp; + using op_type = utilities::dsl::Term; /// Is T cv-qualified? template @@ -41,6 +56,10 @@ class Labeled : public utilities::dsl::BinaryOp, /// Is ObjectType cv-qualified? static constexpr bool has_cv_object_v = is_cv_v; + /// Shorthand for type @p T if ObjectType is const, and @p U otherwise + template + using if_cv_t = std::conditional_t; + /// Does *this have a cv-qualified object and T is mutable? template static constexpr bool is_cv_conversion_v = has_cv_object_v && !is_cv_v; @@ -50,27 +69,41 @@ class Labeled : public utilities::dsl::BinaryOp, using enable_if_cv_conversion_t = std::enable_if_t>; public: + // ------------------------------------------------------------------------- + // -- Types associated with the object + // ------------------------------------------------------------------------- + /// Type of the object (useful for TMP) using object_type = std::decay_t; - /// Type of the labels (useful for TMP) - using label_type = LabelType; + /// Type of a read-only reference to an object of object_type + using const_object_reference = const object_type&; - /** @brief Creates a Labeled object that does not alias an object or labels. - * - * This ctor is needed because the base classes assume it is present. - * Users shouldn't actually need it. - * - * @throw None No throw guarantee. - */ - Labeled() = default; + /// Type of a (possibly) mutable reference to an object of object_type + using object_reference = if_cv_t; + + // ------------------------------------------------------------------------- + // -- Types associated with the labels + // ------------------------------------------------------------------------- + + /// Type of the string literal used for index labels + using string_type = StringType; + + /// Type of the object managing the parsed index labels + using label_type = dsl::DummyIndices; - /** @brief Ensures labels are stored correctly. + /// Mutable reference to an object of type label_type + using label_reference = label_type&; + + /// Read-only reference to an object of type label_type + using const_label_reference = const label_type&; + + /** @brief Associates a set of dummy indices with an object. * * @tparam ObjectType2 The type of @p object. Must be implicitly * convertible to @p ObjectType. - * @tparam LabelType2 The type of @p labels. Must be implicitly - * convertible to @p LabelType. + * @tparam LabelType2 The type of @p labels. Assumed to be implicitly + * convertible to either StringType or label_type. * * It is common for the labels to actually be a string literal, e.g., * code like `"i,j"`. Type detection for such a type will not match it @@ -83,47 +116,243 @@ class Labeled : public utilities::dsl::BinaryOp, * @throw std::bad_alloc if converting @p labels to LabelType throws. * Strong throw guarantee. */ - template - Labeled(ObjectType2&& object, LabelType2&& labels) : - op_type(std::forward(object), - LabelType(std::forward(labels))) {} + template + Labeled(ObjectType2&& object, LabelType&& labels) : + m_object_(&object), m_labels_(std::forward(labels)) {} /** @brief Allows implicit conversion from mutable objects to const objects - * - * @p ObjectType may have cv-qualifiers. This ctor allows Labeled instances - * aliasing mutable objects to be used when Labeled instances aliasing - * read-only objects are needed. * * @tparam ObjectType2 The object type stored in @p input. Must be * equivalent to `const ObjectType`. * @tparam Used to disable this overload via SFINAE if * ObjectType2 != `const ObjectType` or if + * `ObjectType` is not mutable. + * + * @p ObjectType may have cv-qualifiers. This ctor allows Labeled instances + * aliasing mutable objects to be used when Labeled instances aliasing + * read-only objects are needed. + * + * @param[in] input The Labeled object to convert. */ template> - Labeled(const Labeled& input) : - Labeled(input.lhs(), input.rhs()) {} + Labeled(const Labeled& input) : + Labeled(input.object(), input.labels()) {} + + /** @brief Creates a new Labeled object by copying @p other. + * + * The Labeled object created with this ctor will alias the same object + * as @p other did, but contain a deep copy of the labels associated with + * @p other. + * + * @param[in] other The Labeled object to copy. + * + * @throw std::bad_alloc if there's a problem allocating memory for the + * copy. Strong throw guarantee. + */ + Labeled(const Labeled& other) = default; + + /** @brief Creates a new Labeled object by taking the state from @p other. + * + * The Labeled object created with this ctor will alias the same object + * as @p other did and take ownership of the labels which were previously + * associated with @p other. + * + * @param[in,out] other The Labeled object to take the state from. After + * this operation @p other is in a valid, but + * otherwise undefined state. + * + * @throw None No throw guarantee. + */ + Labeled(Labeled&& other) noexcept = default; + + /** @brief Sets *this equal to @p rhs. + * + * This method can be used as a copy assignment for the object *this + * aliases, but it is NOT copy assignment for *this. More specifically + * this method will call assign_ to do the actual assignment, which may + * result in permutations and or traces being taken of @p rhs before the + * assignment happens. Whether permutations/traces occur depends on the + * indices of *this. + * + * @note This method is needed because the compiler prefers the + * compiler generated version over the function template overload. + * + * @param[in] rhs The object to assign to *this. + * + * @return *this after assigning @p rhs to it and performing any operations + * specified by the dummy indices. + * + * @throw ??? Throws if assign_ throws. Same throw guarantee. + * + */ + Labeled& operator=(const Labeled& rhs) { return assign_(rhs); }; + + /** @brief Sets *this equal to @p rhs. + * + * This method behaves similar to operator=(const Labeled&) except that + * the parser has the option of reusing @p rhs in the operation instead + * of copying it. + * + * @note This method is needed because the compiler prefers the + * compiler generated version over the function template overload. + * + * @param[in,out] rhs The object to assign to *this. After this operation + * @p rhs is in a valid, but otherwise undefined state. + * + * @throw ??? Throws if assign_ throws. Same throw guarantee. + */ + Labeled& operator=(Labeled&& rhs) { return assign_(std::move(rhs)); } /** @brief Assigns a DSL term to *this. * * @tparam TermType The type of the expression being assigned to *this. * - * Under most circumstances execution of the DSL happens when an - * expression is assigned to Labeled object. The assignment happens via - * this method. + * This method is the generalization of operator=(const Labeled&) to + * other leaves of the AST. Like the other operator= methods it is + * implemented by calling assign_. * - * @param[in] other The expression to assign to *this. + * @param[in] other The object containing the AST to assign to *this. * * @return *this after assigning @p other to *this. + * + * @throw ??? If assign_ throws. Same throw guarantee. */ template my_type& operator=(TermType&& other) { - // TODO: other should be rolled into a tensor graph object that can be - // manipulated at runtime. Parser is then moved to the backend - PairwiseParser p; - *this = p.dispatch(std::move(*this), std::forward(other)); + return assign_(std::forward(other)); + } + + /** @brief Returns a (possibly) read-only reference to the object. + * + * This method is used to access the object associated with the dummy + * indices. The object is mutable if @p ObjectType is a mutable type and + * read-only if @p ObjectType is cv-qualified. + * + * @return A reference to the object. + * + * @throw std::runtime_error if *this does not have an object associated + * with it. Strong throw guarantee. + */ + object_reference object() { + assert_has_object_(); + return *m_object_; + } + + /** @brief Returns a read-only reference to the labeled object. + * + * This method is identical to the non-const version except that the + * resulting object is guarantee to be read-only. See the description for + * the non-const version for more details. + * + * @return A read-only reference to the labeled object. + * + * @throw std::runtime_error if *this does not have an object associated + * with it. Strong throw guarantee. + */ + const_object_reference object() const { + assert_has_object_(); + return *m_object_; + } + + /** @brief The dummy indices associated with the object. + * + * This method is used to retrieve the dummy indices associated with the + * object. + * + * @return A mutable reference to the dummy indices. + * + * @throw None No throw guarantee. + * + */ + label_reference labels() noexcept { return m_labels_; } + + /** @brief Returns a read-only reference to the dummy labels. + * + * This method is identical to the non-const version, except that the + * resulting indices are guaranteed to be read-only. See the description + * for the non-const version for more details. + * + * @return A read-only reference to the dummy indices. + * + * @throw None No throw guarantee. + */ + const_label_reference labels() const noexcept { return m_labels_; } + + /** @brief Does *this have an object? + * + * Under most circumstances *this will be associated with an object. This + * method can be used to explicitly test that *this does have an object. + * + * @return True if *this has an object and false otherwise. + * + * @throw None No throw guarantee. + */ + bool has_object() const noexcept { return static_cast(m_object_); } + + /** @brief Determines if *this is value equal to @p rhs. + * + * Two Labeled objects are value equal if their labels compare value equal + * and if they both: + * 1. Contain objects which compare (polymorphically) value equal, or + * 2. Do not contain objects. + * + * @param[in] rhs The object to compare to. + * + * @return True if *this is value equal to @p rhs and false otherwise. + * + * @throw None No throw guarantee. + */ + bool operator==(const const_my_type& rhs) const noexcept { + if(has_object() != rhs.has_object()) return false; + if(labels() != rhs.labels()) return false; + if(!has_object()) return true; + return object().are_equal(rhs.object()); + } + + /** @brief Is *this different from @p rhs? + * + * This method simply negates operator==. See operator== for the definition + * of value equal. + * + * @param[in] rhs The object to compare to. + * + * @return False if *this is value equal to @p rhs and true otherwise. + * + * @throw None No throw guarantee. + */ + bool operator!=(const const_my_type& rhs) const noexcept { + return !((*this) == rhs); + } + +private: + /// Type of a pointer to a read-only object of object_type + using const_object_pointer = const object_type*; + + /// Type of a pointer to a (possibly) mutable object of object_type + using object_pointer = if_cv_t; + + /// Asserts that m_object_ is non-null + void assert_has_object_() const { + if(has_object()) return; + throw std::runtime_error("Object is null. Was it moved from?"); + } + + /// Common implementation for assigning other to *this. + template + Labeled& assign_(TermType&& other) { + // TODO: other should be rolled into a tensor graph object that can + // be manipulated at runtime. Parser is then moved to the backend + PairwiseParser p; + p.dispatch(*this, std::forward(other)); return *this; } + + /// The object whose modes are indexed. + object_pointer m_object_ = nullptr; + + /// The dummy indices associated with m_object_ + label_type m_labels_; }; } // namespace tensorwrapper::dsl \ No newline at end of file diff --git a/include/tensorwrapper/dsl/pairwise_parser.hpp b/include/tensorwrapper/dsl/pairwise_parser.hpp index c8834ebe..0209dcab 100644 --- a/include/tensorwrapper/dsl/pairwise_parser.hpp +++ b/include/tensorwrapper/dsl/pairwise_parser.hpp @@ -16,71 +16,129 @@ #pragma once #include +#include #include -namespace tensorwrapper { -class Tensor; -namespace dsl { +namespace tensorwrapper::dsl { /** @brief Object which evaluates the AST of an expression pairwise. * - * @tparam ObjectType The type of the objects associated with the dummy - * indices. Expected to be possibly cv-qualified versions - * of Tensor, buffers, shapes, etc. - * @tparam LabelType The type of object used for the dummy indices. + * The easiest way to evaluate a tensor network is as a series of assignments, + * i.e., things that look like `A = B` and binary operations coupled to + * assignments, i.e., things that look like `C = A + B`. That's what this + * parser does. It should be noted that this is not necessarily the most + * performant way to evaluate the AST, e.g., this prohibits detection of + * common intermediates across multiple equations. * - * The easiest way to evaluate an abstract syntax tree which contains - * operations involving at most two objects is by splitting it into subtrees - * which contain at most two connected nodes, i.e., considering each operation - * pairwise. That's what this parser does. + * @note The + * @code + * auto pA = lhs.object().clone(); + * auto pB = lhs.object().clone(); + * auto labels = lhs.labels(); + * auto lA = (*pA)(labels); + * auto lB = (*pB)(labels); + * dispatch(lA, rhs.lhs()); + * dispatch(lB, rhs.rhs()); + * @endcode + * are repetitive, but we need to keep pA and pB alive which inhibits + * factorization. */ -template class PairwiseParser { public: - /// Type of a leaf in the AST - using labeled_type = Labeled; - /** @brief Recursion end-point * - * Evaluates @p rhs given that it will be evaluated into lhs. - * This is the natural end-point for recursion down a branch of the AST. - * - * N.b., this overload is only responsible for evaluating @p rhs NOT for - * assigning it to @p lhs. + * Ternary operations like `C = A + B` are ultimately evaluated by + * assigning `A` and `B` to temporaries and then summing the temporaries. + * The assignment to the temporary ensures that if `A` or `B` is itself a + * term it gets evaluated down to an object before the addition happens. + * The assignment calls this overload of dispatch. * - * @param[in] lhs The object that @p rhs will ultimately be assigned to. + * @param[in] lhs The object to assign @p rhs to. * @param[in] rhs The "expression" that needs to be evaluated. * - * @return @p rhs untouched. - * - * @throw None No throw guarantee. */ - auto dispatch(labeled_type lhs, labeled_type rhs) { return rhs; } + template + void dispatch(LHSType&& lhs, const RHSType& rhs) { + if constexpr(std::is_floating_point_v>) { + lhs.object().scalar_multiplication(rhs); + } else { + lhs.object().permute_assignment(lhs.labels(), rhs); + } + } /** @brief Handles adding two expressions together. * + * @tparam LHSType The type to assign the sum of @p lhs and @p rhs to. * @tparam T The type of the expression on the left side of the "+" sign. * @tparam U The type of the expression on the right side of the "+" sign. * * @param[in] lhs The object that @p rhs will ultimately be assigned to. * @param[in] rhs The expression to evaluate. * + * @throw std::runtime_error if there is a problem doing the operation. + * Strong throw guarantee. + */ + template + void dispatch(LHSType&& lhs, const utilities::dsl::Add& rhs) { + auto pA = lhs.object().clone(); + auto pB = lhs.object().clone(); + auto labels = lhs.labels(); + auto lA = (*pA)(labels); + auto lB = (*pB)(labels); + dispatch(lA, rhs.lhs()); + dispatch(lB, rhs.rhs()); + lhs.object().addition_assignment(labels, lA, lB); + } + + /** @brief Handles subtracting two expressions together. + * + * @tparam LHSType The type of the object the expression will be evaluated + * into. + * @tparam T The type of the expression on the left side of the "-" sign. + * @tparam U The type of the expression on the right side of the "-" sign. + * + * @param[in] lhs The object that @p rhs will ultimately be assigned to. + * @param[in] rhs The expression to evaluate. * + * @throw std::runtime_error if there is a problem doing the operation. + * Strong throw guarantee. */ - template - auto dispatch(labeled_type lhs, const utilities::dsl::Add& rhs) { - // TODO: This shouldn't be assigning to lhs, but letting the layer up - // do that - auto lA = dispatch(lhs, rhs.lhs()); - auto lB = dispatch(lhs, rhs.rhs()); - return add(std::move(lhs), std::move(lA), std::move(lB)); + template + void dispatch(LHSType&& lhs, const utilities::dsl::Subtract& rhs) { + auto pA = lhs.object().clone(); + auto pB = lhs.object().clone(); + auto labels = lhs.labels(); + auto lA = (*pA)(labels); + auto lB = (*pB)(labels); + dispatch(lA, rhs.lhs()); + dispatch(lB, rhs.rhs()); + lhs.object().subtraction_assignment(labels, lA, lB); } -protected: - labeled_type add(labeled_type result, labeled_type lhs, labeled_type rhs); + /** @brief Handles multiplying two expressions together. + * + * @tparam LHSType The type of the object the expression will be evaluated + * into. + * @tparam T The type of the expression on the left side of the "*" sign. + * @tparam U The type of the expression on the right side of the "*" sign. + * + * @param[in] lhs The object that @p rhs will ultimately be assigned to. + * @param[in] rhs The expression to evaluate. + * + * @throw std::runtime_error if there is a problem doing the operation. + * Strong throw guarantee. + */ + template + void dispatch(LHSType&& lhs, const utilities::dsl::Multiply& rhs) { + auto pA = lhs.object().clone(); + auto pB = lhs.object().clone(); + auto labels = lhs.labels(); + auto lA = (*pA)(labels); + auto lB = (*pB)(labels); + dispatch(lA, rhs.lhs()); + dispatch(lB, rhs.rhs()); + lhs.object().multiplication_assignment(labels, lA, lB); + } }; -extern template class PairwiseParser; - -} // namespace dsl -} // namespace tensorwrapper \ No newline at end of file +} // namespace tensorwrapper::dsl \ No newline at end of file diff --git a/include/tensorwrapper/shape/shape_base.hpp b/include/tensorwrapper/shape/shape_base.hpp index ce352e61..33725798 100644 --- a/include/tensorwrapper/shape/shape_base.hpp +++ b/include/tensorwrapper/shape/shape_base.hpp @@ -17,6 +17,7 @@ #pragma once #include #include +#include #include #include #include @@ -37,11 +38,20 @@ namespace tensorwrapper::shape { * - get_rank_() * - get_size_() */ -class ShapeBase : public tensorwrapper::detail_::PolymorphicBase { +class ShapeBase : public tensorwrapper::detail_::PolymorphicBase, + public tensorwrapper::detail_::DSLBase { private: /// Type implementing the traits of this using traits_type = ShapeTraits; +protected: + /// Typedef of the PolymorphicBase class of *this + using polymorphic_base_type = + tensorwrapper::detail_::PolymorphicBase; + + /// Typedef of the DSLBase class of *this + using dsl_base_type = tensorwrapper::detail_::PolymorphicBase; + public: /// Type all shapes inherit from using shape_base = typename traits_type::shape_base; diff --git a/include/tensorwrapper/shape/smooth.hpp b/include/tensorwrapper/shape/smooth.hpp index 95315f2e..32d167de 100644 --- a/include/tensorwrapper/shape/smooth.hpp +++ b/include/tensorwrapper/shape/smooth.hpp @@ -177,6 +177,34 @@ class Smooth : public ShapeBase { return are_equal_impl_(rhs); } + /// Implements addition_assignment via permute_assignment + dsl_reference addition_assignment_(label_type this_labels, + const_labeled_reference lhs, + const_labeled_reference rhs) override; + + /// Implements subtraction_assignment via permute_assignment + dsl_reference subtraction_assignment_(label_type this_labels, + const_labeled_reference lhs, + const_labeled_reference rhs) override; + + /// Implements multiplication_assignment via permute_assignment + dsl_reference multiplication_assignment_( + label_type this_labels, const_labeled_reference lhs, + const_labeled_reference rhs) override; + + /// Implements permute_assignment by permuting the extents in @p rhs. + dsl_reference permute_assignment_(label_type this_labels, + const_labeled_reference rhs) override; + + /// Implements to_string + typename polymorphic_base_type::string_type to_string_() const override { + using str_type = typename polymorphic_base_type::string_type; + str_type buffer("{"); + for(auto x : m_extents_) buffer += str_type(" ") + std::to_string(x); + buffer += str_type("}"); + return buffer; + } + private: /// Type used to hold the extents of *this using extents_type = std::vector; diff --git a/include/tensorwrapper/tensor/tensor_class.hpp b/include/tensorwrapper/tensor/tensor_class.hpp index e67407aa..334b7047 100644 --- a/include/tensorwrapper/tensor/tensor_class.hpp +++ b/include/tensorwrapper/tensor/tensor_class.hpp @@ -96,20 +96,6 @@ class Tensor { /// Type of an initializer list if *this is a rank 4 tensor using tensor4_il_type = std::initializer_list; - /// Type of a label - using label_type = std::string; - - /// Type of a read-only reference to an object of type label_type - using const_label_reference = const label_type&; - - /// Type of a labeled tensor - using labeled_tensor_type = dsl::Labeled; - - /// Type of a read-only labeled tensor - using const_labeled_tensor_type = dsl::Labeled; - - // Tensor() : Tensor(input_type{}) {} - /** @brief Initializes *this by processing the input provided in @p input. * * This ctor is only public to facilitate unit testing of the library. @@ -313,42 +299,6 @@ class Tensor { */ const_buffer_reference buffer() const; - /** @brief Associates @p labels with the modes of *this. - * - * Expressing tensor operations is easier with the use of the Einstein - * summation convention. Usage of this convention requires the user to be - * able to associate dummy indices with the modes of the tensor. This - * function pairs @p labels with the modes of *this such that the i-th - * dummy index of @p labels is paired with the i-th mode of *this. - * - * See dsl::DummyIndices for how the string is interpreted. - * - * Note that if *this is a rank 0 tensor @p labels should be the empty - * string. - * - * @param[in] labels The dummy indices to associate with each mode. - * - * @return A DSL term pairing *this with @p labels. - * - */ - labeled_tensor_type operator()(const_label_reference labels) { - return labeled_tensor_type(*this, labels); - } - - /** @brief Associates @p labels with the modes of *this. - * - * This method is the same as the non-const version except that the - * resulting DSL term contains a reference to an immutable tensor. - * - * @param[in] labels The dummy indices to associate with each mode. - * - * @return A DSL term pairing *this with @p labels. - * - */ - const_labeled_tensor_type operator()(const_label_reference labels) const { - return const_labeled_tensor_type(*this, labels); - } - // ------------------------------------------------------------------------- // -- Utility methods // ------------------------------------------------------------------------- diff --git a/src/tensorwrapper/buffer/buffer_base.cpp b/src/tensorwrapper/buffer/buffer_base.cpp index e40c30cc..e82461fb 100644 --- a/src/tensorwrapper/buffer/buffer_base.cpp +++ b/src/tensorwrapper/buffer/buffer_base.cpp @@ -16,16 +16,4 @@ #include -namespace tensorwrapper::buffer { - -typename BufferBase::labeled_buffer_type BufferBase::operator()( - label_type labels) { - return labeled_buffer_type(*this, std::move(labels)); -} - -typename BufferBase::labeled_const_buffer_type BufferBase::operator()( - label_type labels) const { - return labeled_const_buffer_type(*this, std::move(labels)); -} - -} // namespace tensorwrapper::buffer \ No newline at end of file +namespace tensorwrapper::buffer {} // namespace tensorwrapper::buffer \ No newline at end of file diff --git a/src/tensorwrapper/buffer/eigen.cpp b/src/tensorwrapper/buffer/eigen.cpp index 78c3a2d8..c5d90da4 100644 --- a/src/tensorwrapper/buffer/eigen.cpp +++ b/src/tensorwrapper/buffer/eigen.cpp @@ -26,52 +26,52 @@ using dummy_indices_type = dsl::DummyIndices; #define TPARAMS template #define EIGEN Eigen -TPARAMS -typename EIGEN::buffer_base_reference EIGEN::addition_assignment_( - label_type this_labels, const_labeled_buffer_reference rhs) { - // TODO layouts - if(layout() != rhs.lhs().layout()) - throw std::runtime_error("Layouts must be the same (for now)"); - - dummy_indices_type llabels(this_labels); - dummy_indices_type rlabels(rhs.rhs()); - - using allocator_type = allocator::Eigen; - const auto& rhs_downcasted = allocator_type::rebind(rhs.lhs()); - - if(llabels != rlabels) { - auto r_to_l = rlabels.permutation(llabels); - std::vector r_to_l2(r_to_l.begin(), r_to_l.end()); - m_tensor_ += rhs_downcasted.value().shuffle(r_to_l2); - } else { - m_tensor_ += rhs_downcasted.value(); - } - - return *this; -} - -TPARAMS -typename EIGEN::buffer_base_reference EIGEN::permute_assignment_( - label_type this_labels, const_labeled_buffer_reference rhs) { - dummy_indices_type llabels(this_labels); - dummy_indices_type rlabels(rhs.rhs()); - - using allocator_type = allocator::Eigen; - const auto& rhs_downcasted = allocator_type::rebind(rhs.lhs()); - - if(llabels != rlabels) { // We need to permute rhs before assignment - auto r_to_l = rlabels.permutation(llabels); - // Eigen wants int objects - std::vector r_to_l2(r_to_l.begin(), r_to_l.end()); - m_tensor_ = rhs_downcasted.value().shuffle(r_to_l2); - } else { - m_tensor_ = rhs_downcasted.value(); - } - - // TODO: permute layout - - return *this; -} +// TPARAMS +// typename EIGEN::buffer_base_reference EIGEN::addition_assignment_( +// label_type this_labels, const_labeled_buffer_reference rhs) { +// // TODO layouts +// if(layout() != rhs.lhs().layout()) +// throw std::runtime_error("Layouts must be the same (for now)"); + +// dummy_indices_type llabels(this_labels); +// dummy_indices_type rlabels(rhs.rhs()); + +// using allocator_type = allocator::Eigen; +// const auto& rhs_downcasted = allocator_type::rebind(rhs.lhs()); + +// if(llabels != rlabels) { +// auto r_to_l = rlabels.permutation(llabels); +// std::vector r_to_l2(r_to_l.begin(), r_to_l.end()); +// m_tensor_ += rhs_downcasted.value().shuffle(r_to_l2); +// } else { +// m_tensor_ += rhs_downcasted.value(); +// } + +// return *this; +// } + +// TPARAMS +// typename EIGEN::buffer_base_reference EIGEN::permute_assignment_( +// label_type this_labels, const_labeled_buffer_reference rhs) { +// dummy_indices_type llabels(this_labels); +// dummy_indices_type rlabels(rhs.rhs()); + +// using allocator_type = allocator::Eigen; +// const auto& rhs_downcasted = allocator_type::rebind(rhs.lhs()); + +// if(llabels != rlabels) { // We need to permute rhs before assignment +// auto r_to_l = rlabels.permutation(llabels); +// // Eigen wants int objects +// std::vector r_to_l2(r_to_l.begin(), r_to_l.end()); +// m_tensor_ = rhs_downcasted.value().shuffle(r_to_l2); +// } else { +// m_tensor_ = rhs_downcasted.value(); +// } + +// // TODO: permute layout + +// return *this; +// } TPARAMS typename EIGEN::string_type EIGEN::to_string_() const { diff --git a/src/tensorwrapper/dsl/pairwise_parser.cpp b/src/tensorwrapper/dsl/pairwise_parser.cpp index 8a2c9189..049b76d2 100644 --- a/src/tensorwrapper/dsl/pairwise_parser.cpp +++ b/src/tensorwrapper/dsl/pairwise_parser.cpp @@ -1,5 +1,5 @@ /* - * Copyright 2024 NWChemEx-Project + * Copyright 2025 NWChemEx-Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,70 +13,3 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - -#include -#include -#include - -namespace tensorwrapper::dsl { -namespace { -struct CallAddition { - template - static decltype(auto) run(LHSType&& lhs, RHSType&& rhs) { - const auto& llabels = lhs.rhs(); - return lhs.lhs().addition(llabels, std::forward(rhs)); - } -}; - -template -decltype(auto) binary_op(ResultType&& result, LHSType&& lhs, RHSType&& rhs) { - auto& rv_object = result.lhs(); - const auto& lhs_object = lhs.lhs(); - const auto& rhs_object = rhs.lhs(); - - const auto& lhs_labels = lhs.rhs(); - const auto& rhs_labels = rhs.rhs(); - - using object_type = typename std::decay_t::object_type; - - if constexpr(std::is_same_v) { - if(rv_object == Tensor{}) { - const auto& llayout = lhs_object.logical_layout(); - // const auto& rlayout = rhs_object.logical_layout(); - std::decay_t rv_layout( - llayout); // FunctorType::run(llayout(lhs_labels), - // rlayout(rhs_labels)); - - auto lbuffer = lhs_object.buffer()(lhs_labels); - auto rbuffer = rhs_object.buffer()(rhs_labels); - auto buffer = FunctorType::run(lbuffer, rbuffer); - - // TODO figure out permutation - Tensor(std::move(rv_layout), std::move(buffer)).swap(rv_object); - } else { - throw std::runtime_error("Hints are not allowed yet!"); - } - } else { - // Getting here means the assert will fail - static_assert(std::is_same_v, "NYI"); - } - return result; -} -} // namespace - -#define TPARAMS template -#define PARSER PairwiseParser -#define LABELED_TYPE typename PARSER::labeled_type - -TPARAMS LABELED_TYPE PARSER::add(labeled_type result, labeled_type lhs, - labeled_type rhs) { - return binary_op(result, lhs, rhs); -} - -#undef PARSER -#undef TPARAMS - -template class PairwiseParser; - -} // namespace tensorwrapper::dsl \ No newline at end of file diff --git a/src/tensorwrapper/shape/smooth.cpp b/src/tensorwrapper/shape/smooth.cpp new file mode 100644 index 00000000..d004a8ec --- /dev/null +++ b/src/tensorwrapper/shape/smooth.cpp @@ -0,0 +1,89 @@ +/* + * Copyright 2024 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace tensorwrapper::shape { + +using dsl_reference = typename Smooth::dsl_reference; + +dsl_reference Smooth::addition_assignment_(label_type this_labels, + const_labeled_reference lhs, + const_labeled_reference rhs) { + // Ultimately addition doesn't change the shape unless there's a trace + // or permutation. permute_assignment_ will take care of both scenarios. + + // The base class ensured that lhs and rhs are related by a permutation. + // So all we have to do is permute either lhs or rhs into the final shape + return permute_assignment(this_labels, rhs); +} + +dsl_reference Smooth::subtraction_assignment_(label_type this_labels, + const_labeled_reference lhs, + const_labeled_reference rhs) { + // Ultimately subtraction doesn't change the shape unless there's a trace + // or permutation. permute_assignment_ will take care of both scenarios. + + // The base class ensured that lhs and rhs are related by a permutation. + // So all we have to do is permute either lhs or rhs into the final shape + return permute_assignment(this_labels, rhs); +} + +dsl_reference Smooth::multiplication_assignment_(label_type this_labels, + const_labeled_reference lhs, + const_labeled_reference rhs) { + const auto& labels_lhs = lhs.labels(); + const auto& labels_rhs = rhs.labels(); + auto smooth_lhs = lhs.object().as_smooth(); + auto smooth_rhs = rhs.object().as_smooth(); + + // For each label + // we will be able to find it in either lhs or rhs and then set temp[i] to + // the corresponding extent + extents_type temp(this_labels.size()); + for(size_type i = 0; i < this_labels.size(); ++i) { + const auto& label_i = this_labels.at(i); + + if(labels_lhs.count(label_i)) { + temp[i] = smooth_lhs.extent(labels_lhs.find(label_i)[0]); + } else { + // Base verified this_labels is a subset of lhs + rhs, so must be + // in rhs + temp[i] = smooth_rhs.extent(labels_rhs.find(label_i)[0]); + } + } + m_extents_.swap(temp); + return *this; +} + +dsl_reference Smooth::permute_assignment_(label_type this_labels, + const_labeled_reference rhs) { + if(this_labels.size() != rhs.labels().size()) + throw std::runtime_error("Trace NYI"); + + auto p = rhs.labels().permutation(this_labels); + auto smooth_rhs = rhs.object().as_smooth(); + + extents_type temp(p.size()); + for(typename extents_type::size_type i = 0; i < p.size(); ++i) + temp[p[i]] = smooth_rhs.extent(i); + m_extents_.swap(temp); + + return *this; +} + +} // namespace tensorwrapper::shape \ No newline at end of file diff --git a/tests/cxx/unit_tests/tensorwrapper/buffer/buffer_base.cpp b/tests/cxx/unit_tests/tensorwrapper/buffer/buffer_base.cpp index b82bbcee..74e415f6 100644 --- a/tests/cxx/unit_tests/tensorwrapper/buffer/buffer_base.cpp +++ b/tests/cxx/unit_tests/tensorwrapper/buffer/buffer_base.cpp @@ -70,38 +70,6 @@ TEST_CASE("BufferBase") { REQUIRE(vector_base.layout().are_equal(vector_layout)); } - SECTION("addition") { - scalar_buffer scalar2(eigen_scalar, scalar_layout); - scalar2.value()() = 42.0; - - auto s = scalar(""); - auto pscalar2 = scalar2.addition("", s); - - scalar_buffer scalar_corr(eigen_scalar, scalar_layout); - scalar_corr.value()() = 43.0; - REQUIRE(*pscalar2 == scalar_corr); - } - - SECTION("operator()(std::string)") { - auto labeled_scalar = scalar_base(""); - REQUIRE(labeled_scalar.lhs().are_equal(scalar_base)); - REQUIRE(labeled_scalar.rhs() == ""); - - auto labeled_vector = vector_base("i"); - REQUIRE(labeled_vector.lhs().are_equal(vector_base)); - REQUIRE(labeled_vector.rhs() == "i"); - } - - SECTION("operator()(std::string) const") { - auto labeled_scalar = std::as_const(scalar_base)(""); - REQUIRE(labeled_scalar.lhs().are_equal(scalar_base)); - REQUIRE(labeled_scalar.rhs() == ""); - - auto labeled_vector = std::as_const(vector_base)("i"); - REQUIRE(labeled_vector.lhs().are_equal(vector_base)); - REQUIRE(labeled_vector.rhs() == "i"); - } - SECTION("operator==") { // Defaulted layout == defaulted layout REQUIRE(defaulted_base == scalar_buffer()); diff --git a/tests/cxx/unit_tests/tensorwrapper/buffer/eigen.cpp b/tests/cxx/unit_tests/tensorwrapper/buffer/eigen.cpp index 9d23cad9..73572c8e 100644 --- a/tests/cxx/unit_tests/tensorwrapper/buffer/eigen.cpp +++ b/tests/cxx/unit_tests/tensorwrapper/buffer/eigen.cpp @@ -159,98 +159,6 @@ TEMPLATE_TEST_CASE("Eigen", "", float, double) { REQUIRE(pscalar.are_equal(scalar2)); REQUIRE_FALSE(pmatrix.are_equal(scalar2)); } - - SECTION("addition_assignment") { - SECTION("scalar") { - scalar_buffer scalar2(eigen_scalar, scalar_layout); - scalar2.value()() = 42.0; - - auto s = scalar(""); - auto pscalar2 = &(scalar2.addition_assignment("", s)); - - scalar_buffer scalar_corr(eigen_scalar, scalar_layout); - scalar_corr.value()() = 43.0; - REQUIRE(pscalar2 == &scalar2); - REQUIRE(scalar2 == scalar_corr); - } - - SECTION("vector") { - vector_buffer vector2(eigen_vector, vector_layout); - - auto vi = vector("i"); - auto pvector2 = &(vector2.addition_assignment("i", vi)); - - vector_buffer vector_corr(eigen_vector, vector_layout); - vector_corr.value()(0) = 2.0; - vector_corr.value()(1) = 4.0; - - REQUIRE(pvector2 == &vector2); - REQUIRE(vector2 == vector_corr); - } - - SECTION("matrix") { - matrix_buffer matrix2(eigen_matrix, matrix_layout); - - auto mij = matrix("i,j"); - auto pmatrix2 = &(matrix2.addition_assignment("i,j", mij)); - - matrix_buffer matrix_corr(eigen_matrix, matrix_layout); - - matrix_corr.value()(0, 0) = 2.0; - matrix_corr.value()(0, 1) = 4.0; - matrix_corr.value()(0, 2) = 6.0; - matrix_corr.value()(1, 0) = 8.0; - matrix_corr.value()(1, 1) = 10.0; - matrix_corr.value()(1, 2) = 12.0; - - REQUIRE(pmatrix2 == &matrix2); - REQUIRE(matrix2 == matrix_corr); - - // SECTION("permutation") { - // layout::Physical l(shape::Smooth{3, 2}, g, p); - // std::array p10{1, 0}; - // auto eigen_matrix_t = eigen_matrix.shuffle(p10); - // matrix_buffer matrix3(eigen_matrix_t, l); - - // auto pmatrix3 = - // &(matrix3.addition_assignment("j,i", mij)); - - // matrix_buffer corr(eigen_matrix_t, l); - // corr.value()(0, 0) = 3.0; - // corr.value()(0, 1) = 6.0; - // corr.value()(1, 0) = 9.0; - // corr.value()(1, 1) = 12.0; - // corr.value()(2, 0) = 15.0; - // corr.value()(2, 1) = 18.0; - - // REQUIRE(pmatrix3 == &matrix3); - // REQUIRE(matrix3 == corr); - // } - } - - // Can't cast - REQUIRE_THROWS_AS(vector.addition_assignment("", scalar("")), - std::runtime_error); - - // Labels must match - REQUIRE_THROWS_AS(vector.addition_assignment("j", vector("i")), - std::runtime_error); - } - - SECTION("permute_assignment") { - // layout::Physical l(shape::Smooth{3, 2}, g, p); - // std::array p10{1, 0}; - // auto eigen_matrix_t = eigen_matrix.shuffle(p10); - // matrix_buffer corr(eigen_matrix_t, l); - - // matrix_buffer matrix2; - - // auto& mij = matrix("i,j"); - // auto pmatrix2 = &(matrix2.permute_assignment("j,i", mij)); - - // REQUIRE(pmatrix2 == &matrix2); - // REQUIRE(matrix2 == corr); - } } } } diff --git a/tests/cxx/unit_tests/tensorwrapper/detail_/dsl_base.cpp b/tests/cxx/unit_tests/tensorwrapper/detail_/dsl_base.cpp new file mode 100644 index 00000000..3e640339 --- /dev/null +++ b/tests/cxx/unit_tests/tensorwrapper/detail_/dsl_base.cpp @@ -0,0 +1,180 @@ +/* + * Copyright 2025 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../testing/testing.hpp" + +/* Testing Strategy. + * + * Derived classes are responsible for overriding the virtual methods of the + * DSLBase class and testing that their overloads work by going through at + * least one public API member. The tests here assume that the virtual method + * implementations work and test that the various public APIs to access those + * virtual methods work. For example, both `addition` and `addition_assignment` + * are + * implemented in terms of `addition_assignment_`. The derived class should test + * that `addition_assignment_` works by going through `addition_assignment`, but + * doesn't need to test that `addition` works because this test case will test + * that. + * + * - The tests here also test assertions that can be caught without knowing more + * about the objects, e.g., permute assignment must result in an object with + * the same (or fewer) modes. + */ + +using namespace tensorwrapper; + +using test_types = std::tuple; + +TEMPLATE_LIST_TEST_CASE("DSLBase", "", test_types) { + using object_type = TestType; + using label_type = typename object_type::label_type; + + test_types default_values{shape::Smooth{}}; + test_types values{test_tensorwrapper::smooth_matrix()}; + + auto default_value = std::get(default_values); + auto value = std::get(values); + + SECTION("operator()()") { + SECTION("string labels") { + auto ldefaulted = default_value(""); + REQUIRE(&ldefaulted.object() == &default_value); + REQUIRE(ldefaulted.labels() == ""); + + auto lvalue = value("i,j"); + REQUIRE(&lvalue.object() == &value); + REQUIRE(lvalue.labels() == "i,j"); + } + + SECTION("DummyIndices") { + auto ldefaulted = default_value(label_type("")); + REQUIRE(&ldefaulted.object() == &default_value); + REQUIRE(ldefaulted.labels() == ""); + + auto lvalue = value(label_type("i,j")); + REQUIRE(&lvalue.object() == &value); + REQUIRE(lvalue.labels() == "i,j"); + } + } + + SECTION("operator()() const") { + SECTION("string labels") { + auto ldefaulted = std::as_const(default_value)(""); + REQUIRE(&ldefaulted.object() == &default_value); + REQUIRE(ldefaulted.labels() == ""); + + auto lvalue = std::as_const(value)("i,j"); + REQUIRE(&lvalue.object() == &value); + REQUIRE(lvalue.labels() == "i,j"); + } + + SECTION("DummyIndices") { + auto ldefaulted = std::as_const(default_value)(label_type("")); + REQUIRE(&ldefaulted.object() == &default_value); + REQUIRE(ldefaulted.labels() == ""); + + auto lvalue = std::as_const(value)(label_type("i,j")); + REQUIRE(&lvalue.object() == &value); + REQUIRE(lvalue.labels() == "i,j"); + } + } + + SECTION("addition_assignment") { + // N.b., does error checks before calling addition_assignment_. We + // assume addition_assignment_ works and focus on the error checks + using error_t = std::runtime_error; + auto s = default_value(""); + auto sij = default_value("i,j"); + auto mij = value("i,j"); + auto mik = value("i,k"); + + // LHS's indices must match rank + REQUIRE_THROWS_AS(value.addition_assignment("i,j", sij, s), error_t); + + // RHS's indices must match rank + REQUIRE_THROWS_AS(value.addition_assignment("i,j", s, sij), error_t); + + // LHS and RHS must be related by a permutation + REQUIRE_THROWS_AS(value.addition_assignment("i,j", mij, mik), error_t); + + // Output must have <= number of dummy indices + REQUIRE_THROWS_AS(value.addition_assignment("i,j", s, s), error_t); + } + + SECTION("subtraction_assignment") { + // N.b., does error checks before calling addition_assignment_. We + // assume addition_assignment_ works and focus on the error checks + using error_t = std::runtime_error; + auto s = default_value(""); + auto sij = default_value("i,j"); + auto mij = value("i,j"); + auto mik = value("i,k"); + + // LHS's indices must match rank + REQUIRE_THROWS_AS(value.subtraction_assignment("i,j", sij, s), error_t); + + // RHS's indices must match rank + REQUIRE_THROWS_AS(value.subtraction_assignment("i,j", s, sij), error_t); + + // LHS and RHS must be related by a permutation + REQUIRE_THROWS_AS(value.subtraction_assignment("i,j", mij, mik), + error_t); + + // Output must have <= number of dummy indices + REQUIRE_THROWS_AS(value.subtraction_assignment("i,j", s, s), error_t); + } + + SECTION("multiplication_assignment") { + // N.b., does error checks before calling addition_assignment_. We + // assume addition_assignment_ works and focus on the error checks + using error_t = std::runtime_error; + auto s = default_value(""); + auto sij = default_value("i,j"); + auto mij = value("i,j"); + auto mik = value("i,k"); + + // LHS's indices must match rank + REQUIRE_THROWS_AS(value.multiplication_assignment("i,j", sij, s), + error_t); + + // RHS's indices must match rank + REQUIRE_THROWS_AS(value.multiplication_assignment("i,j", s, sij), + error_t); + } + + SECTION("permute_assignment") { + // N.b., does error checks before calling permute_assignment_. We assume + // permute_assignment_ works and focus on the error checks + using error_t = std::runtime_error; + auto s = default_value(""); + auto sij = default_value("i,j"); + + // Input's indices must match rank + REQUIRE_THROWS_AS(value.permute_assignment("i,j", sij), error_t); + + // Output must have <= number of dummy indices + REQUIRE_THROWS_AS(value.permute_assignment("i,j", s), error_t); + } + + SECTION("scalar_multiplication") { + // N.b., only tensor and buffer will override so here we're checking + // that other objects throw + using error_t = std::runtime_error; + + // Input's indices must match rank + REQUIRE_THROWS_AS(value.scalar_multiplication(1.0), error_t); + } +} \ No newline at end of file diff --git a/tests/cxx/unit_tests/tensorwrapper/dsl/dsl.cpp b/tests/cxx/unit_tests/tensorwrapper/dsl/dsl.cpp new file mode 100644 index 00000000..68c5d9d4 --- /dev/null +++ b/tests/cxx/unit_tests/tensorwrapper/dsl/dsl.cpp @@ -0,0 +1,68 @@ +/* + * Copyright 2025 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../testing/testing.hpp" +#include + +using namespace tensorwrapper; + +using test_types = std::tuple; + +TEMPLATE_LIST_TEST_CASE("DSL", "", test_types) { + using object_type = TestType; + + test_types scalar_values{test_tensorwrapper::smooth_scalar()}; + test_types matrix_values{test_tensorwrapper::smooth_matrix()}; + auto value0 = std::get(scalar_values); + auto value2 = std::get(matrix_values); + + SECTION("assignment") { + value0("i,j") = value2("i,j"); + REQUIRE(value0 == value2); + } + + SECTION("permutation") { + value0("j,i") = value2("i,j"); + + object_type corr{}; + corr.permute_assignment("i,j", value2("j,i")); + REQUIRE(corr.are_equal(value0)); + } + + SECTION("addition") { + value0("i,j") = value2("i,j") + value2("i,j"); + + object_type corr{}; + corr.addition_assignment("i,j", value2("i,j"), value2("i,j")); + REQUIRE(corr.are_equal(value0)); + } + + SECTION("subtraction") { + value0("i,j") = value2("i,j") - value2("i,j"); + + object_type corr{}; + corr.subtraction_assignment("i,j", value2("i,j"), value2("i,j")); + REQUIRE(corr.are_equal(value0)); + } + + SECTION("multiplication") { + value0("i,j") = value2("i,j") * value2("i,j"); + + object_type corr{}; + corr.multiplication_assignment("i,j", value2("i,j"), value2("i,j")); + REQUIRE(corr.are_equal(value0)); + } +} \ No newline at end of file diff --git a/tests/cxx/unit_tests/tensorwrapper/dsl/dummy_indices.cpp b/tests/cxx/unit_tests/tensorwrapper/dsl/dummy_indices.cpp index e279ebf0..efa42ab2 100644 --- a/tests/cxx/unit_tests/tensorwrapper/dsl/dummy_indices.cpp +++ b/tests/cxx/unit_tests/tensorwrapper/dsl/dummy_indices.cpp @@ -72,6 +72,20 @@ TEST_CASE("DummyIndices") { REQUIRE(dummy_indices_type("i,i").has_repeated_indices()); } + SECTION("is_permutation") { + REQUIRE(scalar.is_permutation(scalar)); + REQUIRE_FALSE(scalar.is_permutation(vector)); + + REQUIRE(vector.is_permutation(vector)); + REQUIRE_FALSE(vector.is_permutation(scalar)); + REQUIRE_FALSE(vector.is_permutation(dummy_indices_type("j"))); + + REQUIRE(matrix.is_permutation(matrix)); + REQUIRE(matrix.is_permutation(dummy_indices_type("j,i"))); + REQUIRE_FALSE(matrix.is_permutation(scalar)); + REQUIRE_FALSE(matrix.is_permutation(dummy_indices_type("i,k"))); + } + SECTION("permutation") { using offset_vector = typename dummy_indices_type::offset_vector; @@ -132,30 +146,98 @@ TEST_CASE("DummyIndices") { REQUIRE(dummy_indices_type("i,i").find("i") == offset_vector{0, 1}); } - SECTION("comparison") { + SECTION("count") { + REQUIRE(defaulted.count("") == 0); + + REQUIRE(scalar.count("") == 0); + + REQUIRE(vector.count("") == 0); + REQUIRE(vector.count("i") == 1); + REQUIRE(vector.count("j") == 0); + + REQUIRE(matrix.count("") == 0); + REQUIRE(matrix.count("i") == 1); + REQUIRE(matrix.count("j") == 1); + REQUIRE(dummy_indices_type("i,i").count("i") == 2); + } + + SECTION("operator==") { // Default construction is indistinguishable from scalar indices REQUIRE(defaulted == scalar); + REQUIRE(defaulted == ""); // Different ranks are different REQUIRE_FALSE(defaulted == vector); + REQUIRE_FALSE(defaulted == "i"); // Same vector indices REQUIRE(vector == dummy_indices_type("i")); + REQUIRE(vector == "i"); // Different vector indices REQUIRE_FALSE(vector == dummy_indices_type("j")); + REQUIRE_FALSE(vector == "j"); // Same matrix indices REQUIRE(matrix == dummy_indices_type("i,j")); + REQUIRE(matrix == "i,j"); // Spaces aren't significant REQUIRE(matrix == dummy_indices_type("i, j")); + REQUIRE(matrix == "i, j"); REQUIRE(matrix == dummy_indices_type(" i , j ")); + REQUIRE(matrix == " i , j "); // Are case sensitive REQUIRE_FALSE(matrix == dummy_indices_type("I,j")); + REQUIRE_FALSE(matrix == "I,j"); // Permutations are different REQUIRE_FALSE(matrix == dummy_indices_type("j,i")); + REQUIRE_FALSE(matrix == "j,i"); + } + + SECTION("operator!=") { + // Just negates operator== so spot checking is fine + REQUIRE_FALSE(vector != dummy_indices_type("i")); + REQUIRE_FALSE(vector != "i"); + REQUIRE(vector != dummy_indices_type("j")); + REQUIRE(vector != "j"); + } + + SECTION("concatenation") { + dummy_indices_type matrix2("k,l"); + REQUIRE(scalar.concatenation(scalar) == dummy_indices_type("")); + REQUIRE(scalar.concatenation(vector) == dummy_indices_type("i")); + REQUIRE(scalar.concatenation(matrix) == dummy_indices_type("i,j")); + REQUIRE(scalar.concatenation(matrix2) == dummy_indices_type("k,l")); + + REQUIRE(vector.concatenation(scalar) == dummy_indices_type("i")); + REQUIRE(vector.concatenation(vector) == dummy_indices_type("i,i")); + REQUIRE(vector.concatenation(matrix) == dummy_indices_type("i,i,j")); + REQUIRE(vector.concatenation(matrix2) == dummy_indices_type("i,k,l")); + + REQUIRE(matrix.concatenation(scalar) == dummy_indices_type("i,j")); + REQUIRE(matrix.concatenation(vector) == dummy_indices_type("i,j,i")); + REQUIRE(matrix.concatenation(matrix) == dummy_indices_type("i,j,i,j")); + REQUIRE(matrix.concatenation(matrix2) == dummy_indices_type("i,j,k,l")); + } + + SECTION("intersection") { + dummy_indices_type matrix2("k,l"); + REQUIRE(scalar.intersection(scalar) == dummy_indices_type("")); + REQUIRE(scalar.intersection(vector) == dummy_indices_type("")); + REQUIRE(scalar.intersection(matrix) == dummy_indices_type("")); + REQUIRE(scalar.intersection(matrix2) == dummy_indices_type("")); + + REQUIRE(vector.intersection(scalar) == dummy_indices_type("")); + REQUIRE(vector.intersection(vector) == dummy_indices_type("i")); + REQUIRE(vector.intersection(matrix) == dummy_indices_type("i")); + REQUIRE(vector.intersection(matrix2) == dummy_indices_type("")); + + REQUIRE(matrix.intersection(scalar) == dummy_indices_type("")); + REQUIRE(matrix.intersection(vector) == dummy_indices_type("i")); + REQUIRE(matrix.intersection(matrix) == dummy_indices_type("i,j")); + REQUIRE(matrix.intersection(matrix2) == dummy_indices_type("")); } } diff --git a/tests/cxx/unit_tests/tensorwrapper/dsl/labeled.cpp b/tests/cxx/unit_tests/tensorwrapper/dsl/labeled.cpp index 95788bf3..be49f8e1 100644 --- a/tests/cxx/unit_tests/tensorwrapper/dsl/labeled.cpp +++ b/tests/cxx/unit_tests/tensorwrapper/dsl/labeled.cpp @@ -19,59 +19,127 @@ using namespace tensorwrapper; -using test_types = std::tuple; +using test_types = std::tuple; TEMPLATE_LIST_TEST_CASE("Labeled", "", test_types) { - using object_type = TestType; - using labeled_type = dsl::Labeled; - using labels_type = typename labeled_type::label_type; + using object_type = TestType; + using labeled_type = dsl::Labeled; + using const_labeled_type = dsl::Labeled; + using labels_type = typename labeled_type::label_type; + test_types defaulted_values{shape::Smooth{}}; + test_types values{test_tensorwrapper::smooth_matrix()}; + + labels_type scalar; labels_type ij("i,j"); - object_type defaulted{}; - labeled_type labeled_default(defaulted, ij); + auto defaulted = std::get(defaulted_values); + auto value = std::get(values); + + labeled_type labeled_default(defaulted, scalar); + labeled_type labeled_value(value, ij); + const_labeled_type clabeled_default(defaulted, scalar); + const_labeled_type clabeled_value(value, ij); SECTION("Ctor") { SECTION("Value") { - REQUIRE(labeled_default.lhs() == defaulted); - REQUIRE(labeled_default.rhs() == ij); + // Taking label_type object + REQUIRE(&labeled_default.object() == &defaulted); + REQUIRE(labeled_default.labels() == scalar); + + REQUIRE(&clabeled_default.object() == &defaulted); + REQUIRE(clabeled_default.labels() == scalar); + + REQUIRE(&labeled_value.object() == &value); + REQUIRE(labeled_value.labels() == ij); + + REQUIRE(&clabeled_value.object() == &value); + REQUIRE(clabeled_value.labels() == ij); + + // Taking string literal object + labeled_type labeled2(value, "i,j"); + REQUIRE(&labeled2.object() == &value); + REQUIRE(labeled2.labels() == ij); } - SECTION("to const") { - using const_labeled_type = dsl::Labeled; + SECTION("mutable to const conversion") { const_labeled_type const_labeled_default(labeled_default); - REQUIRE(const_labeled_default.lhs() == defaulted); - REQUIRE(const_labeled_default.rhs() == ij); + REQUIRE(&const_labeled_default.object() == &defaulted); + REQUIRE(const_labeled_default.labels() == scalar); } - } - SECTION("operator=") { - // At present this operator just calls Parser dispatch. We know that - // works from other tests so here we just spot check. - Tensor t; - - SECTION("scalar") { - Tensor scalar(testing::smooth_scalar()); - auto labeled_t = t(""); - auto plabeled_t = &(labeled_t = scalar("") + scalar("")); - REQUIRE(plabeled_t == &labeled_t); - - auto buffer = testing::eigen_scalar(); - buffer.value()() = 84.0; - Tensor corr(scalar.logical_layout(), std::move(buffer)); - REQUIRE(t == corr); + // N.b., there is no default ctor so we can't use the testing helpers + SECTION("Copy ctor") { + labeled_type labeled_copy(labeled_default); + REQUIRE(&labeled_copy.object() == &defaulted); + REQUIRE(&labeled_copy.labels() != &labeled_default.labels()); + REQUIRE(labeled_copy.labels() == scalar); } - SECTION("Vector") { - Tensor vector(testing::smooth_vector()); - auto labeled_t = t("i"); - auto plabeled_t = &(labeled_t = vector("i") + vector("i")); - REQUIRE(plabeled_t == &labeled_t); - - auto buffer = testing::eigen_vector(); - for(std::size_t i = 0; i < 5; ++i) buffer.value()(i) = i + i; - Tensor corr(t.logical_layout(), std::move(buffer)); - REQUIRE(t == corr); + SECTION("Move ctor") { + labeled_type labeled_move(std::move(labeled_value)); + REQUIRE(&labeled_move.object() == &value); + REQUIRE(labeled_move.labels() == ij); } } + + SECTION("evaluation, i.e., operator=") { + // copy-assignment-like operation + labeled_type other(defaulted, "i,j"); + auto pother = &(other = labeled_value); + REQUIRE(pother == &other); + REQUIRE(other.object().are_equal(value)); + REQUIRE(&other.labels() != &labeled_value.labels()); + REQUIRE(other.labels() == "i,j"); + } + + SECTION("object()") { + REQUIRE(labeled_default.object().are_equal(defaulted)); + REQUIRE(clabeled_default.object().are_equal(defaulted)); + REQUIRE(labeled_value.object().are_equal(value)); + REQUIRE(clabeled_value.object().are_equal(value)); + } + + SECTION("object() const") { + REQUIRE(std::as_const(labeled_default).object().are_equal(defaulted)); + REQUIRE(std::as_const(clabeled_default).object().are_equal(defaulted)); + REQUIRE(std::as_const(labeled_value).object().are_equal(value)); + REQUIRE(std::as_const(clabeled_value).object().are_equal(value)); + } + + SECTION("labels()") { + REQUIRE(labeled_default.labels() == scalar); + REQUIRE(clabeled_default.labels() == scalar); + REQUIRE(labeled_value.labels() == ij); + REQUIRE(clabeled_value.labels() == ij); + } + + SECTION("labels() const") { + REQUIRE(std::as_const(labeled_default).labels() == scalar); + REQUIRE(std::as_const(clabeled_default).labels() == scalar); + REQUIRE(std::as_const(labeled_value).labels() == ij); + REQUIRE(std::as_const(clabeled_value).labels() == ij); + } + + SECTION("operator==") { + // Same values and const-ness + REQUIRE(labeled_default == labeled_type(defaulted, scalar)); + + // Same values different const-ness + REQUIRE(labeled_default == clabeled_default); + REQUIRE(clabeled_default == labeled_default); + + // Different object, same labels + auto value2 = test_tensorwrapper::smooth_matrix(20, 10); + REQUIRE_FALSE(labeled_value == labeled_type(value2, ij)); + + // Same object, different labels + REQUIRE_FALSE(labeled_value == labeled_type(value, "j,i")); + } + + SECTION("operator!=") { + // Just negates operator== so spot checking is fine + REQUIRE_FALSE(labeled_default != clabeled_default); + REQUIRE(labeled_value != labeled_type(value, "j,i")); + } } \ No newline at end of file diff --git a/tests/cxx/unit_tests/tensorwrapper/dsl/pairwise_parser.cpp b/tests/cxx/unit_tests/tensorwrapper/dsl/pairwise_parser.cpp index 5c8ca0d5..1c82fb60 100644 --- a/tests/cxx/unit_tests/tensorwrapper/dsl/pairwise_parser.cpp +++ b/tests/cxx/unit_tests/tensorwrapper/dsl/pairwise_parser.cpp @@ -13,41 +13,92 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #include "../testing/testing.hpp" -#include using namespace tensorwrapper; -TEST_CASE("PairwiseParser") { - Tensor scalar(testing::smooth_scalar()); - Tensor vector(testing::smooth_vector()); +using test_types = std::tuple; + +TEMPLATE_LIST_TEST_CASE("PairwiseParser", "", test_types) { + using object_type = TestType; + + test_types scalar_values{test_tensorwrapper::smooth_scalar()}; + test_types matrix_values{test_tensorwrapper::smooth_matrix()}; + + auto value0 = std::get(scalar_values); + auto value2 = std::get(matrix_values); + + dsl::PairwiseParser p; + + SECTION("assignment") { + object_type rv{}; + object_type corr{}; + SECTION("scalar") { + p.dispatch(rv(""), value0("")); + corr.permute_assignment("", value0("")); + REQUIRE(corr.are_equal(rv)); + } + + SECTION("matrix") { + p.dispatch(rv("i,j"), value2("i,j")); + corr.permute_assignment("i,j", value2("i,j")); + REQUIRE(corr.are_equal(rv)); + } + } - dsl::PairwiseParser p; + SECTION("addition") { + object_type rv{}; + object_type corr{}; + SECTION("scalar") { + p.dispatch(rv(""), value0("") + value0("")); + corr.addition_assignment("", value0(""), value0("")); + REQUIRE(corr.are_equal(rv)); + } - SECTION("add") { - Tensor t; + SECTION("matrix") { + p.dispatch(rv("i,j"), value2("i,j") + value2("i,j")); + corr.addition_assignment("i,j", value2("i,j"), value2("i,j")); + REQUIRE(corr.are_equal(rv)); + } + } + SECTION("subtraction") { + object_type rv{}; + object_type corr{}; SECTION("scalar") { - auto rv = p.dispatch(t(""), scalar("") + scalar("")); - REQUIRE(&rv.lhs() == &t); - REQUIRE(rv.rhs() == ""); + p.dispatch(rv(""), value0("") - value0("")); + corr.subtraction_assignment("", value0(""), value0("")); + REQUIRE(corr.are_equal(rv)); + } - auto buffer = testing::eigen_scalar(); - buffer.value()() = 84.0; - Tensor corr(scalar.logical_layout(), std::move(buffer)); - REQUIRE(t == corr); + SECTION("matrix") { + p.dispatch(rv("i,j"), value2("i,j") - value2("i,j")); + corr.subtraction_assignment("i,j", value2("i,j"), value2("i,j")); + REQUIRE(corr.are_equal(rv)); } + } - SECTION("Vector") { - auto rv = p.dispatch(t("i"), vector("i") + vector("i")); - REQUIRE(&rv.lhs() == &t); - REQUIRE(rv.rhs() == "i"); + SECTION("multiplication") { + object_type rv{}; + object_type corr{}; + SECTION("scalar") { + p.dispatch(rv(""), value0("") * value0("")); + corr.multiplication_assignment("", value0(""), value0("")); + REQUIRE(corr.are_equal(rv)); + } - auto buffer = testing::eigen_vector(); - for(std::size_t i = 0; i < 5; ++i) buffer.value()(i) = i + i; - Tensor corr(t.logical_layout(), std::move(buffer)); - REQUIRE(t == corr); + SECTION("matrix") { + p.dispatch(rv("i,j"), value2("i,j") * value2("i,j")); + corr.multiplication_assignment("i,j", value2("i,j"), value2("i,j")); + REQUIRE(corr.are_equal(rv)); } } + + SECTION("scalar_multiplication") { + // N.b., only tensor and buffer will override so here we're checking + // that other objects throw + using error_t = std::runtime_error; + + REQUIRE_THROWS_AS(p.dispatch(value0(""), value0("") * 1.0), error_t); + } } \ No newline at end of file diff --git a/tests/cxx/unit_tests/tensorwrapper/shape/smooth.cpp b/tests/cxx/unit_tests/tensorwrapper/shape/smooth.cpp index 9bdff9bd..0df4c693 100644 --- a/tests/cxx/unit_tests/tensorwrapper/shape/smooth.cpp +++ b/tests/cxx/unit_tests/tensorwrapper/shape/smooth.cpp @@ -109,6 +109,187 @@ TEST_CASE("Smooth") { REQUIRE(scalar.are_equal(Smooth{})); REQUIRE_FALSE(vector.are_equal(matrix)); } + + SECTION("addition_assignment_") { + // Works by calling permute_assignment_ so just spot check + Smooth matrix2{}; + auto mij = matrix("i,j"); + auto pmatrix2 = &(matrix2.addition_assignment("i,j", mij, mij)); + REQUIRE(pmatrix2 == &matrix2); + REQUIRE(matrix2 == matrix); + } + + SECTION("subtraction_assignment_") { + // Works by calling permute_assignment_ so just spot check + Smooth matrix2{}; + auto mij = matrix("i,j"); + auto pmatrix2 = &(matrix2.subtraction_assignment("i,j", mij, mij)); + REQUIRE(pmatrix2 == &matrix2); + REQUIRE(matrix2 == matrix); + } + + SECTION("multiplication_assignment_") { + // N.b., these are not exhaustive because there are a lot of + // possibilities. + + Smooth scalar2{}; + auto s = scalar(""); + auto vi = vector("i"); + auto mij = matrix("i,j"); + + SECTION("Scalar times scalar") { + auto pscalar2 = &(scalar2.multiplication_assignment("", s, s)); + REQUIRE(pscalar2 == &scalar2); + REQUIRE(scalar2 == scalar); + } + + SECTION("Scalar times vector") { + scalar2.multiplication_assignment("i", s, vi); + REQUIRE(scalar2 == vector); + + scalar2.multiplication_assignment("i", vi, s); + REQUIRE(scalar2 == vector); + + scalar2.multiplication_assignment("", vi, s); + REQUIRE(scalar2 == scalar); + } + + SECTION("Scalar times matrix") { + scalar2.multiplication_assignment("i,j", s, mij); + REQUIRE(scalar2 == matrix); + + scalar2.multiplication_assignment("i,j", mij, s); + REQUIRE(scalar2 == matrix); + + scalar2.multiplication_assignment("j,i", mij, s); + REQUIRE(scalar2 == Smooth{3, 2}); + + scalar2.multiplication_assignment("j,i", s, mij); + REQUIRE(scalar2 == Smooth{3, 2}); + + scalar2.multiplication_assignment("i", s, mij); + REQUIRE(scalar2 == Smooth{2}); + + scalar2.multiplication_assignment("i", mij, s); + REQUIRE(scalar2 == Smooth{2}); + + scalar2.multiplication_assignment("j", s, mij); + REQUIRE(scalar2 == Smooth{3}); + + scalar2.multiplication_assignment("j", mij, s); + REQUIRE(scalar2 == Smooth{3}); + + scalar2.multiplication_assignment("", mij, s); + REQUIRE(scalar2 == scalar); + } + + SECTION("Vector times vector") { + scalar2.multiplication_assignment("i", vi, vi); + REQUIRE(scalar2 == vector); + + scalar2.multiplication_assignment("i,j", vi, vector("j")); + REQUIRE(scalar2 == Smooth{1, 1}); + + scalar2.multiplication_assignment("", vi, vi); + REQUIRE(scalar2 == scalar); + } + + SECTION("Vector times matrix") { + Smooth vector2{2}; + + scalar2.multiplication_assignment("i,j,k", vector2("k"), mij); + REQUIRE(scalar2 == Smooth{2, 3, 2}); + + scalar2.multiplication_assignment("i,j", vector2("i"), mij); + REQUIRE(scalar2 == matrix); + + scalar2.multiplication_assignment("j,i", mij, vector2("i")); + REQUIRE(scalar2 == Smooth{3, 2}); + + scalar2.multiplication_assignment("i", vector2("i"), mij); + REQUIRE(scalar2 == Smooth{2}); + + scalar2.multiplication_assignment("i", mij, vector2("i")); + REQUIRE(scalar2 == Smooth{2}); + + scalar2.multiplication_assignment("j", vector2("i"), mij); + REQUIRE(scalar2 == Smooth{3}); + + scalar2.multiplication_assignment("j", mij, vector2("i")); + REQUIRE(scalar2 == Smooth{3}); + + scalar2.multiplication_assignment("", mij, vector2("i")); + REQUIRE(scalar2 == scalar); + } + + SECTION("Matrix times matrix") { + Smooth matrix2{2, 3}; + auto mkl = matrix2("k,l"); + + scalar2.multiplication_assignment("i,j,k,l", mkl, mij); + REQUIRE(scalar2 == Smooth{2, 3, 2, 3}); + + scalar2.multiplication_assignment("i,j,k", mkl, mij); + REQUIRE(scalar2 == Smooth{2, 3, 2}); + + scalar2.multiplication_assignment("i,j,l", mkl, mij); + REQUIRE(scalar2 == Smooth{2, 3, 3}); + + scalar2.multiplication_assignment("i,k,l", mkl, mij); + REQUIRE(scalar2 == Smooth{2, 2, 3}); + + scalar2.multiplication_assignment("j,k,l", mkl, mij); + REQUIRE(scalar2 == Smooth{3, 2, 3}); + + scalar2.multiplication_assignment("j,k", matrix2("i,k"), mij); + REQUIRE(scalar2 == Smooth{3, 3}); + } + } + + SECTION("permute_assignment_") { + SECTION("assign to empty") { + Smooth scalar2{}; + auto pscalar2 = &(scalar2.permute_assignment("", scalar(""))); + REQUIRE(pscalar2 == &scalar2); + REQUIRE(scalar2 == scalar); + + Smooth vector2{}; + auto pvector2 = &(vector2.permute_assignment("i", vector("i"))); + REQUIRE(pvector2 == &vector2); + REQUIRE(vector2 == vector); + + Smooth matrix2{}; + auto mij = matrix("i,j"); + auto pmatrix2 = &(matrix2.permute_assignment("i,j", mij)); + REQUIRE(pmatrix2 == &matrix2); + REQUIRE(matrix2 == matrix); + + Smooth tensor2{}; + auto tijk = tensor("i,j,k"); + auto ptensor2 = &(tensor2.permute_assignment("i,j,k", tijk)); + REQUIRE(ptensor2 == &tensor2); + REQUIRE(tensor2 == tensor); + } + + SECTION("assign with permute") { + Smooth matrix2{10, 10}; // Will double check it overwrites + auto mij = matrix("i,j"); // n.b., it's a 2 by 3 + auto pmatrix2 = &(matrix2.permute_assignment("j,i", mij)); + Smooth corr{3, 2}; + REQUIRE(pmatrix2 == &matrix2); + REQUIRE(matrix2 == corr); + + Smooth tensor2{}; + auto tijk = tensor("i,j,k"); // n.b., it's 3 by 4 by 5 + auto ptensor2 = &(tensor2.permute_assignment("k,j,i", tijk)); + REQUIRE(ptensor2 == &tensor2); + REQUIRE(tensor2 == Smooth{5, 4, 3}); + } + + // Requesting a trace + REQUIRE_THROWS_AS(scalar.permute_assignment("", vector("i")), + std::runtime_error); + } } SECTION("Utility methods") { diff --git a/tests/cxx/unit_tests/tensorwrapper/tensor/tensor_class.cpp b/tests/cxx/unit_tests/tensorwrapper/tensor/tensor_class.cpp index 25a9fe5d..05a59f96 100644 --- a/tests/cxx/unit_tests/tensorwrapper/tensor/tensor_class.cpp +++ b/tests/cxx/unit_tests/tensorwrapper/tensor/tensor_class.cpp @@ -119,24 +119,6 @@ TEST_CASE("Tensor") { REQUIRE_THROWS_AS(const_defaulted.buffer(), std::runtime_error); } - SECTION("operator(std::string)") { - auto labeled_scalar = scalar(""); - auto labeled_vector = vector("i"); - - using labeled_tensor_type = Tensor::labeled_tensor_type; - REQUIRE(labeled_scalar == labeled_tensor_type(scalar, "")); - REQUIRE(labeled_vector == labeled_tensor_type(vector, "i")); - } - - SECTION("operator(std::string) const") { - auto labeled_scalar = std::as_const(scalar)(""); - auto labeled_vector = std::as_const(vector)("i"); - - using const_labeled_tensor_type = Tensor::const_labeled_tensor_type; - REQUIRE(labeled_scalar == const_labeled_tensor_type(scalar, "")); - REQUIRE(labeled_vector == const_labeled_tensor_type(vector, "i")); - } - SECTION("swap") { Tensor scalar_copy(scalar); Tensor vector_copy(vector); @@ -179,27 +161,4 @@ TEST_CASE("Tensor") { REQUIRE_FALSE(scalar != other_scalar); REQUIRE(scalar != vector); } - - SECTION("DSL") { - // These are just spot checks to make sure the DSL works on the user - // side - SECTION("Scalar") { - Tensor rv; - rv("") = scalar("") + scalar(""); - auto buffer = testing::eigen_scalar(); - buffer.value()() = 84.0; - Tensor corr(scalar.logical_layout(), std::move(buffer)); - REQUIRE(rv == corr); - } - - SECTION("Vector") { - Tensor rv; - rv("i") = vector("i") + vector("i"); - - auto buffer = testing::eigen_vector(); - for(std::size_t i = 0; i < 5; ++i) buffer.value()(i) = i + i; - Tensor corr(vector.logical_layout(), std::move(buffer)); - REQUIRE(rv == corr); - } - } } diff --git a/tests/cxx/unit_tests/tensorwrapper/testing/shapes.hpp b/tests/cxx/unit_tests/tensorwrapper/testing/shapes.hpp new file mode 100644 index 00000000..1e25ec0c --- /dev/null +++ b/tests/cxx/unit_tests/tensorwrapper/testing/shapes.hpp @@ -0,0 +1,42 @@ +/* + * Copyright 2025 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** @file shapes.hpp + * + * This file contains some already made shape objects to facilitate unit + * testing of TensorWrapper. + */ +#pragma once +#include + +namespace test_tensorwrapper { + +inline auto smooth_scalar() { return tensorwrapper::shape::Smooth{}; } + +inline auto smooth_vector(std::size_t i = 10) { + return tensorwrapper::shape::Smooth{i}; +} + +inline auto smooth_matrix(std::size_t i = 10, std::size_t j = 10) { + return tensorwrapper::shape::Smooth{i, j}; +} + +inline auto smooth_tensor(std::size_t i = 10, std::size_t j = 10, + std::size_t k = 10) { + return tensorwrapper::shape::Smooth{i, j, k}; +} + +} // namespace test_tensorwrapper \ No newline at end of file diff --git a/tests/cxx/unit_tests/tensorwrapper/testing/testing.hpp b/tests/cxx/unit_tests/tensorwrapper/testing/testing.hpp index f5b5e683..47e6de70 100644 --- a/tests/cxx/unit_tests/tensorwrapper/testing/testing.hpp +++ b/tests/cxx/unit_tests/tensorwrapper/testing/testing.hpp @@ -17,4 +17,5 @@ #pragma once #include "../helpers.hpp" #include "../inputs.hpp" -#include "eigen_buffers.hpp" \ No newline at end of file +#include "eigen_buffers.hpp" +#include "shapes.hpp" \ No newline at end of file