diff --git a/include/tensorwrapper/buffer/mdbuffer.hpp b/include/tensorwrapper/buffer/mdbuffer.hpp index 72f5c765..ab5e1cc0 100644 --- a/include/tensorwrapper/buffer/mdbuffer.hpp +++ b/include/tensorwrapper/buffer/mdbuffer.hpp @@ -15,6 +15,9 @@ */ #pragma once +#include +#include +#include #include namespace tensorwrapper::buffer { @@ -23,42 +26,312 @@ namespace tensorwrapper::buffer { * * This class is a dense multidimensional buffer of floating-point values. */ -class MDBuffer { +class MDBuffer : public Replicated { private: + /// Type *this derives from + using my_base_type = Replicated; + + /// Type defining the types for the public API of *this using traits_type = types::ClassTraits; + /// Type of *this + using my_type = MDBuffer; + public: - /// Add types to public API + /// Add types from traits_type to public API ///@{ - using buffer_type = typename traits_type::buffer_type; - using pimpl_type = typename traits_type::pimpl_type; - using pimpl_pointer = typename traits_type::pimpl_pointer; - using rank_type = typename traits_type::rank_type; - using shape_type = typename traits_type::shape_type; + using value_type = typename traits_type::value_type; + using reference = typename traits_type::reference; + using const_reference = typename traits_type::const_reference; + using buffer_type = typename traits_type::buffer_type; + using buffer_view = typename traits_type::buffer_view; + using const_buffer_view = typename traits_type::const_buffer_view; + using pimpl_type = typename traits_type::pimpl_type; + using pimpl_pointer = typename traits_type::pimpl_pointer; + using rank_type = typename traits_type::rank_type; + using shape_type = typename traits_type::shape_type; + using const_shape_view = typename traits_type::const_shape_view; + using size_type = typename traits_type::size_type; ///@} + using index_vector = std::vector; + using typename my_base_type::label_type; + using string_type = std::string; + + // ------------------------------------------------------------------------- + // -- Ctors, assignment, and dtor + // ------------------------------------------------------------------------- + + /** @brief Creates an empty multi-dimensional buffer. + * + * The resulting buffer will have a shape of rank 0, but a size of 0. Thus + * the buffer can NOT be used to store any elements (including treating + * *this as a scalar). The resulting buffer can be assigned to or moved + * to to populate it. + * + * @throw None No throw guarantee. + */ MDBuffer() noexcept; - template - MDBuffer(shape_type shape, std::vector elements) { - MDBuffer(std::move(shape), buffer_type(std::move(elements))); - } + /** @brief Treats allocated memory like a multi-dimensional buffer. + * + * @tparam T The type of the elements in the buffer. Must satisfy the + * FloatingPoint concept. + * + * This ctor will use @p element to create a buffer_type object and then + * pass that along with @p shape to the main ctor. + * + * @param[in] elements The elements to be used as the backing store. + * @param[in] shape The shape of *this. + * + * @throw std::invalid_argument if the size of @p elements does not match + * the size implied by @p shape. Strong throw + * guarantee. + * @throw std::bad_alloc if there is a problem allocating memory for the + * internal state. Strong throw guarantee. + */ + template + MDBuffer(std::vector elements, shape_type shape) : + MDBuffer(buffer_type(std::move(elements)), std::move(shape)) {} + + /** @brief The main ctor. + * + * This ctor will create *this using @p buffer as the backing store and + * @p shape to describe the geometry of the multidimensional array. + * + * All other ctors (aside from copy and move) delegate to this one. + * + * @param[in] buffer The buffer to be used as the backing store. + * @param[in] shape The shape of *this. + * + * @throw std::invalid_argument if the size of @p buffer does not match + * the size implied by @p shape. Strong throw + * guarantee. + * @throw std::bad_alloc if there is a problem allocating memory for the + * internal state. Strong throw guarantee. + */ + MDBuffer(buffer_type buffer, shape_type shape); + + /** @brief Initializes *this to a deep copy of @p other. + * + * This ctor will initialize *this to be a deep copy of @p other. + * + * @param[in] other The MDBuffer to copy. + * + * @throw std::bad_alloc if there is a problem allocating memory for the + * internal state. Strong throw guarantee. + */ + MDBuffer(const MDBuffer& other) = default; + + /** @brief Move ctor. + * + * This ctor will initialize *this by taking the state from @p other. + * After this ctor is called @p other is left in a valid but unspecified + * state. + * + * @param[in,out] other The MDBuffer to move from. + * + * @throw None No throw guarantee. + */ + MDBuffer(MDBuffer&& other) noexcept = default; + + /** @brief Copy assignment. + * + * This operator will make *this a deep copy of @p other. + * + * @param[in] other The MDBuffer to copy. + * + * @return *this after the assignment. + * + * @throw std::bad_alloc if there is a problem allocating memory for the + * internal state. Strong throw guarantee. + */ + MDBuffer& operator=(const MDBuffer& other) = default; + + /** @brief Move assignment. + * + * This operator will make *this take the state from @p other. After + * this operator is called @p other is left in a valid but unspecified + * state. + * + * @param[in,out] other The MDBuffer to move from. + * + * @return *this after the assignment. + * + * @throw None No throw guarantee. + */ + MDBuffer& operator=(MDBuffer&& other) noexcept = default; + + /** @brief Defaulted dtor. + * + * @throw None No throw guarantee. + */ + ~MDBuffer() override = default; + + // ------------------------------------------------------------------------- + // -- State Accessors + // ------------------------------------------------------------------------- + + /** @brief Returns (a view of) the shape of *this. + * + * The shape of *this describes the geometry of the underlying + * multidimensional array. + * + * @return A view of the shape of *this. + * + * @throw std::bad_alloc if there is a problem allocating memory for the + * returned view. Strong throw guarantee. + */ + const_shape_view shape() const; + + /** @brief The total number of elements in *this. + * + * The total number of elements is the product of the extents of each + * mode of *this. + * + * @return The total number of elements in *this. + * + * @throw None No throw guarantee. + */ + size_type size() const noexcept; + + /** @brief Returns the element with the offsets specified by @p index. + * + * This method will retrieve a const reference to the element at the + * offsets specified by @p index. The length of @p index must be equal + * to the rank of *this and each entry in @p index must be less than the + * extent of the corresponding mode of *this. + * + * This method can only be used to retrieve elements from *this. To modify + * elements use set_elem(). + * + * @param[in] index The offsets into each mode of *this for the desired + * element. + * + * @return A const reference to the element at the specified offsets. + */ + const_reference get_elem(index_vector index) const; + + /** @brief Sets the specified element to @p new_value. + * + * This method will set the element at the offsets specified by @p index. + * The length of @p index must be equal to the rank of *this and each + * entry in @p index must be less than the extent of the corresponding + * mode of *this. + * + * @param[in] index The offsets into each mode of *this for the desired + * element. + * @param[in] new_value The new value for the specified element. + * + * @throw std::out_of_range if any entry in @p index is invalid. Strong + * throw guarantee. + */ + void set_elem(index_vector index, value_type new_value); + + /** @brief Returns a view of the data. + * + * This method is deprecated. Use set_slice instead. + */ + [[deprecated]] buffer_view get_mutable_data(); - MDBuffer(shape_type shape, buffer_type buffer); + /** @brief Returns a read-only view of the data. + * + * This method is deprecated. Use get_slice instead. + */ + [[deprecated]] const_buffer_view get_immutable_data() const; - rank_type rank() const; + // ------------------------------------------------------------------------- + // -- Utility Methods + // ------------------------------------------------------------------------- + + /** @brief Compares two MDBuffer objects for exact equality. + * + * Two MDBuffer objects are exactly equal if they have the same shape and + * if all of their corresponding elements are bitwise identical. + * In practice, the implementation stores a hash of the elements in the + * tensor and compares the hashes for equality rather than checking each + * element individually. + * + * @param[in] rhs The MDBuffer to compare against. + * + * @return True if *this and @p rhs are exactly equal and false otherwise. + * + * @throw None No throw guarantee. + */ + bool operator==(const my_type& rhs) const noexcept; + +protected: + /// Makes a deep polymorphic copy of *this + buffer_base_pointer clone_() const override; + + /// Implements are_equal by checking that rhs is an MDBuffer and then + /// calling operator== + bool are_equal_(const_buffer_base_reference rhs) const noexcept override; + + dsl_reference addition_assignment_(label_type this_labels, + const_labeled_reference lhs, + const_labeled_reference rhs) override; + dsl_reference subtraction_assignment_(label_type this_labels, + const_labeled_reference lhs, + const_labeled_reference rhs) override; + dsl_reference multiplication_assignment_( + label_type this_labels, const_labeled_reference lhs, + const_labeled_reference rhs) override; + + dsl_reference permute_assignment_(label_type this_labels, + const_labeled_reference rhs) override; + + dsl_reference scalar_multiplication_(label_type this_labels, double scalar, + const_labeled_reference rhs) override; + + /// Calls add_to_stream_ on a stringstream to implement + string_type to_string_() const override; + + /// Uses Eigen's printing capabilities to add to stream + std::ostream& add_to_stream_(std::ostream& os) const override; private: - explicit MDBuffer(pimpl_pointer pimpl) noexcept; + /// Type for storing the hash of *this + using hash_type = std::size_t; + + /// Logic for validating that an index is within the bounds of the shape + void check_index_(const index_vector& index) const; + + /// Converts a coordinate index to a linear (ordinal) index + size_type coordinate_to_ordinal_(index_vector index) const; + + /// Returns the hash for the current state of *this, computing first if + /// needed. + hash_type get_hash_() const { + if(m_recalculate_hash_ or !m_hash_caching_) update_hash_(); + return m_hash_; + } + + /// Computes the hash for the current state of *this + void update_hash_() const; + + /// Designates that the state may have changed and to recalculate the hash. + /// This function is really just for readability and clarity. + void mark_for_rehash_() const { m_recalculate_hash_ = true; } + + /// Designates that state changes are not trackable and we should + /// recalculate the hash each time. + void turn_off_hash_caching_() const { m_hash_caching_ = false; } + + /// Tracks whether the hash needs to be redetermined + mutable bool m_recalculate_hash_ = true; - bool has_pimpl_() const noexcept; + /// Tracks whether hash caching has been turned off + mutable bool m_hash_caching_ = true; - void assert_pimpl_() const; + /// Holds the computed hash value for this instance's state + mutable hash_type m_hash_ = 0; - pimpl_type& pimpl_(); - const pimpl_type& pimpl_() const; + /// How the hyper-rectangular array is shaped + shape_type m_shape_; - pimpl_pointer m_pimpl_; + /// The flat buffer holding the elements of *this + buffer_type m_buffer_; }; } // namespace tensorwrapper::buffer diff --git a/include/tensorwrapper/concepts/floating_point.hpp b/include/tensorwrapper/concepts/floating_point.hpp new file mode 100644 index 00000000..d95588d0 --- /dev/null +++ b/include/tensorwrapper/concepts/floating_point.hpp @@ -0,0 +1,26 @@ +/* + * Copyright 2025 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +namespace tensorwrapper::concepts { + +using wtf::concepts::ConstFloatingPoint; +using wtf::concepts::FloatingPoint; +using wtf::concepts::UnmodifiedFloatingPoint; + +} // namespace tensorwrapper::concepts diff --git a/include/tensorwrapper/forward_declarations.hpp b/include/tensorwrapper/forward_declarations.hpp index 16c51064..e030b3a9 100644 --- a/include/tensorwrapper/forward_declarations.hpp +++ b/include/tensorwrapper/forward_declarations.hpp @@ -28,6 +28,8 @@ class MDBuffer; } // namespace buffer namespace shape { +template +class SmoothView; class Smooth; diff --git a/include/tensorwrapper/shape/smooth.hpp b/include/tensorwrapper/shape/smooth.hpp index 32d167de..fd6cc86e 100644 --- a/include/tensorwrapper/shape/smooth.hpp +++ b/include/tensorwrapper/shape/smooth.hpp @@ -39,6 +39,8 @@ class Smooth : public ShapeBase { // -- Ctors, assignment, and dtor // ------------------------------------------------------------------------- + Smooth() noexcept = default; + /** @brief Constructs *this with a statically specified number of extents. * * This ctor is used to create a Smooth object by explicitly providing diff --git a/include/tensorwrapper/types/floating_point.hpp b/include/tensorwrapper/types/floating_point.hpp index fb37346a..46bf8464 100644 --- a/include/tensorwrapper/types/floating_point.hpp +++ b/include/tensorwrapper/types/floating_point.hpp @@ -17,6 +17,7 @@ #pragma once #include #include +#include #ifdef ENABLE_SIGMA #include #endif @@ -47,6 +48,10 @@ T fabs(T value) { MACRO_IN(double); \ MACRO_IN(types::ufloat); \ MACRO_IN(types::udouble) +} // namespace tensorwrapper::types + +WTF_REGISTER_FP_TYPE(tensorwrapper::types::ufloat); +WTF_REGISTER_FP_TYPE(tensorwrapper::types::udouble); #else using ufloat = float; @@ -66,6 +71,5 @@ T fabs(T value) { MACRO_IN(float); \ MACRO_IN(double) -#endif - } // namespace tensorwrapper::types +#endif diff --git a/include/tensorwrapper/types/mdbuffer_traits.hpp b/include/tensorwrapper/types/mdbuffer_traits.hpp index 27c74421..aa60a608 100644 --- a/include/tensorwrapper/types/mdbuffer_traits.hpp +++ b/include/tensorwrapper/types/mdbuffer_traits.hpp @@ -28,7 +28,9 @@ struct MDBufferTraitsCommon { using buffer_type = wtf::buffer::FloatBuffer; using const_buffer_view = wtf::buffer::BufferView; using shape_type = shape::Smooth; - using rank_type = typename shape_type::rank_type; + using const_shape_view = shape::SmoothView; + using rank_type = typename ClassTraits::rank_type; + using size_type = typename ClassTraits::size_type; using pimpl_type = tensorwrapper::buffer::detail_::MDBufferPIMPL; using pimpl_pointer = std::unique_ptr; }; diff --git a/src/tensorwrapper/backends/eigen/eigen_tensor_impl.cpp b/src/tensorwrapper/backends/eigen/eigen_tensor_impl.cpp index 53eb677e..28d13020 100644 --- a/src/tensorwrapper/backends/eigen/eigen_tensor_impl.cpp +++ b/src/tensorwrapper/backends/eigen/eigen_tensor_impl.cpp @@ -76,7 +76,11 @@ auto EIGEN_TENSOR::to_string_() const -> string_type { TPARAMS std::ostream& EIGEN_TENSOR::add_to_stream_(std::ostream& os) const { os << std::fixed << std::setprecision(16); - return os << m_tensor_.format(Eigen::TensorIOFormat::Numpy()); + if constexpr(Rank > 0) { + return os << m_tensor_.format(Eigen::TensorIOFormat::Numpy()); + } else { + return os << m_tensor_; + } } TPARAMS diff --git a/src/tensorwrapper/buffer/detail_/addition_visitor.hpp b/src/tensorwrapper/buffer/detail_/addition_visitor.hpp deleted file mode 100644 index 4e021e8a..00000000 --- a/src/tensorwrapper/buffer/detail_/addition_visitor.hpp +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright 2025 NWChemEx-Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include - -namespace tensorwrapper::buffer::detail_ { - -/** @brief Dispatches to the appropriate backend based on the FP type. - * - * - * - */ -class AdditionVisitor { -public: - // AdditionVisitor(shape, permutation, shape, permutation) - template - void operator()(std::span lhs, std::span rhs) { - // auto lhs_wrapped = backends::eigen::wrap_span(lhs); - // auto rhs_wrapped = backends::eigen::wrap_span(rhs); - for(std::size_t i = 0; i < lhs.size(); ++i) lhs[i] += rhs[i]; - } -}; - -} // namespace tensorwrapper::buffer::detail_ diff --git a/src/tensorwrapper/buffer/detail_/binary_operation_visitor.hpp b/src/tensorwrapper/buffer/detail_/binary_operation_visitor.hpp new file mode 100644 index 00000000..25363fd1 --- /dev/null +++ b/src/tensorwrapper/buffer/detail_/binary_operation_visitor.hpp @@ -0,0 +1,138 @@ +/* + * Copyright 2025 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "../../backends/eigen/eigen_tensor_impl.hpp" +#include "unary_operation_visitor.hpp" +#include +#include +#include +#include +#include +#include + +namespace tensorwrapper::buffer::detail_ { + +/** @brief Dispatches to the appropriate backend based on the FP type. + * + * This visitor is intended to be used with WTF's buffer visitation mechanism. + * This base class implements the logic common to all binary operations and + * lets the derived classes implement the operation-specific logic. + * + * @note This class derives from UnaryOperationVisitor to reuse some of its + * functionality. This inheritance is private because it does not make + * sense to use a BinaryOperationVisitor as a UnaryOperationVisitor. + */ +class BinaryOperationVisitor : private UnaryOperationVisitor { +private: + using base_class = UnaryOperationVisitor; + +public: + /// Pull in types from the base class + ///@{ + using typename base_class::buffer_type; + using typename base_class::const_shape_view; + using typename base_class::label_type; + using typename base_class::shape_type; + using typename base_class::string_type; + ///@} + + BinaryOperationVisitor(buffer_type& this_buffer, label_type this_labels, + shape_type this_shape, label_type lhs_labels, + shape_type lhs_shape, label_type rhs_labels, + shape_type rhs_shape) : + UnaryOperationVisitor(this_buffer, this_labels, this_shape, lhs_labels, + lhs_shape), + m_rhs_labels_(std::move(rhs_labels)), + m_rhs_shape_(std::move(rhs_shape)) {} + + using base_class::this_labels; + using base_class::this_shape; + + const auto& lhs_shape() const { return other_shape(); } + const auto& rhs_shape() const { return m_rhs_shape_; } + + const auto& lhs_labels() const { return other_labels(); } + const auto& rhs_labels() const { return m_rhs_labels_; } + + template + requires(!std::is_same_v) + void operator()(std::span, std::span) const { + throw std::runtime_error( + "BinaryOperationVisitor: Mixed types not supported"); + } + +protected: + using base_class::make_this_eigen_tensor_; + + template + auto make_lhs_eigen_tensor_(std::span data) { + return base_class::make_other_eigen_tensor_(data); + } + + template + auto make_rhs_eigen_tensor_(std::span data) { + /// XXX: Ideally we would not need to const_cast here, but we didn't + /// code EigenTensor correctly... + + using clean_type = std::decay_t; + auto* pdata = const_cast(data.data()); + std::span non_const_data(pdata, data.size()); + return backends::eigen::make_eigen_tensor(non_const_data, m_rhs_shape_); + } + +private: + label_type m_rhs_labels_; + shape_type m_rhs_shape_; +}; + +/// Visitor that calls addition_assignment +class AdditionVisitor : public BinaryOperationVisitor { +public: + using BinaryOperationVisitor::BinaryOperationVisitor; + using BinaryOperationVisitor::operator(); + + template + void operator()(std::span lhs, std::span rhs) { + using clean_t = std::decay_t; + auto pthis = this->make_this_eigen_tensor_(); + auto plhs = this->make_lhs_eigen_tensor_(lhs); + auto prhs = this->make_rhs_eigen_tensor_(rhs); + + pthis->addition_assignment(this_labels(), lhs_labels(), rhs_labels(), + *plhs, *prhs); + } +}; + +/// Visitor that calls subtraction_assignment +class SubtractionVisitor : public BinaryOperationVisitor { +public: + using BinaryOperationVisitor::BinaryOperationVisitor; + using BinaryOperationVisitor::operator(); + + template + void operator()(std::span lhs, std::span rhs) { + using clean_t = std::decay_t; + auto pthis = this->make_this_eigen_tensor_(); + auto plhs = this->make_lhs_eigen_tensor_(lhs); + auto prhs = this->make_rhs_eigen_tensor_(rhs); + + pthis->subtraction_assignment(this_labels(), lhs_labels(), rhs_labels(), + *plhs, *prhs); + } +}; + +} // namespace tensorwrapper::buffer::detail_ diff --git a/src/tensorwrapper/buffer/detail_/hash_utilities.hpp b/src/tensorwrapper/buffer/detail_/hash_utilities.hpp index 021b291e..a9c35cdb 100644 --- a/src/tensorwrapper/buffer/detail_/hash_utilities.hpp +++ b/src/tensorwrapper/buffer/detail_/hash_utilities.hpp @@ -68,4 +68,21 @@ void hash_input(hash_type& seed, const sigma::Uncertain& value) { #endif +class HashVisitor { +public: + HashVisitor(hash_type seed = 0) : m_seed_(seed) {} + + hash_type get_hash() const { return m_seed_; } + + template + void operator()(std::span data) { + for(std::size_t i = 0; i < data.size(); ++i) { + hash_input(m_seed_, data[i]); + } + } + +private: + hash_type m_seed_; +}; + } // namespace tensorwrapper::buffer::detail_::hash_utilities diff --git a/src/tensorwrapper/buffer/detail_/mdbuffer_pimpl.hpp b/src/tensorwrapper/buffer/detail_/mdbuffer_pimpl.hpp deleted file mode 100644 index 6f410098..00000000 --- a/src/tensorwrapper/buffer/detail_/mdbuffer_pimpl.hpp +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Copyright 2025 NWChemEx-Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include -#include - -namespace tensorwrapper::buffer::detail_ { - -class MDBufferPIMPL { -public: - using parent_type = tensorwrapper::buffer::MDBuffer; - using traits_type = tensorwrapper::types::ClassTraits; - - /// Add types to public API - ///@{ - using value_type = typename traits_type::value_type; - using rank_type = typename traits_type::rank_type; - using buffer_type = typename traits_type::buffer_type; - using shape_type = typename traits_type::shape_type; - ///@} - - MDBufferPIMPL(shape_type shape, buffer_type buffer) noexcept : - m_shape_(std::move(shape)), m_buffer_(std::move(buffer)) {} - - auto& shape() noexcept { return m_shape_; } - - const auto& shape() const noexcept { return m_shape_; } - - auto& buffer() noexcept { return m_buffer_; } - - const auto& buffer() const noexcept { return m_buffer_; } - -private: - shape_type m_shape_; - - buffer_type m_buffer_; -}; - -} // namespace tensorwrapper::buffer::detail_ diff --git a/src/tensorwrapper/buffer/detail_/unary_operation_visitor.hpp b/src/tensorwrapper/buffer/detail_/unary_operation_visitor.hpp new file mode 100644 index 00000000..4a99c003 --- /dev/null +++ b/src/tensorwrapper/buffer/detail_/unary_operation_visitor.hpp @@ -0,0 +1,146 @@ +/* + * Copyright 2025 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "../../backends/eigen/eigen_tensor_impl.hpp" +#include +#include +#include +#include +#include +#include + +namespace tensorwrapper::buffer::detail_ { + +/** @brief Dispatches to the appropriate backend based on the FP type. + * + * This visitor is intended to be used with WTF's buffer visitation mechanism. + * This base class implements the logic common to all unary operations and + * lets the derived classes implement the operation-specific logic. + * + */ +class UnaryOperationVisitor { +public: + /// Type of the WTF buffer + using buffer_type = wtf::buffer::FloatBuffer; + + /// Type that the labels use for representing indices + using string_type = std::string; + + /// Type of a set of labels + using label_type = dsl::DummyIndices; + + /// Type describing the shape of the tensors + using shape_type = shape::Smooth; + + /// Type describing a read-only view acting like shape_type + using const_shape_view = shape::SmoothView; + + UnaryOperationVisitor(buffer_type& this_buffer, label_type this_labels, + shape_type this_shape, label_type other_labels, + shape_type other_shape) : + m_pthis_buffer_(&this_buffer), + m_this_labels_(std::move(this_labels)), + m_this_shape_(std::move(this_shape)), + m_other_labels_(std::move(other_labels)), + m_other_shape_(std::move(other_shape)) {} + + const auto& this_shape() const { return m_this_shape_; } + const auto& other_shape() const { return m_other_shape_; } + + const auto& this_labels() const { return m_this_labels_; } + const auto& other_labels() const { return m_other_labels_; } + +protected: + template + auto make_eigen_tensor_(std::span data, const_shape_view shape) { + return backends::eigen::make_eigen_tensor(data, shape); + } + + template + auto make_this_eigen_tensor_() { + if(m_pthis_buffer_->size() != m_this_shape_.size()) { + std::vector temp_buffer(m_this_shape_.size()); + *m_pthis_buffer_ = buffer_type(std::move(temp_buffer)); + } + auto this_span = + wtf::buffer::contiguous_buffer_cast(*m_pthis_buffer_); + return backends::eigen::make_eigen_tensor(this_span, m_this_shape_); + } + + template + auto make_other_eigen_tensor_(std::span data) { + /// XXX: Ideally we would not need to const_cast here, but we didn't + /// code EigenTensor correctly... + + using clean_type = std::decay_t; + auto* pdata = const_cast(data.data()); + std::span non_const_data(pdata, data.size()); + return backends::eigen::make_eigen_tensor(non_const_data, + m_other_shape_); + } + +private: + buffer_type* m_pthis_buffer_; + label_type m_this_labels_; + shape_type m_this_shape_; + + label_type m_other_labels_; + shape_type m_other_shape_; +}; + +class PermuteVisitor : public UnaryOperationVisitor { +public: + using UnaryOperationVisitor::UnaryOperationVisitor; + + template + void operator()(std::span other) { + using clean_t = std::decay_t; + auto pthis = this->make_this_eigen_tensor_(); + auto pother = this->make_other_eigen_tensor_(other); + + pthis->permute_assignment(this->this_labels(), other_labels(), *pother); + } +}; + +class ScalarMultiplicationVisitor : public UnaryOperationVisitor { +public: + using scalar_type = wtf::fp::Float; + ScalarMultiplicationVisitor(buffer_type& this_buffer, + label_type this_labels, shape_type this_shape, + label_type other_labels, shape_type other_shape, + scalar_type scalar) : + UnaryOperationVisitor(this_buffer, this_labels, this_shape, other_labels, + other_shape), + m_scalar_(scalar) {} + + template + void operator()(std::span other) { + using clean_t = std::decay_t; + auto pthis = this->make_this_eigen_tensor_(); + auto pother = this->make_other_eigen_tensor_(other); + + // TODO: Change when public API changes to support other FP types + auto scalar = wtf::fp::float_cast(m_scalar_); + pthis->scalar_multiplication(this->this_labels(), other_labels(), + scalar, *pother); + } + +private: + scalar_type m_scalar_; +}; + +} // namespace tensorwrapper::buffer::detail_ diff --git a/src/tensorwrapper/buffer/mdbuffer.cpp b/src/tensorwrapper/buffer/mdbuffer.cpp index fe92be9c..dc016b47 100644 --- a/src/tensorwrapper/buffer/mdbuffer.cpp +++ b/src/tensorwrapper/buffer/mdbuffer.cpp @@ -14,43 +14,250 @@ * limitations under the License. */ -#include "detail_/addition_visitor.hpp" -#include "detail_/mdbuffer_pimpl.hpp" +#include "../backends/eigen/eigen_tensor_impl.hpp" +#include "detail_/binary_operation_visitor.hpp" +#include "detail_/hash_utilities.hpp" #include #include namespace tensorwrapper::buffer { +namespace { -MDBuffer::MDBuffer() noexcept : m_pimpl_(nullptr) {} +template +const MDBuffer& downcast(T&& object) { + auto* pobject = dynamic_cast(&object); + if(pobject == nullptr) { + throw std::invalid_argument("The provided buffer must be an MDBuffer."); + } + return *pobject; +} +} // namespace + +using fp_types = types::floating_point_types; + +MDBuffer::MDBuffer() noexcept = default; + +MDBuffer::MDBuffer(buffer_type buffer, shape_type shape) : + my_base_type(std::make_unique(shape), nullptr), + m_shape_(std::move(shape)), + m_buffer_() { + if(buffer.size() == shape.size()) { + m_buffer_ = std::move(buffer); + } else { + throw std::invalid_argument( + "The size of the provided buffer does not match the size " + "implied by the provided shape."); + } +} + +// ----------------------------------------------------------------------------- +// -- State Accessor +// ----------------------------------------------------------------------------- + +auto MDBuffer::shape() const -> const_shape_view { return m_shape_; } + +auto MDBuffer::size() const noexcept -> size_type { return m_buffer_.size(); } + +auto MDBuffer::get_elem(index_vector index) const -> const_reference { + auto ordinal_index = coordinate_to_ordinal_(index); + return m_buffer_.at(ordinal_index); +} + +void MDBuffer::set_elem(index_vector index, value_type new_value) { + auto ordinal_index = coordinate_to_ordinal_(index); + mark_for_rehash_(); + m_buffer_.at(ordinal_index) = new_value; +} + +auto MDBuffer::get_mutable_data() -> buffer_view { + mark_for_rehash_(); + return m_buffer_; +} + +auto MDBuffer::get_immutable_data() const -> const_buffer_view { + return m_buffer_; +} + +// ----------------------------------------------------------------------------- +// -- Utility Methods +// ----------------------------------------------------------------------------- + +bool MDBuffer::operator==(const my_type& rhs) const noexcept { + if(!my_base_type::operator==(rhs)) return false; + return get_hash_() == rhs.get_hash_(); +} + +// ----------------------------------------------------------------------------- +// -- Protected Methods +// ----------------------------------------------------------------------------- + +auto MDBuffer::clone_() const -> buffer_base_pointer { + return std::make_unique(*this); +} + +bool MDBuffer::are_equal_(const_buffer_base_reference rhs) const noexcept { + return my_base_type::template are_equal_impl_(rhs); +} + +auto MDBuffer::addition_assignment_(label_type this_labels, + const_labeled_reference lhs, + const_labeled_reference rhs) + -> dsl_reference { + const auto& lhs_down = downcast(lhs.object()); + const auto& rhs_down = downcast(rhs.object()); + const auto& lhs_labels = lhs.labels(); + const auto& rhs_labels = rhs.labels(); + const auto& lhs_shape = lhs_down.m_shape_; + const auto& rhs_shape = rhs_down.m_shape_; + + auto labeled_lhs_shape = lhs_shape(lhs_labels); + auto labeled_rhs_shape = rhs_shape(rhs_labels); -MDBuffer::MDBuffer(shape_type shape, buffer_type buffer) : - MDBuffer(std::make_unique(std::move(shape), - std::move(buffer))) {} + m_shape_.addition_assignment(this_labels, labeled_lhs_shape, + labeled_rhs_shape); -MDBuffer::MDBuffer(pimpl_pointer pimpl) noexcept : m_pimpl_(std::move(pimpl)) {} + detail_::AdditionVisitor visitor(m_buffer_, this_labels, m_shape_, + lhs.labels(), lhs_shape, rhs.labels(), + rhs_shape); -auto MDBuffer::rank() const -> rank_type { - assert_pimpl_(); - return m_pimpl_->shape().rank(); + wtf::buffer::visit_contiguous_buffer(visitor, lhs_down.m_buffer_, + rhs_down.m_buffer_); + mark_for_rehash_(); + return *this; } -bool MDBuffer::has_pimpl_() const noexcept { return m_pimpl_ != nullptr; } +auto MDBuffer::subtraction_assignment_(label_type this_labels, + const_labeled_reference lhs, + const_labeled_reference rhs) + -> dsl_reference { + const auto& lhs_down = downcast(lhs.object()); + const auto& rhs_down = downcast(rhs.object()); + const auto& lhs_labels = lhs.labels(); + const auto& rhs_labels = rhs.labels(); + const auto& lhs_shape = lhs_down.m_shape_; + const auto& rhs_shape = rhs_down.m_shape_; -void MDBuffer::assert_pimpl_() const { - if(!has_pimpl_()) { - throw std::runtime_error( - "MDBuffer has no PIMPL. Was it default constructed?"); + auto labeled_lhs_shape = lhs_shape(lhs_labels); + auto labeled_rhs_shape = rhs_shape(rhs_labels); + + m_shape_.subtraction_assignment(this_labels, labeled_lhs_shape, + labeled_rhs_shape); + + detail_::SubtractionVisitor visitor(m_buffer_, this_labels, m_shape_, + lhs.labels(), lhs_shape, rhs.labels(), + rhs_shape); + + wtf::buffer::visit_contiguous_buffer(visitor, lhs_down.m_buffer_, + rhs_down.m_buffer_); + mark_for_rehash_(); + return *this; +} + +auto MDBuffer::multiplication_assignment_(label_type this_labels, + const_labeled_reference lhs, + const_labeled_reference rhs) + -> dsl_reference { + throw std::runtime_error("multiplication NYI"); +} + +auto MDBuffer::permute_assignment_(label_type this_labels, + const_labeled_reference rhs) + -> dsl_reference { + const auto& rhs_down = downcast(rhs.object()); + const auto& rhs_labels = rhs.labels(); + const auto& rhs_shape = rhs_down.m_shape_; + + auto labeled_rhs_shape = rhs_shape(rhs_labels); + + m_shape_.permute_assignment(this_labels, labeled_rhs_shape); + + detail_::PermuteVisitor visitor(m_buffer_, this_labels, m_shape_, + rhs.labels(), rhs_shape); + + wtf::buffer::visit_contiguous_buffer(visitor, rhs_down.m_buffer_); + mark_for_rehash_(); + return *this; +} + +auto MDBuffer::scalar_multiplication_(label_type this_labels, double scalar, + const_labeled_reference rhs) + -> dsl_reference { + const auto& rhs_down = downcast(rhs.object()); + const auto& rhs_labels = rhs.labels(); + const auto& rhs_shape = rhs_down.m_shape_; + + auto labeled_rhs_shape = rhs_shape(rhs_labels); + + m_shape_.permute_assignment(this_labels, labeled_rhs_shape); + + detail_::ScalarMultiplicationVisitor visitor( + m_buffer_, this_labels, m_shape_, rhs.labels(), rhs_shape, scalar); + + wtf::buffer::visit_contiguous_buffer(visitor, rhs_down.m_buffer_); + mark_for_rehash_(); + return *this; +} + +auto MDBuffer::to_string_() const -> string_type { + std::stringstream ss; + add_to_stream_(ss); + return ss.str(); +} + +std::ostream& MDBuffer::add_to_stream_(std::ostream& os) const { + /// XXX: EigenTensor should handle aliasing a const buffer correctly. That's + /// a lot of work, just to get this to work though... + + if(m_buffer_.size() == 0) return os; + auto lambda = [&](auto&& span) { + using clean_type = std::decay_t::value_type; + auto data_ptr = const_cast(span.data()); + std::span data_span(data_ptr, span.size()); + auto ptensor = backends::eigen::make_eigen_tensor(data_span, m_shape_); + ptensor->add_to_stream(os); + }; + wtf::buffer::visit_contiguous_buffer(lambda, m_buffer_); + return os; +} + +// ----------------------------------------------------------------------------- +// -- Private Methods +// ----------------------------------------------------------------------------- + +void MDBuffer::check_index_(const index_vector& index) const { + if(index.size() != m_shape_.rank()) { + throw std::out_of_range( + "The length of the provided index does not match the rank of " + "*this."); + } + for(rank_type i = 0; i < m_shape_.rank(); ++i) { + if(index[i] >= m_shape_.extent(i)) { + throw std::out_of_range( + "An index provided is out of bounds for the corresponding " + "dimension."); + } } } -auto MDBuffer::pimpl_() -> pimpl_type& { - assert_pimpl_(); - return *m_pimpl_; +auto MDBuffer::coordinate_to_ordinal_(index_vector index) const -> size_type { + check_index_(index); + using size_type = typename decltype(index)::size_type; + size_type ordinal = 0; + size_type stride = 1; + for(rank_type i = shape().rank(); i-- > 0;) { + ordinal += index[i] * stride; + stride *= shape().extent(i); + } + return ordinal; } -auto MDBuffer::pimpl_() const -> const pimpl_type& { - assert_pimpl_(); - return *m_pimpl_; +void MDBuffer::update_hash_() const { + buffer::detail_::hash_utilities::HashVisitor visitor; + if(m_buffer_.size()) { + wtf::buffer::visit_contiguous_buffer(visitor, m_buffer_); + m_hash_ = visitor.get_hash(); + } + m_recalculate_hash_ = false; } } // namespace tensorwrapper::buffer diff --git a/tests/cxx/unit_tests/tensorwrapper/buffer/detail_/addition_visitor.cpp b/tests/cxx/unit_tests/tensorwrapper/buffer/detail_/addition_visitor.cpp deleted file mode 100644 index d7b46618..00000000 --- a/tests/cxx/unit_tests/tensorwrapper/buffer/detail_/addition_visitor.cpp +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright 2025 NWChemEx-Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// #include -// #include - -// using namespace tensorwrapper; - -// TEMPLATE_LIST_TEST_CASE("AdditionVisitor", "[buffer][detail_]", -// types::floating_point_types) { -// using VisitorType = buffer::detail_::AdditionVisitor; - -// VisitorType visitor; - -// SECTION("vectors") { -// std::vector lhs{1.0, 2.0, 3.0}; -// std::vector rhs{4.0, 5.0, 6.0}; - -// visitor(std::span(lhs), std::span(rhs)); - -// REQUIRE(lhs[0] == Approx(5.0).epsilon(1e-10)); -// REQUIRE(lhs[1] == Approx(7.0).epsilon(1e-10)); -// REQUIRE(lhs[2] == Approx(9.0).epsilon(1e-10)); -// } -// } diff --git a/tests/cxx/unit_tests/tensorwrapper/buffer/detail_/binary_operation_visitor.cpp b/tests/cxx/unit_tests/tensorwrapper/buffer/detail_/binary_operation_visitor.cpp new file mode 100644 index 00000000..bbcc0599 --- /dev/null +++ b/tests/cxx/unit_tests/tensorwrapper/buffer/detail_/binary_operation_visitor.cpp @@ -0,0 +1,148 @@ +/* + * Copyright 2025 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../../testing/testing.hpp" +#include +#include +using namespace tensorwrapper; + +/* Testing notes: + * + * In testing the derived classes we assume that the backends have been + * exhaustively tested. Therefore, we simply ensure that each overload works + * correctly and that the correct backend is dispatched to. + */ +TEMPLATE_LIST_TEST_CASE("BinaryOperationVisitor", "[buffer][detail_]", + types::floating_point_types) { + using VisitorType = buffer::detail_::BinaryOperationVisitor; + using buffer_type = typename VisitorType::buffer_type; + using label_type = typename VisitorType::label_type; + using shape_type = typename VisitorType::shape_type; + + buffer_type this_buffer(std::vector(6, TestType(0.0))); + + label_type this_labels("i,j"); + shape_type this_shape({2, 3}); + + label_type lhs_labels("i,k"); + shape_type lhs_shape({2, 4}); + + label_type rhs_labels("k,j"); + shape_type rhs_shape({4, 3}); + + VisitorType visitor(this_buffer, this_labels, this_shape, lhs_labels, + lhs_shape, rhs_labels, rhs_shape); + + REQUIRE(visitor.this_shape() == this_shape); + REQUIRE(visitor.lhs_shape() == lhs_shape); + REQUIRE(visitor.rhs_shape() == rhs_shape); + + REQUIRE(visitor.this_labels() == this_labels); + REQUIRE(visitor.lhs_labels() == lhs_labels); + REQUIRE(visitor.rhs_labels() == rhs_labels); + + std::span dspan; + std::span fspan; + REQUIRE_THROWS_AS(visitor(dspan, fspan), std::runtime_error); +} + +TEMPLATE_LIST_TEST_CASE("AdditionVisitor", "[buffer][detail_]", + types::floating_point_types) { + using VisitorType = buffer::detail_::AdditionVisitor; + using buffer_type = typename VisitorType::buffer_type; + using label_type = typename VisitorType::label_type; + using shape_type = typename VisitorType::shape_type; + + TestType one{1.0}, two{2.0}, three{3.0}, four{4.0}; + std::vector this_data{one, two, three, four}; + std::vector lhs_data{four, three, two, one}; + std::vector rhs_data{one, one, one, one}; + shape_type shape({4}); + label_type labels("i"); + + std::span lhs_span(lhs_data.data(), lhs_data.size()); + std::span clhs_span(lhs_data.data(), lhs_data.size()); + std::span rhs_span(rhs_data.data(), rhs_data.size()); + std::span crhs_span(rhs_data.data(), rhs_data.size()); + + SECTION("existing buffer") { + buffer_type this_buffer(this_data); + VisitorType visitor(this_buffer, labels, shape, labels, shape, labels, + shape); + + visitor(lhs_span, rhs_span); + REQUIRE(this_buffer.at(0) == TestType(5.0)); + REQUIRE(this_buffer.at(1) == TestType(4.0)); + REQUIRE(this_buffer.at(2) == TestType(3.0)); + REQUIRE(this_buffer.at(3) == TestType(2.0)); + } + + SECTION("non-existing buffer") { + buffer_type empty_buffer; + VisitorType visitor(empty_buffer, labels, shape, labels, shape, labels, + shape); + + visitor(clhs_span, crhs_span); + REQUIRE(empty_buffer.at(0) == TestType(5.0)); + REQUIRE(empty_buffer.at(1) == TestType(4.0)); + REQUIRE(empty_buffer.at(2) == TestType(3.0)); + REQUIRE(empty_buffer.at(3) == TestType(2.0)); + } +} + +TEMPLATE_LIST_TEST_CASE("SubtractionVisitor", "[buffer][detail_]", + types::floating_point_types) { + using VisitorType = buffer::detail_::SubtractionVisitor; + using buffer_type = typename VisitorType::buffer_type; + using label_type = typename VisitorType::label_type; + using shape_type = typename VisitorType::shape_type; + + TestType one{1.0}, two{2.0}, three{3.0}, four{4.0}; + std::vector this_data{one, two, three, four}; + std::vector lhs_data{four, three, two, one}; + std::vector rhs_data{one, one, one, one}; + shape_type shape({4}); + label_type labels("i"); + + std::span lhs_span(lhs_data.data(), lhs_data.size()); + std::span clhs_span(lhs_data.data(), lhs_data.size()); + std::span rhs_span(rhs_data.data(), rhs_data.size()); + std::span crhs_span(rhs_data.data(), rhs_data.size()); + + SECTION("existing buffer") { + buffer_type this_buffer(this_data); + VisitorType visitor(this_buffer, labels, shape, labels, shape, labels, + shape); + + visitor(lhs_span, rhs_span); + REQUIRE(this_buffer.at(0) == TestType(3.0)); + REQUIRE(this_buffer.at(1) == TestType(2.0)); + REQUIRE(this_buffer.at(2) == TestType(1.0)); + REQUIRE(this_buffer.at(3) == TestType(0.0)); + } + + SECTION("non-existing buffer") { + buffer_type empty_buffer; + VisitorType visitor(empty_buffer, labels, shape, labels, shape, labels, + shape); + + visitor(clhs_span, crhs_span); + REQUIRE(empty_buffer.at(0) == TestType(3.0)); + REQUIRE(empty_buffer.at(1) == TestType(2.0)); + REQUIRE(empty_buffer.at(2) == TestType(1.0)); + REQUIRE(empty_buffer.at(3) == TestType(0.0)); + } +} diff --git a/tests/cxx/unit_tests/tensorwrapper/buffer/detail_/unary_operation_visitor.cpp b/tests/cxx/unit_tests/tensorwrapper/buffer/detail_/unary_operation_visitor.cpp new file mode 100644 index 00000000..83d9b675 --- /dev/null +++ b/tests/cxx/unit_tests/tensorwrapper/buffer/detail_/unary_operation_visitor.cpp @@ -0,0 +1,151 @@ +/* + * Copyright 2025 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../../testing/testing.hpp" +#include +#include +using namespace tensorwrapper; + +/* Testing notes: + * + * In testing the derived classes we assume that the backends have been + * exhaustively tested. Therefore, we simply ensure that each overload works + * correctly and that the correct backend is dispatched to. + */ +TEMPLATE_LIST_TEST_CASE("UnaryOperationVisitor", "[buffer][detail_]", + types::floating_point_types) { + using VisitorType = buffer::detail_::UnaryOperationVisitor; + using buffer_type = typename VisitorType::buffer_type; + using label_type = typename VisitorType::label_type; + using shape_type = typename VisitorType::shape_type; + + buffer_type this_buffer(std::vector(6, TestType(0.0))); + + label_type this_labels("i,j"); + shape_type this_shape({2, 3}); + + label_type other_labels("i,k"); + shape_type other_shape({2, 4}); + + VisitorType visitor(this_buffer, this_labels, this_shape, other_labels, + other_shape); + + REQUIRE(visitor.this_shape() == this_shape); + REQUIRE(visitor.other_shape() == other_shape); + + REQUIRE(visitor.this_labels() == this_labels); + REQUIRE(visitor.other_labels() == other_labels); +} + +TEMPLATE_LIST_TEST_CASE("PermuteVisitor", "[buffer][detail_]", + types::floating_point_types) { + using VisitorType = buffer::detail_::PermuteVisitor; + using buffer_type = typename VisitorType::buffer_type; + using label_type = typename VisitorType::label_type; + using shape_type = typename VisitorType::shape_type; + + label_type this_labels("i,j"); + shape_type this_shape({2, 3}); + + label_type other_labels("j,i"); + shape_type other_shape({3, 2}); + + std::vector other_data = {TestType(1.0), TestType(2.0), + TestType(3.0), TestType(4.0), + TestType(5.0), TestType(6.0)}; + std::span other_span(other_data.data(), other_data.size()); + std::span cother_span(other_data.data(), other_data.size()); + + SECTION("Buffer is allocated") { + buffer_type this_buffer(std::vector(6, TestType(0.0))); + VisitorType visitor(this_buffer, this_labels, this_shape, other_labels, + other_shape); + visitor(other_span); + + REQUIRE(this_buffer.at(0) == TestType(1.0)); + REQUIRE(this_buffer.at(1) == TestType(3.0)); + REQUIRE(this_buffer.at(2) == TestType(5.0)); + REQUIRE(this_buffer.at(3) == TestType(2.0)); + REQUIRE(this_buffer.at(4) == TestType(4.0)); + REQUIRE(this_buffer.at(5) == TestType(6.0)); + } + + SECTION("Buffer is not allocated") { + buffer_type this_buffer; + VisitorType visitor(this_buffer, this_labels, this_shape, other_labels, + other_shape); + visitor(cother_span); + + REQUIRE(this_buffer.at(0) == TestType(1.0)); + REQUIRE(this_buffer.at(1) == TestType(3.0)); + REQUIRE(this_buffer.at(2) == TestType(5.0)); + REQUIRE(this_buffer.at(3) == TestType(2.0)); + REQUIRE(this_buffer.at(4) == TestType(4.0)); + REQUIRE(this_buffer.at(5) == TestType(6.0)); + } +} + +TEMPLATE_LIST_TEST_CASE("ScalarMultiplicationVisitor", "[buffer][detail_]", + types::floating_point_types) { + using VisitorType = buffer::detail_::ScalarMultiplicationVisitor; + using buffer_type = typename VisitorType::buffer_type; + using label_type = typename VisitorType::label_type; + using shape_type = typename VisitorType::shape_type; + + label_type this_labels("i,j"); + shape_type this_shape({2, 3}); + + label_type other_labels("j,i"); + shape_type other_shape({3, 2}); + + std::vector other_data = {TestType(1.0), TestType(2.0), + TestType(3.0), TestType(4.0), + TestType(5.0), TestType(6.0)}; + std::span other_span(other_data.data(), other_data.size()); + std::span cother_span(other_data.data(), other_data.size()); + + // TODO: when public API of MDBuffer supports other FP types, test them here + double scalar_{2.0}; + TestType scalar(scalar_); + + SECTION("Buffer is allocated") { + buffer_type this_buffer(std::vector(6, TestType(0.0))); + VisitorType visitor(this_buffer, this_labels, this_shape, other_labels, + other_shape, scalar_); + visitor(other_span); + + REQUIRE(this_buffer.at(0) == TestType(1.0) * scalar); + REQUIRE(this_buffer.at(1) == TestType(3.0) * scalar); + REQUIRE(this_buffer.at(2) == TestType(5.0) * scalar); + REQUIRE(this_buffer.at(3) == TestType(2.0) * scalar); + REQUIRE(this_buffer.at(4) == TestType(4.0) * scalar); + REQUIRE(this_buffer.at(5) == TestType(6.0) * scalar); + } + + SECTION("Buffer is not allocated") { + buffer_type this_buffer; + VisitorType visitor(this_buffer, this_labels, this_shape, other_labels, + other_shape, scalar_); + visitor(cother_span); + + REQUIRE(this_buffer.at(0) == TestType(1.0) * scalar); + REQUIRE(this_buffer.at(1) == TestType(3.0) * scalar); + REQUIRE(this_buffer.at(2) == TestType(5.0) * scalar); + REQUIRE(this_buffer.at(3) == TestType(2.0) * scalar); + REQUIRE(this_buffer.at(4) == TestType(4.0) * scalar); + REQUIRE(this_buffer.at(5) == TestType(6.0) * scalar); + } +} diff --git a/tests/cxx/unit_tests/tensorwrapper/buffer/mdbuffer.cpp b/tests/cxx/unit_tests/tensorwrapper/buffer/mdbuffer.cpp new file mode 100644 index 00000000..11dd8081 --- /dev/null +++ b/tests/cxx/unit_tests/tensorwrapper/buffer/mdbuffer.cpp @@ -0,0 +1,417 @@ +/* + * Copyright 2025 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../testing/testing.hpp" +#include +#include + +using namespace tensorwrapper; + +/* Testing notes: + * + * The various operations (addition_assignment, etc.) are not exhaustively + * tested here. These operations are implemented via visitors that dispatch to + * various backends. The visitors themselves are tested in their own unit tests. + * Here we assume the visitors work and spot check a couple of operations for + * to help catch any integration issues. + */ + +TEMPLATE_LIST_TEST_CASE("MDBuffer", "", types::floating_point_types) { + using buffer::MDBuffer; + using buffer_type = MDBuffer::buffer_type; + using shape_type = typename MDBuffer::shape_type; + using label_type = typename MDBuffer::label_type; + + TestType one(1.0), two(2.0), three(3.0), four(4.0); + std::vector data = {one, two, three, four}; + + shape_type scalar_shape({}); + shape_type vector_shape({4}); + shape_type matrix_shape({2, 2}); + + MDBuffer defaulted; + MDBuffer scalar(std::vector{one}, scalar_shape); + MDBuffer vector(data, vector_shape); + MDBuffer matrix(data, matrix_shape); + + SECTION("Ctors and assignment") { + SECTION("Default ctor") { + REQUIRE(defaulted.size() == 0); + REQUIRE(defaulted.shape() == shape_type()); + } + + SECTION("vector ctor") { + REQUIRE(scalar.size() == 1); + REQUIRE(scalar.shape() == scalar_shape); + REQUIRE(scalar.get_elem({}) == one); + + REQUIRE(vector.size() == 4); + REQUIRE(vector.shape() == vector_shape); + REQUIRE(vector.get_elem({0}) == one); + REQUIRE(vector.get_elem({1}) == two); + REQUIRE(vector.get_elem({2}) == three); + REQUIRE(vector.get_elem({3}) == four); + + REQUIRE(matrix.size() == 4); + REQUIRE(matrix.shape() == matrix_shape); + REQUIRE(matrix.get_elem({0, 0}) == one); + REQUIRE(matrix.get_elem({0, 1}) == two); + REQUIRE(matrix.get_elem({1, 0}) == three); + REQUIRE(matrix.get_elem({1, 1}) == four); + + REQUIRE_THROWS_AS(MDBuffer(data, scalar_shape), + std::invalid_argument); + } + + SECTION("FloatBuffer ctor") { + buffer_type buf(data); + + MDBuffer vector_buf(buf, vector_shape); + REQUIRE(vector_buf == vector); + + MDBuffer matrix_buf(buf, matrix_shape); + REQUIRE(matrix_buf == matrix); + + REQUIRE_THROWS_AS(MDBuffer(buf, scalar_shape), + std::invalid_argument); + } + + SECTION("Copy ctor") { + MDBuffer defaulted_copy(defaulted); + REQUIRE(defaulted_copy == defaulted); + + MDBuffer scalar_copy(scalar); + REQUIRE(scalar_copy == scalar); + + MDBuffer vector_copy(vector); + REQUIRE(vector_copy == vector); + + MDBuffer matrix_copy(matrix); + REQUIRE(matrix_copy == matrix); + } + + SECTION("Move ctor") { + MDBuffer defaulted_temp(defaulted); + MDBuffer defaulted_move(std::move(defaulted_temp)); + REQUIRE(defaulted_move == defaulted); + + MDBuffer scalar_temp(scalar); + MDBuffer scalar_move(std::move(scalar_temp)); + REQUIRE(scalar_move == scalar); + + MDBuffer vector_temp(vector); + MDBuffer vector_move(std::move(vector_temp)); + REQUIRE(vector_move == vector); + + MDBuffer matrix_temp(matrix); + MDBuffer matrix_move(std::move(matrix_temp)); + REQUIRE(matrix_move == matrix); + } + + SECTION("Copy assignment") { + MDBuffer defaulted_copy; + auto pdefaulted_copy = &(defaulted_copy = defaulted); + REQUIRE(defaulted_copy == defaulted); + REQUIRE(pdefaulted_copy == &defaulted_copy); + + MDBuffer scalar_copy; + auto pscalar_copy = &(scalar_copy = scalar); + REQUIRE(scalar_copy == scalar); + REQUIRE(pscalar_copy == &scalar_copy); + + MDBuffer vector_copy; + auto pvector_copy = &(vector_copy = vector); + REQUIRE(vector_copy == vector); + REQUIRE(pvector_copy == &vector_copy); + + MDBuffer matrix_copy; + auto pmatrix_copy = &(matrix_copy = matrix); + REQUIRE(matrix_copy == matrix); + REQUIRE(pmatrix_copy == &matrix_copy); + } + + SECTION("Move assignment") { + MDBuffer defaulted_temp(defaulted); + MDBuffer defaulted_move; + auto pdefaulted_move = + &(defaulted_move = std::move(defaulted_temp)); + REQUIRE(defaulted_move == defaulted); + REQUIRE(pdefaulted_move == &defaulted_move); + + MDBuffer scalar_temp(scalar); + MDBuffer scalar_move; + auto pscalar_move = &(scalar_move = std::move(scalar_temp)); + REQUIRE(scalar_move == scalar); + REQUIRE(pscalar_move == &scalar_move); + + MDBuffer vector_temp(vector); + MDBuffer vector_move; + auto pvector_move = &(vector_move = std::move(vector_temp)); + REQUIRE(vector_move == vector); + REQUIRE(pvector_move == &vector_move); + + MDBuffer matrix_temp(matrix); + MDBuffer matrix_move; + auto pmatrix_move = &(matrix_move = std::move(matrix_temp)); + REQUIRE(matrix_move == matrix); + REQUIRE(pmatrix_move == &matrix_move); + } + } + + SECTION("shape") { + REQUIRE(defaulted.shape() == shape_type()); + REQUIRE(scalar.shape() == scalar_shape); + REQUIRE(vector.shape() == vector_shape); + REQUIRE(matrix.shape() == matrix_shape); + } + + SECTION("size") { + REQUIRE(defaulted.size() == 0); + REQUIRE(scalar.size() == 1); + REQUIRE(vector.size() == 4); + REQUIRE(matrix.size() == 4); + } + + SECTION("get_elem") { + REQUIRE_THROWS_AS(defaulted.get_elem({}), std::out_of_range); + + REQUIRE(scalar.get_elem({}) == one); + REQUIRE_THROWS_AS(scalar.get_elem({0}), std::out_of_range); + + REQUIRE(vector.get_elem({0}) == one); + REQUIRE(vector.get_elem({1}) == two); + REQUIRE(vector.get_elem({2}) == three); + REQUIRE(vector.get_elem({3}) == four); + REQUIRE_THROWS_AS(vector.get_elem({4}), std::out_of_range); + + REQUIRE(matrix.get_elem({0, 0}) == one); + REQUIRE(matrix.get_elem({0, 1}) == two); + REQUIRE(matrix.get_elem({1, 0}) == three); + REQUIRE(matrix.get_elem({1, 1}) == four); + REQUIRE_THROWS_AS(matrix.get_elem({2, 0}), std::out_of_range); + } + + SECTION("set_elem") { + REQUIRE_THROWS_AS(defaulted.set_elem({}, one), std::out_of_range); + + REQUIRE(scalar.get_elem({}) != two); + scalar.set_elem({}, two); + REQUIRE(scalar.get_elem({}) == two); + + REQUIRE(vector.get_elem({2}) != four); + vector.set_elem({2}, four); + REQUIRE(vector.get_elem({2}) == four); + + REQUIRE(matrix.get_elem({1, 0}) != one); + matrix.set_elem({1, 0}, one); + REQUIRE(matrix.get_elem({1, 0}) == one); + } + + SECTION("operator==") { + // Same object + REQUIRE(defaulted == defaulted); + + MDBuffer scalar_copy(std::vector{one}, scalar_shape); + REQUIRE(scalar == scalar_copy); + + MDBuffer vector_copy(data, vector_shape); + REQUIRE(vector == vector_copy); + + MDBuffer matrix_copy(data, matrix_shape); + REQUIRE(matrix == matrix_copy); + + // Different ranks + REQUIRE_FALSE(scalar == vector); + REQUIRE_FALSE(vector == matrix); + REQUIRE_FALSE(scalar == matrix); + + // Different shapes + shape_type matrix_shape2({4, 1}); + REQUIRE_FALSE(scalar == MDBuffer(data, matrix_shape2)); + + // Different values + std::vector diff_data = {two, three, four, one}; + MDBuffer scalar_diff(std::vector{two}, scalar_shape); + REQUIRE_FALSE(scalar == scalar_diff); + REQUIRE_FALSE(vector == MDBuffer(diff_data, vector_shape)); + REQUIRE_FALSE(matrix == MDBuffer(diff_data, matrix_shape)); + } + + SECTION("addition_assignment_") { + SECTION("scalar") { + label_type labels(""); + MDBuffer result; + result.addition_assignment(labels, scalar(labels), scalar(labels)); + REQUIRE(result.shape() == scalar_shape); + REQUIRE(result.get_elem({}) == TestType(2.0)); + } + + SECTION("vector") { + label_type labels("i"); + MDBuffer result; + result.addition_assignment(labels, vector(labels), vector(labels)); + REQUIRE(result.shape() == vector_shape); + REQUIRE(result.get_elem({0}) == TestType(2.0)); + REQUIRE(result.get_elem({1}) == TestType(4.0)); + REQUIRE(result.get_elem({2}) == TestType(6.0)); + REQUIRE(result.get_elem({3}) == TestType(8.0)); + } + + SECTION("matrix") { + label_type labels("i,j"); + MDBuffer result; + result.addition_assignment(labels, matrix(labels), matrix(labels)); + REQUIRE(result.shape() == matrix_shape); + REQUIRE(result.get_elem({0, 0}) == TestType(2.0)); + REQUIRE(result.get_elem({0, 1}) == TestType(4.0)); + REQUIRE(result.get_elem({1, 0}) == TestType(6.0)); + REQUIRE(result.get_elem({1, 1}) == TestType(8.0)); + } + } + + SECTION("subtraction_assignment_") { + SECTION("scalar") { + label_type labels(""); + MDBuffer result; + result.subtraction_assignment(labels, scalar(labels), + scalar(labels)); + REQUIRE(result.shape() == scalar_shape); + REQUIRE(result.get_elem({}) == TestType(0.0)); + } + + SECTION("vector") { + label_type labels("i"); + MDBuffer result; + result.subtraction_assignment(labels, vector(labels), + vector(labels)); + REQUIRE(result.shape() == vector_shape); + REQUIRE(result.get_elem({0}) == TestType(0.0)); + REQUIRE(result.get_elem({1}) == TestType(0.0)); + REQUIRE(result.get_elem({2}) == TestType(0.0)); + REQUIRE(result.get_elem({3}) == TestType(0.0)); + } + + SECTION("matrix") { + label_type labels("i,j"); + MDBuffer result; + result.subtraction_assignment(labels, matrix(labels), + matrix(labels)); + REQUIRE(result.shape() == matrix_shape); + REQUIRE(result.get_elem({0, 0}) == TestType(0.0)); + REQUIRE(result.get_elem({0, 1}) == TestType(0.0)); + REQUIRE(result.get_elem({1, 0}) == TestType(0.0)); + REQUIRE(result.get_elem({1, 1}) == TestType(0.0)); + } + } + + SECTION("scalar_multiplication_") { + // TODO: Test with other scalar types when public API supports it + using scalar_type = double; + scalar_type scalar_value_{2.0}; + TestType scalar_value(scalar_value_); + SECTION("scalar") { + label_type labels(""); + MDBuffer result; + result.scalar_multiplication(labels, scalar_value_, scalar(labels)); + REQUIRE(result.shape() == scalar_shape); + REQUIRE(result.get_elem({}) == TestType(1.0) * scalar_value); + } + + SECTION("vector") { + label_type labels("i"); + MDBuffer result; + result.scalar_multiplication(labels, scalar_value_, vector(labels)); + REQUIRE(result.shape() == vector_shape); + REQUIRE(result.get_elem({0}) == TestType(1.0) * scalar_value); + REQUIRE(result.get_elem({1}) == TestType(2.0) * scalar_value); + REQUIRE(result.get_elem({2}) == TestType(3.0) * scalar_value); + REQUIRE(result.get_elem({3}) == TestType(4.0) * scalar_value); + } + + SECTION("matrix") { + label_type rhs_labels("i,j"); + label_type lhs_labels("j,i"); + MDBuffer result; + result.scalar_multiplication(lhs_labels, scalar_value_, + matrix(rhs_labels)); + REQUIRE(result.shape() == matrix_shape); + REQUIRE(result.get_elem({0, 0}) == TestType(1.0) * scalar_value); + REQUIRE(result.get_elem({0, 1}) == TestType(3.0) * scalar_value); + REQUIRE(result.get_elem({1, 0}) == TestType(2.0) * scalar_value); + REQUIRE(result.get_elem({1, 1}) == TestType(4.0) * scalar_value); + } + } + + SECTION("permute_assignment_") { + SECTION("scalar") { + label_type labels(""); + MDBuffer result; + result.permute_assignment(labels, scalar(labels)); + REQUIRE(result.shape() == scalar_shape); + REQUIRE(result.get_elem({}) == TestType(1.0)); + } + + SECTION("vector") { + label_type labels("i"); + MDBuffer result; + result.permute_assignment(labels, vector(labels)); + REQUIRE(result.shape() == vector_shape); + REQUIRE(result.get_elem({0}) == TestType(1.0)); + REQUIRE(result.get_elem({1}) == TestType(2.0)); + REQUIRE(result.get_elem({2}) == TestType(3.0)); + REQUIRE(result.get_elem({3}) == TestType(4.0)); + } + + SECTION("matrix") { + label_type rhs_labels("i,j"); + label_type lhs_labels("j,i"); + MDBuffer result; + result.permute_assignment(lhs_labels, matrix(rhs_labels)); + REQUIRE(result.shape() == matrix_shape); + REQUIRE(result.get_elem({0, 0}) == TestType(1.0)); + REQUIRE(result.get_elem({0, 1}) == TestType(3.0)); + REQUIRE(result.get_elem({1, 0}) == TestType(2.0)); + REQUIRE(result.get_elem({1, 1}) == TestType(4.0)); + } + } + + SECTION("to_string") { + REQUIRE(defaulted.to_string().empty()); + REQUIRE_FALSE(scalar.to_string().empty()); + REQUIRE_FALSE(vector.to_string().empty()); + REQUIRE_FALSE(matrix.to_string().empty()); + } + + SECTION("add_to_stream") { + std::stringstream ss; + SECTION("defaulted") { + defaulted.add_to_stream(ss); + REQUIRE(ss.str().empty()); + } + SECTION("scalar") { + scalar.add_to_stream(ss); + REQUIRE_FALSE(ss.str().empty()); + } + SECTION("vector") { + vector.add_to_stream(ss); + REQUIRE_FALSE(ss.str().empty()); + } + SECTION("matrix") { + matrix.add_to_stream(ss); + REQUIRE_FALSE(ss.str().empty()); + } + } +}